Skip to content

Instantly share code, notes, and snippets.

@ngoodger
Last active November 13, 2021 04:51
Show Gist options
  • Save ngoodger/a98b1071952ef3af2d192ec3e96098ca to your computer and use it in GitHub Desktop.
Save ngoodger/a98b1071952ef3af2d192ec3e96098ca to your computer and use it in GitHub Desktop.
def fori_body(i, val):
env_state, action_key, all_obsv, all_reward, all_done = val
action = random.randint(action_key, (1,), 0, 2)[0]
action_key = random.split(action_key)[0]
env_state, obsv, reward, done, info = env.step(env_state, action)
all_obsv = all_obsv.at[i].set(obsv)
all_reward = all_reward.at[i].set(reward)
all_done = all_done.at[i].set(done)
val = (env_state, action_key, all_obsv, all_reward, all_done)
return val
@pmap
@vmap
def rollout(key):
all_obsv = jnp.zeros(shape=(TIMESTEPS, 4))
all_reward = jnp.zeros(shape=(TIMESTEPS, 1))
all_done = jnp.zeros(shape=(TIMESTEPS, 1), dtype=jnp.bool_)
env = JaxCartPole()
action_key = jax.random.PRNGKey(0)
env_state, obsv = env.reset(key)
val = (env_state, action_key, all_obsv, all_reward, all_done)
val = jax.lax.fori_loop(0, TIMESTEPS, fori_body, val)
env_state, action_key, all_obsv, all_reward, all_done = val
return all_obsv, all_reward, all_done
NUM_ENV = 1
NUM_DEVICES = len(jax.local_devices())
seed = 0
key = jax.random.PRNGKey(seed)
keys = random.split(key, NUM_ENV).reshape(NUM_DEVICES, NUM_ENV // NUM_DEVICES, -1)
env = JaxCartPole()
all_obsv, all_reward, all_done = rollout(keys)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment