diff --git a/pom.xml b/pom.xml index 9c4f82d..735d9b5 100644 --- a/pom.xml +++ b/pom.xml @@ -47,11 +47,6 @@ json 20090211 - - org.java-websocket - Java-WebSocket - 1.3.0 - junit junit @@ -64,6 +59,11 @@ 1.3 test + + com.squareup.okhttp + okhttp-ws + 2.3.0 + diff --git a/src/main/java/com/github/nkzawa/engineio/client/Socket.java b/src/main/java/com/github/nkzawa/engineio/client/Socket.java index 4da61a6..73de492 100644 --- a/src/main/java/com/github/nkzawa/engineio/client/Socket.java +++ b/src/main/java/com/github/nkzawa/engineio/client/Socket.java @@ -10,6 +10,7 @@ import com.github.nkzawa.parseqs.ParseQS; import com.github.nkzawa.thread.EventThread; import org.json.JSONException; +import javax.net.ssl.HostnameVerifier; import javax.net.ssl.SSLContext; import java.net.URI; import java.net.URISyntaxException; @@ -121,6 +122,7 @@ public class Socket extends Emitter { private Future pingTimeoutTimer; private Future pingIntervalTimer; private SSLContext sslContext; + private HostnameVerifier hostnameVerifier; private ReadyState readyState; private ScheduledExecutorService heartbeatScheduler; @@ -197,6 +199,7 @@ public class Socket extends Emitter { opts.transports : new String[]{Polling.NAME, WebSocket.NAME})); this.policyPort = opts.policyPort != 0 ? opts.policyPort : 843; this.rememberUpgrade = opts.rememberUpgrade; + this.hostnameVerifier = opts.hostnameVerifier; } /** @@ -254,6 +257,7 @@ public class Socket extends Emitter { opts.timestampParam = this.timestampParam; opts.policyPort = this.policyPort; opts.socket = this; + opts.hostnameVerifier = this.hostnameVerifier; Transport transport; if (WebSocket.NAME.equals(name)) { diff --git a/src/main/java/com/github/nkzawa/engineio/client/Transport.java b/src/main/java/com/github/nkzawa/engineio/client/Transport.java index 8277075..2e12bad 100644 --- a/src/main/java/com/github/nkzawa/engineio/client/Transport.java +++ b/src/main/java/com/github/nkzawa/engineio/client/Transport.java @@ -6,6 +6,7 @@ import com.github.nkzawa.engineio.parser.Packet; import com.github.nkzawa.engineio.parser.Parser; import com.github.nkzawa.thread.EventThread; +import javax.net.ssl.HostnameVerifier; import javax.net.ssl.SSLContext; import java.util.Map; @@ -42,6 +43,7 @@ public abstract class Transport extends Emitter { protected String timestampParam; protected SSLContext sslContext; protected Socket socket; + protected HostnameVerifier hostnameVerifier; protected ReadyState readyState; @@ -55,6 +57,7 @@ public abstract class Transport extends Emitter { this.timestampRequests = opts.timestampRequests; this.sslContext = opts.sslContext; this.socket = opts.socket; + this.hostnameVerifier = opts.hostnameVerifier; } protected Transport onError(String msg, Exception desc) { @@ -144,6 +147,7 @@ public abstract class Transport extends Emitter { public int policyPort = -1; public Map query; public SSLContext sslContext; + public HostnameVerifier hostnameVerifier; protected Socket socket; } } diff --git a/src/main/java/com/github/nkzawa/engineio/client/transports/PollingXHR.java b/src/main/java/com/github/nkzawa/engineio/client/transports/PollingXHR.java index c937fc9..f82ef49 100644 --- a/src/main/java/com/github/nkzawa/engineio/client/transports/PollingXHR.java +++ b/src/main/java/com/github/nkzawa/engineio/client/transports/PollingXHR.java @@ -4,6 +4,7 @@ package com.github.nkzawa.engineio.client.transports; import com.github.nkzawa.emitter.Emitter; import com.github.nkzawa.thread.EventThread; +import javax.net.ssl.HostnameVerifier; import javax.net.ssl.HttpsURLConnection; import javax.net.ssl.SSLContext; import java.io.*; @@ -37,6 +38,7 @@ public class PollingXHR extends Polling { } opts.uri = this.uri(); opts.sslContext = this.sslContext; + opts.hostnameVerifier = this.hostnameVerifier; Request req = new Request(opts); @@ -148,12 +150,14 @@ public class PollingXHR extends Polling { private SSLContext sslContext; private HttpURLConnection xhr; + private HostnameVerifier hostnameVerifier; public Request(Options opts) { this.method = opts.method != null ? opts.method : "GET"; this.uri = opts.uri; this.data = opts.data; this.sslContext = opts.sslContext; + this.hostnameVerifier = opts.hostnameVerifier; } public void create() { @@ -170,8 +174,13 @@ public class PollingXHR extends Polling { xhr.setConnectTimeout(10000); - if (xhr instanceof HttpsURLConnection && this.sslContext != null) { - ((HttpsURLConnection)xhr).setSSLSocketFactory(this.sslContext.getSocketFactory()); + if (xhr instanceof HttpsURLConnection) { + if (this.sslContext != null) { + ((HttpsURLConnection)xhr).setSSLSocketFactory(this.sslContext.getSocketFactory()); + } + if (this.hostnameVerifier != null) { + ((HttpsURLConnection)xhr).setHostnameVerifier(this.hostnameVerifier); + } } Map headers = new TreeMap(String.CASE_INSENSITIVE_ORDER); @@ -317,6 +326,7 @@ public class PollingXHR extends Polling { public String method; public byte[] data; public SSLContext sslContext; + public HostnameVerifier hostnameVerifier; } } } diff --git a/src/main/java/com/github/nkzawa/engineio/client/transports/WebSocket.java b/src/main/java/com/github/nkzawa/engineio/client/transports/WebSocket.java index f7a8c88..a11439e 100644 --- a/src/main/java/com/github/nkzawa/engineio/client/transports/WebSocket.java +++ b/src/main/java/com/github/nkzawa/engineio/client/transports/WebSocket.java @@ -6,22 +6,33 @@ import com.github.nkzawa.engineio.parser.Packet; import com.github.nkzawa.engineio.parser.Parser; import com.github.nkzawa.parseqs.ParseQS; import com.github.nkzawa.thread.EventThread; -import org.java_websocket.client.DefaultSSLWebSocketClientFactory; -import org.java_websocket.client.WebSocketClient; -import org.java_websocket.drafts.Draft_17; -import org.java_websocket.handshake.ServerHandshake; +import com.squareup.okhttp.Headers; +import com.squareup.okhttp.OkHttpClient; +import com.squareup.okhttp.Request; +import com.squareup.okhttp.Response; +import com.squareup.okhttp.ws.WebSocket.PayloadType; +import com.squareup.okhttp.ws.WebSocketCall; +import com.squareup.okhttp.ws.WebSocketListener; -import java.net.URI; -import java.net.URISyntaxException; -import java.nio.ByteBuffer; -import java.util.*; +import java.io.IOException; +import java.util.Date; +import java.util.HashMap; +import java.util.Map; +import java.util.TreeMap; + +import javax.net.ssl.SSLSocketFactory; + +import okio.Buffer; +import okio.BufferedSource; + +import static com.squareup.okhttp.ws.WebSocket.PayloadType.BINARY; +import static com.squareup.okhttp.ws.WebSocket.PayloadType.TEXT; public class WebSocket extends Transport { public static final String NAME = "websocket"; - - private WebSocketClient ws; - + private com.squareup.okhttp.ws.WebSocket ws; + private WebSocketCall wsCall; public WebSocket(Options opts) { super(opts); @@ -37,70 +48,98 @@ public class WebSocket extends Transport { this.emit(EVENT_REQUEST_HEADERS, headers); final WebSocket self = this; - try { - this.ws = new WebSocketClient(new URI(this.uri()), new Draft_17(), headers, 0) { - @Override - public void onOpen(final ServerHandshake serverHandshake) { - EventThread.exec(new Runnable() { - @Override - public void run() { - Map headers = new TreeMap(String.CASE_INSENSITIVE_ORDER); - Iterator it = serverHandshake.iterateHttpFields(); - while (it.hasNext()) { - String field = it.next(); - if (field == null) continue; - headers.put(field, serverHandshake.getFieldValue(field)); - } - self.emit(EVENT_RESPONSE_HEADERS, headers); - - self.onOpen(); - } - }); - } - @Override - public void onClose(int i, String s, boolean b) { - EventThread.exec(new Runnable() { - @Override - public void run() { - self.onClose(); - } - }); - } - @Override - public void onMessage(final String s) { - EventThread.exec(new Runnable() { - @Override - public void run() { - self.onData(s); - } - }); - } - @Override - public void onMessage(final ByteBuffer s) { - EventThread.exec(new Runnable() { - @Override - public void run() { - self.onData(s.array()); - } - }); - } - @Override - public void onError(final Exception e) { - EventThread.exec(new Runnable() { - @Override - public void run() { - self.onError("websocket error", e); - } - }); - } - }; - if (this.sslContext != null) { - this.ws.setWebSocketFactory(new DefaultSSLWebSocketClientFactory(this.sslContext)); - } - this.ws.connect(); - } catch (URISyntaxException e) { - throw new RuntimeException(e); + final OkHttpClient client = new OkHttpClient(); + if (this.sslContext != null) { + SSLSocketFactory factory = sslContext.getSocketFactory();// (SSLSocketFactory) SSLSocketFactory.getDefault(); + client.setSslSocketFactory(factory); } + if (this.hostnameVerifier != null) { + client.setHostnameVerifier(this.hostnameVerifier); + } + Request.Builder builder = new Request.Builder().url(uri()); + for (Map.Entry entry : headers.entrySet()) { + builder.addHeader(entry.getKey(), entry.getValue()); + } + final Request request = builder.build(); + (wsCall = WebSocketCall.create(client, request)).enqueue(new WebSocketListener() { + @Override + public void onOpen(com.squareup.okhttp.ws.WebSocket webSocket, Request request, Response response) throws IOException { + ws = webSocket; + final Map headers = new TreeMap(String.CASE_INSENSITIVE_ORDER); + Headers responseHeaders = response.headers(); + for (int i = 0, size = responseHeaders.size(); i < size; i++) { + headers.put(responseHeaders.name(i), responseHeaders.value(i)); + } + EventThread.exec(new Runnable() { + @Override + public void run() { + self.emit(EVENT_RESPONSE_HEADERS, headers); + self.onOpen(); + } + }); + } + + @Override + public void onMessage(BufferedSource payload, final PayloadType type) throws IOException { + Object data = null; + switch (type) { + case TEXT: + data = payload.readUtf8(); + break; + case BINARY: + data = payload.readByteArray(); + break; + default: + EventThread.exec(new Runnable() { + @Override + public void run() { + self.onError("Unknown payload type: " + type, new IllegalStateException()); + } + }); + } + payload.close(); + final Object finalData = data; + EventThread.exec(new Runnable() { + @Override + public void run() { + if (finalData == null) { + return; + } + if (finalData instanceof String) { + self.onData((String) finalData); + } else { + self.onData((byte[]) finalData); + } + } + }); + + } + + @Override + public void onPong(Buffer payload) { + } + + @Override + public void onClose(int code, String reason) { + EventThread.exec(new Runnable() { + @Override + public void run() { + self.onClose(); + } + }); + } + + @Override + public void onFailure(final IOException e) { + EventThread.exec(new Runnable() { + @Override + public void run() { + self.onError("websocket error", e); + } + }); + } + }); + client.getDispatcher().getExecutorService().shutdown(); } protected void write(Packet[] packets) { @@ -110,10 +149,14 @@ public class WebSocket extends Transport { Parser.encodePacket(packet, new Parser.EncodeCallback() { @Override public void call(Object packet) { - if (packet instanceof String) { - self.ws.send((String) packet); - } else if (packet instanceof byte[]) { - self.ws.send((byte[]) packet); + try { + if (packet instanceof String) { + self.ws.sendMessage(TEXT, new Buffer().writeUtf8((String) packet)); + } else if (packet instanceof byte[]) { + self.ws.sendMessage(BINARY, new Buffer().write((byte[]) packet)); + } + } catch (IOException e) { + self.onError("websocket error", e); } } }); @@ -138,8 +181,17 @@ public class WebSocket extends Transport { } protected void doClose() { - if (this.ws != null) { - this.ws.close(); + if (wsCall != null) { + wsCall.cancel(); + wsCall = null; + } + if (ws != null) { + try { + ws.close(1000, ""); + } catch (IOException e) { + onError("doClose error", e); + } + ws = null; } } @@ -173,4 +225,4 @@ public class WebSocket extends Transport { return true; } -} +} \ No newline at end of file diff --git a/src/test/java/com/github/nkzawa/engineio/client/SSLConnectionTest.java b/src/test/java/com/github/nkzawa/engineio/client/SSLConnectionTest.java index c4c0e5c..3ec7148 100644 --- a/src/test/java/com/github/nkzawa/engineio/client/SSLConnectionTest.java +++ b/src/test/java/com/github/nkzawa/engineio/client/SSLConnectionTest.java @@ -6,6 +6,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import javax.net.ssl.HostnameVerifier; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.TrustManagerFactory; @@ -23,15 +24,11 @@ import static org.junit.Assert.assertThat; @RunWith(JUnit4.class) public class SSLConnectionTest extends Connection { - static { - // for test on localhost - javax.net.ssl.HttpsURLConnection.setDefaultHostnameVerifier( - new javax.net.ssl.HostnameVerifier(){ - public boolean verify(String hostname, javax.net.ssl.SSLSession sslSession) { - return hostname.equals("localhost"); - } - }); - } + static HostnameVerifier hostnameVerifier = new javax.net.ssl.HostnameVerifier(){ + public boolean verify(String hostname, javax.net.ssl.SSLSession sslSession) { + return hostname.equals("localhost"); + } + }; private Socket socket; @@ -74,6 +71,7 @@ public class SSLConnectionTest extends Connection { Socket.Options opts = createOptions(); opts.sslContext = createSSLContext(); + opts.hostnameVerifier = SSLConnectionTest.hostnameVerifier; socket = new Socket(opts); socket.on(Socket.EVENT_OPEN, new Emitter.Listener() { @Override @@ -98,6 +96,7 @@ public class SSLConnectionTest extends Connection { Socket.Options opts = createOptions(); opts.sslContext = createSSLContext(); + opts.hostnameVerifier = SSLConnectionTest.hostnameVerifier; socket = new Socket(opts); socket.on(Socket.EVENT_OPEN, new Emitter.Listener() { @Override @@ -127,7 +126,9 @@ public class SSLConnectionTest extends Connection { final BlockingQueue values = new LinkedBlockingQueue(); Socket.setDefaultSSLContext(createSSLContext()); - socket = new Socket(createOptions()); + Socket.Options opts = createOptions(); + opts.hostnameVerifier = SSLConnectionTest.hostnameVerifier; + socket = new Socket(opts); socket.on(Socket.EVENT_OPEN, new Emitter.Listener() { @Override public void call(Object... args) {