fix #3 enable to set SSLContext as an option

This commit is contained in:
Naoyuki Kanezawa
2014-07-07 22:41:59 +09:00
parent b1e43ba1b2
commit fe2fd4413a
13 changed files with 208 additions and 36 deletions

View File

@@ -10,6 +10,7 @@ import com.github.nkzawa.parseqs.ParseQS;
import com.github.nkzawa.thread.EventThread; import com.github.nkzawa.thread.EventThread;
import org.json.JSONException; import org.json.JSONException;
import javax.net.ssl.SSLContext;
import java.net.URI; import java.net.URI;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.util.*; import java.util.*;
@@ -117,6 +118,7 @@ public class Socket extends Emitter {
/*package*/ Transport transport; /*package*/ Transport transport;
private Future pingTimeoutTimer; private Future pingTimeoutTimer;
private Future pingIntervalTimer; private Future pingIntervalTimer;
private SSLContext sslContext;
private ReadyState readyState; private ReadyState readyState;
private ScheduledExecutorService heartbeatScheduler = Executors.newSingleThreadScheduledExecutor(); private ScheduledExecutorService heartbeatScheduler = Executors.newSingleThreadScheduledExecutor();
@@ -165,6 +167,7 @@ public class Socket extends Emitter {
} }
this.secure = opts.secure; this.secure = opts.secure;
this.sslContext = opts.sslContext;
this.hostname = opts.hostname != null ? opts.hostname : "localhost"; this.hostname = opts.hostname != null ? opts.hostname : "localhost";
this.port = opts.port != 0 ? opts.port : (this.secure ? 443 : 80); this.port = opts.port != 0 ? opts.port : (this.secure ? 443 : 80);
this.query = opts.query != null ? this.query = opts.query != null ?
@@ -211,6 +214,7 @@ public class Socket extends Emitter {
} }
Transport.Options opts = new Transport.Options(); Transport.Options opts = new Transport.Options();
opts.sslContext = this.sslContext;
opts.hostname = this.hostname; opts.hostname = this.hostname;
opts.port = this.port; opts.port = this.port;
opts.secure = this.secure; opts.secure = this.secure;

View File

@@ -6,6 +6,7 @@ import com.github.nkzawa.engineio.parser.Packet;
import com.github.nkzawa.engineio.parser.Parser; import com.github.nkzawa.engineio.parser.Parser;
import com.github.nkzawa.thread.EventThread; import com.github.nkzawa.thread.EventThread;
import javax.net.ssl.SSLContext;
import java.util.Map; import java.util.Map;
public abstract class Transport extends Emitter { public abstract class Transport extends Emitter {
@@ -39,6 +40,7 @@ public abstract class Transport extends Emitter {
protected String path; protected String path;
protected String hostname; protected String hostname;
protected String timestampParam; protected String timestampParam;
protected SSLContext sslContext;
protected Socket socket; protected Socket socket;
protected ReadyState readyState; protected ReadyState readyState;
@@ -51,6 +53,7 @@ public abstract class Transport extends Emitter {
this.query = opts.query; this.query = opts.query;
this.timestampParam = opts.timestampParam; this.timestampParam = opts.timestampParam;
this.timestampRequests = opts.timestampRequests; this.timestampRequests = opts.timestampRequests;
this.sslContext = opts.sslContext;
this.socket = opts.socket; this.socket = opts.socket;
} }
@@ -140,6 +143,7 @@ public abstract class Transport extends Emitter {
public int port; public int port;
public int policyPort; public int policyPort;
public Map<String, String> query; public Map<String, String> query;
public SSLContext sslContext;
protected Socket socket; protected Socket socket;
} }
} }

View File

@@ -4,6 +4,8 @@ package com.github.nkzawa.engineio.client.transports;
import com.github.nkzawa.emitter.Emitter; import com.github.nkzawa.emitter.Emitter;
import com.github.nkzawa.thread.EventThread; import com.github.nkzawa.thread.EventThread;
import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLContext;
import java.io.*; import java.io.*;
import java.net.HttpURLConnection; import java.net.HttpURLConnection;
import java.net.URL; import java.net.URL;
@@ -36,6 +38,7 @@ public class PollingXHR extends Polling {
opts = new Request.Options(); opts = new Request.Options();
} }
opts.uri = this.uri(); opts.uri = this.uri();
opts.sslContext = this.sslContext;
Request req = new Request(opts); Request req = new Request(opts);
@@ -141,15 +144,17 @@ public class PollingXHR extends Polling {
private static final ExecutorService xhrService = Executors.newCachedThreadPool(); private static final ExecutorService xhrService = Executors.newCachedThreadPool();
String method; private String method;
String uri; private String uri;
byte[] data; private byte[] data;
HttpURLConnection xhr; private SSLContext sslContext;
private HttpURLConnection xhr;
public Request(Options opts) { public Request(Options opts) {
this.method = opts.method != null ? opts.method : "GET"; this.method = opts.method != null ? opts.method : "GET";
this.uri = opts.uri; this.uri = opts.uri;
this.data = opts.data; this.data = opts.data;
this.sslContext = opts.sslContext;
} }
public void create() { public void create() {
@@ -164,6 +169,10 @@ public class PollingXHR extends Polling {
return; return;
} }
if (xhr instanceof HttpsURLConnection && this.sslContext != null) {
((HttpsURLConnection)xhr).setSSLSocketFactory(this.sslContext.getSocketFactory());
}
Map<String, String> headers = new TreeMap<String, String>(String.CASE_INSENSITIVE_ORDER); Map<String, String> headers = new TreeMap<String, String>(String.CASE_INSENSITIVE_ORDER);
if ("POST".equals(this.method)) { if ("POST".equals(this.method)) {
@@ -293,6 +302,7 @@ public class PollingXHR extends Polling {
public String uri; public String uri;
public String method; public String method;
public byte[] data; public byte[] data;
public SSLContext sslContext;
} }
} }
} }

View File

@@ -6,6 +6,7 @@ import com.github.nkzawa.engineio.parser.Packet;
import com.github.nkzawa.engineio.parser.Parser; import com.github.nkzawa.engineio.parser.Parser;
import com.github.nkzawa.parseqs.ParseQS; import com.github.nkzawa.parseqs.ParseQS;
import com.github.nkzawa.thread.EventThread; import com.github.nkzawa.thread.EventThread;
import org.java_websocket.client.DefaultSSLWebSocketClientFactory;
import org.java_websocket.client.WebSocketClient; import org.java_websocket.client.WebSocketClient;
import org.java_websocket.drafts.Draft_17; import org.java_websocket.drafts.Draft_17;
import org.java_websocket.handshake.ServerHandshake; import org.java_websocket.handshake.ServerHandshake;
@@ -93,6 +94,9 @@ public class WebSocket extends Transport {
}); });
} }
}; };
if (this.sslContext != null) {
this.ws.setWebSocketFactory(new DefaultSSLWebSocketClientFactory(this.sslContext));
}
this.ws.connect(); this.ws.connect();
} catch (URISyntaxException e) { } catch (URISyntaxException e) {
throw new RuntimeException(e); throw new RuntimeException(e);

View File

@@ -54,7 +54,7 @@ public class BinaryWSTest extends Connection {
} }
@Test(timeout = TIMEOUT) @Test(timeout = TIMEOUT)
public void receiveBinaryDataAndMultiplebyteUTF8String() throws InterruptedException { public void receiveBinaryDataAndMultibyteUTF8String() throws InterruptedException {
final Semaphore semaphore = new Semaphore(0); final Semaphore semaphore = new Semaphore(0);
final byte[] binaryData = new byte[5]; final byte[] binaryData = new byte[5];
for (int i = 0; i < binaryData.length; i++) { for (int i = 0; i < binaryData.length; i++) {

View File

@@ -24,7 +24,7 @@ public abstract class Connection {
final CountDownLatch latch = new CountDownLatch(1); final CountDownLatch latch = new CountDownLatch(1);
serverProcess = Runtime.getRuntime().exec( serverProcess = Runtime.getRuntime().exec(
"node src/test/resources/index.js " + PORT, new String[] {"DEBUG=engine*"}); "node src/test/resources/server.js", createEnv());
serverService = Executors.newCachedThreadPool(); serverService = Executors.newCachedThreadPool();
serverOutout = serverService.submit(new Runnable() { serverOutout = serverService.submit(new Runnable() {
@Override @Override
@@ -70,4 +70,14 @@ public abstract class Connection {
serverService.shutdown(); serverService.shutdown();
serverService.awaitTermination(3000, TimeUnit.MILLISECONDS); serverService.awaitTermination(3000, TimeUnit.MILLISECONDS);
} }
Socket.Options createOptions() {
Socket.Options opts = new Socket.Options();
opts.port = PORT;
return opts;
}
String[] createEnv() {
return new String[] {"DEBUG=engine*", "PORT=" + PORT};
}
} }

View File

@@ -21,9 +21,7 @@ public class ConnectionTest extends Connection {
public void connectToLocalhost() throws InterruptedException { public void connectToLocalhost() throws InterruptedException {
final Semaphore semaphore = new Semaphore(0); final Semaphore semaphore = new Semaphore(0);
Socket.Options opts = new Socket.Options(); socket = new Socket(createOptions());
opts.port = PORT;
socket = new Socket(opts);
socket.on(Socket.EVENT_OPEN, new Emitter.Listener() { socket.on(Socket.EVENT_OPEN, new Emitter.Listener() {
@Override @Override
public void call(Object... args) { public void call(Object... args) {
@@ -45,9 +43,7 @@ public class ConnectionTest extends Connection {
public void receiveMultibyteUTF8StringsWithPolling() throws InterruptedException { public void receiveMultibyteUTF8StringsWithPolling() throws InterruptedException {
final Semaphore semaphore = new Semaphore(0); final Semaphore semaphore = new Semaphore(0);
Socket.Options opts = new Socket.Options(); socket = new Socket(createOptions());
opts.port = PORT;
socket = new Socket(opts);
socket.on(Socket.EVENT_OPEN, new Emitter.Listener() { socket.on(Socket.EVENT_OPEN, new Emitter.Listener() {
@Override @Override
public void call(Object... args) { public void call(Object... args) {
@@ -71,9 +67,7 @@ public class ConnectionTest extends Connection {
public void receiveEmoji() throws InterruptedException { public void receiveEmoji() throws InterruptedException {
final Semaphore semaphore = new Semaphore(0); final Semaphore semaphore = new Semaphore(0);
Socket.Options opts = new Socket.Options(); socket = new Socket(createOptions());
opts.port = PORT;
socket = new Socket(opts);
socket.on(Socket.EVENT_OPEN, new Emitter.Listener() { socket.on(Socket.EVENT_OPEN, new Emitter.Listener() {
@Override @Override
public void call(Object... args) { public void call(Object... args) {
@@ -97,9 +91,7 @@ public class ConnectionTest extends Connection {
public void notSendPacketsIfSocketCloses() throws InterruptedException { public void notSendPacketsIfSocketCloses() throws InterruptedException {
final Semaphore semaphore = new Semaphore(0); final Semaphore semaphore = new Semaphore(0);
Socket.Options opts = new Socket.Options(); socket = new Socket(createOptions());
opts.port = PORT;
socket = new Socket(opts);
socket.on(Socket.EVENT_OPEN, new Emitter.Listener() { socket.on(Socket.EVENT_OPEN, new Emitter.Listener() {
@Override @Override
public void call(Object... args) { public void call(Object... args) {

View File

@@ -0,0 +1,122 @@
package com.github.nkzawa.engineio.client;
import com.github.nkzawa.emitter.Emitter;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManagerFactory;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.util.concurrent.CountDownLatch;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat;
@RunWith(JUnit4.class)
public class SSLConnectionTest extends Connection {
static {
// for test on localhost
javax.net.ssl.HttpsURLConnection.setDefaultHostnameVerifier(
new javax.net.ssl.HostnameVerifier(){
public boolean verify(String hostname, javax.net.ssl.SSLSession sslSession) {
return hostname.equals("localhost");
}
});
}
private Socket socket;
@Override
Socket.Options createOptions() {
Socket.Options opts = super.createOptions();
opts.secure = true;
return opts;
}
@Override
String[] createEnv() {
return new String[] {"DEBUG=engine*", "PORT=" + PORT, "SSL=1"};
}
SSLContext createSSLContext() throws GeneralSecurityException, IOException {
KeyStore ks = KeyStore.getInstance("JKS");
File file = new File("src/test/resources/keystore.jks");
ks.load(new FileInputStream(file), "password".toCharArray());
KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
kmf.init(ks, "password".toCharArray());
TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
tmf.init(ks);
SSLContext sslContext = SSLContext.getInstance("TLS");
sslContext.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
return sslContext;
}
@Test(timeout = TIMEOUT)
public void connect() throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
Socket.Options opts = createOptions();
opts.sslContext = createSSLContext();
socket = new Socket(opts);
socket.on(Socket.EVENT_OPEN, new Emitter.Listener() {
@Override
public void call(Object... args) {
socket.on(Socket.EVENT_MESSAGE, new Emitter.Listener() {
@Override
public void call(Object... args) {
assertThat((String)args[0], is("hi"));
socket.close();
latch.countDown();
}
});
}
}).on("error", new Emitter.Listener() {
@Override
public void call(Object... args) {
((Exception)args[0]).printStackTrace();
}
});
socket.open();
latch.await();
}
@Test(timeout = TIMEOUT)
public void upgrade() throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
Socket.Options opts = createOptions();
opts.sslContext = createSSLContext();
socket = new Socket(opts);
socket.on(Socket.EVENT_OPEN, new Emitter.Listener() {
@Override
public void call(Object... args) {
socket.on(Socket.EVENT_UPGRADE, new Emitter.Listener() {
@Override
public void call(Object... args) {
socket.send("hi");
socket.on(Socket.EVENT_MESSAGE, new Emitter.Listener() {
@Override
public void call(Object... args) {
assertThat((String) args[0], is("hi"));
socket.close();
latch.countDown();
}
});
}
});
}
});
socket.open();
latch.await();
}
}

View File

@@ -30,7 +30,7 @@ public class ServerConnectionTest extends Connection {
public void openAndClose() throws URISyntaxException, InterruptedException { public void openAndClose() throws URISyntaxException, InterruptedException {
final BlockingQueue<String> events = new LinkedBlockingQueue<String>(); final BlockingQueue<String> events = new LinkedBlockingQueue<String>();
socket = new Socket("ws://localhost:" + PORT); socket = new Socket(createOptions());
socket.on(Socket.EVENT_OPEN, new Emitter.Listener() { socket.on(Socket.EVENT_OPEN, new Emitter.Listener() {
@Override @Override
public void call(Object... args) { public void call(Object... args) {
@@ -53,7 +53,7 @@ public class ServerConnectionTest extends Connection {
public void messages() throws URISyntaxException, InterruptedException { public void messages() throws URISyntaxException, InterruptedException {
final BlockingQueue<String> events = new LinkedBlockingQueue<String>(); final BlockingQueue<String> events = new LinkedBlockingQueue<String>();
socket = new Socket("ws://localhost:" + PORT); socket = new Socket(createOptions());
socket.on(Socket.EVENT_OPEN, new Emitter.Listener() { socket.on(Socket.EVENT_OPEN, new Emitter.Listener() {
@Override @Override
public void call(Object... args) { public void call(Object... args) {
@@ -76,7 +76,7 @@ public class ServerConnectionTest extends Connection {
public void handshake() throws URISyntaxException, InterruptedException { public void handshake() throws URISyntaxException, InterruptedException {
final Semaphore semaphore = new Semaphore(0); final Semaphore semaphore = new Semaphore(0);
socket = new Socket("ws://localhost:" + PORT); socket = new Socket(createOptions());
socket.on(Socket.EVENT_HANDSHAKE, new Emitter.Listener() { socket.on(Socket.EVENT_HANDSHAKE, new Emitter.Listener() {
@Override @Override
public void call(Object... args) { public void call(Object... args) {
@@ -102,7 +102,7 @@ public class ServerConnectionTest extends Connection {
public void upgrade() throws URISyntaxException, InterruptedException { public void upgrade() throws URISyntaxException, InterruptedException {
final BlockingQueue<Object[]> events = new LinkedBlockingQueue<Object[]>(); final BlockingQueue<Object[]> events = new LinkedBlockingQueue<Object[]>();
socket = new Socket("ws://localhost:" + PORT); socket = new Socket(createOptions());
socket.on(Socket.EVENT_UPGRADING, new Emitter.Listener() { socket.on(Socket.EVENT_UPGRADING, new Emitter.Listener() {
@Override @Override
public void call(Object... args) { public void call(Object... args) {
@@ -136,10 +136,10 @@ public class ServerConnectionTest extends Connection {
public void pollingHeaders() throws URISyntaxException, InterruptedException { public void pollingHeaders() throws URISyntaxException, InterruptedException {
final BlockingQueue<String> messages = new LinkedBlockingQueue<String>(); final BlockingQueue<String> messages = new LinkedBlockingQueue<String>();
Socket.Options opts = new Socket.Options(); Socket.Options opts = createOptions();
opts.transports = new String[] {Polling.NAME}; opts.transports = new String[] {Polling.NAME};
socket = new Socket("ws://localhost:" + PORT, opts); socket = new Socket(opts);
socket.on(Socket.EVENT_TRANSPORT, new Emitter.Listener() { socket.on(Socket.EVENT_TRANSPORT, new Emitter.Listener() {
@Override @Override
public void call(Object... args) { public void call(Object... args) {
@@ -172,10 +172,10 @@ public class ServerConnectionTest extends Connection {
public void websocketHandshakeHeaders() throws URISyntaxException, InterruptedException { public void websocketHandshakeHeaders() throws URISyntaxException, InterruptedException {
final BlockingQueue<String> messages = new LinkedBlockingQueue<String>(); final BlockingQueue<String> messages = new LinkedBlockingQueue<String>();
Socket.Options opts = new Socket.Options(); Socket.Options opts = createOptions();
opts.transports = new String[] {WebSocket.NAME}; opts.transports = new String[] {WebSocket.NAME};
socket = new Socket("ws://localhost:" + PORT, opts); socket = new Socket(opts);
socket.on(Socket.EVENT_TRANSPORT, new Emitter.Listener() { socket.on(Socket.EVENT_TRANSPORT, new Emitter.Listener() {
@Override @Override
public void call(Object... args) { public void call(Object... args) {
@@ -210,10 +210,7 @@ public class ServerConnectionTest extends Connection {
EventThread.exec(new Runnable() { EventThread.exec(new Runnable() {
@Override @Override
public void run() { public void run() {
Socket.Options opts = new Socket.Options(); final Socket socket = new Socket(createOptions());
opts.port = PORT;
final Socket socket = new Socket(opts);
socket.on(Socket.EVENT_UPGRADE, new Emitter.Listener() { socket.on(Socket.EVENT_UPGRADE, new Emitter.Listener() {
@Override @Override
@@ -246,10 +243,7 @@ public class ServerConnectionTest extends Connection {
EventThread.exec(new Runnable() { EventThread.exec(new Runnable() {
@Override @Override
public void run() { public void run() {
Socket.Options opts = new Socket.Options(); final Socket socket = new Socket(createOptions());
opts.port = PORT;
final Socket socket = new Socket(opts);
socket.on(Socket.EVENT_UPGRADE, new Emitter.Listener() { socket.on(Socket.EVENT_UPGRADE, new Emitter.Listener() {
@Override @Override

View File

@@ -0,0 +1,10 @@
-----BEGIN CERTIFICATE-----
MIIBfDCCASYCCQDTnGd/oOyF1DANBgkqhkiG9w0BAQUFADBFMQswCQYDVQQGEwJB
VTETMBEGA1UECBMKU29tZS1TdGF0ZTEhMB8GA1UEChMYSW50ZXJuZXQgV2lkZ2l0
cyBQdHkgTHRkMB4XDTE0MDcwNzEzMTUzN1oXDTQxMTEyMTEzMTUzN1owRTELMAkG
A1UEBhMCQVUxEzARBgNVBAgTClNvbWUtU3RhdGUxITAfBgNVBAoTGEludGVybmV0
IFdpZGdpdHMgUHR5IEx0ZDBcMA0GCSqGSIb3DQEBAQUAA0sAMEgCQQC6sdeFPlqk
5Pap9woFx1RO05gLidw4MNcL+ZRSxy/sNeE4PhT/RLFcEvnXiHc92wT8YB5Z+WCM
k/jRQ0q19PNPAgMBAAEwDQYJKoZIhvcNAQEFBQADQQCnmm1N/yZiMBZw2JDfbsx3
ecc0BGQ2BwWQuGHzP07TMi1AuOyNZSczl907OphYb9iRC8shZ4O+oXjQAuGTQ1Hp
-----END CERTIFICATE-----

View File

@@ -0,0 +1,9 @@
-----BEGIN RSA PRIVATE KEY-----
MIIBOwIBAAJBALqx14U+WqTk9qn3CgXHVE7TmAuJ3Dgw1wv5lFLHL+w14Tg+FP9E
sVwS+deIdz3bBPxgHln5YIyT+NFDSrX0808CAwEAAQJAIdwLSIEsk2drTRwe1zl1
ku5RTxZruE0zU1qqifDSQjab1StAK1tapxBVRlRlyLCfD704UClsU8sjGtq0Nh6n
kQIhAO2YJM1g0w9bWYet3zC2UdEASPzaQ7llpZmc51NRBx2NAiEAyShICAaclEuy
wwuD4hibV+b6I8CLYoyPBo32EaceN0sCIQCUed6NxfM/houlgV+Xtmfcnzv9X3yx
EDdzjpz08Q7sRQIgZFv1fBOYYSBXQppnJRFzx2pUmCvDHtrTrMh84RfIqnsCIQCf
JjNXXxOaHn1PNZpi6EHReiFQmy1Swt+AxpTsKixsfA==
-----END RSA PRIVATE KEY-----

Binary file not shown.

View File

@@ -1,8 +1,19 @@
var http = require('http').Server(); var fs = require('fs');
var engine = require('engine.io'); var engine = require('engine.io');
var http;
if (process.env.SSL) {
http = require('https').createServer({
key: fs.readFileSync(__dirname + '/key.pem'),
cert: fs.readFileSync(__dirname + '/cert.pem')
});
} else {
http = require('http').createServer();
}
var server = engine.attach(http, {pingInterval: 500}); var server = engine.attach(http, {pingInterval: 500});
var port = parseInt(process.argv[2], 10) || 3000 var port = process.env.PORT || 3000
http.listen(port, function() { http.listen(port, function() {
console.log('Engine.IO server listening on port', port); console.log('Engine.IO server listening on port', port);
}); });
@@ -17,6 +28,8 @@ server.on('connection', function(socket) {
socket.on('error', function(err) { socket.on('error', function(err) {
throw err; throw err;
}); });
}).on('error', function(err) {
console.error(err);
}); });
var handleRequest = server.handleRequest; var handleRequest = server.handleRequest;