Skip to content

Instantly share code, notes, and snippets.

@tweakimp
Last active December 17, 2018 21:12
Show Gist options
  • Save tweakimp/ac33581b1cca9c1da213ea52e88b00f0 to your computer and use it in GitHub Desktop.
Save tweakimp/ac33581b1cca9c1da213ea52e88b00f0 to your computer and use it in GitHub Desktop.
A* pathfinding in starcraft
"""
A* pathfinding starcraft in python
How can I optimize create_path?
"""
import heapq
import random
import time
import sc2
from sc2 import Race, maps, run_game
from sc2.ids.ability_id import LIFT
from sc2.ids.unit_typeid import COMMANDCENTER
from sc2.player import Bot
from sc2.position import Point2, Point3
def terrain_to_z_height(h):
return round(-100 + 200 * h / 255)
class PriorityQueue:
def __init__(self):
self.elements = []
def empty(self):
return len(self.elements) == 0
def put(self, item, priority):
heapq.heappush(self.elements, (priority, item))
def get(self):
return heapq.heappop(self.elements)[1]
def heuristic(a, b):
(x1, y1) = a
(x2, y2) = b
return abs(x1 - x2) + abs(y1 - y2)
def cost(a, b):
dx = a.x - b.x
dy = a.y - b.y
return 1 if dx == 0 or dy == 0 else 1.414
class astar(sc2.BotAI):
def __init__(self, dummy=True):
self.actions = []
self.iter = None
self.directions = set()
self.dummy = dummy
self.random_points = []
self.path = []
self.all_pathable_points = None
async def on_step(self, iteration):
for base in self.units(COMMANDCENTER):
self.actions.append(base(LIFT))
await self.do_actions(self.actions)
self.actions = []
# dummy client just to lift his cc to clear all paths
if self.dummy:
return
self.iter = iteration
# lifted
if not self.units(COMMANDCENTER) and self.iter >= 30:
# wait until lifted, then calculate all pathable points
if not self.all_pathable_points:
self._game_info = await self._client.get_game_info()
pathing_grid = self._game_info.pathing_grid
self.all_pathable_points = {
Point2((x, y))
for x in range(pathing_grid.width)
for y in range(pathing_grid.height)
if pathing_grid[Point2((x, y))] == 0
}
start_time = time.time()
self.create_directions()
print("direction calculation time", time.time() - start_time)
start_time = time.time()
# self.random_points = list(random.sample(self.all_pathable_points, 2))
# test with points that are far away from each other
self.random_points = [Point2((123, 19)), Point2((21, 143))]
self.create_path(self.random_points[0], self.random_points[1])
await self.draw()
path_time = time.time() - start_time
# print("path calculation time", round(path_time, 2))
# print("path length", len(self.path))
if self.path:
print("time per node", round(path_time / len(self.path), 10))
await self._client.send_debug()
# time.sleep(10)
# kill clients at the same iteration every time to profile the same task
if self.iter == 40:
await self._client.leave()
def create_directions(self):
self.directions = {
(point, near_point)
for point in self.all_pathable_points
for near_point in [
Point2((point.x + a, point.y + b))
for a, b in {(-1, -1), (-1, -0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, -0), (1, 1)}
if Point2((point.x + a, point.y + b)) in self.all_pathable_points
]
}
def create_path(self, start, end):
frontier = PriorityQueue()
frontier.put(start, 0)
came_from = {}
cost_so_far = {}
came_from[start] = None
cost_so_far[start] = 0
while not frontier.empty():
current = frontier.get()
if current == end:
break
neighbors = [
Point2((current.position.x + a, current.position.y + b))
for a, b in {(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)}
if Point2((current.position.x + a, current.position.y + b)) in self.all_pathable_points
]
for next_node in neighbors:
new_cost = cost_so_far[current] + cost(current, next_node)
if next_node not in cost_so_far or new_cost < cost_so_far[next_node]:
cost_so_far[next_node] = new_cost
priority = new_cost + heuristic(end, next_node)
frontier.put(next_node, priority)
came_from[next_node] = current
current = end
self.path = []
while current != start:
self.path.append(current)
current = came_from[current]
self.path.append(start)
self.path.reverse()
async def draw(self):
for point in self.path:
height = terrain_to_z_height(self.get_terrain_height(point))
location3d1 = Point3((point.x - 0.1, point.y - 0.1, height + 0.1))
location3d2 = Point3((point.x + 0.1, point.y + 0.1, height + 0.1))
self._client.debug_box_out(location3d1, location3d2, Point3((255, 255, 255)))
run_game(
maps.get("KairosJunctionLE"),
[Bot(Race.Terran, astar(dummy=False)), Bot(Race.Terran, astar(dummy=True))],
realtime=False,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment