Skip to content

Instantly share code, notes, and snippets.

@ttumiel
Created February 15, 2023 22:34
Show Gist options
  • Save ttumiel/1d7201140d462ca5a35641d6ab5ef5a1 to your computer and use it in GitHub Desktop.
Save ttumiel/1d7201140d462ca5a35641d6ab5ef5a1 to your computer and use it in GitHub Desktop.
# Numpy/torch Complex Observation Handling
import pytest
import tree
import gym
from gym import spaces
import numpy as np
import torch
def make_storage(obs_space: gym.Space, batch_dims: tuple, device: torch.device):
dummy = obs_space.sample()
return tree.map_structure(
lambda x: torch.zeros(batch_dims + np.shape(x), dtype=torch.tensor(x).dtype, device=device), dummy)
def to_tensor(o, device: torch.device):
return tree.map_structure(lambda obs: torch.from_numpy(obs).to(device), o)
def save_obs(samples, storage, index: int):
def store(obs, store):
store[index] = obs
tree.map_structure(store, samples, storage)
def minibatch(obs, idx: int):
"Gathers a minibatch of observation given indices `idx`"
return tree.map_structure(lambda o: o[idx], obs)
def flatten_obs(obs):
return tree.map_structure(lambda o: o.reshape((-1,) + o.shape[2:]), obs)
def stack_obs(o):
return tree.map_structure(lambda *obs: np.stack(obs), *o)
obs = [
spaces.Discrete(4),
spaces.Box(0, 1, (10, 2)),
spaces.Tuple((spaces.Discrete(4), spaces.Box(0, 1, (10, 2)))),
spaces.Dict({"discrete": spaces.Discrete(4), "continuous": spaces.Box(0, 1, (10, 2))})
]
@pytest.fixture
def num_steps():
return 1
@pytest.fixture
def num_envs():
return 2
def check_batch_shape(x, s):
def thunk(v):
assert v.shape[0] == s
tree.map_structure(thunk, x)
@pytest.mark.parametrize("o", obs)
def test_obs(o, num_steps, num_envs):
device = torch.device("cuda")
storage = make_storage(o, (num_steps, num_envs), device)
for i in range(num_steps):
# Simulate worker rollout
samples = [o.sample() for _ in range(num_envs)]
samples = to_tensor(stack_obs(samples, device))
check_batch_shape(samples, num_envs)
save_obs(samples, storage, i)
b_obs = flatten_obs(storage)
bs = num_steps * num_envs
check_batch_shape(b_obs, bs)
minibatch_size = bs // 2
idxs = np.arange(bs)
np.random.shuffle(idxs)
for i in range(0, bs, minibatch_size):
mb = minibatch(b_obs, idxs[i:i+minibatch_size])
check_batch_shape(mb, minibatch_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment