Skip to content

Instantly share code, notes, and snippets.

@ruandao
Created February 28, 2014 05:17
Show Gist options
  • Save ruandao/9265733 to your computer and use it in GitHub Desktop.
Save ruandao/9265733 to your computer and use it in GitHub Desktop.
超能饮料,暴力破解
#!/usr/bin/python
# encoding:utf-8
"""
一个点过的所有环里面有很多点集是重复的,只需要保留最小的那个
譬如C1C2C3C1, C1C3C2C1 这两个环是不同的,点集却是相同的,当然造价一般是不同的,那么只需要保留造价最低的一个就可以了
找出所有的过某一点的所有环,数据存储格式如下:
d = {
“C1”: [],
“C2”: [],
}
一个环里面所需要的数据:
价格,路径,点集, 点集hash
先获取所有点所过的所有环,
然后任选一点,取他所有环,组成优先队列
先从优先队列中pop出总价最小的环C
将C与C上的每个点的所有环分别合并组成C1,C2,C3...
将这些Cn 放入队列
两个环之间的合并:
设C1,C2
如果C2的点集在C1上,那么就没有合并的必要了,直接进行下一个合并
将C1上的路径集减去C2上的路径集,然后计算新路径集的造价P3
将P1 + P3 作为新环的造价, 将两个点的点集合并,并计算新的点集的hash值
还需要维护一份点集列表
如果一个环的点集比别的点集小,价格又比别的点集高,
我觉得这个环就没有计算的必要了
过13个点的最低时耗时15分钟... 下面数据样本结果为 1539828
M1 C1 C2 94945
M2 C2 C1 173017
M3 C1 C3 331586
M4 C3 C1 938698
M5 C1 C4 175677
M6 C4 C1 718488
M7 C1 C5 751710
M8 C5 C1 818761
M9 C1 C6 340500
M10 C6 C1 642094
M11 C1 C7 957227
M12 C7 C1 262932
M13 C1 C8 153792
M14 C8 C1 537394
M15 C1 C9 473263
M16 C9 C1 175835
M17 C1 C10 287474
M18 C10 C1 183148
M19 C1 C11 826289
M20 C11 C1 707737
M21 C1 C12 213644
M22 C12 C1 758400
M23 C1 C13 418286
M24 C13 C1 272166
M25 C2 C3 190798
M26 C3 C2 677571
M27 C2 C4 231720
M28 C4 C2 16053
M29 C2 C5 863568
M30 C5 C2 662507
M31 C2 C6 115848
M32 C6 C2 957513
M33 C2 C7 834524
M34 C7 C2 446435
M35 C2 C8 896211
M36 C8 C2 10201
M37 C2 C9 164923
M38 C9 C2 647920
M39 C2 C10 827962
M40 C10 C2 504423
M41 C2 C11 290014
M42 C11 C2 785188
M43 C2 C12 766356
M44 C12 C2 442806
M45 C2 C13 322581
M46 C13 C2 239618
M47 C3 C4 617641
M48 C4 C3 609056
M49 C3 C5 421767
M50 C5 C3 443930
M51 C3 C6 316793
M52 C6 C3 634411
M53 C3 C7 202329
M54 C7 C3 734079
M55 C3 C8 905577
M56 C8 C3 392127
M57 C3 C9 411650
M58 C9 C3 137297
M59 C3 C10 407181
M60 C10 C3 275217
M61 C3 C11 798804
M62 C11 C3 522030
M63 C3 C12 232730
M64 C12 C3 633328
M65 C3 C13 967465
M66 C13 C3 128940
M67 C4 C5 642529
M68 C5 C4 132387
M69 C4 C6 775860
M70 C6 C4 470491
M71 C4 C7 635811
M72 C7 C4 65873
M73 C4 C8 255679
M74 C8 C4 402166
M75 C4 C9 507679
M76 C9 C4 577261
M77 C4 C10 640784
M78 C10 C4 125320
M79 C4 C11 186317
M80 C11 C4 62551
M81 C4 C12 568250
M82 C12 C4 502110
M83 C4 C13 695962
M84 C13 C4 769580
M85 C5 C6 236189
M86 C6 C5 601539
M87 C5 C7 161707
M88 C7 C5 646839
M89 C5 C8 737836
M90 C8 C5 567888
M91 C5 C9 921057
M92 C9 C5 536640
M93 C5 C10 89918
M94 C10 C5 153786
M95 C5 C11 169967
M96 C11 C5 57382
M97 C5 C12 281726
M98 C12 C5 811497
M99 C5 C13 188770
M100 C13 C5 57586
M101 C6 C7 281988
M102 C7 C6 823582
M103 C6 C8 122459
M104 C8 C6 536667
M105 C6 C9 225747
M106 C9 C6 629139
M107 C6 C10 113928
M108 C10 C6 865532
M109 C6 C11 753460
M110 C11 C6 299245
M111 C6 C12 927083
M112 C12 C6 321710
M113 C6 C13 800356
M114 C13 C6 623045
M115 C7 C8 91290
M116 C8 C7 36544
M117 C7 C9 224583
M118 C9 C7 251997
M119 C7 C10 682384
M120 C10 C7 961420
M121 C7 C11 818886
M122 C11 C7 603440
M123 C7 C12 498059
M124 C12 C7 907804
M125 C7 C13 756227
M126 C13 C7 667027
M127 C8 C9 964187
M128 C9 C8 37952
M129 C8 C10 478524
M130 C10 C8 152957
M131 C8 C11 94539
M132 C11 C8 759512
M133 C8 C12 975539
M134 C12 C8 215998
M135 C8 C13 296179
M136 C13 C8 201286
M137 C9 C10 844138
M138 C10 C9 409107
M139 C9 C11 66818
M140 C11 C9 597598
M141 C9 C12 707353
M142 C12 C9 992902
M143 C9 C13 918308
M144 C13 C9 507708
M145 C10 C11 615946
M146 C11 C10 9597
M147 C10 C12 543253
M148 C12 C10 839530
M149 C10 C13 260595
M150 C13 C10 225637
M151 C11 C12 800949
M152 C12 C11 79481
M153 C11 C13 828077
M154 C13 C11 299008
M155 C12 C13 986286
M156 C13 C12 584304
"""
import sys
import heapq
nodes = {} # [x][y]: [price, machine, k] 所有设备的图案
nodes2 = {}
allMachines = [] #[(price, x, y, machine),]
# calculatePoints[endPoint][pointSetHash] = (sum, pathsSet, pointSet)
calculatePoints = {} # 保存,过某一点的所有环,将有相同点集的环视为同一种环
def hashSet(pointSet):
h = []
for x in pointSet:
heapq.heappush(h, x)
return "".join(h)
def findCycles():
# 从一个点集出发,找到所有的单环
keys = nodes.keys()
# start # 从start出发完成剩下的路线,回到起点
# prePoints # 前面经过的点
# taboo # 禁忌表,用来存储已经经过的点(不包括起点/终点)
# endPoint # 终点
def _findCycle(startPoint, prePoints, taboo, endPoint):
if startPoint in taboo: return
if startPoint == endPoint:
start, sum, pathsSet, pointSet = None, 0, set(), set()
for to in prePoints:
if start is not None:
pathsSet.add((start, to))
sum += nodes[start][to][0]
start = to
pointSet.add(start)
hash = hashSet(pointSet)
if calculatePoints[endPoint].get(hash) \
and calculatePoints[endPoint][hash][0] < sum:
return
calculatePoints[endPoint][hash] = (sum, pathsSet, pointSet)
return
taboo.add(startPoint)
for x in nodes[startPoint].keys():
# [:] 很耗性能的, 改为append, pop
# 上面对应的in 的操作也要改
if x in taboo: continue
# print " %s->%s " % (startPoint, x)
prePoints.append(x)
_findCycle(x, prePoints, taboo, endPoint)
prePoints.pop()
# print "out %s->%s" % (startPoint, x)
taboo.remove(startPoint)
def _getSumOfPaths(paths):
sum = 0
for x,y in paths:
sum += nodes[x][y][0]
return sum
# 对所有点计算过该点的所有环
for point in keys:
calculatePoints[point] = {}
for x in nodes[point].keys():
_findCycle(x, [point, x], set(), point)
selectPoint = keys[0]
allCycles = []
allCyclesDic = {}
pointSetList = {} # 点集列表, [n]: [(pointSet, price)]
def _shouldUse(hash, pointSet, price):
# 如果一个环的点集在另外一个环中的点集中(真包含),且要价比另外一个环要高,那么这个环显然是不靠谱的环
setLen = len(pointSet)
for n,d in pointSetList.items():
if n > setLen:
for setItem,price1 in d.values():
if len(pointSet - setItem) == 0 and price >= price1:
return False
if not pointSetList.get(setLen): pointSetList[setLen] = {}
if not pointSetList[setLen].get(hash) or pointSetList[setLen][hash][1] > price:
pointSetList[setLen][hash] = (pointSet, price)
return True
for hash,(price, pathsSet, pointSet) in calculatePoints[selectPoint].items():
if not _shouldUse(hash, pointSet, price): continue
allCyclesDic[hash] = (price, pathsSet, pointSet)
heapq.heappush(allCycles, (price, hash))
n12Chose = None
calculateCycles = {}
lenKeys = len(keys)
while allCycles:
price, hash = heapq.heappop(allCycles)
price, pathsSet, pointSet = allCyclesDic[hash]
if len(pointSet) == lenKeys:
if n12Chose[0] > price: n12Chose = (price, pathsSet)
break
# 如果已经计算过的相同点集比这个点集造价还小,那么不用再算了
if calculateCycles.get(hash) and calculateCycles[hash] < price: continue
calculateCycles[hash] = price
# print price
for point in pointSet:
for price2, pathsSet2, pointSet2 in calculatePoints[point].values():
newPointSet = pointSet2 - pointSet
# 如果引入这个新环,不能带来点的增加,那么没有必要引入
if len(newPointSet) == 0: continue
newPathsSet = pathsSet2 - pathsSet
price3 = price + _getSumOfPaths(newPathsSet)
pathsSet3 = newPathsSet | pathsSet
pointSet3 = newPointSet | pointSet
hash3 = hashSet(pointSet3)
# 如果已经计算过的相同点集比这个点集造价还小,那么不用再算了
if calculateCycles.get(hash3) and calculateCycles[hash3] < price3: continue
if len(pointSet3) == lenKeys and (n12Chose is None or n12Chose[0] > price3):
n12Chose = (price3, pathsSet3)
elif n12Chose is not None and n12Chose[0] < price3: continue
else:
if not allCyclesDic.get(hash3) or price3 < allCyclesDic[hash3][0]:
# 如果这个点集在其他点集中,且这个点集造价更高,那么也不用算了
if not _shouldUse(hash3, pointSet3, price3): continue
allCyclesDic[hash3] = (price3, pathsSet3, pointSet3)
heapq.heappush(allCycles, (price3, hash3))
print n12Chose[0]
def calMachineName(paths):
h = []
for x,y in paths:
heapq.heappush(h, int(nodes[x][y][1][1:]))
return h
h = calMachineName(n12Chose[1])
while h:
print heapq.heappop(h),
def flody():
keys = nodes.keys()
for k in keys:
for x in keys:
if x == k:continue
for y in nodes[x].keys():
if y == k or y==x:continue
if not nodes[x].get(k) or not nodes[k].get(y):continue
sum = nodes[x][k][0] + nodes[k][y][0]
if nodes[x][y][0] > sum:
nodes[x][y] = [nodes[x][y][0], None, k]
def clearRedunary():
keys = nodes.keys()
for x in keys:
for y in nodes[x].keys():
if nodes[x][y][2] is not None:
# print "clean %s->%s k: %s" % (x,y, nodes[x][y][2])
del nodes[x][y]
def process():
count = len(allMachines)
count2 = 0
while count2 != count:
count2 = count
# # 求出任意两点间的最短距离
flody()
# 剔除冗余路线,如果x->y < x->k->y 那么删除x->y
clearRedunary()
copyPrice()
count = countPaths()
# 现在就来吧
findCycles()
def copyPrice():
for x in nodes.keys():
for y in nodes[x].keys():
nodes[x][y] = [nodes2[x][y][0], nodes2[x][y][1], None]
def countPaths():
count = 0
for x in nodes.keys():
for y in nodes[x].keys():
count += 1
return count
def output():
sum = 0
machine = []
for start in nodes.keys():
for end in nodes[start].keys():
sum += nodes[start][end][0]
machineName = nodes[start][end][1][1:]
hadBeenInsert = False
for index,name in enumerate(machine):
if int(machineName) < int(name):
machine.insert(index, machineName)
hadBeenInsert = True
break
if not hadBeenInsert:
machine.append(machineName)
print sum
print " ".join(machine)
pass
def initDataFromFile(f):
for line in f:
line = line.replace("\t", " ").split()
# from, to, price, machineName
machineName = line[0]
x = line[1]
y = line[2]
price = int(line[3])
if not nodes.get(x):
nodes[x] = {}
nodes[x][y] = [price, machineName, None]
if not nodes2.get(x):
nodes2[x] = {}
nodes2[x][y] = [price, machineName, None]
allMachines.append((price,x,y,machineName))
def log(str):
# print str
pass
def main(f):
# 初始化数据
# 计算任意两点的最短距离
# 处理图
# 输出结果
initDataFromFile(f)
log("init finish")
process()
log("process finish")
# output()
log("output finish")
if __name__ == '__main__':
import sys
filename = sys.argv[1]
f = open(filename)
main(f)
f.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment