Skip to content

Instantly share code, notes, and snippets.

@hkurokawa
Created June 6, 2021 03:16
Show Gist options
  • Save hkurokawa/a9c88cdaa4460bafc0730ccf6e6c9d6c to your computer and use it in GitHub Desktop.
Save hkurokawa/a9c88cdaa4460bafc0730ccf6e6c9d6c to your computer and use it in GitHub Desktop.
Traveling Salesman Problem Solver
#!/usr/bin/env python3
import math
import random
import sys
import time
from common import print_tour, read_input
TIME_LIMIT_TWO_OPT = 10 * 60 # 10 min.
TIME_LIMIT_THREE_OPT = 10 * 60 # 10 min.
TIME_LIMIT_ALL = 100 * 60 # 100 min.
def distance(city1, city2):
return math.sqrt((city1[0] - city2[0]) ** 2 + (city1[1] - city2[1]) ** 2)
def total_distance(tour, cities):
total = 0
for i in range(len(tour) - 1):
total += distance(cities[tour[i]], cities[tour[i + 1]])
total += distance(cities[tour[-1]], cities[tour[0]])
return total
def build_greedy_tour(n, dist, start):
current_city = start
unvisited_cities = set(range(0, n))
unvisited_cities.remove(current_city)
tour = [current_city]
while unvisited_cities:
next_city = min(unvisited_cities,
key=lambda city: dist[current_city][city])
unvisited_cities.remove(next_city)
tour.append(next_city)
current_city = next_city
return tour
def swap_route(tour, i, j):
"""
Swap the points at i and j in tour. It also reverses the path between i and j.
"""
if i > j:
return swap_route(tour, j, i)
tour[i + 1: j + 1] = reversed(tour[i + 1: j + 1])
def distance_cities(tour, i, j, cities):
if j >= len(tour):
j -= len(tour)
return distance(cities[tour[i]], cities[tour[j]])
def delta_swapped(tour, i, j, cities):
res = 0
res -= distance_cities(tour, i, i + 1, cities)
res -= distance_cities(tour, j, j + 1, cities)
res += distance_cities(tour, i, j, cities)
res += distance_cities(tour, i + 1, j + 1, cities)
return res
def two_opt(n, cities, tour, dist):
start = time.monotonic()
while True:
updated = False
for i in range(n - 1):
for j in range(i + 2, n):
delta = delta_swapped(tour, i, j, cities)
if delta < -1e-9:
swap_route(tour, i, j)
updated = True
dist += delta
break
if updated:
break
if not updated or time.monotonic() - start > TIME_LIMIT_TWO_OPT:
break
return tour, dist
def reverse_segment_if_better(tour, i, j, k, cities):
a, b, c, d, e, f = tour[i - 1], tour[i], tour[j - 1], tour[j], tour[k - 1], tour[k % len(tour)]
d0 = distance(cities[a], cities[b]) + distance(cities[c], cities[d]) + distance(cities[e], cities[f])
d1 = distance(cities[a], cities[c]) + distance(cities[b], cities[d]) + distance(cities[e], cities[f])
d2 = distance(cities[a], cities[b]) + distance(cities[c], cities[e]) + distance(cities[d], cities[f])
d3 = distance(cities[a], cities[d]) + distance(cities[e], cities[b]) + distance(cities[c], cities[f])
d4 = distance(cities[f], cities[b]) + distance(cities[c], cities[d]) + distance(cities[e], cities[a])
if d0 > d1:
tour[i:j] = reversed(tour[i:j])
return -d0 + d1
elif d0 > d2:
tour[j:k] = reversed(tour[j:k])
return -d0 + d2
elif d0 > d4:
tour[i:k] = reversed(tour[i:k])
return -d0 + d4
elif d0 > d3:
tmp = tour[j:k] + tour[i:j]
tour[i:k] = tmp
return -d0 + d3
return 0
def all_segment(n):
return ((i, j, k) for i in range(n) for j in range(i + 1, n) for k in range(j + 1, n + (i > 0)))
def three_opt(n, cities, tour):
start = time.monotonic()
while True:
delta = 0
for a, b, c in all_segment(n):
delta += reverse_segment_if_better(tour, a, b, c, cities)
if delta >= 0 or time.monotonic() - start > TIME_LIMIT_THREE_OPT:
break
return tour
def solve(cities):
n = len(cities)
dist = [[0] * n for i in range(n)]
for i in range(n):
for j in range(i, n):
dist[i][j] = dist[j][i] = distance(cities[i], cities[j])
start = time.monotonic()
best = None
best_total_dist = -1
for start_node in random.sample(range(n), n):
tour = build_greedy_tour(n, dist, start_node)
tour = three_opt(n, cities, tour)
total_dist = total_distance(tour, cities)
if best_total_dist < 0 or best_total_dist > total_dist:
best_total_dist = total_dist
best = tour
if time.monotonic() - start > TIME_LIMIT_ALL:
break
return best
def eprint(msg):
print(msg, file=sys.stderr)
def main(input_file):
cities = read_input(input_file)
tour = solve(cities)
print_tour(tour)
eprint("Distance for {}: {}".format(input_file, total_distance(tour, cities)))
if __name__ == '__main__':
assert len(sys.argv) > 1
main(sys.argv[1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment