diff --git a/src/main/java/com/github/nkzawa/engineio/client/Socket.java b/src/main/java/com/github/nkzawa/engineio/client/Socket.java index 7ef488b..a70e94f 100644 --- a/src/main/java/com/github/nkzawa/engineio/client/Socket.java +++ b/src/main/java/com/github/nkzawa/engineio/client/Socket.java @@ -115,7 +115,7 @@ public class Socket extends Emitter { private List transports; private List upgrades; private Map query; - private LinkedList writeBuffer = new LinkedList(); + /*package*/ LinkedList writeBuffer = new LinkedList(); private LinkedList callbackBuffer = new LinkedList(); /*package*/ Transport transport; private Future pingTimeoutTimer; @@ -325,6 +325,7 @@ public class Socket extends Emitter { logger.fine(String.format("probe transport '%s' pong", name)); self.upgrading = true; self.emit(EVENT_UPGRADING, transport[0]); + if (null == transport) return; Socket.priorWebsocketSuccess = WebSocket.NAME.equals(transport[0].name); logger.fine(String.format("pausing current transport '%s'", self.transport.name)); @@ -332,9 +333,7 @@ public class Socket extends Emitter { @Override public void run() { if (failed[0]) return; - if (ReadyState.CLOSED == self.readyState || ReadyState.CLOSING == self.readyState) { - return; - } + if (ReadyState.CLOSED == self.readyState) return; logger.fine("changing transport and sending upgrade packet"); @@ -667,6 +666,10 @@ public class Socket extends Emitter { } private void sendPacket(Packet packet, Runnable fn) { + if (ReadyState.CLOSING == this.readyState || ReadyState.CLOSED == this.readyState) { + return; + } + if (fn == null) { // ConcurrentLinkedList does not permit `null`. fn = noop; @@ -688,11 +691,55 @@ public class Socket extends Emitter { @Override public void run() { if (Socket.this.readyState == ReadyState.OPENING || Socket.this.readyState == ReadyState.OPEN) { - Socket.this.onClose("forced close"); - logger.fine("socket closing - telling transport to close"); - Socket.this.transport.close(); - } + Socket.this.readyState = ReadyState.CLOSING; + final Socket self = Socket.this; + + final Runnable close = new Runnable() { + @Override + public void run() { + self.onClose("forced close"); + logger.fine("socket closing - telling transport to close"); + self.transport.close(); + } + }; + + final Listener[] cleanupAndClose = new Listener[1]; + cleanupAndClose[0] = new Listener() { + @Override + public void call(Object ...args) { + self.off(EVENT_UPGRADE, cleanupAndClose[0]); + self.off(EVENT_UPGRADE_ERROR, cleanupAndClose[0]); + close.run(); + } + }; + + final Runnable waitForUpgrade = new Runnable() { + @Override + public void run() { + // wait for updade to finish since we can't send packets while pausing a transport + self.once(EVENT_UPGRADE, cleanupAndClose[0]); + self.once(EVENT_UPGRADE_ERROR, cleanupAndClose[0]); + } + }; + + if (Socket.this.writeBuffer.size() > 0) { + Socket.this.once(EVENT_DRAIN, new Listener() { + @Override + public void call(Object... args) { + if (Socket.this.upgrading) { + waitForUpgrade.run(); + } else { + close.run(); + } + } + }); + } else if (Socket.this.upgrading) { + waitForUpgrade.run(); + } else { + close.run(); + } + } } }); return this; @@ -710,7 +757,7 @@ public class Socket extends Emitter { } private void onClose(String reason, Exception desc) { - if (this.readyState == ReadyState.OPENING || this.readyState == ReadyState.OPEN) { + if (ReadyState.OPENING == this.readyState || ReadyState.OPEN == this.readyState || ReadyState.CLOSING == this.readyState) { logger.fine(String.format("socket close with reason: %s", reason)); final Socket self = this; diff --git a/src/test/java/com/github/nkzawa/engineio/client/ConnectionTest.java b/src/test/java/com/github/nkzawa/engineio/client/ConnectionTest.java index 391fff1..a1f5c6a 100644 --- a/src/test/java/com/github/nkzawa/engineio/client/ConnectionTest.java +++ b/src/test/java/com/github/nkzawa/engineio/client/ConnectionTest.java @@ -104,6 +104,7 @@ public class ConnectionTest extends Connection { } }); socket.close(); + socket.send("hi"); Timer timer = new Timer(); timer.schedule(new TimerTask() { @Override @@ -117,4 +118,128 @@ public class ConnectionTest extends Connection { socket.open(); assertThat((Boolean)values.take(), is(true)); } + + @Test(timeout = TIMEOUT) + public void deferCloseWhenUpgrading() throws InterruptedException { + final BlockingQueue values = new LinkedBlockingQueue(); + + socket = new Socket(createOptions()); + socket.on(Socket.EVENT_OPEN, new Emitter.Listener() { + @Override + public void call(Object... args) { + final boolean[] upgraded = new boolean[] {false}; + socket.on(Socket.EVENT_UPGRADE, new Emitter.Listener() { + @Override + public void call(Object... args) { + upgraded[0] = true; + } + }).on(Socket.EVENT_UPGRADING, new Emitter.Listener() { + @Override + public void call(Object... args) { + socket.on(Socket.EVENT_CLOSE, new Emitter.Listener() { + @Override + public void call(Object... args) { + values.offer(upgraded[0]); + } + }); + socket.close(); + } + }); + } + }); + socket.open(); + assertThat((Boolean)values.take(), is(true)); + } + + @Test(timeout = TIMEOUT) + public void closeOnUpgradeErrorIfClosingIsDeferred() throws InterruptedException { + final BlockingQueue values = new LinkedBlockingQueue(); + + socket = new Socket(createOptions()); + socket.on(Socket.EVENT_OPEN, new Emitter.Listener() { + @Override + public void call(Object... args) { + final boolean[] upgradError = new boolean[] {false}; + socket.on(Socket.EVENT_UPGRADE_ERROR, new Emitter.Listener() { + @Override + public void call(Object... args) { + upgradError[0] = true; + } + }).on(Socket.EVENT_UPGRADING, new Emitter.Listener() { + @Override + public void call(Object... args) { + socket.on(Socket.EVENT_CLOSE, new Emitter.Listener() { + @Override + public void call(Object... args) { + values.offer(upgradError[0]); + } + }); + socket.close(); + socket.transport.onError("upgrade error", new Exception()); + } + }); + } + }); + socket.open(); + assertThat((Boolean) values.take(), is(true)); + } + + public void notSendPacketsIfClosingIsDeferred() throws InterruptedException { + final BlockingQueue values = new LinkedBlockingQueue(); + + socket = new Socket(createOptions()); + socket.on(Socket.EVENT_OPEN, new Emitter.Listener() { + @Override + public void call(Object... args) { + final boolean[] noPacket = new boolean[] {true}; + socket.on(Socket.EVENT_UPGRADING, new Emitter.Listener() { + @Override + public void call(Object... args) { + socket.on(Socket.EVENT_PACKET_CREATE, new Emitter.Listener() { + @Override + public void call(Object... args) { + noPacket[0] = false; + } + }); + socket.close(); + socket.send("hi"); + } + }); + new Timer().schedule(new TimerTask() { + @Override + public void run() { + values.offer(noPacket[0]); + } + }, 1200); + } + }); + socket.open(); + assertThat((Boolean) values.take(), is(true)); + } + + @Test(timeout = TIMEOUT) + public void sendAllBufferedPacketsIfClosingIsDeferred() throws InterruptedException { + final BlockingQueue values = new LinkedBlockingQueue(); + + socket = new Socket(createOptions()); + socket.on(Socket.EVENT_OPEN, new Emitter.Listener() { + @Override + public void call(Object... args) { + socket.on(Socket.EVENT_UPGRADING, new Emitter.Listener() { + @Override + public void call(Object... args) { + socket.send("hi"); + socket.close(); + } + }).on(Socket.EVENT_CLOSE, new Emitter.Listener() { + @Override + public void call(Object... args) { + values.offer(socket.writeBuffer.size()); + } + }); + } + }); + socket.open(); + assertThat((Integer) values.take(), is(0)); + } }