Skip to content

Instantly share code, notes, and snippets.

@linxGnu
Created June 9, 2018 04:51
Show Gist options
  • Save linxGnu/4fbdb795f5b0daaad30c1f19eb4a6683 to your computer and use it in GitHub Desktop.
Save linxGnu/4fbdb795f5b0daaad30c1f19eb4a6683 to your computer and use it in GitHub Desktop.
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
public class bench {
public static int numRequest = 100000;
public static int numEndpoints = 500;
public static void main(String[] args) throws Exception {
Random rand = new Random();
bench("normal round robin, all weight: 3", id ->
new Endpoint("127.0.0.1", id + 1, 3));
bench("randomly, mainly weight: 1, max weight: 10", id ->
new Endpoint("127.0.0.1", id + 1, 1 + (id % 50 == 0 ? rand.nextInt(10) : 0)));
bench("mainly weight: 1, max weight: 30", id ->
new Endpoint("127.0.0.1", id + 1, 1 + (id % 50 == 0 ? 29 : 0)));
bench("randomly, max weight: 10", id ->
new Endpoint("127.0.0.1", id + 1, 1 + rand.nextInt(10)));
bench("randomly, max weight: 100", id ->
new Endpoint("127.0.0.1", id + 1, 1 + rand.nextInt(100)));
bench("randomly, max weight: 200", id ->
new Endpoint("127.0.0.1", id + 1, 1 + rand.nextInt(200)));
bench("all weights are unique", id ->
new Endpoint("127.0.0.1", id + 1, id + 1));
}
public static void bench(String title, EndpointGenerator eg) throws Exception {
Endpoint[] testCase = new Endpoint[numEndpoints];
for (int i = 0; i < testCase.length; i++) {
testCase[i] = eg.generate(i);
}
List<Endpoint> endpoints = new ArrayList<>(Arrays.asList(testCase));
System.out.println("[Bench] " + title);
long runtime1 = benchCurrentAlgorithm(endpoints, numRequest);
long runtime2 = benchNewAlgorithm(endpoints, numRequest);
double rt1 = runtime1 / 1000000.0;
System.out.printf("Current algorithm: %.3f(ms) %.3f req/s\n", rt1, (numRequest * 1000.0 / rt1));
double rt2 = runtime2 / 1000000.0;
System.out.printf("New algorithm: %.3f(ms) %.3f req/s\n", rt2, (numRequest * 1000.0 / rt2));
System.out.println("---------------------------------------------------------------------");
}
public static long benchCurrentAlgorithm(List<Endpoint> endpoints, int numberOfRequest) throws Exception {
long startTime = System.nanoTime();
EndpointSelector selector = WeightedRoundRobinStrategy.newSelector(endpoints);
for (int i = 0; i < numberOfRequest; i++) {
Endpoint tmp = selector.select();
if (tmp == null) {
throw new Exception("Error when selecting");
}
}
return System.nanoTime() - startTime;
}
public static long benchNewAlgorithm(List<Endpoint> endpoints, int numberOfRequest) throws Exception {
long startTime = System.nanoTime();
EndpointSelector selector = NewWeightedRoundRobinStrategy.newSelector(endpoints);
for (int i = 0; i < numberOfRequest; i++) {
Endpoint tmp = selector.select();
if (tmp == null) {
throw new Exception("Error when selecting");
}
}
return System.nanoTime() - startTime;
}
//----------------------------- All in one -------------------------------------
public interface EndpointGenerator {
Endpoint generate(int id);
}
public interface EndpointSelector {
Endpoint select();
}
public static class Endpoint {
private int weight;
private String host;
private int port;
public Endpoint(String host, int port, int weight) {
this.host = host;
this.port = port;
this.weight = weight;
}
public int weight() {
return weight;
}
public String host() {
return host;
}
public int port() {
return port;
}
}
final static class WeightedRoundRobinStrategy {
public static EndpointSelector newSelector(List<Endpoint> endpoints) {
return new WeightedRoundRobinSelector(endpoints);
}
/**
* A weighted round robin select strategy.
* <p>
* <p>For example, with node a, b and c:
* <ul>
* <li>if endpoint weights are 1,1,1 (or 2,2,2), then select result is abc abc ...</li>
* <li>if endpoint weights are 1,2,3 (or 2,4,6), then select result is abcbcc(or abcabcbcbccc) ...</li>
* <li>if endpoint weights are 3,5,7, then select result is abcabcabcbcbcbb abcabcabcbcbcbb ...</li>
* </ul>
*/
private static final class WeightedRoundRobinSelector implements EndpointSelector {
private final List<Endpoint> endpoints;
private final AtomicInteger sequence = new AtomicInteger();
private volatile EndpointsAndWeights endpointsAndWeights;
WeightedRoundRobinSelector(List<Endpoint> endpoints) {
this.endpoints = endpoints;
endpointsAndWeights = new EndpointsAndWeights(endpoints);
}
@Override
public Endpoint select() {
final int currentSequence = sequence.getAndIncrement();
return endpointsAndWeights.selectEndpoint(currentSequence);
}
private static final class EndpointsAndWeights {
private final List<Endpoint> endpoints;
private final boolean weighted;
private final int maxWeight;
private final int totalWeight;
EndpointsAndWeights(List<Endpoint> endpoints) {
int minWeight = Integer.MAX_VALUE;
int maxWeight = Integer.MIN_VALUE;
int totalWeight = 0;
for (Endpoint endpoint : endpoints) {
final int weight = endpoint.weight();
minWeight = Math.min(minWeight, weight);
maxWeight = Math.max(maxWeight, weight);
totalWeight += weight;
}
this.endpoints = endpoints;
this.maxWeight = maxWeight;
this.totalWeight = totalWeight;
weighted = minWeight != maxWeight;
}
Endpoint selectEndpoint(int currentSequence) {
if (endpoints.isEmpty()) {
return null;
}
if (weighted) {
final int[] weights = endpoints.stream()
.mapToInt(Endpoint::weight)
.toArray();
int mod = currentSequence % totalWeight;
for (int i = 0; i < maxWeight; i++) {
for (int j = 0; j < weights.length; j++) {
if (mod == 0 && weights[j] > 0) {
return endpoints.get(j);
}
if (weights[j] > 0) {
weights[j]--;
mod--;
}
}
}
}
return endpoints.get(Math.abs(currentSequence % endpoints.size()));
}
}
}
}
final static class NewWeightedRoundRobinStrategy {
public static EndpointSelector newSelector(List<Endpoint> endpoints) {
return new NewWeightedRoundRobinSelector(endpoints);
}
/**
* A weighted round robin select strategy.
* <p>
* <p>For example, with node a, b and c:
* <ul>
* <li>if endpoint weights are 1,1,1 (or 2,2,2), then select result is abc abc ...</li>
* <li>if endpoint weights are 1,2,3 (or 2,4,6), then select result is abcbcc(or abcabcbcbccc) ...</li>
* <li>if endpoint weights are 3,5,7, then select result is abcabcabcbcbcbb abcabcabcbcbcbb ...</li>
* </ul>
*/
private static final class NewWeightedRoundRobinSelector implements EndpointSelector {
private final List<Endpoint> endpoints;
private final AtomicInteger sequence = new AtomicInteger();
private volatile EndpointsAndWeights endpointsAndWeights;
NewWeightedRoundRobinSelector(List<Endpoint> endpoints) {
this.endpoints = endpoints;
endpointsAndWeights = new EndpointsAndWeights(this.endpoints);
}
@Override
public Endpoint select() {
final int currentSequence = sequence.getAndIncrement();
return endpointsAndWeights.selectEndpoint(currentSequence);
}
private static final class EndpointsAndWeights {
private final List<Endpoint> endpoints;
private final boolean weighted;
private final long totalWeight; // prevent overflow by using long
private static final class EndpointsGroupByWeight {
int startIndex;
int weight;
long accumulatedWeight;
EndpointsGroupByWeight(int startIndex, int weight, long accumulatedWeight) {
this.startIndex = startIndex;
this.weight = weight;
this.accumulatedWeight = accumulatedWeight;
}
}
private final EndpointsGroupByWeight[] endpointsGroupByWeight;
EndpointsAndWeights(List<Endpoint> endpoints) {
int minWeight = Integer.MAX_VALUE;
int maxWeight = Integer.MIN_VALUE;
long totalWeight = 0;
// get min and max weight
for (Endpoint endpoint : endpoints) {
final int weight = endpoint.weight();
minWeight = Math.min(minWeight, weight);
maxWeight = Math.max(maxWeight, weight);
}
// prepare endpoints
List<Endpoint> endps = endpoints
.stream()
.filter(endpoint -> endpoint.weight() > 0) // only process endpoint with weight > 0
.sorted(Comparator
.comparing(Endpoint::weight)
.thenComparing(Endpoint::host)
.thenComparingInt(Endpoint::port))
.collect(Collectors.toList());
int numEndpoints = endps.size();
// accumulation
LinkedList<EndpointsGroupByWeight> accumulatedGroups = new LinkedList<>();
EndpointsGroupByWeight currentGroup = null;
int rest = numEndpoints;
for (Endpoint endpoint : endps) {
if (currentGroup == null || currentGroup.weight != endpoint.weight()) {
totalWeight += currentGroup == null ?
(long) endpoint.weight() * (long) rest
: (long) (endpoint.weight() - currentGroup.weight) * (long) rest;
currentGroup = new EndpointsGroupByWeight(numEndpoints - rest,
endpoint.weight(), totalWeight);
accumulatedGroups.addLast(currentGroup);
}
rest--;
}
this.endpoints = endps;
this.endpointsGroupByWeight = accumulatedGroups.toArray(
new EndpointsGroupByWeight[accumulatedGroups.size()]
);
this.totalWeight = totalWeight;
this.weighted = minWeight != maxWeight;
}
Endpoint selectEndpoint(int currentSequence) {
if (endpoints.isEmpty()) {
return null;
}
int numberEndpoints = endpoints.size();
if (weighted) {
long mod = Math.abs((long) (currentSequence) % totalWeight);
if (mod < endpointsGroupByWeight[0].accumulatedWeight) {
return endpoints.get((int) (mod % numberEndpoints));
}
int left = 0;
int right = endpointsGroupByWeight.length - 1;
int mid;
while (left < right) {
mid = left + ((right - left) >> 1);
if (mid == left) {
break;
}
if (endpointsGroupByWeight[mid].accumulatedWeight <= mod) {
left = mid;
} else {
right = mid;
}
}
// (left + 1) is the part where sequence belongs
int indexInPart = (int) (mod - endpointsGroupByWeight[left].accumulatedWeight);
int realIndex = endpointsGroupByWeight[left + 1].startIndex +
indexInPart % (numberEndpoints - endpointsGroupByWeight[left + 1].startIndex);
return endpoints.get(realIndex);
}
return endpoints.get(Math.abs(currentSequence % numberEndpoints));
}
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment