Skip to content

Instantly share code, notes, and snippets.

@benallard
Last active March 29, 2020 14:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save benallard/da0318b36ea81b940ca64ac2c9e5cba2 to your computer and use it in GitHub Desktop.
Save benallard/da0318b36ea81b940ca64ac2c9e5cba2 to your computer and use it in GitHub Desktop.
Travelling's Salesman based on historic data
+ABCDE
+CBDF
+ACDF
+BEACD
+ABDE
ABCDEF?
-ABCDE
BCDEF?
import random
import sys
# public domain
class Entry(object):
def __init__(self, score):
self.score = score
self.amount = 1;
def add(self, score):
self.amount += 1;
self.score += score
def sub(self, score):
if self.amount > 0:
self.amount -= 1
self.score -= score
@property
def value(self):
if self.amount == 0:
return 0
return self.score / self.amount
class Model(object):
def __init__(self):
self.data = dict()
def learn(self, route):
prev = '0'
for i, station in enumerate(route):
self._process(prev, route[i:], True)
prev = station
def forget(self, route):
prev = '0'
for i, station in enumerate(route):
self._process(prev, route[i:], False)
prev = station
def predict(self, stations):
stations = [s for s in stations]
random.shuffle(stations)
prev = '0'
res = []
while len(stations) > 0:
best = max(stations, key=lambda x : self.__get(prev, x))
stations.remove(best)
res.append(best)
print(''.join(res))
def _process(self, frm, tos, add=True):
scores = [8,4,2,1]
for i, to in enumerate(tos):
if i >= len(scores):
break
if add:
self.__add(frm, to, scores[i])
else:
self.__sub(frm, to, scores[i])
def __get(self, frm, to):
key = frm + to
if key in self.data:
return self.data[key].value
else:
return 0
def __add(self, frm, to, score):
key = frm + to
if key in self.data:
self.data[key].add(score)
else:
self.data[key] = Entry(score)
def __sub(self, frm, to, score):
key = frm + to
if key in self.data:
self.data[key].sub(score)
def __str__(self):
s = " "
s += " ".join(self.__tos())
s += "\n"
for frm in self.__froms():
s += " " + frm + " "
for to in self.__tos():
key = frm + to
if key in self.data:
s += '{:4.1f}'.format(self.data[key].value)
else:
s += " "
s += "\n"
return s
def __froms(self):
return sorted(set(k[0] for k in self.data.keys()))
def __tos(self):
return sorted(set(k[1] for k in self.data.keys()))
def main(filename):
model = Model()
with open(filename) as f:
for line in f:
line = line.strip()
if line.startswith("+"):
model.learn(line[1:])
elif line.startswith("-"):
model.forget(line[1:])
elif line.endswith("?"):
model.predict(line[:-1])
else:
print(line)
print(model)
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv[1]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment