Skip to content

Instantly share code, notes, and snippets.

@alper111
Last active June 8, 2024 12:37
Show Gist options
  • Save alper111/1feadf9e21cb2ef1548284bbe7d97ba1 to your computer and use it in GitHub Desktop.
Save alper111/1feadf9e21cb2ef1548284bbe7d97ba1 to your computer and use it in GitHub Desktop.
A Sokoban game with crates having MNIST digits on top of them
import torchvision
import pygame
import numpy as np
import gymnasium as gym
class MNISTSokoban(gym.Env):
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 20}
def __init__(self, map_file: str = None, size: tuple[int, int] = None, max_crates: int = 5, max_steps=200,
render_mode: str = None, rand_digits: bool = False, rand_agent: bool = False, rand_x: bool = False):
assert map_file is not None or size is not None, "Either map_file or size must be provided"
self._map_file = map_file
self._size = size
self._max_crates = max_crates
self._max_steps = max_steps
self.render_mode = render_mode
self.rand_digits = rand_digits
self.rand_agent = rand_agent
self.rand_x = rand_x
self._shape = None
self._window = None
self._clock = None
self._map = None
self._digit_idx = None
self._agent_loc = None
self._delta = np.array([[0, 1], [-1, 0], [0, -1], [1, 0]])
self._t = 0
dataset = torchvision.datasets.MNIST(root="data", train=True, download=True)
self._data = dataset.data.numpy()
_labels = dataset.targets.numpy()
self._labels = {i: np.where(_labels == i)[0] for i in range(10)}
self.action_space = gym.spaces.Discrete(4)
def reset(self) -> tuple[np.ndarray, dict]:
self._init_agent_mark()
self._init_x_mark()
self._init_digits()
if self._map_file is not None:
self._map = self.read_map(self._map_file)
else:
self._map = self.generate_map(self._size, max_crates=self._max_crates)
self._shape = (len(self._map), len(self._map[0]))
shape = (self._shape[0]*32, self._shape[1]*32)
self.observation_space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8)
ax, ay = -1, -1
for i in range(self._shape[0]):
for j in range(self._shape[1]):
if self._map[i][j][1] == "@":
ax, ay = i, j
break
self._agent_loc = np.array([ax, ay])
self._t = 0
obs = self._get_obs()
info = self._get_info()
if self.render_mode == "human":
self._render_frame()
return obs, info
def step(self, action: int) -> tuple[np.ndarray, float, bool, bool, dict]:
assert self._map is not None, "You must call reset() before calling step()"
pos = self._agent_loc
next_pos = pos + self._delta[action]
curr_bg, curr_tile = self._map[pos[0]][pos[1]]
next_bg, next_tile = self._map[next_pos[0]][next_pos[1]]
# the next tile is empty
if next_tile == " ":
self._map[pos[0]][pos[1]] = (curr_bg, " ")
self._map[next_pos[0]][next_pos[1]] = (next_bg, "@")
self._agent_loc = next_pos
# the next tile is a wall
elif next_tile == "#":
pass
# the next tile contains a crate
else:
# check whether the crate can be pushed
further_pos = next_pos + self._delta[action]
further_bg, further_tile = self._map[further_pos[0]][further_pos[1]]
if further_tile == " ":
self._map[pos[0]][pos[1]] = (curr_bg, " ")
self._map[next_pos[0]][next_pos[1]] = (next_bg, "@")
self._map[further_pos[0]][further_pos[1]] = (further_bg, next_tile)
self._agent_loc = next_pos
self._t += 1
obs = self._get_obs()
info = self._get_info()
reward = self._get_reward()
terminated = reward > 1 - 1e-6
truncated = (self._t >= self._max_steps)
return obs, reward, terminated, truncated, info
def render(self):
if self.render_mode == "rgb_array":
return self._render_frame()
def _init_x_mark(self):
self._x_corners = [
np.random.randint(2, 9),
np.random.randint(2, 9),
np.random.randint(24, 31),
np.random.randint(24, 31),
np.random.randint(24, 31),
np.random.randint(2, 9),
np.random.randint(2, 9),
np.random.randint(24, 31)
]
def _init_agent_mark(self):
# random points for drawing the cross
self._a_corners = [
np.random.randint(13, 20),
np.random.randint(2, 9),
np.random.randint(2, 9),
np.random.randint(24, 31),
np.random.randint(24, 31),
np.random.randint(24, 31)
]
def _init_digits(self):
self._digit_idx = np.zeros(10, dtype=np.int64)
for i in self._labels:
self._digit_idx[i] = np.random.choice(self._labels[i])
def _render_frame(self):
canvas = pygame.Surface((self._shape[1]*32, self._shape[0]*32))
canvas.fill((30, 30, 30))
for i in range(self._shape[0]):
for j in range(self._shape[1]):
bg, tile = self._map[i][j]
if bg == "0":
if self.rand_digits:
digit_idx = np.random.choice(self._labels[0])
else:
digit_idx = self._digit_idx[0]
digit = self._data[digit_idx]
digit = np.stack([digit]*3, axis=-1)
digit = pygame.surfarray.make_surface(np.transpose(digit, (1, 0, 2)))
bg_tile = pygame.transform.scale(digit, (32, 32))
else:
bg_tile = pygame.Surface((32, 32))
bg_tile.fill((30, 30, 30))
canvas.blit(bg_tile, (j*32, i*32))
if tile == "#":
color = (80, 80, 80)
rect = pygame.Rect(j*32, i*32, 32, 32)
pygame.draw.rect(canvas, color, rect)
elif tile == "@":
if self.rand_agent:
self._init_agent_mark()
color = (255, 255, 255)
width = 4
pygame.draw.line(canvas, color,
(j*32+self._a_corners[0], i*32+self._a_corners[1]),
(j*32+self._a_corners[2], i*32+self._a_corners[3]),
width)
pygame.draw.line(canvas, color,
(j*32+self._a_corners[0], i*32+self._a_corners[1]),
(j*32+self._a_corners[4], i*32+self._a_corners[5]),
width)
pygame.draw.line(canvas, color,
(j*32+self._a_corners[2], i*32+self._a_corners[3]),
(j*32+self._a_corners[4], i*32+self._a_corners[5]),
width)
pygame.draw.circle(canvas, color,
(j*32+self._a_corners[0], i*32+self._a_corners[1]),
width//2)
pygame.draw.circle(canvas, color,
(j*32+self._a_corners[2], i*32+self._a_corners[3]),
width//2)
pygame.draw.circle(canvas, color,
(j*32+self._a_corners[4], i*32+self._a_corners[5]),
width//2)
elif tile != " ":
digit = int(self._map[i][j][1])
if self.rand_digits:
digit_idx = np.random.choice(self._labels[digit])
else:
digit_idx = self._digit_idx[digit]
digit = self._data[digit_idx]
digit = np.stack([digit]*3, axis=-1)
tile = pygame.surfarray.make_surface(np.transpose(digit, (1, 0, 2)))
# scale the tile to 32x32
tile = pygame.transform.scale(tile, (32, 32))
canvas.blit(tile, (j*32, i*32))
if bg == "0":
if self.rand_x:
self._init_x_mark()
color = (255, 255, 255)
width = 4
pygame.draw.line(canvas, color,
(j*32+self._x_corners[0], i*32+self._x_corners[1]),
(j*32+self._x_corners[2], i*32+self._x_corners[3]),
width)
pygame.draw.line(canvas, color,
(j*32+self._x_corners[4], i*32+self._x_corners[5]),
(j*32+self._x_corners[6], i*32+self._x_corners[7]),
width)
pygame.draw.circle(canvas, color,
(j*32+self._x_corners[0], i*32+self._x_corners[1]),
width//2)
pygame.draw.circle(canvas, color,
(j*32+self._x_corners[2], i*32+self._x_corners[3]),
width//2)
pygame.draw.circle(canvas, color,
(j*32+self._x_corners[4], i*32+self._x_corners[5]),
width//2)
pygame.draw.circle(canvas, color,
(j*32+self._x_corners[6], i*32+self._x_corners[7]),
width//2)
if self.render_mode == "human":
if self._window is None:
pygame.init()
pygame.display.init()
self._window = pygame.display.set_mode((self._shape[1]*32, self._shape[0]*32))
if self._clock is None:
self._clock = pygame.time.Clock()
self._window.blit(canvas, canvas.get_rect())
pygame.event.pump()
pygame.display.update()
self._clock.tick(self.metadata["render_fps"])
return np.transpose(pygame.surfarray.array3d(canvas)[:, :, 0], (1, 0))
def _get_obs(self) -> np.ndarray:
return self._render_frame()
def _get_info(self) -> dict:
return {"map": self._map}
def _get_reward(self) -> float:
n_crossed = 0
n_total = 0
for i in range(self._shape[0]):
for j in range(self._shape[1]):
bg, fg = self._map[i][j]
if bg == "0":
n_total += 1
if (fg != "#" and fg != " " and fg != "@"):
n_crossed += 1
return n_crossed / n_total
@property
def map(self) -> np.ndarray:
return self._map
@staticmethod
def read_map(map_file: str) -> list[list[str]]:
with open(map_file, "r") as f:
lines = f.readlines()
_map = []
for line in lines:
row = []
for x in line.strip():
if x == "0":
row.append(("0", " "))
else:
row.append((" ", x))
_map.append(row)
return _map
@staticmethod
def generate_map(size: tuple[int, int] = (10, 10), max_crates: int = 5) -> list[list[str]]:
ni, nj = size
assert ni >= 3 and nj >= 3, "The size of the map must be at least 3x3"
total_middle_tiles = (ni-4)*(nj-4)
assert (2*max_crates+1) <= total_middle_tiles, \
"The number of crates (together with their goals) must be less than the total non-edge empty tiles"
_map = [[(" ", " ") for _ in range(nj)] for _ in range(ni)]
for i in range(ni):
for j in range(nj):
if i == 0 or i == ni-1 or j == 0 or j == nj-1:
_map[i][j] = (" ", "#")
n = np.random.randint(1, max_crates+1)
digits = np.random.randint(1, 10, n)
locations = np.random.permutation((ni-4)*(nj-4))[:(2*n+1)]
for i, x_i in enumerate(digits):
di, dj = locations[i] // (nj-4) + 2, locations[i] % (nj-4) + 2
_map[di][dj] = (" ", str(x_i))
di, dj = locations[i+n] // (nj-4) + 2, locations[i+n] % (nj-4) + 2
_map[di][dj] = ("0", " ")
ax, ay = locations[-1] // (nj-4) + 2, locations[-1] % (nj-4) + 2
_map[ax][ay] = (" ", "@")
return _map
def example_map1():
map_lines = [
"#########\n",
"# 0 #\n",
"# 0 5 #\n",
"# 1 #\n",
"# @ 30 #\n",
"# #\n",
"#########"
]
with open("map1.txt", "w") as f:
f.writelines(map_lines)
if __name__ == "__main__":
# env = MNISTSokoban(map_file="map1.txt", max_crates=2, max_steps=200, render_mode="human")
env = MNISTSokoban(size=(7, 7), max_crates=3, max_steps=50, render_mode="human",
rand_digits=True, rand_agent=True, rand_x=True)
for _ in range(20):
env.reset()
done = False
while not done:
action = env.action_space.sample()
obs, rew, term, trun, info = env.step(action)
done = term or trun
@alper111
Copy link
Author

alper111 commented Jun 7, 2024

Screen.Recording.2024-06-07.at.5.23.26.PM.mov
Screen.Recording.2024-06-07.at.5.36.59.PM.mov
Screen.Recording.2024-06-07.at.5.48.02.PM.mov

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment