Skip to content

Instantly share code, notes, and snippets.

@bjuergens
Created January 21, 2021 14:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bjuergens/baceb62a68c113a3ab770184629f2cb1 to your computer and use it in GitHub Desktop.
Save bjuergens/baceb62a68c113a3ab770184629f2cb1 to your computer and use it in GitHub Desktop.
quicktest for combination of procgen envs and different multiprocessinghandlers
import gym
import tap
import gc
import logging
from multiprocessing import Pool
import os
from dask.distributed import Client, LocalCluster
class Params(tap.Tap):
n: int = 100 # now often?
info: int = 0 # now often to print messages?
gc_force: bool = False # should gc be called after ever work?
steps: int = 0 # should steps be done in the env? If so how many?
env: str = "procgen:procgen-heist-v0"
ph: str = "sequence" # sequence or mp or dask
def work(steps, gc_force, info, env_id, i):
if info:
if i % info == 0:
print(str(i))
env = gym.make(env_id,
distribution_mode="memory",
use_monochrome_assets=False,
restrict_themes=True,
use_backgrounds=False)
if steps:
env.reset()
for _ in range(steps):
env.step(env.action_space.sample())
if gc_force:
logging.info("calling gc.collect()...")
gc.collect()
def main(args):
logging.info("starting " + str(args.n) + " iterations")
if args.ph == "sequence":
logging.info("starting singlethreaded...")
for i in range(args.n):
work(args.steps, args.gc_force, args.info, args.env, i)
elif args.ph == "mp":
logging.info("starting with mp...")
with Pool(os.cpu_count()) as pool:
params = []
for i in range(args.n):
toup = (args.steps, args.gc_force, args.info, args.env, i)
params.append(toup)
pool.starmap(work, params)
elif args.ph == "dask":
cluster = LocalCluster(processes=True, asynchronous=False, threads_per_worker=1, n_workers=args.n,
memory_pause_fraction=False, interface="lo")
client = Client(cluster)
logging.info("Dask dashboard available at port: " + str(client.scheduler_info()["services"]["dashboard"]))
params = [[], [], [], [], []]
for i in range(args.n):
params[0].append(args.steps)
params[1].append(args.gc_force)
params[2].append(args.info)
params[3].append(args.env)
params[4].append(i)
client.gather(client.map(work, *params))
else:
raise RuntimeError("unknown value for ph: " + str(args.ph))
logging.info("done")
if __name__ == "__main__":
main(Params(underscores_to_dashes=True).parse_args())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment