Skip to content

Instantly share code, notes, and snippets.

@ruandao
Last active August 29, 2015 13:56
Show Gist options
  • Save ruandao/9228082 to your computer and use it in GitHub Desktop.
Save ruandao/9228082 to your computer and use it in GitHub Desktop.
模拟退火 想用来解决 文明盛世的超能饮料问题,不过 效率跟遍历没差 (看来只合适用来找近似最优解)
#!/usr/bin/python
# encoding:utf-8
import sys
import heapq
import math,random
import time
# 1 获取一个随机解
# 2 从上一个解中获取一个随机解
# 新解为最优时,选定新解;否则按概率在解与新解中选择
def annealing(sa, newF, T=10000.0, cool=0.95):
"""
sa: (price, solution)
"""
chose = sa
while T>0.1:
# print "c"
sb = newF(sa)
d = sb[0] - sa[0]
# print sb[0], sa[0]
if d < 0:
sa = sb
elif math.exp(-d/T) > random.random():
sa = sb
else:
# print "continue"
continue
T = T*cool
if chose[0] > sa[0]:
chose = sa
print sa[0],chose[0]
return chose
def gR(max):
return random.randint(0, max - 1)
nodes = {} # [x][y]: {price, name}
def getRandom():
"""
获取一个随机解,要求符合过所有点的环
"""
keys = nodes.keys()
lenKey = len(keys)
node = keys[random.randint(0,lenKey-1)]
keys2 = keys[:]
keys2.remove(node)
points = [node]
needOut = [node]
g = {}
sum = [0]
endPoints = set([node])
def connect(x,y):
if x in needOut:
needOut.remove(x)
if y not in points:
needOut.append(y)
points.append(y)
if not g.get(x):
g[x] = {}
if g[x].get(y) is not True:
sum[0] = sum[0] + nodes[x][y][0]
g[x][y] = True
def exist(x,y):
return nodes[x].get(y)
def addPath(x,y):
# 将x 到y的路线上的点标识为终点
# print "addPath %s->%s" % (x,y), " *" * 20
def _findPath(x, prePoints, y):
if x in prePoints:
return
prePoints.append(x)
# print "\t",prePoints
if x == y:
for point in prePoints:
endPoints.add(point)
return
if not g.get(x):
return
for point in g[x].keys():
_findPath(point, prePoints[:], y)
_findPath(x,[],y)
while not(len(points) == lenKey and len(needOut) == 0):
lenPoints = len(points)
if keys2:
p = points[gR(lenPoints)]
node2 = keys2[gR(lenKey - lenPoints)]
if not exist(p,node2):
continue
else:
keys2.remove(node2)
else:
p = needOut[gR(len(needOut))]
node2 = list(endPoints)[gR(len(endPoints))]
if not exist(p, node2):
# 检测needout 的点是否能够连接到 endPoints
allConnect = []
for p in needOut:
allConnect.extend(nodes[p].keys())
canConnect = False
for p in endPoints:
if p in allConnect:
canConnect = True
break
if not canConnect:
return getRandom()
continue
else:
addPath(node2, p)
if exist(p,node2):
connect(p, node2)
# print "connect %s->%s" %(p, node2), "\tneedOut: ", needOut
# print "\tpoints", points
# print "\tkeys2", keys2
# print "\tendPoints", endPoints
return (sum[0], g)
def newF(sa):
(price1, g1) = sa
g = {} # 要对g 进行拷贝
paths = []
# 随机破坏g1 中的某条线路(x,y)
for x in g1.keys():
g[x] = {}
for y in g1[x].keys():
g[x][y] = True
paths.append((x,y))
def choseAPath():
path = paths[gR(len(paths))]
paths.remove(path)
if paths:
if not hasOtherPath(path[0], path[1]):
return choseAPath()
return path
return None
def hasOtherPath(x,y):
"""
找到从x到y的非直接路径(就是路径上不能有(x,y))
"""
has = [False]
def findOther(x, prePoints,y,withOutPoint):
if x in prePoints or (len(prePoints) >0 and x == withOutPoint):
return
if x == y:
has[0] = True
return
prePoints.append(x)
while not has:
for k in nodes[x].keys():
findOther(k, prePoints[:],y, withOutPoint)
findOther(x, [], y, x)
return has[0]
while paths:
path = choseAPath()
if not path:
return getRandom()
del g[x][y]
sum = price1 - nodes[x][y][0]
# 然后随机加上某些路径来保证
# x 可以连到y
# 先获得x可以连到的端点有哪些,然后从这些端点出发,随机加入某些路径(这些路径不应该是x-y),直到连接到y
s = set([x])
stack = [x]
while stack:
# print "www"
p = stack.pop()
for k in g[p].keys():
if k not in s:
s.add(k)
stack.append(k)
# print "mmm"
while y not in s:
p = list(s)[gR(len(s))]
keys = nodes[p].keys()
p2 = keys[gR(len(keys))]
# print "pppppp", s,y
if p == x and p2 == y:
continue
if p2 in s:
continue
s.add(p2)
if g[p].get(p2) is not True:
sum += nodes[p][p2][0]
g[p][p2] = True
return (sum, g)
def initDataFromFile(f):
for line in f:
line = line.replace("\t", " ").split()
# from, to, price, machineName
machineName = line[0]
x = line[1]
y = line[2]
price = int(line[3])
if not nodes.get(x):
nodes[x] = {}
nodes[x][y] = [price, machineName]
def main(f):
# 初始化数据
# 计算任意两点的最短距离
# 处理图
# 输出结果
initDataFromFile(f)
# log("init finish")
chose = None
l = []
sum = 0
for x in xrange(1,100):
sa = getRandom()
sum += sa[0]
l.append(sa[0])
if chose is None or chose[0]>sa[0]:
chose = sa
avg = sum / 100
l = [ (x-avg)*(x-avg) for x in l]
sum = 0
for x in l:
sum += x
variance = sum / 100
print variance
r = annealing(chose, newF, variance)
if r[0] < chose[0]:
chose = r
print chose
# print r
# process()
# log("process finish")
# output()
# log("output finish")
if __name__ == '__main__':
import sys
filename = sys.argv[1]
f = open(filename)
main(f)
f.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment