support the rememberUpgrade option

This commit is contained in:
Naoyuki Kanezawa
2014-03-25 23:21:26 +09:00
parent 777ea28fda
commit 9d6751d6b4
7 changed files with 175 additions and 20 deletions

View File

@@ -91,10 +91,13 @@ public abstract class Socket extends Emitter {
*/ */
public static final int protocol = Parser.protocol; public static final int protocol = Parser.protocol;
public static boolean priorWebsocketSuccess = false;
private boolean secure; private boolean secure;
private boolean upgrade; private boolean upgrade;
private boolean timestampRequests; private boolean timestampRequests;
private boolean upgrading; private boolean upgrading;
private boolean rememberUpgrade;
private int port; private int port;
private int policyPort; private int policyPort;
private int prevBufferLen; private int prevBufferLen;
@@ -109,7 +112,7 @@ public abstract class Socket extends Emitter {
private Map<String, String> query; private Map<String, String> query;
private LinkedList<Packet> writeBuffer = new LinkedList<Packet>(); private LinkedList<Packet> writeBuffer = new LinkedList<Packet>();
private LinkedList<Runnable> callbackBuffer = new LinkedList<Runnable>(); private LinkedList<Runnable> callbackBuffer = new LinkedList<Runnable>();
private Transport transport; /*package*/ Transport transport;
private Future pingTimeoutTimer; private Future pingTimeoutTimer;
private Future pingIntervalTimer; private Future pingIntervalTimer;
@@ -117,6 +120,10 @@ public abstract class Socket extends Emitter {
private ScheduledExecutorService heartbeatScheduler = Executors.newSingleThreadScheduledExecutor(); private ScheduledExecutorService heartbeatScheduler = Executors.newSingleThreadScheduledExecutor();
public Socket() {
this(new Options());
}
/** /**
* Creates a socket. * Creates a socket.
* *
@@ -167,6 +174,7 @@ public abstract class Socket extends Emitter {
this.transports = new ArrayList<String>(Arrays.asList(opts.transports != null ? this.transports = new ArrayList<String>(Arrays.asList(opts.transports != null ?
opts.transports : new String[]{Polling.NAME, WebSocket.NAME})); opts.transports : new String[]{Polling.NAME, WebSocket.NAME}));
this.policyPort = opts.policyPort != 0 ? opts.policyPort : 843; this.policyPort = opts.policyPort != 0 ? opts.policyPort : 843;
this.rememberUpgrade = opts.rememberUpgrade;
} }
/** /**
@@ -176,7 +184,12 @@ public abstract class Socket extends Emitter {
EventThread.exec(new Runnable() { EventThread.exec(new Runnable() {
@Override @Override
public void run() { public void run() {
String transportName = Socket.this.transports.get(0); String transportName;
if (Socket.this.rememberUpgrade && Socket.priorWebsocketSuccess && Socket.this.transports.contains(WebSocket.NAME)) {
transportName = WebSocket.NAME;
} else {
transportName = Socket.this.transports.get(0);
}
Socket.this.readyState = ReadyState.OPENING; Socket.this.readyState = ReadyState.OPENING;
Transport transport = Socket.this.createTransport(transportName); Transport transport = Socket.this.createTransport(transportName);
Socket.this.setTransport(transport); Socket.this.setTransport(transport);
@@ -204,6 +217,7 @@ public abstract class Socket extends Emitter {
opts.timestampRequests = this.timestampRequests; opts.timestampRequests = this.timestampRequests;
opts.timestampParam = this.timestampParam; opts.timestampParam = this.timestampParam;
opts.policyPort = this.policyPort; opts.policyPort = this.policyPort;
opts.socket = this;
if (WebSocket.NAME.equals(name)) { if (WebSocket.NAME.equals(name)) {
return new WebSocket(opts); return new WebSocket(opts);
@@ -256,6 +270,8 @@ public abstract class Socket extends Emitter {
final boolean[] failed = new boolean[] {false}; final boolean[] failed = new boolean[] {false};
final Socket self = this; final Socket self = this;
Socket.priorWebsocketSuccess = false;
final Listener onerror = new Listener() { final Listener onerror = new Listener() {
@Override @Override
public void call(Object... args) { public void call(Object... args) {
@@ -292,6 +308,7 @@ public abstract class Socket extends Emitter {
logger.fine(String.format("probe transport '%s' pong", name)); logger.fine(String.format("probe transport '%s' pong", name));
self.upgrading = true; self.upgrading = true;
self.emit(EVENT_UPGRADING, transport[0]); self.emit(EVENT_UPGRADING, transport[0]);
Socket.priorWebsocketSuccess = WebSocket.NAME.equals(transport[0].name);
logger.fine(String.format("pausing current transport '%s'", self.transport.name)); logger.fine(String.format("pausing current transport '%s'", self.transport.name));
((Polling)self.transport).pause(new Runnable() { ((Polling)self.transport).pause(new Runnable() {
@@ -304,10 +321,10 @@ public abstract class Socket extends Emitter {
logger.fine("changing transport and sending upgrade packet"); logger.fine("changing transport and sending upgrade packet");
transport[0].off(Transport.EVENT_ERROR, onerror); transport[0].off(Transport.EVENT_ERROR, onerror);
self.emit(EVENT_UPGRADE, transport[0]);
self.setTransport(transport[0]); self.setTransport(transport[0]);
Packet packet = new Packet(Packet.UPGRADE); Packet packet = new Packet(Packet.UPGRADE);
transport[0].send(new Packet[]{packet}); transport[0].send(new Packet[]{packet});
self.emit(EVENT_UPGRADE, transport[0]);
transport[0] = null; transport[0] = null;
self.upgrading = false; self.upgrading = false;
self.flush(); self.flush();
@@ -356,6 +373,7 @@ public abstract class Socket extends Emitter {
private void onOpen() { private void onOpen() {
logger.fine("socket open"); logger.fine("socket open");
this.readyState = ReadyState.OPEN; this.readyState = ReadyState.OPEN;
Socket.priorWebsocketSuccess = WebSocket.NAME.equals(this.transport.name);
this.emit(EVENT_OPEN); this.emit(EVENT_OPEN);
this.onopen(); this.onopen();
this.flush(); this.flush();
@@ -574,6 +592,7 @@ public abstract class Socket extends Emitter {
private void onError(Exception err) { private void onError(Exception err) {
logger.fine(String.format("socket error %s", err)); logger.fine(String.format("socket error %s", err));
Socket.priorWebsocketSuccess = false;
this.emit(EVENT_ERROR, err); this.emit(EVENT_ERROR, err);
this.onerror(err); this.onerror(err);
this.onClose("transport error", err); this.onClose("transport error", err);
@@ -605,23 +624,23 @@ public abstract class Socket extends Emitter {
} }
}); });
// ensure transport won't stay open
this.transport.close();
// ignore further transport communication // ignore further transport communication
this.transport.off(); this.transport.off();
// set ready state // set ready state
ReadyState prev = this.readyState;
this.readyState = ReadyState.CLOSED; this.readyState = ReadyState.CLOSED;
// clear session id // clear session id
this.id = null; this.id = null;
// emit events // emit close events
if (prev == ReadyState.OPEN) {
this.emit(EVENT_CLOSE, reason, desc); this.emit(EVENT_CLOSE, reason, desc);
this.onclose(); this.onclose();
} }
} }
}
/*package*/ List<String > filterUpgrades(List<String> upgrades) { /*package*/ List<String > filterUpgrades(List<String> upgrades) {
List<String> filteredUpgrades = new ArrayList<String>(); List<String> filteredUpgrades = new ArrayList<String>();
@@ -653,6 +672,7 @@ public abstract class Socket extends Emitter {
*/ */
public boolean upgrade = true; public boolean upgrade = true;
public boolean rememberUpgrade;
public String host; public String host;
public String query; public String query;

View File

@@ -36,6 +36,7 @@ public abstract class Transport extends Emitter {
protected String path; protected String path;
protected String hostname; protected String hostname;
protected String timestampParam; protected String timestampParam;
protected Socket socket;
protected ReadyState readyState; protected ReadyState readyState;
@@ -47,6 +48,7 @@ public abstract class Transport extends Emitter {
this.query = opts.query; this.query = opts.query;
this.timestampParam = opts.timestampParam; this.timestampParam = opts.timestampParam;
this.timestampRequests = opts.timestampRequests; this.timestampRequests = opts.timestampRequests;
this.socket = opts.socket;
} }
protected Transport onError(String msg, Exception desc) { protected Transport onError(String msg, Exception desc) {
@@ -131,5 +133,6 @@ public abstract class Transport extends Emitter {
public int port; public int port;
public int policyPort; public int policyPort;
public Map<String, String> query; public Map<String, String> query;
protected Socket socket;
} }
} }

View File

@@ -244,11 +244,13 @@ public class PollingXHR extends Polling {
} }
private void cleanup() { private void cleanup() {
if (xhr != null) { if (xhr == null) {
return;
}
xhr.disconnect(); xhr.disconnect();
xhr = null; xhr = null;
} }
}
public void abort() { public void abort() {
this.cleanup(); this.cleanup();

View File

@@ -22,7 +22,7 @@ public class WebSocket extends Transport {
public static final String NAME = "websocket"; public static final String NAME = "websocket";
private WebSocketClient socket; private WebSocketClient ws;
public WebSocket(Options opts) { public WebSocket(Options opts) {
@@ -40,7 +40,7 @@ public class WebSocket extends Transport {
final WebSocket self = this; final WebSocket self = this;
try { try {
this.socket = new WebSocketClient(new URI(this.uri()), new Draft_17(), headers, 0) { this.ws = new WebSocketClient(new URI(this.uri()), new Draft_17(), headers, 0) {
@Override @Override
public void onOpen(final ServerHandshake serverHandshake) { public void onOpen(final ServerHandshake serverHandshake) {
EventThread.exec(new Runnable() { EventThread.exec(new Runnable() {
@@ -87,7 +87,7 @@ public class WebSocket extends Transport {
}); });
} }
}; };
this.socket.connect(); this.ws.connect();
} catch (URISyntaxException e) { } catch (URISyntaxException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@@ -97,7 +97,7 @@ public class WebSocket extends Transport {
final WebSocket self = this; final WebSocket self = this;
this.writable = false; this.writable = false;
for (Packet packet : packets) { for (Packet packet : packets) {
this.socket.send(Parser.encodePacket(packet)); this.ws.send(Parser.encodePacket(packet));
} }
final Runnable ondrain = new Runnable() { final Runnable ondrain = new Runnable() {
@@ -119,8 +119,8 @@ public class WebSocket extends Transport {
} }
protected void doClose() { protected void doClose() {
if (this.socket != null) { if (this.ws != null) {
this.socket.close(); this.ws.close();
} }
} }

View File

@@ -311,4 +311,127 @@ public class ServerConnectionTest {
assertThat(messages.take(), is("foo")); assertThat(messages.take(), is("foo"));
socket.close(); socket.close();
} }
@Test(timeout = TIMEOUT)
public void rememberWebsocket() throws URISyntaxException, InterruptedException {
final Semaphore semaphore = new Semaphore(0);
EventThread.exec(new Runnable() {
@Override
public void run() {
Socket.Options opts = new Socket.Options();
opts.port = PORT;
final Socket socket = new Socket(opts) {
@Override
public void onopen() {
}
@Override
public void onmessage(String data) {
}
@Override
public void onclose() {
}
@Override
public void onerror(Exception err) {
}
};
socket.on(Socket.EVENT_UPGRADE, new Emitter.Listener() {
@Override
public void call(Object... args) {
Transport transport = (Transport) args[0];
socket.close();
if (WebSocket.NAME.equals(transport.name)) {
Socket.Options opts = new Socket.Options();
opts.port = PORT;
opts.rememberUpgrade = true;
final Socket socket2 = new Socket(opts) {
@Override
public void onopen() {
}
@Override
public void onmessage(String data) {
}
@Override
public void onclose() {
}
@Override
public void onerror(Exception err) {
}
};
socket2.open();
assertThat(socket2.transport.name, is(WebSocket.NAME));
}
semaphore.release();
}
});
socket.open();
assertThat(socket.transport.name, is(Polling.NAME));
}
});
semaphore.acquire();
}
@Test(timeout = TIMEOUT)
public void notRememberWebsocket() throws URISyntaxException, InterruptedException {
final Semaphore semaphore = new Semaphore(0);
EventThread.exec(new Runnable() {
@Override
public void run() {
Socket.Options opts = new Socket.Options();
opts.port = PORT;
final Socket socket = new Socket(opts) {
@Override
public void onopen() {}
@Override
public void onmessage(String data) {}
@Override
public void onclose() {}
@Override
public void onerror(Exception err) {}
};
socket.on(Socket.EVENT_UPGRADE, new Emitter.Listener() {
@Override
public void call(Object... args) {
Transport transport = (Transport)args[0];
socket.close();
if (WebSocket.NAME.equals(transport.name)) {
Socket.Options opts = new Socket.Options();
opts.port = PORT;
opts.rememberUpgrade = false;
final Socket socket2 = new Socket(opts) {
@Override
public void onopen() {}
@Override
public void onmessage(String data) {}
@Override
public void onclose() {}
@Override
public void onerror(Exception err) {}
};
socket2.open();
assertThat(socket2.transport.name, is(not(WebSocket.NAME)));
}
semaphore.release();
}
});
socket.open();
assertThat(socket.transport.name, is(Polling.NAME));
}
});
semaphore.acquire();
}
} }

View File

@@ -12,6 +12,7 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Timer; import java.util.Timer;
import java.util.TimerTask; import java.util.TimerTask;
import java.util.concurrent.Semaphore;
import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat; import static org.junit.Assert.assertThat;
@@ -42,13 +43,14 @@ public class SocketTest {
} }
/** /**
* should not emit close on incorrect connection. * should emit close on incorrect connection.
* *
* @throws URISyntaxException * @throws URISyntaxException
* @throws InterruptedException * @throws InterruptedException
*/ */
@Test @Test
public void socketClosing() throws URISyntaxException, InterruptedException { public void socketClosing() throws URISyntaxException, InterruptedException {
final Semaphore semaphore = new Semaphore(0);
Socket socket = new Socket("ws://0.0.0.0:8080") { Socket socket = new Socket("ws://0.0.0.0:8080") {
@Override @Override
public void onopen() {} public void onopen() {}
@@ -68,7 +70,8 @@ public class SocketTest {
timer.schedule(new TimerTask() { timer.schedule(new TimerTask() {
@Override @Override
public void run() { public void run() {
assertThat(closed[0], is(false)); assertThat(closed[0], is(true));
semaphore.release();
} }
}, 20); }, 20);
} }
@@ -81,5 +84,7 @@ public class SocketTest {
} }
}); });
socket.open(); socket.open();
semaphore.acquire();
} }
} }

View File

@@ -14,6 +14,8 @@ import static org.junit.Assert.assertThat;
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
public class TransportTest { public class TransportTest {
// NOTE: tests for the rememberUpgrade option are on ServerConnectionTest.
@Test @Test
public void uri() { public void uri() {
Transport.Options opt = new Transport.Options(); Transport.Options opt = new Transport.Options();