From dfe65e3b3b5eab4c3fddb9dfbf53d684fe461043 Mon Sep 17 00:00:00 2001 From: Damien Arrachequesne Date: Fri, 11 Dec 2020 12:24:07 +0100 Subject: [PATCH] feat: add an extraHeaders option Similar to the option of the JS client: ```java opts = new Socket.Options(); opts.extraHeaders = singletonMap("authorization", singletonList("bearer abcd")); socket = new Socket(opts); ``` Note: the refactor of the options (similar to [1]) will be done in a future step. [1] https://github.com/socketio/engine.io-client/commit/5f47a50ee5dc47962f3823f4e8dde0b4b407eccd --- .../io/socket/engineio/client/Socket.java | 3 + .../io/socket/engineio/client/Transport.java | 4 ++ .../client/transports/PollingXHR.java | 9 ++- .../engineio/client/transports/WebSocket.java | 3 + .../engineio/client/ServerConnectionTest.java | 67 +++++++++++++++++++ 5 files changed, 85 insertions(+), 1 deletion(-) diff --git a/src/main/java/io/socket/engineio/client/Socket.java b/src/main/java/io/socket/engineio/client/Socket.java index 9b0423e..2d86e31 100644 --- a/src/main/java/io/socket/engineio/client/Socket.java +++ b/src/main/java/io/socket/engineio/client/Socket.java @@ -130,6 +130,7 @@ public class Socket extends Emitter { private Future pingTimeoutTimer; private okhttp3.WebSocket.Factory webSocketFactory; private okhttp3.Call.Factory callFactory; + private final Map> extraHeaders; private ReadyState readyState; private ScheduledExecutorService heartbeatScheduler; @@ -221,6 +222,7 @@ public class Socket extends Emitter { } webSocketFactory = defaultOkHttpClient; } + this.extraHeaders = opts.extraHeaders; } public static void setDefaultOkHttpWebSocketFactory(okhttp3.WebSocket.Factory factory) { @@ -293,6 +295,7 @@ public class Socket extends Emitter { opts.policyPort = options != null ? options.policyPort : this.policyPort; opts.callFactory = options != null ? options.callFactory : this.callFactory; opts.webSocketFactory = options != null ? options.webSocketFactory : this.webSocketFactory; + opts.extraHeaders = this.extraHeaders; Transport transport; if (WebSocket.NAME.equals(name)) { diff --git a/src/main/java/io/socket/engineio/client/Transport.java b/src/main/java/io/socket/engineio/client/Transport.java index 56b921c..7a8ce5c 100644 --- a/src/main/java/io/socket/engineio/client/Transport.java +++ b/src/main/java/io/socket/engineio/client/Transport.java @@ -1,6 +1,7 @@ package io.socket.engineio.client; +import java.util.List; import java.util.Map; import io.socket.emitter.Emitter; @@ -43,6 +44,7 @@ public abstract class Transport extends Emitter { protected ReadyState readyState; protected WebSocket.Factory webSocketFactory; protected Call.Factory callFactory; + protected Map> extraHeaders; public Transport(Options opts) { this.path = opts.path; @@ -55,6 +57,7 @@ public abstract class Transport extends Emitter { this.socket = opts.socket; this.webSocketFactory = opts.webSocketFactory; this.callFactory = opts.callFactory; + this.extraHeaders = opts.extraHeaders; } protected Transport onError(String msg, Exception desc) { @@ -146,5 +149,6 @@ public abstract class Transport extends Emitter { protected Socket socket; public WebSocket.Factory webSocketFactory; public Call.Factory callFactory; + public Map> extraHeaders; } } diff --git a/src/main/java/io/socket/engineio/client/transports/PollingXHR.java b/src/main/java/io/socket/engineio/client/transports/PollingXHR.java index 64b3cda..11ec6b3 100644 --- a/src/main/java/io/socket/engineio/client/transports/PollingXHR.java +++ b/src/main/java/io/socket/engineio/client/transports/PollingXHR.java @@ -43,6 +43,7 @@ public class PollingXHR extends Polling { } opts.uri = this.uri(); opts.callFactory = this.callFactory; + opts.extraHeaders = this.extraHeaders; Request req = new Request(opts); @@ -72,6 +73,7 @@ public class PollingXHR extends Polling { Request.Options opts = new Request.Options(); opts.method = "POST"; opts.data = data; + opts.extraHeaders = this.extraHeaders; Request req = this.request(opts); final PollingXHR self = this; req.on(Request.EVENT_SUCCESS, new Emitter.Listener() { @@ -150,6 +152,7 @@ public class PollingXHR extends Polling { private String data; private Call.Factory callFactory; + private Map> extraHeaders; private Response response; private Call requestCall; @@ -158,13 +161,16 @@ public class PollingXHR extends Polling { this.uri = opts.uri; this.data = opts.data; this.callFactory = opts.callFactory != null ? opts.callFactory : new OkHttpClient(); + this.extraHeaders = opts.extraHeaders; } public void create() { final Request self = this; if (LOGGABLE_FINE) logger.fine(String.format("xhr open %s: %s", this.method, this.uri)); Map> headers = new TreeMap>(String.CASE_INSENSITIVE_ORDER); - + if (this.extraHeaders != null) { + headers.putAll(this.extraHeaders); + } if ("POST".equals(this.method)) { headers.put("Content-type", new LinkedList(Collections.singletonList(TEXT_CONTENT_TYPE))); } @@ -255,6 +261,7 @@ public class PollingXHR extends Polling { public String method; public String data; public Call.Factory callFactory; + public Map> extraHeaders; } } } diff --git a/src/main/java/io/socket/engineio/client/transports/WebSocket.java b/src/main/java/io/socket/engineio/client/transports/WebSocket.java index 74cf65f..c2153c3 100644 --- a/src/main/java/io/socket/engineio/client/transports/WebSocket.java +++ b/src/main/java/io/socket/engineio/client/transports/WebSocket.java @@ -35,6 +35,9 @@ public class WebSocket extends Transport { protected void doOpen() { Map> headers = new TreeMap>(String.CASE_INSENSITIVE_ORDER); + if (this.extraHeaders != null) { + headers.putAll(this.extraHeaders); + } this.emit(EVENT_REQUEST_HEADERS, headers); final WebSocket self = this; diff --git a/src/test/java/io/socket/engineio/client/ServerConnectionTest.java b/src/test/java/io/socket/engineio/client/ServerConnectionTest.java index 2beb744..e7a9e73 100644 --- a/src/test/java/io/socket/engineio/client/ServerConnectionTest.java +++ b/src/test/java/io/socket/engineio/client/ServerConnectionTest.java @@ -10,11 +10,14 @@ import org.junit.runners.JUnit4; import java.net.URISyntaxException; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.CoreMatchers.notNullValue; @@ -168,6 +171,38 @@ public class ServerConnectionTest extends Connection { socket.close(); } + @Test(timeout = TIMEOUT) + public void pollingHeaders_withExtraHeadersOption() throws URISyntaxException, InterruptedException { + final BlockingQueue messages = new LinkedBlockingQueue(); + + Socket.Options opts = createOptions(); + opts.transports = new String[] {Polling.NAME}; + opts.extraHeaders = singletonMap("X-EngineIO", singletonList("bar")); + + socket = new Socket(opts); + socket.on(Socket.EVENT_TRANSPORT, new Emitter.Listener() { + @Override + public void call(Object... args) { + Transport transport = (Transport)args[0]; + transport.on(Transport.EVENT_RESPONSE_HEADERS, new Emitter.Listener() { + @Override + public void call(Object... args) { + @SuppressWarnings("unchecked") + Map> headers = (Map>)args[0]; + List values = headers.get("X-EngineIO"); + messages.offer(values.get(0)); + messages.offer(values.get(1)); + } + }); + } + }); + socket.open(); + + assertThat(messages.take(), is("hi")); + assertThat(messages.take(), is("bar")); + socket.close(); + } + @Test(timeout = TIMEOUT) public void websocketHandshakeHeaders() throws URISyntaxException, InterruptedException { final BlockingQueue messages = new LinkedBlockingQueue(); @@ -206,6 +241,38 @@ public class ServerConnectionTest extends Connection { socket.close(); } + @Test(timeout = TIMEOUT) + public void websocketHandshakeHeaders_withExtraHeadersOption() throws URISyntaxException, InterruptedException { + final BlockingQueue messages = new LinkedBlockingQueue(); + + Socket.Options opts = createOptions(); + opts.transports = new String[] {WebSocket.NAME}; + opts.extraHeaders = singletonMap("X-EngineIO", singletonList("bar")); + + socket = new Socket(opts); + socket.on(Socket.EVENT_TRANSPORT, new Emitter.Listener() { + @Override + public void call(Object... args) { + Transport transport = (Transport)args[0]; + transport.on(Transport.EVENT_RESPONSE_HEADERS, new Emitter.Listener() { + @Override + public void call(Object... args) { + @SuppressWarnings("unchecked") + Map> headers = (Map>)args[0]; + List values = headers.get("X-EngineIO"); + messages.offer(values.get(0)); + messages.offer(values.get(1)); + } + }); + } + }); + socket.open(); + + assertThat(messages.take(), is("hi")); + assertThat(messages.take(), is("bar")); + socket.close(); + } + @Test(timeout = TIMEOUT) public void rememberWebsocket() throws InterruptedException { final BlockingQueue values = new LinkedBlockingQueue();