Merge pull request #14 from b95505017/websocket_okhttp

Replace Java-WebSocket to OkHttp WebSocket
This commit is contained in:
Naoyuki Kanezawa
2015-05-02 14:27:08 +09:00
6 changed files with 169 additions and 98 deletions

10
pom.xml
View File

@@ -47,11 +47,6 @@
<artifactId>json</artifactId> <artifactId>json</artifactId>
<version>20090211</version> <version>20090211</version>
</dependency> </dependency>
<dependency>
<groupId>org.java-websocket</groupId>
<artifactId>Java-WebSocket</artifactId>
<version>1.3.0</version>
</dependency>
<dependency> <dependency>
<groupId>junit</groupId> <groupId>junit</groupId>
<artifactId>junit</artifactId> <artifactId>junit</artifactId>
@@ -64,6 +59,11 @@
<version>1.3</version> <version>1.3</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>com.squareup.okhttp</groupId>
<artifactId>okhttp-ws</artifactId>
<version>2.3.0</version>
</dependency>
</dependencies> </dependencies>
<distributionManagement> <distributionManagement>

View File

@@ -10,6 +10,7 @@ import com.github.nkzawa.parseqs.ParseQS;
import com.github.nkzawa.thread.EventThread; import com.github.nkzawa.thread.EventThread;
import org.json.JSONException; import org.json.JSONException;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLContext; import javax.net.ssl.SSLContext;
import java.net.URI; import java.net.URI;
import java.net.URISyntaxException; import java.net.URISyntaxException;
@@ -121,6 +122,7 @@ public class Socket extends Emitter {
private Future pingTimeoutTimer; private Future pingTimeoutTimer;
private Future pingIntervalTimer; private Future pingIntervalTimer;
private SSLContext sslContext; private SSLContext sslContext;
private HostnameVerifier hostnameVerifier;
private ReadyState readyState; private ReadyState readyState;
private ScheduledExecutorService heartbeatScheduler; private ScheduledExecutorService heartbeatScheduler;
@@ -197,6 +199,7 @@ public class Socket extends Emitter {
opts.transports : new String[]{Polling.NAME, WebSocket.NAME})); opts.transports : new String[]{Polling.NAME, WebSocket.NAME}));
this.policyPort = opts.policyPort != 0 ? opts.policyPort : 843; this.policyPort = opts.policyPort != 0 ? opts.policyPort : 843;
this.rememberUpgrade = opts.rememberUpgrade; this.rememberUpgrade = opts.rememberUpgrade;
this.hostnameVerifier = opts.hostnameVerifier;
} }
/** /**
@@ -254,6 +257,7 @@ public class Socket extends Emitter {
opts.timestampParam = this.timestampParam; opts.timestampParam = this.timestampParam;
opts.policyPort = this.policyPort; opts.policyPort = this.policyPort;
opts.socket = this; opts.socket = this;
opts.hostnameVerifier = this.hostnameVerifier;
Transport transport; Transport transport;
if (WebSocket.NAME.equals(name)) { if (WebSocket.NAME.equals(name)) {

View File

@@ -6,6 +6,7 @@ import com.github.nkzawa.engineio.parser.Packet;
import com.github.nkzawa.engineio.parser.Parser; import com.github.nkzawa.engineio.parser.Parser;
import com.github.nkzawa.thread.EventThread; import com.github.nkzawa.thread.EventThread;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLContext; import javax.net.ssl.SSLContext;
import java.util.Map; import java.util.Map;
@@ -42,6 +43,7 @@ public abstract class Transport extends Emitter {
protected String timestampParam; protected String timestampParam;
protected SSLContext sslContext; protected SSLContext sslContext;
protected Socket socket; protected Socket socket;
protected HostnameVerifier hostnameVerifier;
protected ReadyState readyState; protected ReadyState readyState;
@@ -55,6 +57,7 @@ public abstract class Transport extends Emitter {
this.timestampRequests = opts.timestampRequests; this.timestampRequests = opts.timestampRequests;
this.sslContext = opts.sslContext; this.sslContext = opts.sslContext;
this.socket = opts.socket; this.socket = opts.socket;
this.hostnameVerifier = opts.hostnameVerifier;
} }
protected Transport onError(String msg, Exception desc) { protected Transport onError(String msg, Exception desc) {
@@ -144,6 +147,7 @@ public abstract class Transport extends Emitter {
public int policyPort = -1; public int policyPort = -1;
public Map<String, String> query; public Map<String, String> query;
public SSLContext sslContext; public SSLContext sslContext;
public HostnameVerifier hostnameVerifier;
protected Socket socket; protected Socket socket;
} }
} }

View File

@@ -4,6 +4,7 @@ package com.github.nkzawa.engineio.client.transports;
import com.github.nkzawa.emitter.Emitter; import com.github.nkzawa.emitter.Emitter;
import com.github.nkzawa.thread.EventThread; import com.github.nkzawa.thread.EventThread;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.HttpsURLConnection; import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLContext; import javax.net.ssl.SSLContext;
import java.io.*; import java.io.*;
@@ -37,6 +38,7 @@ public class PollingXHR extends Polling {
} }
opts.uri = this.uri(); opts.uri = this.uri();
opts.sslContext = this.sslContext; opts.sslContext = this.sslContext;
opts.hostnameVerifier = this.hostnameVerifier;
Request req = new Request(opts); Request req = new Request(opts);
@@ -148,12 +150,14 @@ public class PollingXHR extends Polling {
private SSLContext sslContext; private SSLContext sslContext;
private HttpURLConnection xhr; private HttpURLConnection xhr;
private HostnameVerifier hostnameVerifier;
public Request(Options opts) { public Request(Options opts) {
this.method = opts.method != null ? opts.method : "GET"; this.method = opts.method != null ? opts.method : "GET";
this.uri = opts.uri; this.uri = opts.uri;
this.data = opts.data; this.data = opts.data;
this.sslContext = opts.sslContext; this.sslContext = opts.sslContext;
this.hostnameVerifier = opts.hostnameVerifier;
} }
public void create() { public void create() {
@@ -170,8 +174,13 @@ public class PollingXHR extends Polling {
xhr.setConnectTimeout(10000); xhr.setConnectTimeout(10000);
if (xhr instanceof HttpsURLConnection && this.sslContext != null) { if (xhr instanceof HttpsURLConnection) {
((HttpsURLConnection)xhr).setSSLSocketFactory(this.sslContext.getSocketFactory()); if (this.sslContext != null) {
((HttpsURLConnection)xhr).setSSLSocketFactory(this.sslContext.getSocketFactory());
}
if (this.hostnameVerifier != null) {
((HttpsURLConnection)xhr).setHostnameVerifier(this.hostnameVerifier);
}
} }
Map<String, String> headers = new TreeMap<String, String>(String.CASE_INSENSITIVE_ORDER); Map<String, String> headers = new TreeMap<String, String>(String.CASE_INSENSITIVE_ORDER);
@@ -317,6 +326,7 @@ public class PollingXHR extends Polling {
public String method; public String method;
public byte[] data; public byte[] data;
public SSLContext sslContext; public SSLContext sslContext;
public HostnameVerifier hostnameVerifier;
} }
} }
} }

View File

@@ -6,22 +6,33 @@ import com.github.nkzawa.engineio.parser.Packet;
import com.github.nkzawa.engineio.parser.Parser; import com.github.nkzawa.engineio.parser.Parser;
import com.github.nkzawa.parseqs.ParseQS; import com.github.nkzawa.parseqs.ParseQS;
import com.github.nkzawa.thread.EventThread; import com.github.nkzawa.thread.EventThread;
import org.java_websocket.client.DefaultSSLWebSocketClientFactory; import com.squareup.okhttp.Headers;
import org.java_websocket.client.WebSocketClient; import com.squareup.okhttp.OkHttpClient;
import org.java_websocket.drafts.Draft_17; import com.squareup.okhttp.Request;
import org.java_websocket.handshake.ServerHandshake; import com.squareup.okhttp.Response;
import com.squareup.okhttp.ws.WebSocket.PayloadType;
import com.squareup.okhttp.ws.WebSocketCall;
import com.squareup.okhttp.ws.WebSocketListener;
import java.net.URI; import java.io.IOException;
import java.net.URISyntaxException; import java.util.Date;
import java.nio.ByteBuffer; import java.util.HashMap;
import java.util.*; import java.util.Map;
import java.util.TreeMap;
import javax.net.ssl.SSLSocketFactory;
import okio.Buffer;
import okio.BufferedSource;
import static com.squareup.okhttp.ws.WebSocket.PayloadType.BINARY;
import static com.squareup.okhttp.ws.WebSocket.PayloadType.TEXT;
public class WebSocket extends Transport { public class WebSocket extends Transport {
public static final String NAME = "websocket"; public static final String NAME = "websocket";
private com.squareup.okhttp.ws.WebSocket ws;
private WebSocketClient ws; private WebSocketCall wsCall;
public WebSocket(Options opts) { public WebSocket(Options opts) {
super(opts); super(opts);
@@ -37,70 +48,98 @@ public class WebSocket extends Transport {
this.emit(EVENT_REQUEST_HEADERS, headers); this.emit(EVENT_REQUEST_HEADERS, headers);
final WebSocket self = this; final WebSocket self = this;
try { final OkHttpClient client = new OkHttpClient();
this.ws = new WebSocketClient(new URI(this.uri()), new Draft_17(), headers, 0) { if (this.sslContext != null) {
@Override SSLSocketFactory factory = sslContext.getSocketFactory();// (SSLSocketFactory) SSLSocketFactory.getDefault();
public void onOpen(final ServerHandshake serverHandshake) { client.setSslSocketFactory(factory);
EventThread.exec(new Runnable() {
@Override
public void run() {
Map<String, String> headers = new TreeMap<String, String>(String.CASE_INSENSITIVE_ORDER);
Iterator<String> it = serverHandshake.iterateHttpFields();
while (it.hasNext()) {
String field = it.next();
if (field == null) continue;
headers.put(field, serverHandshake.getFieldValue(field));
}
self.emit(EVENT_RESPONSE_HEADERS, headers);
self.onOpen();
}
});
}
@Override
public void onClose(int i, String s, boolean b) {
EventThread.exec(new Runnable() {
@Override
public void run() {
self.onClose();
}
});
}
@Override
public void onMessage(final String s) {
EventThread.exec(new Runnable() {
@Override
public void run() {
self.onData(s);
}
});
}
@Override
public void onMessage(final ByteBuffer s) {
EventThread.exec(new Runnable() {
@Override
public void run() {
self.onData(s.array());
}
});
}
@Override
public void onError(final Exception e) {
EventThread.exec(new Runnable() {
@Override
public void run() {
self.onError("websocket error", e);
}
});
}
};
if (this.sslContext != null) {
this.ws.setWebSocketFactory(new DefaultSSLWebSocketClientFactory(this.sslContext));
}
this.ws.connect();
} catch (URISyntaxException e) {
throw new RuntimeException(e);
} }
if (this.hostnameVerifier != null) {
client.setHostnameVerifier(this.hostnameVerifier);
}
Request.Builder builder = new Request.Builder().url(uri());
for (Map.Entry<String, String> entry : headers.entrySet()) {
builder.addHeader(entry.getKey(), entry.getValue());
}
final Request request = builder.build();
(wsCall = WebSocketCall.create(client, request)).enqueue(new WebSocketListener() {
@Override
public void onOpen(com.squareup.okhttp.ws.WebSocket webSocket, Request request, Response response) throws IOException {
ws = webSocket;
final Map<String, String> headers = new TreeMap<String, String>(String.CASE_INSENSITIVE_ORDER);
Headers responseHeaders = response.headers();
for (int i = 0, size = responseHeaders.size(); i < size; i++) {
headers.put(responseHeaders.name(i), responseHeaders.value(i));
}
EventThread.exec(new Runnable() {
@Override
public void run() {
self.emit(EVENT_RESPONSE_HEADERS, headers);
self.onOpen();
}
});
}
@Override
public void onMessage(BufferedSource payload, final PayloadType type) throws IOException {
Object data = null;
switch (type) {
case TEXT:
data = payload.readUtf8();
break;
case BINARY:
data = payload.readByteArray();
break;
default:
EventThread.exec(new Runnable() {
@Override
public void run() {
self.onError("Unknown payload type: " + type, new IllegalStateException());
}
});
}
payload.close();
final Object finalData = data;
EventThread.exec(new Runnable() {
@Override
public void run() {
if (finalData == null) {
return;
}
if (finalData instanceof String) {
self.onData((String) finalData);
} else {
self.onData((byte[]) finalData);
}
}
});
}
@Override
public void onPong(Buffer payload) {
}
@Override
public void onClose(int code, String reason) {
EventThread.exec(new Runnable() {
@Override
public void run() {
self.onClose();
}
});
}
@Override
public void onFailure(final IOException e) {
EventThread.exec(new Runnable() {
@Override
public void run() {
self.onError("websocket error", e);
}
});
}
});
client.getDispatcher().getExecutorService().shutdown();
} }
protected void write(Packet[] packets) { protected void write(Packet[] packets) {
@@ -110,10 +149,14 @@ public class WebSocket extends Transport {
Parser.encodePacket(packet, new Parser.EncodeCallback() { Parser.encodePacket(packet, new Parser.EncodeCallback() {
@Override @Override
public void call(Object packet) { public void call(Object packet) {
if (packet instanceof String) { try {
self.ws.send((String) packet); if (packet instanceof String) {
} else if (packet instanceof byte[]) { self.ws.sendMessage(TEXT, new Buffer().writeUtf8((String) packet));
self.ws.send((byte[]) packet); } else if (packet instanceof byte[]) {
self.ws.sendMessage(BINARY, new Buffer().write((byte[]) packet));
}
} catch (IOException e) {
self.onError("websocket error", e);
} }
} }
}); });
@@ -138,8 +181,17 @@ public class WebSocket extends Transport {
} }
protected void doClose() { protected void doClose() {
if (this.ws != null) { if (wsCall != null) {
this.ws.close(); wsCall.cancel();
wsCall = null;
}
if (ws != null) {
try {
ws.close(1000, "");
} catch (IOException e) {
onError("doClose error", e);
}
ws = null;
} }
} }

View File

@@ -6,6 +6,7 @@ import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.JUnit4; import org.junit.runners.JUnit4;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext; import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManagerFactory; import javax.net.ssl.TrustManagerFactory;
@@ -23,15 +24,11 @@ import static org.junit.Assert.assertThat;
@RunWith(JUnit4.class) @RunWith(JUnit4.class)
public class SSLConnectionTest extends Connection { public class SSLConnectionTest extends Connection {
static { static HostnameVerifier hostnameVerifier = new javax.net.ssl.HostnameVerifier(){
// for test on localhost public boolean verify(String hostname, javax.net.ssl.SSLSession sslSession) {
javax.net.ssl.HttpsURLConnection.setDefaultHostnameVerifier( return hostname.equals("localhost");
new javax.net.ssl.HostnameVerifier(){ }
public boolean verify(String hostname, javax.net.ssl.SSLSession sslSession) { };
return hostname.equals("localhost");
}
});
}
private Socket socket; private Socket socket;
@@ -74,6 +71,7 @@ public class SSLConnectionTest extends Connection {
Socket.Options opts = createOptions(); Socket.Options opts = createOptions();
opts.sslContext = createSSLContext(); opts.sslContext = createSSLContext();
opts.hostnameVerifier = SSLConnectionTest.hostnameVerifier;
socket = new Socket(opts); socket = new Socket(opts);
socket.on(Socket.EVENT_OPEN, new Emitter.Listener() { socket.on(Socket.EVENT_OPEN, new Emitter.Listener() {
@Override @Override
@@ -98,6 +96,7 @@ public class SSLConnectionTest extends Connection {
Socket.Options opts = createOptions(); Socket.Options opts = createOptions();
opts.sslContext = createSSLContext(); opts.sslContext = createSSLContext();
opts.hostnameVerifier = SSLConnectionTest.hostnameVerifier;
socket = new Socket(opts); socket = new Socket(opts);
socket.on(Socket.EVENT_OPEN, new Emitter.Listener() { socket.on(Socket.EVENT_OPEN, new Emitter.Listener() {
@Override @Override
@@ -127,7 +126,9 @@ public class SSLConnectionTest extends Connection {
final BlockingQueue<Object> values = new LinkedBlockingQueue<Object>(); final BlockingQueue<Object> values = new LinkedBlockingQueue<Object>();
Socket.setDefaultSSLContext(createSSLContext()); Socket.setDefaultSSLContext(createSSLContext());
socket = new Socket(createOptions()); Socket.Options opts = createOptions();
opts.hostnameVerifier = SSLConnectionTest.hostnameVerifier;
socket = new Socket(opts);
socket.on(Socket.EVENT_OPEN, new Emitter.Listener() { socket.on(Socket.EVENT_OPEN, new Emitter.Listener() {
@Override @Override
public void call(Object... args) { public void call(Object... args) {