diff --git a/src/main/java/io/socket/backo/Backoff.java b/src/main/java/io/socket/backo/Backoff.java index 4252266..24cf41e 100644 --- a/src/main/java/io/socket/backo/Backoff.java +++ b/src/main/java/io/socket/backo/Backoff.java @@ -1,5 +1,8 @@ package io.socket.backo; +import java.math.BigDecimal; +import java.math.BigInteger; + public class Backoff { private long ms = 100; @@ -11,17 +14,16 @@ public class Backoff { public Backoff() {} public long duration() { - long ms = this.ms * (long) Math.pow(this.factor, this.attempts++); + BigInteger ms = BigInteger.valueOf(this.ms) + .multiply(BigInteger.valueOf(this.factor).pow(this.attempts++)); if (jitter != 0.0) { double rand = Math.random(); - int deviation = (int) Math.floor(rand * this.jitter * ms); - ms = (((int) Math.floor(rand * 10)) & 1) == 0 ? ms - deviation : ms + deviation; + BigInteger deviation = BigDecimal.valueOf(rand) + .multiply(BigDecimal.valueOf(jitter)) + .multiply(new BigDecimal(ms)).toBigInteger(); + ms = (((int) Math.floor(rand * 10)) & 1) == 0 ? ms.subtract(deviation) : ms.add(deviation); } - if (ms < this.ms) { - // overflow happened - ms = Long.MAX_VALUE; - } - return Math.min(ms, this.max); + return ms.min(BigInteger.valueOf(this.max)).longValue(); } public void reset() { diff --git a/src/test/java/io/socket/backo/BackoffTest.java b/src/test/java/io/socket/backo/BackoffTest.java index 5da719b..a268829 100644 --- a/src/test/java/io/socket/backo/BackoffTest.java +++ b/src/test/java/io/socket/backo/BackoffTest.java @@ -2,6 +2,9 @@ package io.socket.backo; import org.junit.Test; +import java.math.BigDecimal; +import java.math.BigInteger; + import static org.junit.Assert.assertTrue; public class BackoffTest { @@ -22,14 +25,23 @@ public class BackoffTest { @Test public void durationOverflow() { - Backoff b = new Backoff(); - b.setMin(100); - b.setMax(10000); - b.setJitter(1.0); + for (int i = 0; i < 10; i++) { + Backoff b = new Backoff(); + b.setMin(100); + b.setMax(10000); + b.setJitter(0.5); - for (int i = 0; i < 100; i++) { - long duration = b.duration(); - assertTrue(100 <= duration && duration <= 10000); + // repeats to make it overflow (a long can have 2 ** 63 - 1) + for (int j = 0; j < 100; j++) { + BigInteger ms = BigInteger.valueOf(100).multiply(BigInteger.valueOf(2).pow(j)); + BigInteger deviation = new BigDecimal(ms).multiply(BigDecimal.valueOf(0.5)).toBigInteger(); + BigInteger duration = BigInteger.valueOf(b.duration()); + + BigInteger min = ms.subtract(deviation).min(BigInteger.valueOf(10000)); + BigInteger max = ms.add(deviation).min(BigInteger.valueOf(10001)); + assertTrue(min + " <= " + duration + " < " + max, + min.compareTo(duration) <= 0 && max.compareTo(duration) == 1); + } } } }