Created
February 15, 2023 22:34
-
-
Save ttumiel/1d7201140d462ca5a35641d6ab5ef5a1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Numpy/torch Complex Observation Handling | |
import pytest | |
import tree | |
import gym | |
from gym import spaces | |
import numpy as np | |
import torch | |
def make_storage(obs_space: gym.Space, batch_dims: tuple, device: torch.device): | |
dummy = obs_space.sample() | |
return tree.map_structure( | |
lambda x: torch.zeros(batch_dims + np.shape(x), dtype=torch.tensor(x).dtype, device=device), dummy) | |
def to_tensor(o, device: torch.device): | |
return tree.map_structure(lambda obs: torch.from_numpy(obs).to(device), o) | |
def save_obs(samples, storage, index: int): | |
def store(obs, store): | |
store[index] = obs | |
tree.map_structure(store, samples, storage) | |
def minibatch(obs, idx: int): | |
"Gathers a minibatch of observation given indices `idx`" | |
return tree.map_structure(lambda o: o[idx], obs) | |
def flatten_obs(obs): | |
return tree.map_structure(lambda o: o.reshape((-1,) + o.shape[2:]), obs) | |
def stack_obs(o): | |
return tree.map_structure(lambda *obs: np.stack(obs), *o) | |
obs = [ | |
spaces.Discrete(4), | |
spaces.Box(0, 1, (10, 2)), | |
spaces.Tuple((spaces.Discrete(4), spaces.Box(0, 1, (10, 2)))), | |
spaces.Dict({"discrete": spaces.Discrete(4), "continuous": spaces.Box(0, 1, (10, 2))}) | |
] | |
@pytest.fixture | |
def num_steps(): | |
return 1 | |
@pytest.fixture | |
def num_envs(): | |
return 2 | |
def check_batch_shape(x, s): | |
def thunk(v): | |
assert v.shape[0] == s | |
tree.map_structure(thunk, x) | |
@pytest.mark.parametrize("o", obs) | |
def test_obs(o, num_steps, num_envs): | |
device = torch.device("cuda") | |
storage = make_storage(o, (num_steps, num_envs), device) | |
for i in range(num_steps): | |
# Simulate worker rollout | |
samples = [o.sample() for _ in range(num_envs)] | |
samples = to_tensor(stack_obs(samples, device)) | |
check_batch_shape(samples, num_envs) | |
save_obs(samples, storage, i) | |
b_obs = flatten_obs(storage) | |
bs = num_steps * num_envs | |
check_batch_shape(b_obs, bs) | |
minibatch_size = bs // 2 | |
idxs = np.arange(bs) | |
np.random.shuffle(idxs) | |
for i in range(0, bs, minibatch_size): | |
mb = minibatch(b_obs, idxs[i:i+minibatch_size]) | |
check_batch_shape(mb, minibatch_size) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment