Created
January 12, 2016 04:18
-
-
Save jliszka/6dd4afe75b8b0e3cd08b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from itertools import permutations | |
import pprint | |
class Strategy: | |
def cards(self, cs): | |
return False | |
def filter(self, p): | |
return True | |
def evalPermutation(self, p): | |
for i in range(4): | |
prefix = p[0:i+1] | |
m = min(prefix) | |
if self.cards([c - m for c in prefix]): | |
return p[i] | |
return p[4] | |
def eval(self): | |
ps = permutations(range(5), 5) | |
a = 0 | |
n = 0 | |
for p in ps: | |
if self.filter(p): | |
n += 1 | |
a += self.evalPermutation(p) | |
return a * 1.0 / n | |
class First(Strategy): | |
def cards(self, cs): | |
return True | |
class NextBest(Strategy): | |
def cards(self, cs): | |
return len(cs) > 1 and cs[-1] == 0 | |
class NextBest2(Strategy): | |
def cards(self, cs): | |
return max(cs) >= 3 and cs[-1] <= 1 | |
memo = dict() | |
class Opt(Strategy): | |
def __init__(self, filterSet = (), count = 0): | |
self.filterSet = filterSet | |
self.count = count | |
def normalize(self, p): | |
if len(p) == 0: | |
return [] | |
m = min(p) | |
return tuple(sorted([ c-m for c in p ])) | |
def filter(self, p): | |
if self.count == 0: | |
return True | |
return self.normalize(p[0:self.count]) == self.filterSet | |
def evalPermutation(self, p): | |
r = super(Opt, self).evalPermutation(p) | |
if self.count == 0: | |
return r | |
return r - min(p[0:self.count]) | |
def cards(self, cs): | |
if len(cs) <= self.count: | |
return False | |
s = Opt(self.normalize(cs), len(cs)) | |
return cs[-1] < s.eval() | |
def eval(self): | |
if self.filterSet in memo: | |
return memo[self.filterSet] | |
r = super(Opt, self).eval() | |
memo[self.filterSet] = r | |
return r | |
class Opt2(Strategy): | |
def cards(self, cs): | |
if max(cs) >= 3 and cs[-1] <= 1: | |
return True | |
if len(cs) == 3 and max(cs) == 2 and cs[-1] == 0: | |
# 0, 1, 2 | |
return True | |
if len(cs) == 4 and max(cs) == 4: | |
if 1 in cs: | |
# 0, 1, 2, 4 | |
# 0, 1, 3, 4 | |
return cs[-1] <= 2 | |
else: | |
# 0, 2, 3, 4 | |
return cs[-1] == 0 | |
return False | |
print(First().eval()) # 2.0 | |
print(NextBest().eval()) # 1.1 | |
print(NextBest2().eval()) # 0.95 | |
print(Opt().eval()) # 0.9 | |
print(Opt2().eval()) # 0.9 | |
pp = pprint.PrettyPrinter(indent=2) | |
pp.pprint(memo) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment