feat: add an extraHeaders option

Similar to the option of the JS client:

```java
opts = new Socket.Options();
opts.extraHeaders = singletonMap("authorization", singletonList("bearer abcd"));
socket = new Socket(opts);
```

Note: the refactor of the options (similar to [1]) will be done in a
future step.

[1] 5f47a50ee5
This commit is contained in:
Damien Arrachequesne
2020-12-11 12:24:07 +01:00
parent 41f89a38b7
commit dfe65e3b3b
5 changed files with 85 additions and 1 deletions

View File

@@ -130,6 +130,7 @@ public class Socket extends Emitter {
private Future pingTimeoutTimer; private Future pingTimeoutTimer;
private okhttp3.WebSocket.Factory webSocketFactory; private okhttp3.WebSocket.Factory webSocketFactory;
private okhttp3.Call.Factory callFactory; private okhttp3.Call.Factory callFactory;
private final Map<String, List<String>> extraHeaders;
private ReadyState readyState; private ReadyState readyState;
private ScheduledExecutorService heartbeatScheduler; private ScheduledExecutorService heartbeatScheduler;
@@ -221,6 +222,7 @@ public class Socket extends Emitter {
} }
webSocketFactory = defaultOkHttpClient; webSocketFactory = defaultOkHttpClient;
} }
this.extraHeaders = opts.extraHeaders;
} }
public static void setDefaultOkHttpWebSocketFactory(okhttp3.WebSocket.Factory factory) { public static void setDefaultOkHttpWebSocketFactory(okhttp3.WebSocket.Factory factory) {
@@ -293,6 +295,7 @@ public class Socket extends Emitter {
opts.policyPort = options != null ? options.policyPort : this.policyPort; opts.policyPort = options != null ? options.policyPort : this.policyPort;
opts.callFactory = options != null ? options.callFactory : this.callFactory; opts.callFactory = options != null ? options.callFactory : this.callFactory;
opts.webSocketFactory = options != null ? options.webSocketFactory : this.webSocketFactory; opts.webSocketFactory = options != null ? options.webSocketFactory : this.webSocketFactory;
opts.extraHeaders = this.extraHeaders;
Transport transport; Transport transport;
if (WebSocket.NAME.equals(name)) { if (WebSocket.NAME.equals(name)) {

View File

@@ -1,6 +1,7 @@
package io.socket.engineio.client; package io.socket.engineio.client;
import java.util.List;
import java.util.Map; import java.util.Map;
import io.socket.emitter.Emitter; import io.socket.emitter.Emitter;
@@ -43,6 +44,7 @@ public abstract class Transport extends Emitter {
protected ReadyState readyState; protected ReadyState readyState;
protected WebSocket.Factory webSocketFactory; protected WebSocket.Factory webSocketFactory;
protected Call.Factory callFactory; protected Call.Factory callFactory;
protected Map<String, List<String>> extraHeaders;
public Transport(Options opts) { public Transport(Options opts) {
this.path = opts.path; this.path = opts.path;
@@ -55,6 +57,7 @@ public abstract class Transport extends Emitter {
this.socket = opts.socket; this.socket = opts.socket;
this.webSocketFactory = opts.webSocketFactory; this.webSocketFactory = opts.webSocketFactory;
this.callFactory = opts.callFactory; this.callFactory = opts.callFactory;
this.extraHeaders = opts.extraHeaders;
} }
protected Transport onError(String msg, Exception desc) { protected Transport onError(String msg, Exception desc) {
@@ -146,5 +149,6 @@ public abstract class Transport extends Emitter {
protected Socket socket; protected Socket socket;
public WebSocket.Factory webSocketFactory; public WebSocket.Factory webSocketFactory;
public Call.Factory callFactory; public Call.Factory callFactory;
public Map<String, List<String>> extraHeaders;
} }
} }

View File

@@ -43,6 +43,7 @@ public class PollingXHR extends Polling {
} }
opts.uri = this.uri(); opts.uri = this.uri();
opts.callFactory = this.callFactory; opts.callFactory = this.callFactory;
opts.extraHeaders = this.extraHeaders;
Request req = new Request(opts); Request req = new Request(opts);
@@ -72,6 +73,7 @@ public class PollingXHR extends Polling {
Request.Options opts = new Request.Options(); Request.Options opts = new Request.Options();
opts.method = "POST"; opts.method = "POST";
opts.data = data; opts.data = data;
opts.extraHeaders = this.extraHeaders;
Request req = this.request(opts); Request req = this.request(opts);
final PollingXHR self = this; final PollingXHR self = this;
req.on(Request.EVENT_SUCCESS, new Emitter.Listener() { req.on(Request.EVENT_SUCCESS, new Emitter.Listener() {
@@ -150,6 +152,7 @@ public class PollingXHR extends Polling {
private String data; private String data;
private Call.Factory callFactory; private Call.Factory callFactory;
private Map<String, List<String>> extraHeaders;
private Response response; private Response response;
private Call requestCall; private Call requestCall;
@@ -158,13 +161,16 @@ public class PollingXHR extends Polling {
this.uri = opts.uri; this.uri = opts.uri;
this.data = opts.data; this.data = opts.data;
this.callFactory = opts.callFactory != null ? opts.callFactory : new OkHttpClient(); this.callFactory = opts.callFactory != null ? opts.callFactory : new OkHttpClient();
this.extraHeaders = opts.extraHeaders;
} }
public void create() { public void create() {
final Request self = this; final Request self = this;
if (LOGGABLE_FINE) logger.fine(String.format("xhr open %s: %s", this.method, this.uri)); if (LOGGABLE_FINE) logger.fine(String.format("xhr open %s: %s", this.method, this.uri));
Map<String, List<String>> headers = new TreeMap<String, List<String>>(String.CASE_INSENSITIVE_ORDER); Map<String, List<String>> headers = new TreeMap<String, List<String>>(String.CASE_INSENSITIVE_ORDER);
if (this.extraHeaders != null) {
headers.putAll(this.extraHeaders);
}
if ("POST".equals(this.method)) { if ("POST".equals(this.method)) {
headers.put("Content-type", new LinkedList<String>(Collections.singletonList(TEXT_CONTENT_TYPE))); headers.put("Content-type", new LinkedList<String>(Collections.singletonList(TEXT_CONTENT_TYPE)));
} }
@@ -255,6 +261,7 @@ public class PollingXHR extends Polling {
public String method; public String method;
public String data; public String data;
public Call.Factory callFactory; public Call.Factory callFactory;
public Map<String, List<String>> extraHeaders;
} }
} }
} }

View File

@@ -35,6 +35,9 @@ public class WebSocket extends Transport {
protected void doOpen() { protected void doOpen() {
Map<String, List<String>> headers = new TreeMap<String, List<String>>(String.CASE_INSENSITIVE_ORDER); Map<String, List<String>> headers = new TreeMap<String, List<String>>(String.CASE_INSENSITIVE_ORDER);
if (this.extraHeaders != null) {
headers.putAll(this.extraHeaders);
}
this.emit(EVENT_REQUEST_HEADERS, headers); this.emit(EVENT_REQUEST_HEADERS, headers);
final WebSocket self = this; final WebSocket self = this;

View File

@@ -10,11 +10,14 @@ import org.junit.runners.JUnit4;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.BlockingQueue; import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.LinkedBlockingQueue;
import static java.util.Collections.singletonList;
import static java.util.Collections.singletonMap;
import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.CoreMatchers.notNullValue; import static org.hamcrest.CoreMatchers.notNullValue;
@@ -168,6 +171,38 @@ public class ServerConnectionTest extends Connection {
socket.close(); socket.close();
} }
@Test(timeout = TIMEOUT)
public void pollingHeaders_withExtraHeadersOption() throws URISyntaxException, InterruptedException {
final BlockingQueue<String> messages = new LinkedBlockingQueue<String>();
Socket.Options opts = createOptions();
opts.transports = new String[] {Polling.NAME};
opts.extraHeaders = singletonMap("X-EngineIO", singletonList("bar"));
socket = new Socket(opts);
socket.on(Socket.EVENT_TRANSPORT, new Emitter.Listener() {
@Override
public void call(Object... args) {
Transport transport = (Transport)args[0];
transport.on(Transport.EVENT_RESPONSE_HEADERS, new Emitter.Listener() {
@Override
public void call(Object... args) {
@SuppressWarnings("unchecked")
Map<String, List<String>> headers = (Map<String, List<String>>)args[0];
List<String> values = headers.get("X-EngineIO");
messages.offer(values.get(0));
messages.offer(values.get(1));
}
});
}
});
socket.open();
assertThat(messages.take(), is("hi"));
assertThat(messages.take(), is("bar"));
socket.close();
}
@Test(timeout = TIMEOUT) @Test(timeout = TIMEOUT)
public void websocketHandshakeHeaders() throws URISyntaxException, InterruptedException { public void websocketHandshakeHeaders() throws URISyntaxException, InterruptedException {
final BlockingQueue<String> messages = new LinkedBlockingQueue<String>(); final BlockingQueue<String> messages = new LinkedBlockingQueue<String>();
@@ -206,6 +241,38 @@ public class ServerConnectionTest extends Connection {
socket.close(); socket.close();
} }
@Test(timeout = TIMEOUT)
public void websocketHandshakeHeaders_withExtraHeadersOption() throws URISyntaxException, InterruptedException {
final BlockingQueue<String> messages = new LinkedBlockingQueue<String>();
Socket.Options opts = createOptions();
opts.transports = new String[] {WebSocket.NAME};
opts.extraHeaders = singletonMap("X-EngineIO", singletonList("bar"));
socket = new Socket(opts);
socket.on(Socket.EVENT_TRANSPORT, new Emitter.Listener() {
@Override
public void call(Object... args) {
Transport transport = (Transport)args[0];
transport.on(Transport.EVENT_RESPONSE_HEADERS, new Emitter.Listener() {
@Override
public void call(Object... args) {
@SuppressWarnings("unchecked")
Map<String, List<String>> headers = (Map<String, List<String>>)args[0];
List<String> values = headers.get("X-EngineIO");
messages.offer(values.get(0));
messages.offer(values.get(1));
}
});
}
});
socket.open();
assertThat(messages.take(), is("hi"));
assertThat(messages.take(), is("bar"));
socket.close();
}
@Test(timeout = TIMEOUT) @Test(timeout = TIMEOUT)
public void rememberWebsocket() throws InterruptedException { public void rememberWebsocket() throws InterruptedException {
final BlockingQueue<Object> values = new LinkedBlockingQueue<Object>(); final BlockingQueue<Object> values = new LinkedBlockingQueue<Object>();