Last active
September 27, 2021 23:20
-
-
Save sathish316/4a725234334c8c6ba6290a61f3f1a462 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.algos.shuffle; | |
import java.util.*; | |
import java.util.stream.Collectors; | |
import java.util.stream.IntStream; | |
import java.util.stream.Stream; | |
public class WeightedRandom { | |
private final double[] cdf; | |
private final double[] weights; | |
public WeightedRandom(double[] weights) { | |
this.weights = weights.clone(); | |
cdf = weights.clone(); | |
for (int i=1; i<cdf.length; i++) { | |
cdf[i] += cdf[i-1]; | |
} | |
} | |
public int select() { | |
// (-(insertion point) - 1) if not found | |
int ret = Arrays.binarySearch(cdf, Math.random() * cdf[cdf.length - 1]); | |
return ret < 0 ? -1 - ret : ret; | |
} | |
// this one performs the best | |
public int[] weightedShuffle2() { | |
int[] ret = new int[cdf.length]; | |
double[] mcdf = cdf.clone(); | |
for (int i=0; i<ret.length; i++) { | |
// (-(insertion point) - 1) if not found | |
int index = Arrays.binarySearch(mcdf, Math.random() * mcdf[mcdf.length - 1]); | |
index = (index >= 0) ? index : -1 - index; | |
ret[i] = index; | |
for (int j=index + 1; j<mcdf.length; j++) { | |
mcdf[j] -= weights[index]; | |
} | |
mcdf[index] = index == 0 ? 0 : mcdf[index - 1]; | |
} | |
return ret; | |
} | |
public int[] weightedShuffle3() { | |
int[] ret = new int[weights.length]; | |
double[] mcdf = cdf.clone(); | |
int n = ret.length; | |
while (n > 0) { | |
// (-(insertion point) - 1) if not found | |
int index = Arrays.binarySearch(cdf, Math.random() * cdf[n - 1]); | |
index = (index >= 0) ? ((index < n) ? index : n - 1) : -1 - index; | |
ret[ret.length - n] = index; | |
double m = mcdf[index]; | |
for (int j=index; j<mcdf.length-1; j++) { | |
mcdf[j] = mcdf[j+1] - m; | |
} | |
n--; | |
} | |
return ret; | |
} | |
private static class IndexValPair { | |
private final int index; | |
private final double val; | |
private IndexValPair(int index, double val) { | |
this.index = index; | |
this.val = val; | |
} | |
} | |
public int[] weightedShuffle() { | |
int[] ret = new int[weights.length]; | |
List<IndexValPair> l = new LinkedList<>(); | |
double sum = cdf[cdf.length - 1]; | |
for (int i=0; i<weights.length; i++) { | |
l.add(new IndexValPair(i, weights[i])); | |
} | |
while (l.size() > 0) { | |
Iterator<IndexValPair> iterator = l.iterator(); | |
double remain = Math.random() * sum; | |
while (true) { | |
IndexValPair ivp = iterator.next(); | |
remain -= ivp.val; | |
if (remain <= 0) { | |
ret[weights.length - l.size()] = ivp.index; | |
iterator.remove(); | |
sum -= ivp.val; | |
break; | |
} | |
} | |
} | |
return ret; | |
} | |
static class Element implements Comparable<Element> { | |
int id; | |
Double weight; | |
public Element(int id, double weight) { | |
this.id = id; | |
this.weight = weight; | |
} | |
public int getId() { | |
return id; | |
} | |
public double getWeight() { | |
return weight; | |
} | |
@Override | |
public int compareTo(Element o) { | |
return weight.compareTo(o.weight); | |
} | |
} | |
public List<Element> weightedShuffle4() { | |
Random random = new Random(); | |
List<Element> shuffleList = new ArrayList<>(); | |
List<Element> options = new ArrayList<>(); | |
int i = 0; | |
for (double w : weights) { | |
options.add(new Element(i++, w)); | |
} | |
double totalWeight = options.stream().mapToDouble(e -> e.getWeight()).sum(); | |
while (!options.isEmpty()) { | |
double randomValue = random.nextDouble() * totalWeight; | |
for (Element option : options) { | |
randomValue -= option.getWeight(); | |
if (randomValue <= 0) { | |
shuffleList.add(option); | |
options.remove(option); | |
totalWeight -= option.getWeight(); | |
break; | |
} | |
} | |
} | |
return shuffleList; | |
} | |
public List<Element> weightedShuffle5() { | |
Random random = new Random(); | |
List<Element> shuffleList = new ArrayList<>(); | |
Set<Element> options = new TreeSet<>(); | |
int i = 0; | |
for (double w : weights) { | |
options.add(new Element(i++, w)); | |
} | |
double totalWeight = options.stream().mapToDouble(e -> e.getWeight()).sum(); | |
while (!options.isEmpty()) { | |
double randomValue = random.nextDouble() * totalWeight; | |
for (Element option : options) { | |
randomValue -= option.getWeight(); | |
if (randomValue <= 0) { | |
shuffleList.add(option); | |
options.remove(option); | |
totalWeight -= option.getWeight(); | |
break; | |
} | |
} | |
} | |
return shuffleList; | |
} | |
static class ItemWeight<T> implements Comparable<ItemWeight> { | |
T element; | |
Double weight; | |
public ItemWeight(T element, Double weight) { | |
this.element = element; | |
this.weight = weight; | |
} | |
@Override | |
public int compareTo(ItemWeight o) { | |
return this.weight.compareTo(o.weight); | |
} | |
} | |
public List<Element> weightedShuffle6() { | |
//reference: https://www.generacodice.com/en/articolo/2544892/how-to-implement-a-weighted-shuffle | |
Random random = new Random(); | |
List<Element> items = new ArrayList<>(); | |
int index = 0; | |
for (double w : weights) { | |
items.add(new Element(index++, w)); | |
} | |
Stream<Integer> order = IntStream.range(0, items.size()) | |
.mapToObj(i -> new ItemWeight<>(i, Math.pow(-random.nextDouble(), (1.0 / weights[i])))) | |
.sorted() | |
.map(o -> o.element); | |
List<Element> shuffled = order.map(i -> items.get(i)).collect(Collectors.toList()); | |
return shuffled; | |
} | |
public static void main(String[] args) { | |
// double[] candidates = new double[] {0.1D, 0.5D, 0.3D, 0.1D}; | |
int count = 500; | |
double[] candidates = new double[count]; | |
for (int i=0; i<count; i++) { | |
candidates[i] = (int) (Math.random() * count); | |
} | |
WeightedRandom wrs = new WeightedRandom(candidates); | |
int[] cnt = new int[candidates.length]; | |
int total = 100000; | |
for (int i=0; i<total; i++) { | |
cnt[wrs.select()]++; | |
} | |
for (int i=0; i<candidates.length; i++) { | |
System.out.println(candidates[i] + ": " + cnt[i] + " / " + total); | |
} | |
long start = System.currentTimeMillis(); | |
for (int i=0; i<total; i++) { | |
wrs.weightedShuffle(); | |
} | |
System.out.println("wrs.weightedShuffle: " + (System.currentTimeMillis() - start)); | |
start = System.currentTimeMillis(); | |
for (int i=0; i<total; i++) { | |
wrs.weightedShuffle2(); | |
} | |
System.out.println("wrs.weightedShuffle2: " + (System.currentTimeMillis() - start)); | |
start = System.currentTimeMillis(); | |
for (int i=0; i<total; i++) { | |
wrs.weightedShuffle3(); | |
} | |
System.out.println("wrs.weightedShuffle3: " + (System.currentTimeMillis() - start)); | |
start = System.currentTimeMillis(); | |
for (int i=0; i<total; i++) { | |
wrs.weightedShuffle4(); | |
} | |
System.out.println("wrs.weightedShuffle4: " + (System.currentTimeMillis() - start)); | |
start = System.currentTimeMillis(); | |
for (int i=0; i<total; i++) { | |
wrs.weightedShuffle5(); | |
} | |
System.out.println("wrs.weightedShuffle5: " + (System.currentTimeMillis() - start)); | |
start = System.currentTimeMillis(); | |
for (int i=0; i<total; i++) { | |
wrs.weightedShuffle6(); | |
} | |
System.out.println("wrs.weightedShuffle6: " + (System.currentTimeMillis() - start)); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Benchmarks for 100000 runs with count=500
wrs.weightedShuffle: 13161
wrs.weightedShuffle2: 10888
wrs.weightedShuffle3: 11258
wrs.weightedShuffle4: 16967
wrs.weightedShuffle6: 8541