Skip to content

Instantly share code, notes, and snippets.

@shrkw
Last active August 29, 2015 14:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shrkw/9f184af5b54cd91bb789 to your computer and use it in GitHub Desktop.
Save shrkw/9f184af5b54cd91bb789 to your computer and use it in GitHub Desktop.
Codeiqの牛乳配達ルート問題 ref: http://qiita.com/shrkw/items/47effa195bd722488084
#!/bin/env python
# coding:utf-8
import csv
import sys
"""
根ノードから再帰的に深さ優先で木構造を作成しつつ探索も行う。
コスト集計も中途で行い、最低コストを上回った時点で以降の子孫ノードの探索を中止する。
木構造のプログラムを書くのに慣れていないため、実直にコーディングした。
"""
class Accumulator(object):
def __init__(self, sum=sys.maxsize, route=[]):
self.sum = sum
self.route = route
class Node(object):
def __init__(self, id, parent=None, cost_from_root=0, children=[]):
self.id = id
self.parent = parent
self.cost_from_root = cost_from_root
self.children = children
def __repr__(self):
return "%i, cost: %i -> %s\n" % (self.id, self.cost_from_root, repr(self.children))
class DeliveryCostCalculator(object):
def __init__(self, filename):
self.filename = filename
self.cost_table = self.get_table()
self.acc = Accumulator(sys.maxsize, [])
def get_table(self):
cost_table = []
with open(self.filename, 'r') as f:
reader = csv.reader(f)
for row in reader:
cost_table.append([int(col) for col in row])
return cost_table
def calc_total_cost(self, current):
# 残りがなければコスト集計を行う
tmp = Node(0, current, current.cost_from_root + self.cost_table[current.id][0], None)
current.children.append(tmp)
if tmp.cost_from_root < self.acc.sum:
# コストが最低ならばルートのリストも集計し、アキュムレータに渡す
self.acc.sum = tmp.cost_from_root
def _min_r(n, acc):
if n.parent is None:
acc.append(n)
return acc
acc.append(n)
return _min_r(n.parent, acc)
self.acc.route = _min_r(tmp, [])
self.acc.route.reverse()
def main(self):
def _f(slot, current):
if len(slot) <= 0:
self.calc_total_cost(current)
return
for i in slot:
# 子に次のノードを登録する
tmp = Node(i, current, current.cost_from_root + self.cost_table[current.id][i])
if self.acc.sum < tmp.cost_from_root:
# この時点でのコストが最低コストを上回っているなら以降の探索を中止する
return
current.children.append(tmp)
# 追加した数字をリストから削除して再帰する
a = list(slot)
a.remove(i)
_f(a, tmp)
_f(range(1, len(self.cost_table)), Node(0))
return self.acc.sum, self.acc.route
if __name__ == "__main__":
c = DeliveryCostCalculator(sys.argv[1])
(sum, route) = c.main()
print(sum)
print(" -> ".join([str(n.id + 1) for n in route]))
# print([(n.id + 1, n.cost_from_root) for n in route])
#!/bin/env python
# coding:utf-8
import csv
import sys
from itertools import permutations
def get_table(filename):
cost_table = []
with open(filename, 'r') as f:
reader = csv.reader(f)
for row in reader:
cost_table.append([int(col) for col in row])
return cost_table
def main(filename):
cost_table = get_table(filename)
min_cost = sys.maxsize
min_route = ()
for p in permutations(range(1, len(cost_table))):
# add an initial and a last node
p = (0,) + p + (0,)
total_cost = 0
for i in range(len(p)):
if i == len(p) - 1:
continue
# get a cost between a current and next
total_cost += cost_table[p[i]][p[i + 1]]
if min_cost < total_cost:
break
# print(total_cost, p)
if total_cost < min_cost:
min_cost = total_cost
min_route = p
return min_cost, min_route
if __name__ == "__main__":
c, r = main(sys.argv[1])
print(c)
print(" -> ".join([str(n + 1) for n in r]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment