Created
June 9, 2018 04:51
-
-
Save linxGnu/4fbdb795f5b0daaad30c1f19eb4a6683 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.*; | |
import java.util.concurrent.atomic.AtomicInteger; | |
import java.util.stream.Collectors; | |
public class bench { | |
public static int numRequest = 100000; | |
public static int numEndpoints = 500; | |
public static void main(String[] args) throws Exception { | |
Random rand = new Random(); | |
bench("normal round robin, all weight: 3", id -> | |
new Endpoint("127.0.0.1", id + 1, 3)); | |
bench("randomly, mainly weight: 1, max weight: 10", id -> | |
new Endpoint("127.0.0.1", id + 1, 1 + (id % 50 == 0 ? rand.nextInt(10) : 0))); | |
bench("mainly weight: 1, max weight: 30", id -> | |
new Endpoint("127.0.0.1", id + 1, 1 + (id % 50 == 0 ? 29 : 0))); | |
bench("randomly, max weight: 10", id -> | |
new Endpoint("127.0.0.1", id + 1, 1 + rand.nextInt(10))); | |
bench("randomly, max weight: 100", id -> | |
new Endpoint("127.0.0.1", id + 1, 1 + rand.nextInt(100))); | |
bench("randomly, max weight: 200", id -> | |
new Endpoint("127.0.0.1", id + 1, 1 + rand.nextInt(200))); | |
bench("all weights are unique", id -> | |
new Endpoint("127.0.0.1", id + 1, id + 1)); | |
} | |
public static void bench(String title, EndpointGenerator eg) throws Exception { | |
Endpoint[] testCase = new Endpoint[numEndpoints]; | |
for (int i = 0; i < testCase.length; i++) { | |
testCase[i] = eg.generate(i); | |
} | |
List<Endpoint> endpoints = new ArrayList<>(Arrays.asList(testCase)); | |
System.out.println("[Bench] " + title); | |
long runtime1 = benchCurrentAlgorithm(endpoints, numRequest); | |
long runtime2 = benchNewAlgorithm(endpoints, numRequest); | |
double rt1 = runtime1 / 1000000.0; | |
System.out.printf("Current algorithm: %.3f(ms) %.3f req/s\n", rt1, (numRequest * 1000.0 / rt1)); | |
double rt2 = runtime2 / 1000000.0; | |
System.out.printf("New algorithm: %.3f(ms) %.3f req/s\n", rt2, (numRequest * 1000.0 / rt2)); | |
System.out.println("---------------------------------------------------------------------"); | |
} | |
public static long benchCurrentAlgorithm(List<Endpoint> endpoints, int numberOfRequest) throws Exception { | |
long startTime = System.nanoTime(); | |
EndpointSelector selector = WeightedRoundRobinStrategy.newSelector(endpoints); | |
for (int i = 0; i < numberOfRequest; i++) { | |
Endpoint tmp = selector.select(); | |
if (tmp == null) { | |
throw new Exception("Error when selecting"); | |
} | |
} | |
return System.nanoTime() - startTime; | |
} | |
public static long benchNewAlgorithm(List<Endpoint> endpoints, int numberOfRequest) throws Exception { | |
long startTime = System.nanoTime(); | |
EndpointSelector selector = NewWeightedRoundRobinStrategy.newSelector(endpoints); | |
for (int i = 0; i < numberOfRequest; i++) { | |
Endpoint tmp = selector.select(); | |
if (tmp == null) { | |
throw new Exception("Error when selecting"); | |
} | |
} | |
return System.nanoTime() - startTime; | |
} | |
//----------------------------- All in one ------------------------------------- | |
public interface EndpointGenerator { | |
Endpoint generate(int id); | |
} | |
public interface EndpointSelector { | |
Endpoint select(); | |
} | |
public static class Endpoint { | |
private int weight; | |
private String host; | |
private int port; | |
public Endpoint(String host, int port, int weight) { | |
this.host = host; | |
this.port = port; | |
this.weight = weight; | |
} | |
public int weight() { | |
return weight; | |
} | |
public String host() { | |
return host; | |
} | |
public int port() { | |
return port; | |
} | |
} | |
final static class WeightedRoundRobinStrategy { | |
public static EndpointSelector newSelector(List<Endpoint> endpoints) { | |
return new WeightedRoundRobinSelector(endpoints); | |
} | |
/** | |
* A weighted round robin select strategy. | |
* <p> | |
* <p>For example, with node a, b and c: | |
* <ul> | |
* <li>if endpoint weights are 1,1,1 (or 2,2,2), then select result is abc abc ...</li> | |
* <li>if endpoint weights are 1,2,3 (or 2,4,6), then select result is abcbcc(or abcabcbcbccc) ...</li> | |
* <li>if endpoint weights are 3,5,7, then select result is abcabcabcbcbcbb abcabcabcbcbcbb ...</li> | |
* </ul> | |
*/ | |
private static final class WeightedRoundRobinSelector implements EndpointSelector { | |
private final List<Endpoint> endpoints; | |
private final AtomicInteger sequence = new AtomicInteger(); | |
private volatile EndpointsAndWeights endpointsAndWeights; | |
WeightedRoundRobinSelector(List<Endpoint> endpoints) { | |
this.endpoints = endpoints; | |
endpointsAndWeights = new EndpointsAndWeights(endpoints); | |
} | |
@Override | |
public Endpoint select() { | |
final int currentSequence = sequence.getAndIncrement(); | |
return endpointsAndWeights.selectEndpoint(currentSequence); | |
} | |
private static final class EndpointsAndWeights { | |
private final List<Endpoint> endpoints; | |
private final boolean weighted; | |
private final int maxWeight; | |
private final int totalWeight; | |
EndpointsAndWeights(List<Endpoint> endpoints) { | |
int minWeight = Integer.MAX_VALUE; | |
int maxWeight = Integer.MIN_VALUE; | |
int totalWeight = 0; | |
for (Endpoint endpoint : endpoints) { | |
final int weight = endpoint.weight(); | |
minWeight = Math.min(minWeight, weight); | |
maxWeight = Math.max(maxWeight, weight); | |
totalWeight += weight; | |
} | |
this.endpoints = endpoints; | |
this.maxWeight = maxWeight; | |
this.totalWeight = totalWeight; | |
weighted = minWeight != maxWeight; | |
} | |
Endpoint selectEndpoint(int currentSequence) { | |
if (endpoints.isEmpty()) { | |
return null; | |
} | |
if (weighted) { | |
final int[] weights = endpoints.stream() | |
.mapToInt(Endpoint::weight) | |
.toArray(); | |
int mod = currentSequence % totalWeight; | |
for (int i = 0; i < maxWeight; i++) { | |
for (int j = 0; j < weights.length; j++) { | |
if (mod == 0 && weights[j] > 0) { | |
return endpoints.get(j); | |
} | |
if (weights[j] > 0) { | |
weights[j]--; | |
mod--; | |
} | |
} | |
} | |
} | |
return endpoints.get(Math.abs(currentSequence % endpoints.size())); | |
} | |
} | |
} | |
} | |
final static class NewWeightedRoundRobinStrategy { | |
public static EndpointSelector newSelector(List<Endpoint> endpoints) { | |
return new NewWeightedRoundRobinSelector(endpoints); | |
} | |
/** | |
* A weighted round robin select strategy. | |
* <p> | |
* <p>For example, with node a, b and c: | |
* <ul> | |
* <li>if endpoint weights are 1,1,1 (or 2,2,2), then select result is abc abc ...</li> | |
* <li>if endpoint weights are 1,2,3 (or 2,4,6), then select result is abcbcc(or abcabcbcbccc) ...</li> | |
* <li>if endpoint weights are 3,5,7, then select result is abcabcabcbcbcbb abcabcabcbcbcbb ...</li> | |
* </ul> | |
*/ | |
private static final class NewWeightedRoundRobinSelector implements EndpointSelector { | |
private final List<Endpoint> endpoints; | |
private final AtomicInteger sequence = new AtomicInteger(); | |
private volatile EndpointsAndWeights endpointsAndWeights; | |
NewWeightedRoundRobinSelector(List<Endpoint> endpoints) { | |
this.endpoints = endpoints; | |
endpointsAndWeights = new EndpointsAndWeights(this.endpoints); | |
} | |
@Override | |
public Endpoint select() { | |
final int currentSequence = sequence.getAndIncrement(); | |
return endpointsAndWeights.selectEndpoint(currentSequence); | |
} | |
private static final class EndpointsAndWeights { | |
private final List<Endpoint> endpoints; | |
private final boolean weighted; | |
private final long totalWeight; // prevent overflow by using long | |
private static final class EndpointsGroupByWeight { | |
int startIndex; | |
int weight; | |
long accumulatedWeight; | |
EndpointsGroupByWeight(int startIndex, int weight, long accumulatedWeight) { | |
this.startIndex = startIndex; | |
this.weight = weight; | |
this.accumulatedWeight = accumulatedWeight; | |
} | |
} | |
private final EndpointsGroupByWeight[] endpointsGroupByWeight; | |
EndpointsAndWeights(List<Endpoint> endpoints) { | |
int minWeight = Integer.MAX_VALUE; | |
int maxWeight = Integer.MIN_VALUE; | |
long totalWeight = 0; | |
// get min and max weight | |
for (Endpoint endpoint : endpoints) { | |
final int weight = endpoint.weight(); | |
minWeight = Math.min(minWeight, weight); | |
maxWeight = Math.max(maxWeight, weight); | |
} | |
// prepare endpoints | |
List<Endpoint> endps = endpoints | |
.stream() | |
.filter(endpoint -> endpoint.weight() > 0) // only process endpoint with weight > 0 | |
.sorted(Comparator | |
.comparing(Endpoint::weight) | |
.thenComparing(Endpoint::host) | |
.thenComparingInt(Endpoint::port)) | |
.collect(Collectors.toList()); | |
int numEndpoints = endps.size(); | |
// accumulation | |
LinkedList<EndpointsGroupByWeight> accumulatedGroups = new LinkedList<>(); | |
EndpointsGroupByWeight currentGroup = null; | |
int rest = numEndpoints; | |
for (Endpoint endpoint : endps) { | |
if (currentGroup == null || currentGroup.weight != endpoint.weight()) { | |
totalWeight += currentGroup == null ? | |
(long) endpoint.weight() * (long) rest | |
: (long) (endpoint.weight() - currentGroup.weight) * (long) rest; | |
currentGroup = new EndpointsGroupByWeight(numEndpoints - rest, | |
endpoint.weight(), totalWeight); | |
accumulatedGroups.addLast(currentGroup); | |
} | |
rest--; | |
} | |
this.endpoints = endps; | |
this.endpointsGroupByWeight = accumulatedGroups.toArray( | |
new EndpointsGroupByWeight[accumulatedGroups.size()] | |
); | |
this.totalWeight = totalWeight; | |
this.weighted = minWeight != maxWeight; | |
} | |
Endpoint selectEndpoint(int currentSequence) { | |
if (endpoints.isEmpty()) { | |
return null; | |
} | |
int numberEndpoints = endpoints.size(); | |
if (weighted) { | |
long mod = Math.abs((long) (currentSequence) % totalWeight); | |
if (mod < endpointsGroupByWeight[0].accumulatedWeight) { | |
return endpoints.get((int) (mod % numberEndpoints)); | |
} | |
int left = 0; | |
int right = endpointsGroupByWeight.length - 1; | |
int mid; | |
while (left < right) { | |
mid = left + ((right - left) >> 1); | |
if (mid == left) { | |
break; | |
} | |
if (endpointsGroupByWeight[mid].accumulatedWeight <= mod) { | |
left = mid; | |
} else { | |
right = mid; | |
} | |
} | |
// (left + 1) is the part where sequence belongs | |
int indexInPart = (int) (mod - endpointsGroupByWeight[left].accumulatedWeight); | |
int realIndex = endpointsGroupByWeight[left + 1].startIndex + | |
indexInPart % (numberEndpoints - endpointsGroupByWeight[left + 1].startIndex); | |
return endpoints.get(realIndex); | |
} | |
return endpoints.get(Math.abs(currentSequence % numberEndpoints)); | |
} | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment