Skip to content

Instantly share code, notes, and snippets.

@sathish316
Last active September 27, 2021 23:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sathish316/4a725234334c8c6ba6290a61f3f1a462 to your computer and use it in GitHub Desktop.
Save sathish316/4a725234334c8c6ba6290a61f3f1a462 to your computer and use it in GitHub Desktop.
package com.algos.shuffle;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
public class WeightedRandom {
private final double[] cdf;
private final double[] weights;
public WeightedRandom(double[] weights) {
this.weights = weights.clone();
cdf = weights.clone();
for (int i=1; i<cdf.length; i++) {
cdf[i] += cdf[i-1];
}
}
public int select() {
// (-(insertion point) - 1) if not found
int ret = Arrays.binarySearch(cdf, Math.random() * cdf[cdf.length - 1]);
return ret < 0 ? -1 - ret : ret;
}
// this one performs the best
public int[] weightedShuffle2() {
int[] ret = new int[cdf.length];
double[] mcdf = cdf.clone();
for (int i=0; i<ret.length; i++) {
// (-(insertion point) - 1) if not found
int index = Arrays.binarySearch(mcdf, Math.random() * mcdf[mcdf.length - 1]);
index = (index >= 0) ? index : -1 - index;
ret[i] = index;
for (int j=index + 1; j<mcdf.length; j++) {
mcdf[j] -= weights[index];
}
mcdf[index] = index == 0 ? 0 : mcdf[index - 1];
}
return ret;
}
public int[] weightedShuffle3() {
int[] ret = new int[weights.length];
double[] mcdf = cdf.clone();
int n = ret.length;
while (n > 0) {
// (-(insertion point) - 1) if not found
int index = Arrays.binarySearch(cdf, Math.random() * cdf[n - 1]);
index = (index >= 0) ? ((index < n) ? index : n - 1) : -1 - index;
ret[ret.length - n] = index;
double m = mcdf[index];
for (int j=index; j<mcdf.length-1; j++) {
mcdf[j] = mcdf[j+1] - m;
}
n--;
}
return ret;
}
private static class IndexValPair {
private final int index;
private final double val;
private IndexValPair(int index, double val) {
this.index = index;
this.val = val;
}
}
public int[] weightedShuffle() {
int[] ret = new int[weights.length];
List<IndexValPair> l = new LinkedList<>();
double sum = cdf[cdf.length - 1];
for (int i=0; i<weights.length; i++) {
l.add(new IndexValPair(i, weights[i]));
}
while (l.size() > 0) {
Iterator<IndexValPair> iterator = l.iterator();
double remain = Math.random() * sum;
while (true) {
IndexValPair ivp = iterator.next();
remain -= ivp.val;
if (remain <= 0) {
ret[weights.length - l.size()] = ivp.index;
iterator.remove();
sum -= ivp.val;
break;
}
}
}
return ret;
}
static class Element implements Comparable<Element> {
int id;
Double weight;
public Element(int id, double weight) {
this.id = id;
this.weight = weight;
}
public int getId() {
return id;
}
public double getWeight() {
return weight;
}
@Override
public int compareTo(Element o) {
return weight.compareTo(o.weight);
}
}
public List<Element> weightedShuffle4() {
Random random = new Random();
List<Element> shuffleList = new ArrayList<>();
List<Element> options = new ArrayList<>();
int i = 0;
for (double w : weights) {
options.add(new Element(i++, w));
}
double totalWeight = options.stream().mapToDouble(e -> e.getWeight()).sum();
while (!options.isEmpty()) {
double randomValue = random.nextDouble() * totalWeight;
for (Element option : options) {
randomValue -= option.getWeight();
if (randomValue <= 0) {
shuffleList.add(option);
options.remove(option);
totalWeight -= option.getWeight();
break;
}
}
}
return shuffleList;
}
public List<Element> weightedShuffle5() {
Random random = new Random();
List<Element> shuffleList = new ArrayList<>();
Set<Element> options = new TreeSet<>();
int i = 0;
for (double w : weights) {
options.add(new Element(i++, w));
}
double totalWeight = options.stream().mapToDouble(e -> e.getWeight()).sum();
while (!options.isEmpty()) {
double randomValue = random.nextDouble() * totalWeight;
for (Element option : options) {
randomValue -= option.getWeight();
if (randomValue <= 0) {
shuffleList.add(option);
options.remove(option);
totalWeight -= option.getWeight();
break;
}
}
}
return shuffleList;
}
static class ItemWeight<T> implements Comparable<ItemWeight> {
T element;
Double weight;
public ItemWeight(T element, Double weight) {
this.element = element;
this.weight = weight;
}
@Override
public int compareTo(ItemWeight o) {
return this.weight.compareTo(o.weight);
}
}
public List<Element> weightedShuffle6() {
//reference: https://www.generacodice.com/en/articolo/2544892/how-to-implement-a-weighted-shuffle
Random random = new Random();
List<Element> items = new ArrayList<>();
int index = 0;
for (double w : weights) {
items.add(new Element(index++, w));
}
Stream<Integer> order = IntStream.range(0, items.size())
.mapToObj(i -> new ItemWeight<>(i, Math.pow(-random.nextDouble(), (1.0 / weights[i]))))
.sorted()
.map(o -> o.element);
List<Element> shuffled = order.map(i -> items.get(i)).collect(Collectors.toList());
return shuffled;
}
public static void main(String[] args) {
// double[] candidates = new double[] {0.1D, 0.5D, 0.3D, 0.1D};
int count = 500;
double[] candidates = new double[count];
for (int i=0; i<count; i++) {
candidates[i] = (int) (Math.random() * count);
}
WeightedRandom wrs = new WeightedRandom(candidates);
int[] cnt = new int[candidates.length];
int total = 100000;
for (int i=0; i<total; i++) {
cnt[wrs.select()]++;
}
for (int i=0; i<candidates.length; i++) {
System.out.println(candidates[i] + ": " + cnt[i] + " / " + total);
}
long start = System.currentTimeMillis();
for (int i=0; i<total; i++) {
wrs.weightedShuffle();
}
System.out.println("wrs.weightedShuffle: " + (System.currentTimeMillis() - start));
start = System.currentTimeMillis();
for (int i=0; i<total; i++) {
wrs.weightedShuffle2();
}
System.out.println("wrs.weightedShuffle2: " + (System.currentTimeMillis() - start));
start = System.currentTimeMillis();
for (int i=0; i<total; i++) {
wrs.weightedShuffle3();
}
System.out.println("wrs.weightedShuffle3: " + (System.currentTimeMillis() - start));
start = System.currentTimeMillis();
for (int i=0; i<total; i++) {
wrs.weightedShuffle4();
}
System.out.println("wrs.weightedShuffle4: " + (System.currentTimeMillis() - start));
start = System.currentTimeMillis();
for (int i=0; i<total; i++) {
wrs.weightedShuffle5();
}
System.out.println("wrs.weightedShuffle5: " + (System.currentTimeMillis() - start));
start = System.currentTimeMillis();
for (int i=0; i<total; i++) {
wrs.weightedShuffle6();
}
System.out.println("wrs.weightedShuffle6: " + (System.currentTimeMillis() - start));
}
}
@sathish316
Copy link
Author

Benchmarks for 100000 runs with count=500

wrs.weightedShuffle: 13161
wrs.weightedShuffle2: 10888
wrs.weightedShuffle3: 11258
wrs.weightedShuffle4: 16967
wrs.weightedShuffle6: 8541

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment