Skip to content

Instantly share code, notes, and snippets.

@xire-
Created March 20, 2018 16:38
Show Gist options
  • Save xire-/00beccec7e99e69a13df7fccae40d253 to your computer and use it in GitHub Desktop.
Save xire-/00beccec7e99e69a13df7fccae40d253 to your computer and use it in GitHub Desktop.
payment solver
#!/usr/bin/env python3
# -*- coding: utf8 -*-
from collections import defaultdict
from itertools import permutations
class PaymentSolver:
def __init__(self, people):
# if X owe Y N euro,
# then Y owe X -N euro
# (x,y) -> float
self.x_owe_y = defaultdict(float)
self.people = people
def __getitem__(self, idx):
x,y = idx
assert x in self.people
assert y in self.people
assert x != y
if x > y:
return - self.x_owe_y[y,x]
else:
return self.x_owe_y[x,y]
def __setitem__(self, idx, value):
x,y = idx
assert x in self.people
assert y in self.people
assert x != y
if x > y:
x,y = y,x
value = -value
self.x_owe_y[x,y] = value
def __str__(self):
ris = list()
for p1 in sorted(self.people):
ris.append(f'___{p1:_<30s}')
tot = 0.
for p2 in sorted(self.people):
if p1 == p2: continue
owed = self[p1,p2]
tot += owed
if owed > 0:
ris.append(f'owe {p2} {owed:.2f} EUR')
ris.append('---')
ris.append(f'TOTAL OWED: {tot:.2f} EUR')
ris.append('')
return '\n'.join(ris)
def paid_for(self, who, total, group):
"""
"who" paid for "group", "total" euro
"""
assert who in self.people
assert all(p in self.people for p in group)
total_each = total / len(group)
for p in group:
# skip self
if p == who: continue
self[p,who] += total_each
def split_money(self, who, total, group):
"""
*who* split *total* euro with *group*
it's equivalent to *group* owing negative money to *who*
"""
self.paid_for(who,-total, group)
def _simplify_triangle(self):
'''
se:
A -> C : x
A -> B : y
B -> C : z
diventa:
A -> C : x + min(y,z)
A -> B : y - min(y,z)
B -> C : z - min(y,z)
'''
while True:
fixpoint = True
for A,B,C in permutations(sorted(self.people), 3):
x = self[A,C]
y = self[A,B]
z = self[B,C]
if x >= 0 and y > 0 and z > 0:
fixpoint = False
if y < z:
self[A,C] = x + y
self[A,B] = 0.0
self[B,C] = z - y
else:
self[A,C] = x + z
self[A,B] = y - z
self[B,C] = 0.0
if fixpoint: break
def _simplify_cross(self):
'''
se:
A -> C: x
A -> D: y
B -> C: z
B -> D: w
deventa:
A -> C: x + min(y,z)
A -> D: y - min(y,z)
B -> C: z - min(y,z)
B -> D: w + min(y,z)
'''
while True:
fixpoint = True
for A,B,C,D in permutations(sorted(self.people), 4):
x = self[A,C]
y = self[A,D]
z = self[B,C]
w = self[B,D]
if x > 0 and y > 0 and z > 0 and w > 0:
fixpoint = False
if y < z:
self[A,C] = x + y
self[A,D] = 0.0
self[B,C] = z - y
self[B,D] = w + y
else:
self[A,C] = x + z
self[A,D] = y - z
self[B,C] = 0.0
self[B,D] = w + z
if fixpoint: break
def simplify_payments(self):
self._simplify_triangle()
self._simplify_cross()
# ================================================================================
def main():
people = {
'Marco',
'Palma',
'Mauro',
'Gaspa',
'Cesco',
'Benve',
'Lorenzo',
}
s = PaymentSolver(people)
s.paid_for(who='Lorenzo', total=45.00, group=people)
s.paid_for(who='Gaspa', total=12.00, group={'Gaspa', 'Cesco', 'Benve'})
s.paid_for(who='Marco', total=7.20, group=people-{'Palma'})
# set a single debt
s['Cesco', 'Gaspa'] = 0.0
# increase a single debt
s['Benve', 'Gaspa'] += 10.0
s.simplify_payments()
print(s)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment