Skip to content

Instantly share code, notes, and snippets.

@pekaalto
Last active October 26, 2017 06:40
Show Gist options
  • Save pekaalto/2d465fea6db3d58b0de75709dea61623 to your computer and use it in GitHub Desktop.
Save pekaalto/2d465fea6db3d58b0de75709dea61623 to your computer and use it in GitHub Desktop.
Just editing baselines subprocvecenv for sc2
"""
Almost direct copy from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/subproc_vec_env.py
"""
from multiprocessing import Process, Pipe
from pysc2.env import sc2_env, available_actions_printer
def worker(remote, env_fn_wrapper):
"""
Handling the:
action -> [action] and [timestep] -> timestep
single-player conversions here
"""
env = env_fn_wrapper.x()
while True:
cmd, action = remote.recv()
if cmd == 'step':
timesteps = env.step([action])
assert len(timesteps) == 1
remote.send(timesteps[0])
elif cmd == 'reset':
timesteps = env.reset()
assert len(timesteps) == 1
remote.send(timesteps[0])
elif cmd == 'close':
remote.close()
break
else:
raise NotImplementedError
class SC2VecEnv:
def __init__(self, env_fns):
n_envs = len(env_fns)
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(n_envs)])
self.ps = [Process(target=worker, args=(work_remote, CloudpickleWrapper(env_fn)))
for (work_remote, env_fn) in zip(self.work_remotes, env_fns)]
for p in self.ps:
p.start()
self.n_envs = n_envs
def _step_or_reset(self, command, actions=None):
actions = actions or [None] * self.n_envs
for remote, action in zip(self.remotes, actions):
remote.send((command, action))
timesteps = [remote.recv() for remote in self.remotes]
return timesteps
def step(self, actions):
return self._step_or_reset("step", actions)
def reset(self):
return self._step_or_reset("reset", None)
def close(self):
for remote in self.remotes:
remote.send(('close', None))
for p in self.ps:
p.join()
def reset_done_envs(self):
pass
def make_sc2env(**kwargs):
env = sc2_env.SC2Env(**kwargs)
# env = available_actions_printer.AvailableActionsPrinter(env)
return env
# This is the original baselines one
class CloudpickleWrapper(object):
"""
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
"""
def __init__(self, x):
self.x = x
def __getstate__(self):
import cloudpickle
return cloudpickle.dumps(self.x)
def __setstate__(self, ob):
import pickle
self.x = pickle.loads(ob)
"""
# To use do something like this
from functools import partial
env_args = dict(
map_name=FLAGS.map_name,
step_mul=FLAGS.step_mul,
game_steps_per_episode=0,
screen_size_px=(FLAGS.resolution,) * 2,
minimap_size_px=(FLAGS.resolution,) * 2,
visualize=FLAGS.render
)
envs = SC2VecEnv((partial(make_sc2env, **env_args),) * FLAGS.n_envs)
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment