Skip to content

Instantly share code, notes, and snippets.

@bjuergens
Created Jan 21, 2021
Embed
What would you like to do?
quicktest for combination of procgen envs and different multiprocessinghandlers
import gym
import tap
import gc
import logging
from multiprocessing import Pool
import os
from dask.distributed import Client, LocalCluster
class Params(tap.Tap):
n: int = 100 # now often?
info: int = 0 # now often to print messages?
gc_force: bool = False # should gc be called after ever work?
steps: int = 0 # should steps be done in the env? If so how many?
env: str = "procgen:procgen-heist-v0"
ph: str = "sequence" # sequence or mp or dask
def work(steps, gc_force, info, env_id, i):
if info:
if i % info == 0:
print(str(i))
env = gym.make(env_id,
distribution_mode="memory",
use_monochrome_assets=False,
restrict_themes=True,
use_backgrounds=False)
if steps:
env.reset()
for _ in range(steps):
env.step(env.action_space.sample())
if gc_force:
logging.info("calling gc.collect()...")
gc.collect()
def main(args):
logging.info("starting " + str(args.n) + " iterations")
if args.ph == "sequence":
logging.info("starting singlethreaded...")
for i in range(args.n):
work(args.steps, args.gc_force, args.info, args.env, i)
elif args.ph == "mp":
logging.info("starting with mp...")
with Pool(os.cpu_count()) as pool:
params = []
for i in range(args.n):
toup = (args.steps, args.gc_force, args.info, args.env, i)
params.append(toup)
pool.starmap(work, params)
elif args.ph == "dask":
cluster = LocalCluster(processes=True, asynchronous=False, threads_per_worker=1, n_workers=args.n,
memory_pause_fraction=False, interface="lo")
client = Client(cluster)
logging.info("Dask dashboard available at port: " + str(client.scheduler_info()["services"]["dashboard"]))
params = [[], [], [], [], []]
for i in range(args.n):
params[0].append(args.steps)
params[1].append(args.gc_force)
params[2].append(args.info)
params[3].append(args.env)
params[4].append(i)
client.gather(client.map(work, *params))
else:
raise RuntimeError("unknown value for ph: " + str(args.ph))
logging.info("done")
if __name__ == "__main__":
main(Params(underscores_to_dashes=True).parse_args())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment