Skip to content

Instantly share code, notes, and snippets.

@robertcampion
Created August 8, 2021 02:27
Show Gist options
  • Save robertcampion/c4ad00b4413fa2e59b027c1e6654cd77 to your computer and use it in GitHub Desktop.
Save robertcampion/c4ad00b4413fa2e59b027c1e6654cd77 to your computer and use it in GitHub Desktop.
import z3
N = 8 # number of athletes
M = 3 # number of medals
R = 3 # number of rounds
solver = z3.Solver()
# rank[i][r] is the rank of athlete i in round r
# ... except in the last round, where athlete i currently ranks (i+1)st
rank = [[z3.Int(f'r{i}{r}') for r in range(R-1)] for i in range(N)]
# score[i][k] is the score of athlete i if the last athlete ranks (k+1)st
score = [[z3.Product(r)*(k+1 if i == N-1 else i+1+(k <= i))
for k in range(N)] for i, r in enumerate(rank)]
# medal[i][k] is whether athlete i medals if the last athlete ranks (k+1)st
medal = [[z3.Bool(f'm{i}{k}') for k in range(N)] for i in range(N)]
# the rankings are all between 1 and N
for rs in rank:
for r in rs:
solver.add(1 <= r)
solver.add(r <= N)
# for each round the rankings are distinct
for r in zip(*rank):
solver.add(z3.Distinct(r))
# the number of medalists is M
for ms in zip(*medal):
# equivalent to z3.Sum(ms) == M
solver.add(z3.PbEq([(m, 1) for m in ms], M))
# if athlete i medals and athlete j does not, i's score is no more than than j's.
for i in range(N):
for j in range(N):
if i == j:
continue
for k in range(N):
solver.add(
z3.Implies(
z3.And(medal[i][k], z3.Not(medal[j][k])),
score[i][k] <= score[j][k] if i > j else
score[i][k] < score[j][k]
)
)
# each athlete must medal in at least one case but at most N-1 cases
for ms in medal:
solver.add(z3.PbGe([(m, 1) for m in ms], 1))
solver.add(z3.PbLe([(m, 1) for m in ms], N-1))
# at least one medal-winner must be different in each case
for k1 in range(N):
for k2 in range(k1+1, N):
solver.add(z3.Or([medal[i][k1] != medal[i][k2] for i in range(N)]))
z3.set_option(verbose=1)
result = solver.check()
if result != z3.sat:
print(result)
exit(1)
model = solver.model()
for i, rs in enumerate(rank):
print(*(model.eval(r) for r in rs), i+1)
for ss, ms in zip(zip(*score), zip(*medal)):
print(*(
'{1}{0}{1}'.format(model.eval(s),
'*' if z3.is_true(model.eval(m)) else ' ')
for s, m in zip(ss, ms)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment