Skip to content

Instantly share code, notes, and snippets.

@pursuingpareto
Created March 17, 2016 22:18
Show Gist options
  • Save pursuingpareto/b15f1197d96b1a2bbc48 to your computer and use it in GitHub Desktop.
Save pursuingpareto/b15f1197d96b1a2bbc48 to your computer and use it in GitHub Desktop.
def score(bracket, results, filt,
teams_remaining, blacklist, round_num=0):
"""
Recursively calculates the score of a prediction
bracket against a results bracket.
- bracket
A bitstring representing a prediction bracket. For
a 64 game tournament this would be 63 bits.
- results
A bitstring representing the actual outcome of
a tournament.
- filt
With the exception of the first round in a tournament,
its not possible to score a round by just comparing
the bits in bracket to the bits in results. For example,
correctly predicting the championship game requires not
only the correct bit for that game, but also the correct
prediction for all the games the tournament winner had
won before the final round.
The filt parameter is a one time pre-computed bitstring
used to indicate which games in a round must be
correctly predicted in order to correctly predict
successive games. For a 64 game tournament filt would
contain 62 bits.
- teams_remaining
This is a recursive function where each call
represents another tournament round.
teams_remaining gives the number of teams left in the
tournament as of this function call.
- blacklist
This parameter is a sequence of N bits where N is
the number of games in the current round. It uses
the accuracy of predictions from previous rounds to
"remember" which games are possible to correctly
predict. When calling the score function initially
these bits should all be set to 1.
- round_num
A number representing the current round. For a 64 team
tournament this would take the values 0,1,2,3,4, and 5
"""
# First check if there is a winner
if teams_remaining == 1 :
return 0
# compute constants for round
# round_mask is a bitstring with all bits set to 0
# except the bits corresponding to the current round
num_games = teams_remaining / 2
round_mask = 2 ** num_games - 1
# the current round is encoded in the num_games
# least significant bits. Likewise for results
# and filter
round_predictions = bracket & round_mask
bracket = bracket >> num_games
round_results = results & round_mask
results = results >> num_games
round_filter = filt & round_mask
filt = filt >> num_games
# The overlap between the prediction bits and the
# results bits is calculated by XORing the two and
# then flipping the bits remaining.
overlap = ~(round_predictions ^ round_results)
# In all rounds except the first, overlap will tend
# to overestimate a bracket's correctness. This is
# corrected by ANDing the overlap with the blacklist
scores = overlap & blacklist
# the points for this round are calculated by counting
# the number of 1s in the scores bitstring and then
# multiplying by 2 ^ round_num (this multiplication
# is used to weigh predictions in later rounds more
# heavily than earlier rounds)
points = popcount(scores) << round_num
# with the points calculated we can now use the
# pre-computed filter to figure out which of these
# predictions may impact future predictions
relevant_scores = scores & round_filter
# For each pair of games in this round, look for a 1
# in either of the bits to compute the blacklist
# for the next round.
even_bits, odd_bits = get_odds_and_evens(relevant_scores)
blacklist = even_bits | odd_bits
# recursively call score function with updated params
return points + score(bracket, results, filt,
teams_remaining / 2, blacklist, round_num + 1)
def get_odds_and_evens(bits):
"""
Separates the even and odd bits by repeatedly
shuffling smaller segments of a bitstring.
"""
tmp = (bits ^ (bits >> 1)) & 0x22222222;
bits ^= (tmp ^ (tmp << 1));
tmp = (bits ^ (bits >> 2)) & 0x0c0c0c0c;
bits ^= (tmp ^ (tmp << 2));
tmp = (bits ^ (bits >> 4)) & 0x00f000f0;
bits ^= (tmp ^ (tmp << 4));
tmp = (bits ^ (bits >> 8)) & 0x0000ff00;
bits ^= (tmp ^ (tmp << 8));
evens = bits >> 16
odds = bits % 0x10000
return evens, odds
def popcount(x):
"""
Counts the number of 1s in a bitstring.
"""
x -= (x >> 1) & 0x5555555555555555
x = (x & 0x3333333333333333) + ((x >> 2) & 0x3333333333333333)
x = (x + (x >> 4)) & 0x0f0f0f0f0f0f0f0f
return ((x * 0x0101010101010101) & 0xffffffffffffffff ) >> 56
# this function can be slow since it would only be called once
# when the tournament is over.
def make_mask_binary(results, teams_remaining):
future_rounds = results & (2 ** (teams_remaining/2) -1)
mask = ""
for res in bin(future_rounds)[2:]:
if res == "0":
mask += "01"
else:
mask += "10"
mask = mask[::-1]
return int(mask, 2)
def make_test(bracket, results, N=64):
blacklist = int("1" * (N/2), 2)
filt = make_mask_binary(results, N)
return blacklist, filt
def test():
results = int("1" * 63, 2)
bracket = int("1" * 63, 2)
blacklist, filt = make_test(bracket, results)
assert score(bracket, results, filt, 64, blacklist) == 192
bracket = int("0" * 31 + "1" * 32, 2)
blacklist, filt = make_test(bracket, results)
assert score(bracket, results, filt, 64, blacklist) == 32
bracket = int("1" * 15 + "0" * 16 + "1" * 32, 2)
blacklist, filt = make_test(bracket, results)
assert score(bracket, results, filt, 64, blacklist) == 32
bracket = int("1" * 15 + "01" * 8 + "1" * 32, 2)
blacklist, filt = make_test(bracket, results)
assert score(bracket, results, filt, 64, blacklist) == (192 - 2*8)
bracket = int("1" * 15 + "10" * 8 + "1" * 32, 2)
blacklist, filt = make_test(bracket, results)
assert score(bracket, results, filt, 64, blacklist) == (32 + 2*8)
for _ in range(10):
s = ""
for i in range(63):
s += random.choice(["1", "0"])
bracket = int(s, 2)
results = bracket
assert score(bracket, results, filt, 64, blacklist) == 192
print "tests pass"
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment