Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
class VizdoomEnvMultiplayer(VizdoomEnv):
def __init__(self, level, player_id, num_players, skip_frames, level_map='map01'):
super().__init__(level, skip_frames=skip_frames, level_map=level_map)
self.player_id = player_id
self.num_players = num_players
self.timestep = 0
self.update_state = True
def _is_server(self):
return self.player_id == 0
def _ensure_initialized(self, mode='algo'):
if self.initialized:
# Doom env already initialized!
return = DoomGame()
# Setting an invalid level map will cause the game to freeze silently, 2**32-1))
if mode == 'algo':
if self._is_server():
# This process will function as a host for a multiplayer game with this many players (including the host).
# It will wait for other machines to connect using the -join parameter and then
# start the game when everyone is connected.
f'-host {self.num_players} '
'-deathmatch ' # Deathmatch rules are used for the game.
'+timelimit 10.0 ' # The game (episode) will end after this many minutes have elapsed.
'+sv_forcerespawn 1 ' # Players will respawn automatically after they die.
'+sv_noautoaim 1 ' # Autoaim is disabled for all players.
'+sv_respawnprotect 1 ' # Players will be invulnerable for two second after spawning.
'+sv_spawnfarthest 1 ' # Players will be spawned as far as possible from any other players.
'+sv_nocrouch 1 ' # Disables crouching.
'+viz_respawn_delay 1 ' # Sets delay between respanws (in seconds).
'+viz_nocheat 1', # Disables depth and labels buffer and the ability to use commands
# that could interfere with multiplayer game.
# Name your agent and select color
# colors:
# 0 - green, 1 - gray, 2 - brown, 3 - red, 4 - light gray, 5 - light brown, 6 - light red, 7 - light blue'+name Host +colorset 0')
# TODO: port, name
# Join existing game.'-join') # Connect to a host for a multiplayer game.
# Name your agent and select color
# colors:
# 0 - green, 1 - gray, 2 - brown, 3 - red, 4 - light gray, 5 - light brown, 6 - light red, 7 - light blue'+name AI +colorset 0')
self.initialized = True
def reset(self, mode='algo'):
self.timestep = 0
self.update_state = True
self.state =
img = self.state.screen_buffer
return np.transpose(img, (1, 2, 0))
def step(self, action):
info = {'num_frames': self.skip_frames}
# convert action to vizdoom action space (one hot)
act = np.zeros(self.action_space.n)
act[action] = 1
act = np.uint8(act)
act = act.tolist()
reward = 0, self.update_state)
reward +=
self.timestep += 1
if not self.update_state:
return None, None, None, None
state =
done =
if not done:
observation = np.transpose(state.screen_buffer, (1, 2, 0))
game_variables = self._game_variables_dict(state)
observation = np.zeros(self.observation_space.shape, dtype=np.uint8)
return observation, reward, done, info
def safe_get(q, timeout=1e6, msg='Queue timeout'):
"""Using queue.get() with timeout is necessary, otherwise KeyboardInterrupt is not handled."""
while True:
return q.get(timeout=timeout)
except Empty:
class TaskType(Enum):
class MultiAgentEnvWorker:
def __init__(self, player_id, num_players, make_env_func):
self.player_id = player_id
self.num_players = num_players
self.make_env_func = make_env_func
self.task_queue, self.result_queue = JoinableQueue(), JoinableQueue()
self.process = Process(target=self.start, daemon=True)
def _init(self):'Initializing env for player %d...', self.player_id)
env = self.make_env_func(player_id=self.player_id, num_players=self.num_players)
return env
def _terminate(self, env):'Stop env for player %d...', self.player_id)
env.close()'Env with player %d terminated!', self.player_id)
def _get_info(env):
"""Specific to custom VizDoom environments."""
info = {}
if hasattr(env.unwrapped, 'get_info_all'):
info = env.unwrapped.get_info_all() # info for the new episode
return info
def start(self):
env = None
while True:
action, task_type = safe_get(self.task_queue)
if task_type == TaskType.INIT:
env = self._init()
if task_type == TaskType.TERMINATE:
if task_type == TaskType.RESET:
results = env.reset()
elif task_type == TaskType.INFO:
results = self._get_info(env)
elif task_type == TaskType.STEP or task_type == TaskType.STEP_UPDATE:
# collect obs, reward, done, and info
env.unwrapped.update_state = task_type == TaskType.STEP_UPDATE
results = env.step(action)
raise Exception(f'Unknown task type {task_type}')
class VizdoomMultiAgentEnv:
def __init__(self, num_players, make_env_func, env_config):
self.num_players = num_players
self.skip_frames = 4
env = make_env_func(player_id=-1, num_players=num_players) # temporary
self.action_space = env.action_space
self.observation_space = env.observation_space
self.workers = [MultiAgentEnvWorker(i, num_players, make_env_func) for i in range(num_players)]
for worker in self.workers:
worker.task_queue.put((None, TaskType.INIT))
time.sleep(0.1) # just in case
for worker in self.workers:
worker.task_queue.join()'%d agent workers initialized!', len(self.workers))
def await_tasks(self, data, task_type, timeout=None):
Task result is always a tuple of dicts, e.g.:
{'0': 0th_agent_obs, '1': 1st_agent_obs, ... ,
{'0': 0th_agent_reward, '1': 1st_agent_obs, ... ,
If your "task" returns only one result per agent (e.g. reset() returns only the observation),
the result will be a tuple of lenght 1. It is a responsibility of the caller to index appropriately.
if data is None:
data = {str(i): None for i in range(self.num_players)}
assert len(data) == self.num_players
for i, worker in enumerate(self.workers[1:], start=1):
worker.task_queue.put((data[str(i)], task_type))
self.workers[0].task_queue.put((data[str(0)], task_type))
result_dicts = None
for i, worker in enumerate(self.workers):
results = safe_get(
timeout=0.02 if timeout is None else timeout,
msg=f'Takes a surprisingly long time to process task {task_type}, retry...',
if not isinstance(results, (tuple, list)):
results = [results]
if result_dicts is None:
result_dicts = tuple({} for _ in results)
for j, r in enumerate(results):
result_dicts[j][str(i)] = r
return result_dicts
def info(self):
info = self.await_tasks(None, TaskType.INFO)[0]
return info
def reset(self):
observation = self.await_tasks(None, TaskType.RESET)[0]
return observation
def step(self, actions):
for frame in range(self.skip_frames - 1):
self.await_tasks(actions, TaskType.STEP)
obs, rew, dones, infos = self.await_tasks(actions, TaskType.STEP_UPDATE)
dones['__all__'] = all(dones.values())
return obs, rew, dones, infos
def close(self):'Stopping multi env...')
for worker in self.workers:
worker.task_queue.put((None, TaskType.TERMINATE))
for worker in self.workers:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment