Skip to content

Instantly share code, notes, and snippets.

@rgrig
Created September 22, 2014 07:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rgrig/8e96144698db15df929b to your computer and use it in GitHub Desktop.
Save rgrig/8e96144698db15df929b to your computer and use it in GitHub Desktop.
some rather bad euclidean tsp solver
#!/usr/bin/env python3
from copy import copy
from math import atan2, pi, sqrt
from random import randrange, seed, shuffle
from sys import exit, stderr
from time import time
eps = 1e-9
def length(A, B):
return sqrt((A[0]-B[0])*(A[0]-B[0]) + (A[1]-B[1])*(A[1]-B[1]));
def cost(xys, tour):
n = len(tour)
r = length(xys[tour[0]], xys[tour[-1]])
for i in range(1, n):
r += length(xys[tour[i-1]], xys[tour[i]])
return r
def neighbors_graph(d, xys):
start = time()
n = len(xys)
kdtree = [i for i in range(n)]
def kselect(a, b, c, m):
nonlocal kdtree, xys
assert a <= m and m < b
i = randrange(a, b)
kdtree[i], kdtree[a] = kdtree[a], kdtree[i]
i = j = k = a + 1
while k < b:
D = xys[kdtree[k]][c] - xys[kdtree[a]][c]
if D > 0:
k += 1
elif D == 0:
kdtree[k], kdtree[j] = kdtree[j], kdtree[k]
j += 1
k += 1
else:
# NOTE: parallel assignment fails when i=j!=k
t = kdtree[k]
kdtree[k] = kdtree[j]
kdtree[j] = kdtree[i]
kdtree[i] = t
i += 1
j += 1
k += 1
i -= 1
kdtree[a], kdtree[i] = kdtree[i], kdtree[a]
if m < i:
kselect(a, i, c, m)
elif j <= m:
kselect(j, b, c, m)
def kdsort(a, b, c):
nonlocal kdtree
if b - a < 2:
return
m = (a+b)//2
kselect(a, b, c, m)
kdsort(a, m, 1 - c)
kdsort(m + 1, b, 1 - c)
oops = oops2 = oops3 = oops4 = 0
def kdquery_rec(r, a, b, c, ll, ur):
nonlocal kdtree,oops2
oops2 += 1
if a == b:
return
m = (a + b) // 2
p = xys[kdtree[m]]
if ll[c] - eps < p[c]:
kdquery_rec(r, a, m, 1-c, ll, ur)
if ll[0] - eps < p[0] and p[0] < ur[0] + eps and ll[1] - eps < p[1] and p[1] < ur[1] + eps:
r.append(kdtree[m])
if p[c] < ur[c] + eps:
kdquery_rec(r, m+1, b, 1-c, ll, ur)
def kdquery(ll, ur):
nonlocal n,oops
oops += 1
r = []
kdquery_rec(r, 0, n, 0, ll, ur)
return r
def kdcount_rec(a, b, c, ll, ur, bbll, bbur):
nonlocal kdtree, oops3
oops3 += 1
if a == b:
return 0
if ll[0] - eps < bbll[0] and ll[1] - eps < bbll[1] and bbur[0] < ur[0] + eps and bbur[1] < ur[1] + eps:
return b - a
m = (a + b) // 2
p = xys[kdtree[m]]
r = 0
if ll[c] - eps < p[c]:
bbur2 = copy(bbur)
bbur2[c] = min(bbur2[c], p[c])
r += kdcount_rec(a, m, 1-c, ll, ur, bbll, bbur2)
if ll[0] - eps < p[0] and p[0] < ur[0] + eps and ll[1] - eps < p[1] and p[1] < ur[1] + eps:
r += 1
if p[c] < ur[c] + eps:
bbll2 = copy(bbll)
bbll2[c] = max(bbll2[c], p[c])
r += kdcount_rec(m+1, b, 1-c, ll, ur, bbll2, bbur)
return r
def kdcount(ll, ur):
nonlocal n,oops4
oops4 +=1
return kdcount_rec(0, n, 0, ll, ur, [float('-inf'), float('-inf')], [float('inf'), float('inf')])
kdsort(0, n, 0)
print('kdtree built in {:03f} seconds'.format(time() - start))
G = [set() for _ in range(n)]
xs = [x for x, _ in xys]
ys = [y for _, y in xys]
width = max(xs) - min(xs)
height = max(ys) - min(ys)
for i in range(n):
x, y = xys[i]
l, h = 0, sqrt(width * height * d / n)
while True:
js_len = kdcount((x-h, y-h), (x+h, y+h))
if js_len >= d:
break
l = h
h += h
for _ in range(8):
m = (h + l) / 2
js_len = kdcount((x-m, y-m), (x+m, y+m))
if js_len >= d:
h = m
else:
l = m
js = kdquery((x-h,y-h), (x+h,y+h))
assert len(js) >= d
for j in js:
if i != j:
G[i].add(j)
G[j].add(i)
for i in range(n):
G[i] = [(length(xys[i], xys[j]), j) for j in G[i]]
G[i].sort()
G[i] = [j for _, j in G[i]]
Gsz = sum(len(xs) for xs in G)
print('cnt(kdquery)',oops,'avg(kdquery_rec)',oops2/oops, 'cnt(kdcount)',oops4, 'avg(kdcount_rec)',oops3/oops4)
print('neighbors graph of size {} built in {:03f} seconds'.format(Gsz, time() - start))
return G
def kruskal(G, xys):
start = time()
n = len(xys)
es = []
for i in range(n):
for j in G[i]:
if i < j:
es.append((length(xys[i], xys[j]), i, j))
es.sort()
rep = [i for i in range(n)]
def find(x):
nonlocal rep
ys = []
while rep[x] != x:
ys.append(x)
x = rep[x]
for y in ys:
rep[y] = x
return x
def union(x, y):
nonlocal rep
x, y = (x, y) if randrange(2) == 0 else (y, x)
rep[find(x)] = find(y)
fs = []
for e in es:
if find(e[1]) != find(e[2]):
fs.append((e[1],e[2]))
union(e[1], e[2])
print('len(fs)',len(fs),'n',n)
if len(fs) != n - 1:
print('cannot find mst: disconnected graph')
exit(1)
mst = [[] for x in range(n)]
for x, y in fs:
mst[x].append(y)
mst[y].append(x)
if True:
print('kruskal done in {:03f} seconds'.format(time() - start))
return mst
def traverse_tree(xys, T):
start = time()
n = len(T)
y = 0
tour = [y]
trail = [y]
seen = set(tour)
seen_edge = set()
p = (float('-inf'), 0)
while len(seen) < n:
if False:
print('len(seen)',len(seen),'n',n,'y',y)
q = xys[y]
best_alpha = float('inf')
best_z = None
for z in T[y]:
if (y, z) in seen_edge:
continue
r = xys[z]
alpha = atan2(r[1]-q[1], r[0]-q[0]) - atan2(p[1]-q[1], p[0]-q[0])
while alpha > 2 * pi:
alpha -= 2 * pi
while alpha < eps:
alpha += 2 * pi
if alpha < best_alpha:
best_alpha = alpha
best_z = z
assert best_z is not None
if best_z not in seen:
seen.add(best_z)
tour.append(best_z)
trail.append(best_z)
if (y, best_z) in seen_edge:
with open('bad_trail','w') as f:
for x in trail:
f.write('{} {}\n'.format(xys[x][0], xys[x][1]))
print('oops: nonterminating mst tour')
exit(2)
seen_edge.add((y, best_z))
p = q
y = best_z
if True:
print('mst tour found in {:03f} seconds'.format(time() - start))
return tour
# m = number of segments
def swap_segments(duration, m, G, xys, tour):
assert 2 <= m
start = time()
n = len(G)
# record splits
d = int((2**20 / n)**(1/(m-1)))
if False:
print('d is',d)
def generate_segperm(xs):
nonlocal m
if len(xs) == 2 * m:
if xs[-1] != 2 * m - 2:
yield xs
else:
for x in range(2 * m):
if x in xs:
continue
if x == (xs[-1] + 1) % (m + m) or x == (xs[-1] - 1) % (m + m):
continue
yield from generate_segperm(xs + [x, ((x+1)^1)-1])
# (y, 0, _) means [y] unconstrained
# (y, 1, x) means [y] := [x]+1
# (y, 2, x) means [y] := [x]-1
# (y, 3, x) means [y] neighbor of [x]
# (y, 4, x) means check [y] < [x]
# (y, 5, x) means check [y] > [x]
# TODO: this doesn't generate segments that begin exactly at 0
def make_schedule(sp):
nonlocal m
schedule = []
que1 = []
que2 = []
fixed = set()
def add_to_schedule(op):
nonlocal fixed, que1, que2, schedule, sp
y = op[0]
if y in fixed:
return
fixed.add(y)
schedule.append(op)
a, b = -1, 2 * m + 1
for z in fixed:
if z < y:
a = max(a, z)
if y < z:
b = min(b, z)
if a != -1:
schedule.append((y, 5, a))
if b != 2 * m + 1:
schedule.append((y, 4, b))
z = y ^ 1
que1.append((z, (1 if y < z else 2), y))
i = sp.index(y)
z = sp[(((i + 1) ^ 1) - 1) % (m + m)]
que2.append((z, 3, y))
while len(fixed) < m + m:
if que1 == [] and que2 == []:
for y in range(2 * m):
if y not in fixed:
add_to_schedule((y, 0, None))
break
while not (que1 == [] and que2 == []):
if que1 == []:
op, que2 = que2[0], que2[1:]
else:
op, que1 = que1[0], que1[1:]
add_to_schedule(op)
return schedule
def generate_splits(schedule, i, split):
nonlocal d, n, G
if i == len(schedule):
yield copy(split)
else:
y, what, x = schedule[i]
if what == 0:
for z in range(n):
split[y] = z
yield from generate_splits(schedule, i + 1, split)
elif what == 1:
split[y] = (split[x] + 1) % n
yield from generate_splits(schedule, i + 1, split)
elif what == 2:
split[y] = (split[x] - 1) % n
yield from generate_splits(schedule, i + 1, split)
elif what == 3:
for z in G[split[x]][:d]:
split[y] = z
yield from generate_splits(schedule, i + 1, split)
elif what == 4:
if split[y] <= split[x]:
yield from generate_splits(schedule, i + 1, split)
elif what == 5:
if split[y] >= split[x]:
yield from generate_splits(schedule, i + 1, split)
else:
assert false
splits = []
for sp in generate_segperm([2*m-1, 0]):
if time() - start > duration:
break
schedule = make_schedule(sp)
cnt = 0
for _, what, _ in schedule:
if what == 0:
cnt += 1
if cnt > 1:
continue
for split in generate_splits(schedule, 0, [None] * (2 * m)):
penalty = 0
for i in range(0, m + m, 2):
penalty -= length(xys[tour[split[i]]], xys[tour[split[i+1]]])
for i in range(1, m + m, 2):
penalty += length(xys[tour[split[sp[i]]]], xys[tour[split[sp[(i+1)%(m+m)]]]])
if penalty < -eps:
splits.append((penalty, sp, split))
# choose splits
splits.sort()
def apply_split(x, s):
sp, split = s
ss = [0, split[sp[1]]] + [split[i] for i in sp[2:]] + [split[sp[0]], n-1]
ss = [(ss[i], ss[i+1]) for i in range(0, len(ss), 2)]
if False:
print('x',x)
print('sp',sp)
print('split',split)
print('ss',ss)
pos = 0
for a, b in ss:
if (a <= x and x <= b) or (b <= x and x <= a):
return pos + abs(x - a)
else:
pos += abs(b - a) + 1
assert False
len_splits = len(splits)
if False:
print('splits',splits)
good_splits = []
if False:
# TODO
splits = splits[:2048] # because what follows is quadratic
seen = set()
for p, sp, split in splits:
if len(set(split)&seen) > 0:
continue
seen |= set(split)
for s in good_splits:
split = [apply_split(x, s) for x in split]
good_splits.append((sp, split))
else:
splits = splits[:1]
for _, sp, split in splits:
good_splits.append((sp, split))
if False:
print('good_splits',good_splits)
if True:
print('chose {} splits (of {} segments) out of {}, found in {:03f} seconds'.format(len(good_splits), m, len_splits, time() - start))
# apply splits
new_tour = [None] * n
for x in range(n):
y = x
#print('x',x)
for s in good_splits:
y = apply_split(y, s)
#print('s',s)
#print('y',y)
new_tour[y] = tour[x]
if False:
print('tour',tour)
print('new_tour', new_tour)
return new_tour
def swap_edges2(duration, G, xys, tour):
n = len(tour)
start = time()
# record swaps
swaps = []
for i in range(n):
if time() - start > duration:
break
for j in G[i]:
if j < i:
continue
penalty = 0
penalty -= length(xys[tour[i]], xys[tour[(i+1)%n]])
penalty -= length(xys[tour[j]], xys[tour[(j+1)%n]])
penalty += length(xys[tour[i]], xys[tour[j]])
penalty += length(xys[tour[(i+1)%n]], xys[tour[(j+1)%n]])
if penalty < -eps:
swaps.append((penalty, i, j))
# choose swaps
swaps.sort()
good_swaps = []
seen = set()
for p, i, j in swaps:
if i in seen or j in seen:
continue
seen.add(i)
seen.add(j)
good_swaps.append((i, j))
# apply swaps
print('chose {} swaps out of {}, found in {:03f} seconds'.format(len(good_swaps), len(swaps), time() - start))
m = len(good_swaps)
for x in range(m):
i, j = good_swaps[x]
for y in range(x+1, m):
ii, jj = good_swaps[y]
if i < ii and ii < j:
ii = i + j - ii
if i < jj and jj < j:
jj = i + j - jj
good_swaps[y] = (ii, jj)
new_tour = [None] * n
for x in range(n):
y = x
for i, j in good_swaps:
if i < y and y <= j:
y = i + j + 1 - y
new_tour[y] = tour[x]
return new_tour
# Similar to swap_edges2, but these swaps only make sense to be applied
# in pairs, because otherwise they disconnect the tour.
def swap_edges4(duration, G, xys, tour):
n = len(G)
start = time()
# record (all) swaps
swaps = []
for i in range(n):
if time() - start > duration:
break
for j in G[i]:
if j < i:
continue # should be symmetric
if j == i or j == (i+1)%n or (j+1)%n == i:
continue # TODO: think about these cases
penalty = 0
penalty -= length(xys[tour[i]], xys[tour[(i+1)%n]])
penalty -= length(xys[tour[j]], xys[tour[(j+1)%n]])
penalty += length(xys[tour[i]], xys[tour[(j+1)%n]])
penalty += length(xys[tour[j]], xys[tour[(i+1)%n]])
swaps.append((penalty, i, j))
# choose swap pairs
def apply_swap(x, s):
i, j, k, l = s
assert i < j and j < k and k < l
if i < x and x <= j:
return i + (l - k) + (k - j) + (x - i)
elif j < x and x <= k:
return i + (l - k) + (x - j)
elif k < x and x <= l:
return i + (x - k)
else:
return x
swaps.sort()
m = len(swaps)
swaps = swaps[:512] # because the algo below is cubic
good_swaps = []
while len(swaps) > 0:
p0, i, k = swaps[0]
if p0 > -eps:
break
ii, kk = i, k
for s in good_swaps:
ii = apply_swap(ii, s)
kk = apply_swap(kk, s)
pairfound = False
for p1, j, l in swaps[1:]:
if p0 + p1 > -eps:
break
jj, ll = j, l
for s in good_swaps:
jj = apply_swap(jj, s)
ll = apply_swap(ll, s)
if (ii < jj and jj < kk and kk < ll) or (jj < ii and ii < ll and ll < kk):
if jj < ii:
ii, jj, kk, ll = jj, ii, ll, kk
good_swaps.append((ii, jj, kk, ll))
bad = [i, j, k, l]
bad += [(i+1)%n for i in bad] + [(i-1)%n for i in bad]
swaps = [(p, i, k) for p, i, k in swaps if i not in bad and k not in bad]
pairfound = True
break
if not pairfound:
swaps = swaps[1:]
if True:
print('chose {} swap pairs out of {} swaps, found in {:03f} seconds'.format(len(good_swaps), m, time() - start))
# apply swaps
good_swaps = good_swaps[:1]
new_tour = [None] * n
for x in range(n):
y = x
for s in good_swaps:
y = apply_swap(y, s)
new_tour[y] = tour[x]
if False:
print('before',cost(xys,tour), 'after',cost(xys,new_tour))
return new_tour
def lg(x):
r = 0
while x >= (1 << (1 << r)):
r += 1
h = 1 << r
l = h >> 1
while l + 1 != h:
m = (l + h) >> 1
if x < (1 << m):
h = m
else:
l = m
return l
def exact_path(src, xys, tgt):
n = len(xys)
assert (0 <= n and n <= 16)
now = { (i, 1 << i) : (length(src, xys[i]), [i]) for i in range(n) }
for k in range(1, n):
nxt = {}
for (i, s), (c, p) in now.items():
#print('i {} s {:04x} c {} p {}'.format(i,s,c,p))
assert p[-1] == i
mj = 1
while True:
mj += s
mj = ~s & mj
j = lg(mj)
if j >= n:
break
#print('mj {:04x} j {}'.format(mj, j))
assert i != j
sj = s | mj
cj = c + length(xys[i], xys[j])
pj = p + [j]
if (j, sj) in nxt:
(ocj, opj) = nxt[(j, sj)]
if ocj <= cj:
cj = ocj
pj = opj
nxt[(j, sj)] = (cj, pj)
mj <<= 1
now = nxt
best_c = float('inf')
best_p = None
for (i, _), (c, p) in now.items():
c += length(xys[i], tgt)
if c < best_c:
best_c = c
best_p = p
return best_p
def local_exact(duration, window, xys, tour):
n = len(tour)
window = min(window, n - 2)
if window < 0:
return tour
start = time()
cnt = 0
improvements = 0
offsets = [i for i in range(n)]
shuffle(offsets)
for offset in offsets:
if time() - start > duration:
break
points = [xys[tour[(offset + i) % n]] for i in range(window)]
p = exact_path(xys[tour[(offset-1)%n]], points, xys[tour[(offset+window)%n]])
if not all([i==p[i] for i in range(len(p))]):
improvements += 1
part = [tour[(offset+p[i]) % n] for i in range(window)]
for i in range(window):
tour[(offset+i)%n] = part[i]
if False:
print('apply perm {} at offset {} to get\n{}'.format(p, offset, tour))
cnt += 1
if True:
print('apply {} improvements found by {} runs of local_exact in {:.3f} seconds'.format(improvements, cnt, time() - start))
return tour
def teleport(duration, G, xys, tour):
start = time()
improvements = cnt = 0
n = len(tour)
# record teleports
teleports = []
for i in range(n):
if time() - start > duration:
break
for j in G[i]:
if abs((i - j) % n) < 5:
continue # will be improved (better) by local_exact
cnt += 1
penalty = 0
penalty += length(xys[tour[(i - 1) % n]], xys[tour[(i + 1) % n]])
penalty += length(xys[tour[j]], xys[tour[i]])
penalty += length(xys[tour[i]], xys[tour[(j + 1) % n]])
penalty -= length(xys[tour[(i - 1) % n]], xys[tour[i]])
penalty -= length(xys[tour[i]], xys[tour[(i + 1) % n]])
penalty -= length(xys[tour[j]], xys[tour[(j + 1) % n]])
if penalty < -eps:
teleports.append((penalty, i,j))
# choose teleports
teleports.sort()
good_teleports = []
seen = set()
for _, i, j in teleports:
if (i-1)%n in seen or i in seen or j in seen:
continue
seen.add((i-1)%n)
seen.add(i)
seen.add(j)
good_teleports.append((i,j))
if True:
print('chose {} teleports out of {}, found in {:03f} seconds'.format(len(good_teleports), len(teleports), time() - start))
# apply teleports
m = len(good_teleports)
for x in range(m):
i, j = good_teleports[x]
for y in range(x+1, m):
ii, jj = good_teleports[y]
assert ii != i and ii != j and jj != i and jj != j
if i < ii and ii < j:
ii -= 1
elif j < ii and ii < i:
ii += 1
if i < jj and jj < j:
jj -= 1
elif j < jj and jj < i:
jj += 1
good_teleports[y] = (ii, jj)
new_tour = [None] * n
for x in range(n):
y = x
for i, j in good_teleports:
assert i != j
if i < j:
if y == i:
y = j
elif i < y and y <= j:
y -= 1
else:
if y == i:
y = j + 1
elif j < y and y < i:
y += 1
new_tour[y] = tour[x]
return new_tour
def tsp(xys):
seed(123)
n = len(xys)
assert len(set(xys)) == n
G = neighbors_graph(max(2, min(n-1, 1000000//n)), xys)
#print('xys',xys)
mst = kruskal(G, xys)
#print('mst',mst)
tour = traverse_tree(xys, mst)
#print('tour',tour)
try:
stderr.write('cost {}\n'.format(cost(xys,tour)))
while True:
for i in range(2, 7):
tour = swap_segments(30, i, G, xys, tour)
stderr.write('cost {}\n'.format(cost(xys,tour)))
tour = swap_edges2(float('inf'), G, xys, tour)
stderr.write('cost {}\n'.format(cost(xys,tour)))
tour = swap_edges4(float('inf'), G, xys, tour)
stderr.write('cost {}\n'.format(cost(xys,tour)))
tour = teleport(float('inf'), G, xys, tour)
stderr.write('cost {}\n'.format(cost(xys,tour)))
tour = local_exact(30, 10, xys, tour)
stderr.write('cost {}\n'.format(cost(xys,tour)))
except KeyboardInterrupt:
pass
return tour
def solveIt(inputData):
# parse the input
lines = inputData.split('\n')
nodeCount = int(lines[0])
points = []
for i in range(1, nodeCount+1):
line = lines[i]
parts = line.split()
points.append((float(parts[0]), float(parts[1])))
# build a trivial solution
# visit the nodes in the order they appear in the file
solution = tsp(points)
# calculate the length of the tour
obj = cost(points, solution)
with open('after','w') as tmp:
for i in solution:
tmp.write('{} {}\n'.format(points[i][0], points[i][1]))
tmp.write('{} {}\n'.format(points[solution[0]][0], points[solution[0]][1]))
# prepare the solution in the specified output format
outputData = str(obj) + ' ' + str(0) + '\n'
outputData += ' '.join(map(str, solution))
return outputData
import sys
if __name__ == '__main__':
if len(sys.argv) > 1:
with open('/home/rg/temp/solver.log','a') as log:
log.write('{:010.3f} {}\n'.format(time(), sys.argv[1]))
fileLocation = sys.argv[1].strip()
inputDataFile = open(fileLocation, 'r')
inputData = ''.join(inputDataFile.readlines())
inputDataFile.close()
print(solveIt(inputData))
else:
print('This test requires an input file. Please select one from the data directory. (i.e. python solver.py ./data/tsp_51_1)')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment