Skip to content

Instantly share code, notes, and snippets.

@laudiacay
Created February 13, 2024 02:48
Show Gist options
  • Save laudiacay/8f7dcee06ad29a0179f2d95171285aae to your computer and use it in GitHub Desktop.
Save laudiacay/8f7dcee06ad29a0179f2d95171285aae to your computer and use it in GitHub Desktop.
emily account problem
memoize_case_max_payoff = {}
def find_max_entropy_case(n: int):
if n % 2 == 0:
return n//2
else:
return n//2 + 1
def solve_case(n: int, majority: int):
# we already saw this case!
if (n, majority) in memoize_case_max_payoff:
return memoize_case_max_payoff[(n, majority)]
elif majority > n:
raise "majority cannot be greater than n, you goofed somewhere"
# base case: everything is one color, we win. nice! (contains the n=1 case, optimize out sidewalls of tree)
elif majority == n:
#print("all one color case")
memoize_case_max_payoff[(n, majority)] = 2**n
# flipped majority and minority case (optimize out half the tree)
elif majority < n / 2:
memoize_case_max_payoff[(n, majority)] = solve_case(n, n - majority)
# actual strategy piece
# always bet on the majority color
# outcome A and outcome B are minimal outcomes for either outcome of first bet
# A is where you guessed right, B is where you guessed wrong
# after A, you get into case of (n-1, majority - 1)
# after B, you get into case of (n-1, majority)
# so the best outcome is to guess x where (1 + x) * A = (1 - x) * B
# then, solving: always guess the majority color with bet x = (b - a)/(a + b)
# your minimum payout will be 2ab/(a + b)
else:
# print("actual strategy piece for n", n, "majority", majority, "find_max_entropy_case(n-1)", find_max_entropy_case(n-1))
# you got it right!
a = solve_case(n-1, majority-1)
#print("a", a)
# you got it wrong!
b = solve_case(n-1, majority)
#print("b", b)
max_payoff = 2*a*b/(a+b)
#print("max_payoff", max_payoff)
memoize_case_max_payoff[(n, majority)] = max_payoff
return memoize_case_max_payoff[(n, majority)]
def solve_for_n(n: int):
max_payoff = solve_case(n, find_max_entropy_case(n))
return max_payoff
print(solve_for_n(52))
# output is 9.08132954942779
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment