Skip to content

Instantly share code, notes, and snippets.

@ttumiel
Created February 15, 2023 22:12
Show Gist options
  • Save ttumiel/47c0db298eb7a0b6c120b64c1462b89f to your computer and use it in GitHub Desktop.
Save ttumiel/47c0db298eb7a0b6c120b64c1462b89f to your computer and use it in GitHub Desktop.
# Jax Version of Complex Observation Space
import pytest
from gym import spaces
import numpy as np
import jax.numpy as jnp
import jax.tree_util as tree
def batch_obs(o):
return tree.tree_map(lambda *x: jnp.concatenate(x), *o)
def minibatch(obs, idx):
"Gathers a minibatch of observation given indices `idx`"
return tree.tree_map(lambda o: o[idx], obs)
def stack_obs(o):
return tree.tree_map(lambda *obs: np.stack(obs), *o)
obs = [
spaces.Discrete(4),
spaces.Box(0, 1, (10, 2)),
spaces.Tuple((spaces.Discrete(4), spaces.Box(0, 1, (10, 2)))),
spaces.Dict({"discrete": spaces.Discrete(4), "continuous": spaces.Box(0, 1, (10, 2))})
]
@pytest.fixture
def num_steps():
return 1
@pytest.fixture
def num_envs():
return 2
def check_batch_shape(x, s):
def thunk(v):
assert v.shape[0] == s
tree.tree_map(thunk, x)
@pytest.mark.parametrize("o", obs)
def test_obs(o, num_steps, num_envs):
storage = []
for i in range(num_steps):
# Simulate worker rollout
storage.append(stack_obs([o.sample() for _ in range(num_envs)]))
b_obs = batch_obs(storage)
bs = num_steps * num_envs
check_batch_shape(b_obs, num_envs * num_steps)
minibatch_size = bs // 2
idxs = jnp.arange(bs)
for i in range(0, bs, minibatch_size):
mb = minibatch(b_obs, idxs[i:i+minibatch_size])
check_batch_shape(mb, minibatch_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment