Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
class VizdoomEnvMultiplayer(VizdoomEnv):
def __init__(self, level, player_id, num_players, skip_frames, level_map='map01'):
super().__init__(level, skip_frames=skip_frames, level_map=level_map)
self.player_id = player_id
self.num_players = num_players
self.timestep = 0
self.update_state = True
def _is_server(self):
return self.player_id == 0
def _ensure_initialized(self, mode='algo'):
if self.initialized:
# Doom env already initialized!
return
self.game = DoomGame()
self.game.load_config(self.config_path)
self.game.set_screen_resolution(self.screen_resolution)
# Setting an invalid level map will cause the game to freeze silently
self.game.set_doom_map(self.level_map)
self.game.set_seed(self.rng.random_integers(0, 2**32-1))
if mode == 'algo':
self.game.set_window_visible(False)
if self._is_server():
# This process will function as a host for a multiplayer game with this many players (including the host).
# It will wait for other machines to connect using the -join parameter and then
# start the game when everyone is connected.
self.game.add_game_args(
f'-host {self.num_players} '
'-deathmatch ' # Deathmatch rules are used for the game.
'+timelimit 10.0 ' # The game (episode) will end after this many minutes have elapsed.
'+sv_forcerespawn 1 ' # Players will respawn automatically after they die.
'+sv_noautoaim 1 ' # Autoaim is disabled for all players.
'+sv_respawnprotect 1 ' # Players will be invulnerable for two second after spawning.
'+sv_spawnfarthest 1 ' # Players will be spawned as far as possible from any other players.
'+sv_nocrouch 1 ' # Disables crouching.
'+viz_respawn_delay 1 ' # Sets delay between respanws (in seconds).
'+viz_nocheat 1', # Disables depth and labels buffer and the ability to use commands
# that could interfere with multiplayer game.
)
# Name your agent and select color
# colors:
# 0 - green, 1 - gray, 2 - brown, 3 - red, 4 - light gray, 5 - light brown, 6 - light red, 7 - light blue
self.game.add_game_args('+name Host +colorset 0')
else:
# TODO: port, name
# Join existing game.
self.game.add_game_args('-join 127.0.0.1') # Connect to a host for a multiplayer game.
# Name your agent and select color
# colors:
# 0 - green, 1 - gray, 2 - brown, 3 - red, 4 - light gray, 5 - light brown, 6 - light red, 7 - light blue
self.game.add_game_args('+name AI +colorset 0')
self.game.set_mode(Mode.PLAYER)
self.game.init()
self.initialized = True
def reset(self, mode='algo'):
self._ensure_initialized(mode)
self.timestep = 0
self.update_state = True
self.game.new_episode()
self.state = self.game.get_state()
img = self.state.screen_buffer
return np.transpose(img, (1, 2, 0))
def step(self, action):
self._ensure_initialized()
info = {'num_frames': self.skip_frames}
# convert action to vizdoom action space (one hot)
act = np.zeros(self.action_space.n)
act[action] = 1
act = np.uint8(act)
act = act.tolist()
reward = 0
self.game.set_action(act)
self.game.advance_action(1, self.update_state)
reward += self.game.get_last_reward()
self.timestep += 1
if not self.update_state:
return None, None, None, None
state = self.game.get_state()
done = self.game.is_episode_finished()
if not done:
observation = np.transpose(state.screen_buffer, (1, 2, 0))
game_variables = self._game_variables_dict(state)
info.update(self.get_info(game_variables))
else:
observation = np.zeros(self.observation_space.shape, dtype=np.uint8)
return observation, reward, done, info
def safe_get(q, timeout=1e6, msg='Queue timeout'):
"""Using queue.get() with timeout is necessary, otherwise KeyboardInterrupt is not handled."""
while True:
try:
return q.get(timeout=timeout)
except Empty:
log.exception(msg)
class TaskType(Enum):
INIT, TERMINATE, RESET, STEP, STEP_UPDATE, INFO = range(6)
class MultiAgentEnvWorker:
def __init__(self, player_id, num_players, make_env_func):
self.player_id = player_id
self.num_players = num_players
self.make_env_func = make_env_func
self.task_queue, self.result_queue = JoinableQueue(), JoinableQueue()
self.process = Process(target=self.start, daemon=True)
self.process.start()
def _init(self):
log.info('Initializing env for player %d...', self.player_id)
env = self.make_env_func(player_id=self.player_id, num_players=self.num_players)
env.seed(self.player_id)
return env
def _terminate(self, env):
log.info('Stop env for player %d...', self.player_id)
env.close()
log.info('Env with player %d terminated!', self.player_id)
@staticmethod
def _get_info(env):
"""Specific to custom VizDoom environments."""
info = {}
if hasattr(env.unwrapped, 'get_info_all'):
info = env.unwrapped.get_info_all() # info for the new episode
return info
def start(self):
env = None
while True:
action, task_type = safe_get(self.task_queue)
if task_type == TaskType.INIT:
env = self._init()
self.task_queue.task_done()
continue
if task_type == TaskType.TERMINATE:
self._terminate(env)
self.task_queue.task_done()
break
if task_type == TaskType.RESET:
results = env.reset()
elif task_type == TaskType.INFO:
results = self._get_info(env)
elif task_type == TaskType.STEP or task_type == TaskType.STEP_UPDATE:
# collect obs, reward, done, and info
env.unwrapped.update_state = task_type == TaskType.STEP_UPDATE
results = env.step(action)
else:
raise Exception(f'Unknown task type {task_type}')
self.result_queue.put(results)
self.task_queue.task_done()
class VizdoomMultiAgentEnv:
def __init__(self, num_players, make_env_func, env_config):
self.num_players = num_players
self.skip_frames = 4
env = make_env_func(player_id=-1, num_players=num_players) # temporary
self.action_space = env.action_space
self.observation_space = env.observation_space
env.close()
self.workers = [MultiAgentEnvWorker(i, num_players, make_env_func) for i in range(num_players)]
for worker in self.workers:
worker.task_queue.put((None, TaskType.INIT))
time.sleep(0.1) # just in case
for worker in self.workers:
worker.task_queue.join()
log.info('%d agent workers initialized!', len(self.workers))
def await_tasks(self, data, task_type, timeout=None):
"""
Task result is always a tuple of dicts, e.g.:
(
{'0': 0th_agent_obs, '1': 1st_agent_obs, ... ,
{'0': 0th_agent_reward, '1': 1st_agent_obs, ... ,
...
)
If your "task" returns only one result per agent (e.g. reset() returns only the observation),
the result will be a tuple of lenght 1. It is a responsibility of the caller to index appropriately.
"""
if data is None:
data = {str(i): None for i in range(self.num_players)}
assert len(data) == self.num_players
for i, worker in enumerate(self.workers[1:], start=1):
worker.task_queue.put((data[str(i)], task_type))
self.workers[0].task_queue.put((data[str(0)], task_type))
result_dicts = None
for i, worker in enumerate(self.workers):
worker.task_queue.join()
results = safe_get(
worker.result_queue,
timeout=0.02 if timeout is None else timeout,
msg=f'Takes a surprisingly long time to process task {task_type}, retry...',
)
worker.result_queue.task_done()
if not isinstance(results, (tuple, list)):
results = [results]
if result_dicts is None:
result_dicts = tuple({} for _ in results)
for j, r in enumerate(results):
result_dicts[j][str(i)] = r
return result_dicts
def info(self):
info = self.await_tasks(None, TaskType.INFO)[0]
return info
def reset(self):
observation = self.await_tasks(None, TaskType.RESET)[0]
return observation
def step(self, actions):
for frame in range(self.skip_frames - 1):
self.await_tasks(actions, TaskType.STEP)
obs, rew, dones, infos = self.await_tasks(actions, TaskType.STEP_UPDATE)
dones['__all__'] = all(dones.values())
return obs, rew, dones, infos
def close(self):
log.info('Stopping multi env...')
for worker in self.workers:
worker.task_queue.put((None, TaskType.TERMINATE))
time.sleep(0.1)
for worker in self.workers:
worker.process.join()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment