Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DuaneNielsen/534eecb6602828f4ed0f755cd7e1f385 to your computer and use it in GitHub Desktop.
Save DuaneNielsen/534eecb6602828f4ed0f755cd7e1f385 to your computer and use it in GitHub Desktop.
from typing import Optional
import torch
from torch import tensor
from tensordict import TensorDict
from torchrl.data import CompositeSpec, BoundedTensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec, DiscreteTensorSpec, \
UnboundedContinuousTensorSpec
from torchrl.envs import (
EnvBase,
Transform,
TransformedEnv,
)
from torchrl.envs.utils import check_env_specs, step_mdp
from torchrl.envs.transforms.transforms import _apply_to_composite
from torch.nn.functional import interpolate
from torchvision.utils import make_grid
from math import prod
"""
A minimal stateless vectorized gridworld in pytorch rl
Action space: (0, 1, 2, 3) -> N, E, S, W
Features
walls
1 time pickup rewards or penalties
done tiles
outputs a fully observable RGB image
look at the gen_params function to setup the world
example of configuring and performing a rollout at bottom
After implementing this, I think that pytorch RL doesn't really support
vectorized stateless environments. Which is OK
Maybe I will write an even simpler stateless non-vectorized version of this gridworld
"""
# N/S is reversed as y-axis in images is reversed
action_vec = [
tensor([0, -1]), # N
tensor([1, 0]), # E
tensor([0, 1]), # S
tensor([-1, 0]) # W
]
action_vec = torch.stack(action_vec)
yellow = tensor([255, 255, 0], dtype=torch.uint8)
red = tensor([255, 0, 0], dtype=torch.uint8)
green = tensor([0, 255, 0], dtype=torch.uint8)
pink = tensor([255, 0, 255], dtype=torch.uint8)
violet = tensor([226, 43, 138], dtype=torch.uint8)
white = tensor([255, 255, 255], dtype=torch.uint8)
def _step(state):
# make our life easier by creating a view with a single leading dim
state_flat = state.view(prod(state.shape))
batch_range = torch.arange(state_flat.size(0))
# move player position checking for collisions
next_player_pos = state_flat['player_pos'] + action_vec[state_flat['action'][:, 0]].to(state.device)
next_player_grid = torch.zeros_like(state_flat['wall_tiles'], dtype=torch.bool, device=state.device)
next_player_grid[batch_range, next_player_pos[:, 0], next_player_pos[:, 1]] = True
collide_wall = torch.logical_and(next_player_grid, state_flat['wall_tiles'] == 1).any(-1).any(-1)
player_pos = torch.where(collide_wall[..., None], state_flat['player_pos'], next_player_pos)
player_pos_mask = torch.zeros_like(state_flat['wall_tiles'], dtype=torch.bool, device=state.device)
player_pos_mask[batch_range, player_pos[:, 0], player_pos[:, 1]] = True
player_pos = player_pos.reshape(state['player_pos'].shape)
player_pos_mask = player_pos_mask.reshape(state['wall_tiles'].shape)
# pickup any rewards
reward = state['reward_tiles'][player_pos_mask]
state['reward_tiles'][player_pos_mask] = 0.
# set done flag if hit done tile
done = state['done_tiles'][player_pos_mask]
next = {
'player_pos': player_pos,
'wall_tiles': state['wall_tiles'],
'reward_tiles': state['reward_tiles'],
'done_tiles': state['done_tiles'],
'reward': reward,
'done': done
}
return TensorDict({'next': next}, state.shape)
def _reset(self, state=None):
batch_size = state.shape if state is not None else []
if state is None or state.is_empty():
state = self.gen_params(batch_size)
return state
def gen_params(batch_size=None):
walls = tensor([
[1, 1, 1, 1, 1],
[1, 0, 0, 0, 1],
[1, 0, 0, 0, 1],
[1, 0, 0, 0, 1],
[1, 1, 1, 1, 1],
], dtype=torch.uint8)
rewards = tensor([
[0, 0, 0, 0, 0],
[0, 1, 1, -1, 0],
[0, 1, 0, 1, 0],
[0, -1, 1, 1, 0],
[0, 0, 0, 0, 0],
], dtype=torch.float32)
dones = tensor([
[0, 0, 0, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 0, 0, 0],
], dtype=torch.bool)
player_pos = tensor([2, 2], dtype=torch.int64)
observation = {
"player_pos": player_pos,
"wall_tiles": walls,
"reward_tiles": rewards,
"done_tiles": dones
}
td = TensorDict(observation, batch_size=[])
if batch_size:
td = td.expand(batch_size).contiguous()
return td
def _make_spec(self, td_params):
batch_size = td_params.shape
self.observation_spec = CompositeSpec(
wall_tiles=BoundedTensorSpec(
minimum=0,
maximum=1,
shape=torch.Size((*batch_size, 5, 5)),
dtype=torch.uint8,
),
reward_tiles=UnboundedContinuousTensorSpec(
shape=torch.Size((*batch_size, 5, 5)),
dtype=torch.float32,
),
done_tiles=BoundedTensorSpec(
minimum=0,
maximum=1,
shape=torch.Size((*batch_size, 5, 5)),
dtype=torch.bool,
),
player_pos=UnboundedDiscreteTensorSpec(
shape=torch.Size((*batch_size, 2,)),
dtype=torch.int64
),
shape=torch.Size((*batch_size,))
)
self.input_spec = self.observation_spec.clone()
self.action_spec = DiscreteTensorSpec(4, shape=torch.Size((*batch_size, 1)))
self.reward_spec = UnboundedContinuousTensorSpec(shape=torch.Size((*batch_size, 1)))
def _set_seed(self, seed: Optional[int]):
rng = torch.manual_seed(seed)
self.rng = rng
class Gridworld(EnvBase):
metadata = {
"render_modes": ["human", ""],
"render_fps": 30
}
batch_locked = False
def __init__(self, td_params=None, device="cpu", batch_size=None):
if td_params is None:
td_params = self.gen_params(batch_size)
batch_size = [] if batch_size is None else batch_size
super().__init__(device=device, batch_size=batch_size)
self._make_spec(td_params)
self.shape = batch_size
gen_params = staticmethod(gen_params)
_make_spec = _make_spec
_reset = _reset
_step = staticmethod(_step)
_set_seed = _set_seed
class RGBFullObsTransform(Transform):
def forward(self, tensordict):
return self._call(tensordict)
def _call(self, td):
td_flat = td.view(prod(td.batch_size))
batch_range = torch.arange(td_flat.size(0))
player_pos = td_flat['player_pos']
walls = td_flat['wall_tiles']
rewards = td_flat['reward_tiles']
grid = TensorDict({'image': torch.zeros(*walls.shape, 3, dtype=torch.uint8)}, batch_size=td_flat.batch_size)
x, y = player_pos[:, 0], player_pos[:, 1]
grid['image'][walls == 1] = white
grid['image'][rewards > 0] = green
grid['image'][rewards < 0] = red
grid['image'][batch_range, x, y, :] = yellow
grid['image'] = grid['image'].permute(0, 3, 1, 2)
observation = interpolate(grid['image'], size=[64, 64]).squeeze(0)
return TensorDict({
"observation": observation,
**td
}, batch_size=td.batch_size)
@_apply_to_composite
def transform_observation_spec(self, observation_spec):
return BoundedTensorSpec(
minimum=0,
maximum=255,
shape=torch.Size((3, 64, 64)),
dtype=torch.uint8,
device=observation_spec.device
)
if __name__ == '__main__':
from matplotlib import pyplot as plt
def simple_rollout(steps=100, batch_size=None):
batch_size = [1] if batch_size is None else batch_size
# preallocate:
data = TensorDict({}, [steps, *batch_size])
# reset
_data = env.gen_params(batch_size=batch_size)
_data = env.reset(_data)
for i in range(steps):
_data["action"] = env.action_spec.rand(shape=batch_size)
_data = env.step(_data)
data[i] = _data
_data = step_mdp(_data, keep_other=True)
return data
env = Gridworld()
check_env_specs(env)
env = TransformedEnv(
env,
RGBFullObsTransform(in_keys=['player_pos', 'walls'], out_keys=['observation'])
)
check_env_specs(env)
data = simple_rollout(batch_size=[64])
fig, ax = plt.subplots(1)
img_plt = ax.imshow(make_grid(data[0]['observation']).permute(1, 2, 0))
for timestep in data:
x = make_grid(timestep['observation']).permute(1, 2, 0)
img_plt.set_data(x)
plt.pause(1.0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment