Skip to content

Instantly share code, notes, and snippets.

@berdario berdario/buckets.py
Created Aug 28, 2014

Embed
What would you like to do?
from __future__ import print_function
from numpy.random import binomial
from random import sample
from functools import partial
from operator import add, itemgetter
from itertools import takewhile
# Problem:
# You are given 10 opaque bags containing large numbers of red balls
# and blue balls in unspecified proportions. The balls are identical
# except for their color. Propose an algorithm for picking 100 balls
# from any of the bags that tries to maximize the number of blue balls
# picked. Of course each bag is identifiable but picking a ball from a
# given bag is random.
# by intuition, a good approach could be to go through all the bags/buckets
# and keep track of the number of blue balls (1s) extracted before stumbling
# upon a red ball (a 0), and then prioritize the "best" buckets
# this is what the extractor function does:
def extractor(buckets, sort_fn, scale_fn):
i, step = 0, 0
bucket_scores = {bucket: (0,0) for bucket in buckets}
while i < 100:
for bucket, score in sort_fn(bucket_scores)[:scale_fn(step)]:
if i >= 100:
break
tally = 0
result = bucket()
i += 1
tally += result
while result and i < 100:
result = bucket()
i += 1
tally += result
bucket_scores[bucket] = map(add, score, (tally, -1))
step += 1
return sum(tally for tally, _ in bucket_scores.values())
# sort_fn is used to prioritize, and scale_fn is used to drop the worst buckets
def alg4(buckets):
i = 0
bucket_scores = {bucket: 0 for bucket in buckets}
while i < 100:
for bucket in buckets:
if i >= 100:
break
tally = 0
result = bucket()
i += 1
tally += result
while result and i < 100:
result = bucket()
i += 1
tally += result
bucket_scores[bucket] += tally
if len(buckets)==2:
buckets = [max(bucket_scores.items(), key=itemgetter(1))[0]]
elif len(buckets)>2 and sum(bucket_scores.values()) > 5:
top2 = sorted(bucket_scores.items(), key=itemgetter(1), reverse=True)[:2]
buckets = [bucket for bucket, _ in top2]
return sum(bucket_scores.values())
def scaling(factors):
def scale_fn(n):
return list(takewhile(lambda (i,_): i<=n, factors))[-1][1]
return scale_fn
def alg1(buckets):
l = len(buckets)
def sort_fn(bucket_scores):
return sorted(bucket_scores.items(), key=itemgetter(1), reverse=True)
return extractor(buckets, sort_fn, scaling([(0, l), (2, l//2)]))
def alg1b(buckets):
l = len(buckets)
def sort_fn(bucket_scores):
return sorted(bucket_scores.items(), key=itemgetter(1), reverse=True)
return extractor(buckets, sort_fn, scaling([(0, l), (2, 1)]))
def alg1c(buckets):
l = len(buckets)
def sort_fn(bucket_scores):
return sorted(bucket_scores.items(), key=itemgetter(1), reverse=True)
return extractor(buckets, sort_fn, scaling([(0, l), (1, 2)]))
def alg1d(buckets):
l = len(buckets)
def sort_fn(bucket_scores):
return sorted(bucket_scores.items(), key=itemgetter(1), reverse=True)
return extractor(buckets, sort_fn, scaling([(0, l), (1, 2), (2, 1)]))
def alg1e(buckets):
l = len(buckets)
def sort_fn(bucket_scores):
return sorted(bucket_scores.items(), key=itemgetter(1), reverse=True)
return extractor(buckets, sort_fn, scaling([(0, l), (1, 1)]))
def alg1f(buckets):
l = len(buckets)
def sort_fn(bucket_scores):
return sorted(bucket_scores.items(), key=itemgetter(1), reverse=True)
return extractor(buckets, sort_fn, scaling([(0, l), (2, 2), (3, 1)]))
def alg2(buckets):
l = len(buckets)
def sort_fn(bucket_scores):
return sorted(bucket_scores.items(), key=lambda x: (x[1][1], x[1][0]), reverse=True)
return extractor(buckets, sort_fn, scaling([(0, l), (2, l//2)]))
def alg3(buckets):
l = len(buckets)
def sort_fn(bucket_scores):
return sorted(bucket_scores.items(), key=itemgetter(1), reverse=True)
return extractor(buckets, sort_fn, scaling([(0, l), (2, l//2), (3, l//4)]))
N = 1000
def get_buckets(ps):
return [partial(binomial, 1, p) for p in sample(ps, len(ps))]
positives = partial(get_buckets, [0.6, 0.7, 0.75, 0.8, 0.9]*2)
negatives = partial(get_buckets, [0.4, 0.3, 0.25, 0.2, 0.1]*2)
mixed = partial(get_buckets, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.25, 0.2, 0.1])
#print('alg1d pos', sum(alg1d(positives()) for _ in range(N)))
#print('alg1f pos', sum(alg1f(positives()) for _ in range(N)))
#print('alg4 pos', sum(alg4(positives()) for _ in range(N)))
#print('alg1d neg', sum(alg1d(negatives()) for _ in range(N)))
#print('alg1f neg', sum(alg1f(negatives()) for _ in range(N)))
#print('alg4 neg', sum(alg4(negatives()) for _ in range(N)))
print('alg1d mix', sum(alg1d(mixed()) for _ in range(N)))
#print('alg1f mix', sum(alg1f(mixed()) for _ in range(N)))
#print('alg4 mix', sum(alg4(mixed()) for _ in range(N)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.