Skip to content

Instantly share code, notes, and snippets.

@pitzcarraldo
Created October 21, 2015 00:39
Show Gist options
  • Save pitzcarraldo/ed7cdda48b7f66f66257 to your computer and use it in GitHub Desktop.
Save pitzcarraldo/ed7cdda48b7f66f66257 to your computer and use it in GitHub Desktop.
MAB(Multi-Armed Bandits) Java Implementation
/**
* @descreption Calculate impression ratio for Arms by ArmStatistics.
* @author Minkyu Cho(mrnoname@naver.com)
* The logic of this class is based on R code in below article(Figure 4).
* http://mktg455cnu.net/wp-content/uploads/2014/10/scott.pdf
*/
@Slf4j
@Component
public class RatioCalculator {
private static final int SIMULATE_COUNT_BASE = 100;
/**
*
* Calculate ratios by ArmStatistics.
*
* In R Language
* compute.win.prob(clicks , views, simulateCount);
*
* clicks and views included in ArmStatistic and SimulateCount is number of Arms * SIMULATE_COUNT_BASE
*
* @param statistics the ArmStatistics of single Arm
* @return impressionRationMap key : itemId, value : impressionRatio
*/
public Map<Long, Integer> computeImpressionRatio(List<ArmStatistic> statistics) {
int simulateCount = statistics.size() * SIMULATE_COUNT_BASE;
Map<Long, Integer> ratioMap = sortAndFillPercentage(computeImpressionRatio(makeSimulateMap(statistics, simulateCount)));
for (ArmStatistic statistic : statistics) {
if (!ratioMap.containsKey(statistic.getArmId())) {
ratioMap.put(statistic.getArmId(), 0);
}
}
return ratioMap;
}
/**
* In R Language
* prob.winner <- function(post){
* k <- ncol(post)
* w <- table(factor(max.col(post), levels=1:k))
* return (w/sum(w))
* };
* @param simulateMap
* @return
*/
private Map<Long, Integer> computeImpressionRatio(Multimap<Long, Double> simulateMap) {
Map<Long, AtomicLong> winCountMap = getWinCountMap(getWinItems(simulateMap));
long sum = 0;
for (Map.Entry<Long, AtomicLong> entry : winCountMap.entrySet()) {
sum += entry.getValue().longValue();
}
BigDecimal sumBigDecimal = new BigDecimal(sum);
Map<Long, Integer> winners = Maps.newHashMap();
final int roundUpLimit = 2;
for (Map.Entry<Long, AtomicLong> entry : winCountMap.entrySet()) {
BigDecimal wins = new BigDecimal(entry.getValue().longValue());
BigDecimal ratio = wins.divide(sumBigDecimal, roundUpLimit, BigDecimal.ROUND_HALF_UP).multiply(new BigDecimal(SIMULATE_COUNT_BASE));
winners.put(entry.getKey(), ratio.intValue());
}
return winners;
}
/**
* Make simulation map from log of click and presentation.
* {item1 : [betaValue1, betaValue2 ... ~ simulateCount]}
* {item2 : [betaValue1, betaValue2 ... ~ simulateCount]}
* ...
* betaValue = rbeta(clicks + 1,view - click + 1) > Estimated Conversion Rate.
*
* In R Language
* sim.post <- function(clicks , views, simulateCount){
* nItems <- length(clicks)
* simulateMatrix <- matrix(nrow=simulateCount, ncol=nItems)
* no <- views-clicks
* for(i in 1:nItems)
* simulateMatrix[,i]<-rbeta(simulateCount,clicks[i]+1,no[i]+1)
* return(simulateMatrix)
* };
*
* @param stats
* @param simulateCount
* @return simulateMap
*/
private Multimap<Long, Double> makeSimulateMap(List<ArmStatistic> stats, int simulateCount) {
Multimap<Long, Double> simulateMap = ArrayListMultimap.create();
for (ArmStatistic stat : stats) {
BetaDistribution beta = new BetaDistributionImpl(stat.getClick() + 1, stat.getView() - stat.getClick() + 1);
for (int i = 0; i < simulateCount; i++) {
double betaValue = 0;
try {
betaValue = beta.inverseCumulativeProbability(Math.random());
} catch (Exception e) {
log.error(e.getMessage(), e);
}
simulateMap.put(stat.getArmId(), betaValue);
}
}
return simulateMap;
}
/**
* Pick and sort highest betaValue.
* orderedDrawMap > Sorted item map by betaValue per simulation.
* { simulation 1 : {{betaValue : itemId}, ... 오름차순}}
* { simulation 2 : {{betaValue : itemId}, ... }}
* ...
* winItems > Set of highest betaValue.
*
* In R Language
* w <- table(factor(max.col(post), levels=1:k))
*
* @param simulateMap
* @return ordered itemId list by simulation.
*/
private List<Long> getWinItems(Multimap<Long, Double> simulateMap) {
Map<Integer, Map<Double, Long>> orderedDrawMap = Maps.newHashMap();
for (Long itemId : simulateMap.keySet()) {
List<Double> draws = Lists.newArrayList(simulateMap.get(itemId));
for (int i = 0; i < draws.size(); i++) {
if (!orderedDrawMap.containsKey(i)) {
Map<Double, Long> newMap = Maps.newTreeMap();
orderedDrawMap.put(i, newMap);
}
orderedDrawMap.get(i).put(draws.get(i), itemId);
}
}
List<Long> winItems = Lists.newArrayList();
for (Map.Entry<Integer, Map<Double, Long>> entry : orderedDrawMap.entrySet()) {
TreeMap<Double, Long> currentMap = (TreeMap<Double, Long>) entry.getValue();
winItems.add(currentMap.lastEntry().getValue());
}
return winItems;
}
/**
* Calculate count of test what has highest betaValue per item.
* In R Language
* max.col(post)
* @param winItems
* @return winCountMap
*/
private Map<Long, AtomicLong> getWinCountMap(List<Long> winItems) {
ConcurrentMap<Long, AtomicLong> winCountMap = Maps.newConcurrentMap();
for (Long itemId : winItems) {
winCountMap.putIfAbsent(itemId, new AtomicLong(0));
winCountMap.get(itemId).incrementAndGet();
}
return winCountMap;
}
/**
* Sort presentation ratio by DSC and fill to 100 when size of map is lower than 100.
* @param ratioMap
* @return sortedRatioMap
*/
private Map<Long, Integer> sortAndFillPercentage(Map<Long, Integer> ratioMap) {
Map<Long, Integer> sortedMap = sortByDesc(ratioMap);
fillPercentageToOneHundred(sortedMap);
return sortedMap;
}
/**
* Sort by DSC.
* @param unsortedMap
* @return sortedMap
*/
private Map<Long, Integer> sortByDesc(Map<Long, Integer> unsortedMap) {
List list = new LinkedList(unsortedMap.entrySet());
Collections.sort(list, new Comparator() {
public int compare(Object o1, Object o2) {
return ((Comparable) ((Map.Entry) (o2)).getValue()).compareTo(((Map.Entry) (o1)).getValue());
}
});
Map<Long, Integer> sortedMap = Maps.newLinkedHashMap();
for (Iterator it = list.iterator(); it.hasNext();) {
Map.Entry<Long, Integer> entry = (Map.Entry<Long, Integer>) it.next();
sortedMap.put(entry.getKey(), entry.getValue());
}
return sortedMap;
}
/**
* Fill sortedMap to 100.
* @param sortedMap
*/
private void fillPercentageToOneHundred(Map<Long, Integer> sortedMap) {
int sum = 0;
for (Map.Entry<Long, Integer> entry : sortedMap.entrySet()) {
sum += entry.getValue();
}
if (sum <= 0) {
return;
}
if (sum != SIMULATE_COUNT_BASE) {
int gap = SIMULATE_COUNT_BASE - sum;
Long lastKey = Iterables.getLast(sortedMap.keySet(), null);
sortedMap.put(lastKey, sortedMap.get(lastKey) + gap);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment