update parser to not utf8 encode for string payloads

This commit is contained in:
nkzawa
2017-07-13 12:47:00 +09:00
parent 8b07bbd9f0
commit 1b3c795210
4 changed files with 197 additions and 103 deletions

View File

@@ -183,10 +183,16 @@ abstract public class Polling extends Transport {
} }
}; };
Parser.encodePayload(packets, new Parser.EncodeCallback<byte[]>() { Parser.encodePayload(packets, new Parser.EncodeCallback() {
@Override @Override
public void call(byte[] data) { public void call(Object data) {
self.doWrite(data, callbackfn); if (data instanceof byte[]) {
self.doWrite((byte[])data, callbackfn);
} else if (data instanceof String) {
self.doWrite((String)data, callbackfn);
} else {
logger.warning("Unexpected data: " + data);
}
} }
}); });
} }
@@ -220,5 +226,7 @@ abstract public class Polling extends Transport {
abstract protected void doWrite(byte[] data, Runnable fn); abstract protected void doWrite(byte[] data, Runnable fn);
abstract protected void doWrite(String data, Runnable fn);
abstract protected void doPoll(); abstract protected void doPoll();
} }

View File

@@ -8,6 +8,7 @@ import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.TreeMap; import java.util.TreeMap;
import java.util.logging.Level;
import java.util.logging.Logger; import java.util.logging.Logger;
import io.socket.emitter.Emitter; import io.socket.emitter.Emitter;
@@ -26,6 +27,8 @@ public class PollingXHR extends Polling {
private static final Logger logger = Logger.getLogger(PollingXHR.class.getName()); private static final Logger logger = Logger.getLogger(PollingXHR.class.getName());
private static boolean LOGGABLE_FINE = logger.isLoggable(Level.FINE);
public PollingXHR(Transport.Options opts) { public PollingXHR(Transport.Options opts) {
super(opts); super(opts);
} }
@@ -66,6 +69,15 @@ public class PollingXHR extends Polling {
@Override @Override
protected void doWrite(byte[] data, final Runnable fn) { protected void doWrite(byte[] data, final Runnable fn) {
this.doWrite((Object) data, fn);
}
@Override
protected void doWrite(String data, final Runnable fn) {
this.doWrite((Object) data, fn);
}
private void doWrite(Object data, final Runnable fn) {
Request.Options opts = new Request.Options(); Request.Options opts = new Request.Options();
opts.method = "POST"; opts.method = "POST";
opts.data = data; opts.data = data;
@@ -140,13 +152,17 @@ public class PollingXHR extends Polling {
public static final String EVENT_ERROR = "error"; public static final String EVENT_ERROR = "error";
public static final String EVENT_REQUEST_HEADERS = "requestHeaders"; public static final String EVENT_REQUEST_HEADERS = "requestHeaders";
public static final String EVENT_RESPONSE_HEADERS = "responseHeaders"; public static final String EVENT_RESPONSE_HEADERS = "responseHeaders";
private static final String BINARY_CONTENT_TYPE = "application/octet-stream"; private static final String BINARY_CONTENT_TYPE = "application/octet-stream";
private static final String TEXT_CONTENT_TYPE = "text/plain;charset=UTF-8";
private static final MediaType BINARY_MEDIA_TYPE = MediaType.parse(BINARY_CONTENT_TYPE);
private static final MediaType TEXT_MEDIA_TYPE = MediaType.parse(TEXT_CONTENT_TYPE);
private String method; private String method;
private String uri; private String uri;
// data is always a binary private Object data;
private byte[] data;
private Call.Factory callFactory; private Call.Factory callFactory;
private Response response; private Response response;
@@ -161,28 +177,42 @@ public class PollingXHR extends Polling {
public void create() { public void create() {
final Request self = this; final Request self = this;
logger.fine(String.format("xhr open %s: %s", this.method, this.uri)); if (LOGGABLE_FINE) logger.fine(String.format("xhr open %s: %s", this.method, this.uri));
Map<String, List<String>> headers = new TreeMap<String, List<String>>(String.CASE_INSENSITIVE_ORDER); Map<String, List<String>> headers = new TreeMap<String, List<String>>(String.CASE_INSENSITIVE_ORDER);
if ("POST".equals(this.method)) { if ("POST".equals(this.method)) {
if (this.data instanceof byte[]) {
headers.put("Content-type", new LinkedList<String>(Collections.singletonList(BINARY_CONTENT_TYPE))); headers.put("Content-type", new LinkedList<String>(Collections.singletonList(BINARY_CONTENT_TYPE)));
} else {
headers.put("Content-type", new LinkedList<String>(Collections.singletonList(TEXT_CONTENT_TYPE)));
}
} }
headers.put("Accept", new LinkedList<String>(Collections.singletonList("*/*"))); headers.put("Accept", new LinkedList<String>(Collections.singletonList("*/*")));
self.onRequestHeaders(headers); this.onRequestHeaders(headers);
if (LOGGABLE_FINE) {
logger.fine(String.format("sending xhr with url %s | data %s", this.uri,
this.data instanceof byte[] ? Arrays.toString((byte[]) this.data) : this.data));
}
logger.fine(String.format("sending xhr with url %s | data %s", this.uri, Arrays.toString(this.data)));
okhttp3.Request.Builder requestBuilder = new okhttp3.Request.Builder(); okhttp3.Request.Builder requestBuilder = new okhttp3.Request.Builder();
for (Map.Entry<String, List<String>> header : headers.entrySet()) { for (Map.Entry<String, List<String>> header : headers.entrySet()) {
for (String v : header.getValue()){ for (String v : header.getValue()){
requestBuilder.addHeader(header.getKey(), v); requestBuilder.addHeader(header.getKey(), v);
} }
} }
RequestBody body = null;
if (this.data instanceof byte[]) {
body = RequestBody.create(BINARY_MEDIA_TYPE, (byte[])this.data);
} else if (this.data instanceof String) {
body = RequestBody.create(TEXT_MEDIA_TYPE, (String)this.data);
}
okhttp3.Request request = requestBuilder okhttp3.Request request = requestBuilder
.url(HttpUrl.parse(self.uri)) .url(HttpUrl.parse(self.uri))
.method(self.method, (self.data != null) ? .method(self.method, body)
RequestBody.create(MediaType.parse(BINARY_CONTENT_TYPE), self.data) : null)
.build(); .build();
requestCall = callFactory.newCall(request); requestCall = callFactory.newCall(request);
@@ -255,7 +285,7 @@ public class PollingXHR extends Polling {
public String uri; public String uri;
public String method; public String method;
public byte[] data; public Object data;
public Call.Factory callFactory; public Call.Factory callFactory;
} }
} }

View File

@@ -35,6 +35,11 @@ public class Parser {
private static Packet<String> err = new Packet<String>(Packet.ERROR, "parser error"); private static Packet<String> err = new Packet<String>(Packet.ERROR, "parser error");
private static UTF8.Options utf8Options = new UTF8.Options();
static {
utf8Options.strict = false;
}
private Parser() {} private Parser() {}
@@ -55,7 +60,7 @@ public class Parser {
String encoded = String.valueOf(packets.get(packet.type)); String encoded = String.valueOf(packets.get(packet.type));
if (null != packet.data) { if (null != packet.data) {
encoded += utf8encode ? UTF8.encode(String.valueOf(packet.data)) : String.valueOf(packet.data); encoded += utf8encode ? UTF8.encode(String.valueOf(packet.data), utf8Options) : String.valueOf(packet.data);
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@@ -89,7 +94,7 @@ public class Parser {
if (utf8decode) { if (utf8decode) {
try { try {
data = UTF8.decode(data); data = UTF8.decode(data, utf8Options);
} catch (UTF8Exception e) { } catch (UTF8Exception e) {
return err; return err;
} }
@@ -113,7 +118,40 @@ public class Parser {
return new Packet<byte[]>(packetslist.get(type), intArray); return new Packet<byte[]>(packetslist.get(type), intArray);
} }
public static void encodePayload(Packet[] packets, EncodeCallback<byte[]> callback) throws UTF8Exception { public static void encodePayload(Packet[] packets, EncodeCallback callback) throws UTF8Exception {
for (Packet packet : packets) {
if (packet.data instanceof byte[]) {
@SuppressWarnings("unchecked")
EncodeCallback<byte[]> _callback = (EncodeCallback<byte[]>) callback;
encodePayloadAsBinary(packets, _callback);
return;
}
}
if (packets.length == 0) {
callback.call("0:");
return;
}
final StringBuilder result = new StringBuilder();
for (Packet packet : packets) {
encodePacket(packet, false, new EncodeCallback() {
@Override
public void call(Object message) {
result.append(setLengthHeader((String)message));
}
});
}
callback.call(result.toString());
}
private static String setLengthHeader(String message) {
return message.length() + ":" + message;
}
private static void encodePayloadAsBinary(Packet[] packets, EncodeCallback<byte[]> callback) throws UTF8Exception {
if (packets.length == 0) { if (packets.length == 0) {
callback.call(new byte[0]); callback.call(new byte[0]);
return; return;
@@ -122,7 +160,19 @@ public class Parser {
final ArrayList<byte[]> results = new ArrayList<byte[]>(packets.length); final ArrayList<byte[]> results = new ArrayList<byte[]>(packets.length);
for (Packet packet : packets) { for (Packet packet : packets) {
encodePacket(packet, true, new EncodeCallback() { encodeOneBinaryPacket(packet, new EncodeCallback<byte[]>() {
@Override
public void call(byte[] data) {
results.add(data);
}
});
}
callback.call(Buffer.concat(results.toArray(new byte[results.size()][])));
}
private static void encodeOneBinaryPacket(Packet p, final EncodeCallback<byte[]> doneCallback) throws UTF8Exception {
encodePacket(p, true, new EncodeCallback() {
@Override @Override
public void call(Object packet) { public void call(Object packet) {
if (packet instanceof String) { if (packet instanceof String) {
@@ -134,7 +184,7 @@ public class Parser {
sizeBuffer[i + 1] = (byte)Character.getNumericValue(encodingLength.charAt(i)); sizeBuffer[i + 1] = (byte)Character.getNumericValue(encodingLength.charAt(i));
} }
sizeBuffer[sizeBuffer.length - 1] = (byte)255; sizeBuffer[sizeBuffer.length - 1] = (byte)255;
results.add(Buffer.concat(new byte[][] {sizeBuffer, stringToByteArray((String)packet)})); doneCallback.call(Buffer.concat(new byte[][] {sizeBuffer, stringToByteArray((String)packet)}));
return; return;
} }
@@ -145,14 +195,11 @@ public class Parser {
sizeBuffer[i + 1] = (byte)Character.getNumericValue(encodingLength.charAt(i)); sizeBuffer[i + 1] = (byte)Character.getNumericValue(encodingLength.charAt(i));
} }
sizeBuffer[sizeBuffer.length - 1] = (byte)255; sizeBuffer[sizeBuffer.length - 1] = (byte)255;
results.add(Buffer.concat(new byte[][] {sizeBuffer, (byte[])packet})); doneCallback.call(Buffer.concat(new byte[][] {sizeBuffer, (byte[])packet}));
} }
}); });
} }
callback.call(Buffer.concat(results.toArray(new byte[results.size()][])));
}
public static void decodePayload(String data, DecodePayloadCallback<String> callback) { public static void decodePayload(String data, DecodePayloadCallback<String> callback) {
if (data == null || data.length() == 0) { if (data == null || data.length() == 0) {
callback.call(err, 0, 1); callback.call(err, 0, 1);
@@ -165,7 +212,9 @@ public class Parser {
if (':' != chr) { if (':' != chr) {
length.append(chr); length.append(chr);
} else { continue;
}
int n; int n;
try { try {
n = Integer.parseInt(length.toString()); n = Integer.parseInt(length.toString());
@@ -183,20 +232,21 @@ public class Parser {
} }
if (msg.length() != 0) { if (msg.length() != 0) {
Packet<String> packet = decodePacket(msg, true); Packet<String> packet = decodePacket(msg, false);
if (err.type.equals(packet.type) && err.data.equals(packet.data)) { if (err.type.equals(packet.type) && err.data.equals(packet.data)) {
callback.call(err, 0, 1); callback.call(err, 0, 1);
return; return;
} }
boolean ret = callback.call(packet, i + n, l); boolean ret = callback.call(packet, i + n, l);
if (!ret) return; if (!ret) {
return;
}
} }
i += n; i += n;
length = new StringBuilder(); length = new StringBuilder();
} }
}
if (length.length() > 0) { if (length.length() > 0) {
callback.call(err, 0, 1); callback.call(err, 0, 1);
@@ -210,23 +260,17 @@ public class Parser {
while (bufferTail.capacity() > 0) { while (bufferTail.capacity() > 0) {
StringBuilder strLen = new StringBuilder(); StringBuilder strLen = new StringBuilder();
boolean isString = (bufferTail.get(0) & 0xFF) == 0; boolean isString = (bufferTail.get(0) & 0xFF) == 0;
boolean numberTooLong = false;
for (int i = 1; ; i++) { for (int i = 1; ; i++) {
int b = bufferTail.get(i) & 0xFF; int b = bufferTail.get(i) & 0xFF;
if (b == 255) break; if (b == 255) break;
// supports only integer // supports only integer
if (strLen.length() > MAX_INT_CHAR_LENGTH) { if (strLen.length() > MAX_INT_CHAR_LENGTH) {
numberTooLong = true; callback.call(err, 0, 1);
break; return;
} }
strLen.append(b); strLen.append(b);
} }
if (numberTooLong) {
@SuppressWarnings("unchecked")
DecodePayloadCallback<String> tempCallback = callback;
tempCallback.call(err, 0, 1);
return;
}
bufferTail.position(strLen.length() + 1); bufferTail.position(strLen.length() + 1);
bufferTail = bufferTail.slice(); bufferTail = bufferTail.slice();

View File

@@ -156,6 +156,19 @@ public class ParserTest {
}); });
} }
@Test
public void encodingStringMessageWithLoneSurrogatesReplacedByUFFFD() throws UTF8Exception {
String data = "\uDC00\uD834\uDF06\uDC00 \uD800\uD835\uDF07\uD800";
encodePacket(new Packet<String>(Packet.MESSAGE, data), true, new EncodeCallback<String>() {
@Override
public void call(String encoded) {
Packet<String> p = decodePacket(encoded, true);
assertThat(p.type, is(Packet.MESSAGE));
assertThat(p.data, is("\uFFFD\uD834\uDF06\uFFFD \uFFFD\uD835\uDF07\uFFFD"));
}
});
}
@Test @Test
public void decodeEmptyPayload() { public void decodeEmptyPayload() {
Packet<String> p = decodePacket((String)null); Packet<String> p = decodePacket((String)null);
@@ -186,20 +199,20 @@ public class ParserTest {
@Test @Test
public void encodePayloads() throws UTF8Exception { public void encodePayloads() throws UTF8Exception {
encodePayload(new Packet[]{new Packet(Packet.PING), new Packet(Packet.PONG)}, new EncodeCallback<byte[]>() { encodePayload(new Packet[]{new Packet(Packet.PING), new Packet(Packet.PONG)}, new EncodeCallback<String>() {
@Override @Override
public void call(byte[] data) { public void call(String data) {
assertThat(data, isA(byte[].class)); assertThat(data, isA(String.class));
} }
}); });
} }
@Test @Test
public void encodeAndDecodePayloads() throws UTF8Exception { public void encodeAndDecodePayloads() throws UTF8Exception {
encodePayload(new Packet[] {new Packet<String>(Packet.MESSAGE, "a")}, new EncodeCallback<byte[]>() { encodePayload(new Packet[] {new Packet<String>(Packet.MESSAGE, "a")}, new EncodeCallback<String>() {
@Override @Override
public void call(byte[] data) { public void call(String data) {
decodePayload(data, new DecodePayloadCallback() { decodePayload(data, new DecodePayloadCallback<String>() {
@Override @Override
public boolean call(Packet packet, int index, int total) { public boolean call(Packet packet, int index, int total) {
boolean isLast = index + 1 == total; boolean isLast = index + 1 == total;
@@ -209,10 +222,10 @@ public class ParserTest {
}); });
} }
}); });
encodePayload(new Packet[]{new Packet<String>(Packet.MESSAGE, "a"), new Packet(Packet.PING)}, new EncodeCallback<byte[]>() { encodePayload(new Packet[]{new Packet<String>(Packet.MESSAGE, "a"), new Packet(Packet.PING)}, new EncodeCallback<String>() {
@Override @Override
public void call(byte[] data) { public void call(String data) {
decodePayload(data, new DecodePayloadCallback() { decodePayload(data, new DecodePayloadCallback<String>() {
@Override @Override
public boolean call(Packet packet, int index, int total) { public boolean call(Packet packet, int index, int total) {
boolean isLast = index + 1 == total; boolean isLast = index + 1 == total;
@@ -230,10 +243,10 @@ public class ParserTest {
@Test @Test
public void encodeAndDecodeEmptyPayloads() throws UTF8Exception { public void encodeAndDecodeEmptyPayloads() throws UTF8Exception {
encodePayload(new Packet[] {}, new EncodeCallback<byte[]>() { encodePayload(new Packet[] {}, new EncodeCallback<String>() {
@Override @Override
public void call(byte[] data) { public void call(String data) {
decodePayload(data, new DecodePayloadCallback() { decodePayload(data, new DecodePayloadCallback<String>() {
@Override @Override
public boolean call(Packet packet, int index, int total) { public boolean call(Packet packet, int index, int total) {
assertThat(packet.type, is(Packet.OPEN)); assertThat(packet.type, is(Packet.OPEN));
@@ -246,6 +259,19 @@ public class ParserTest {
}); });
} }
@Test
public void notUTF8EncodeWhenDealingWithStringsOnly() throws UTF8Exception {
encodePayload(new Packet[] {
new Packet(Packet.MESSAGE, "€€€"),
new Packet(Packet.MESSAGE, "α")
}, new EncodeCallback<String>() {
@Override
public void call(String data) {
assertThat(data, is("4:4€€€2:4α"));
}
});
}
@Test @Test
public void decodePayloadBadFormat() { public void decodePayloadBadFormat() {
decodePayload("1!", new DecodePayloadCallback<String>() { decodePayload("1!", new DecodePayloadCallback<String>() {
@@ -328,20 +354,6 @@ public class ParserTest {
}); });
} }
@Test
public void decodePayloadInvalidUTF8() {
decodePayload("2:4\uffff", new DecodePayloadCallback<String>() {
@Override
public boolean call(Packet<String> packet, int index, int total) {
boolean isLast = index + 1 == total;
assertThat(packet.type, is(Packet.ERROR));
assertThat(packet.data, is(ERROR_DATA));
assertThat(isLast, is(true));
return true;
}
});
}
@Test @Test
public void encodeBinaryMessage() throws UTF8Exception { public void encodeBinaryMessage() throws UTF8Exception {
final byte[] data = new byte[5]; final byte[] data = new byte[5];