Skip to content

Instantly share code, notes, and snippets.

@samiede
Created July 16, 2021 12:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save samiede/9962c084bff4f821141bf78c342f85b3 to your computer and use it in GitHub Desktop.
Save samiede/9962c084bff4f821141bf78c342f85b3 to your computer and use it in GitHub Desktop.
Environment Wrappers
import gym
import numpy as np
import cv2
import torch
import collections
import torchvision.transforms as T
from PIL import Image
class FrameStackingEnv(gym.Wrapper):
def __init__(self, env, num_stack=4, random_start = 30, transform=T.Compose([T.ToPILImage(),
T.Resize((84, 84), interpolation=Image.CUBIC),
T.ToTensor()])):
super().__init__(env)
self.env = env
self.n = num_stack
self.transform = transform
self.random_start = random_start
self.last_unprocessed_frame = None
self.buffer = collections.deque(maxlen=num_stack)
@staticmethod
def _preprocess_frame(first_frame, second_frame):
image_r = np.maximum(first_frame[:, :, 0], second_frame[:, :, 0])
image_g = np.maximum(first_frame[:, :, 1], second_frame[:, :, 1])
image_b = np.maximum(first_frame[:, :, 2], second_frame[:, :, 2])
# openCV uses BGR order of color channels
image = np.stack((image_b, image_g, image_r), axis=-1)
img_yuv = cv2.cvtColor(image, cv2.COLOR_BGR2YUV)
y, _, _ = cv2.split(img_yuv)
return torch.from_numpy(y)
def reset(self, fill=False, **kwargs):
if self.random_start >= 2 and not fill:
super(FrameStackingEnv, self).reset(**kwargs)
for i in range(self.random_start - 2):
self.step(0)
first_frame, _, _, _ = super(FrameStackingEnv, self).step(0)
second_frame, _, _, _ = super(FrameStackingEnv, self).step(0)
frame = self._preprocess_frame(first_frame, second_frame)
for i in range(self.n):
self.buffer.append(frame)
return self.transform(torch.stack(tuple(self.buffer), dim=0))
elif 0 < self.random_start < 2 and not fill:
super(FrameStackingEnv, self).reset(**kwargs)
first_frame, _, _, _ = super(FrameStackingEnv, self).step(0)
second_frame, _, _, _ = super(FrameStackingEnv, self).step(0)
frame = self._preprocess_frame(first_frame, second_frame)
for i in range(self.n):
self.buffer.append(frame)
return self.transform(torch.stack(tuple(self.buffer), dim=0))
else:
first_frame = super(FrameStackingEnv, self).reset(**kwargs)
second_frame = super(FrameStackingEnv, self).reset(**kwargs)
frame = self._preprocess_frame(first_frame, second_frame)
for i in range(self.n):
self.buffer.append(frame)
return self.transform(torch.stack(tuple(self.buffer), dim=0))
"""
Take the action once and get the return value for the step in the environment
Repeat the action frame_skip - 2 times. In the frame_skip - 1 action, save the frame
in order to preprocess the current and the last frame of the `render` method
"""
def step(self, action, frame_skip=4):
if frame_skip < 2:
obs, reward, _done, info = super(FrameStackingEnv, self).step(action)
return obs, reward, _done, info
reward_sum = 0
done = False
# info = None
for i in range(0, frame_skip - 2):
obs, reward, _done, info = super(FrameStackingEnv, self).step(action)
reward_sum += reward
done = done or _done
# second-to-last frame
obs, reward, _done, info = super(FrameStackingEnv, self).step(action)
reward_sum += reward
self.last_unprocessed_frame = obs
done = done or _done
# last frame
obs, reward, _done, info = super(FrameStackingEnv, self).step(action)
reward_sum += reward
done = done or _done
frame = self._preprocess_frame(self.last_unprocessed_frame, obs)
# 0,1,2 -> 1,2,3
# self.buffer[:, :, 1:self.n] = self.buffer[0:self.n-1]
# self.buffer[:, :, 0] = frame
self.buffer.append(frame)
return self.transform(torch.stack(tuple(self.buffer), dim=0)), reward_sum, done, info
def render(self, mode='human', *kwargs):
# if save_temp_frame and mode == 'rgb_array':
# self.last_unprocessed_frame = super(FrameStackingEnv, self).render('rgb_array')
if mode == 'rgb_array':
return self.buffer.copy(), super(FrameStackingEnv, self).render('rgb_array')
return super(FrameStackingEnv, self).render(mode)
def snapshot(self):
state = self.env.clone_state() # make snapshot for atari. load with .restoreState()
return state, self.buffer.copy()
def restore_state_and_buffer(self, state, buffer):
self.env.restore_state(state)
self.buffer = buffer.copy()
return self.transform(torch.stack(tuple(self.buffer), dim=0))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment