Created
July 16, 2021 12:44
-
-
Save samiede/9962c084bff4f821141bf78c342f85b3 to your computer and use it in GitHub Desktop.
Environment Wrappers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
import numpy as np | |
import cv2 | |
import torch | |
import collections | |
import torchvision.transforms as T | |
from PIL import Image | |
class FrameStackingEnv(gym.Wrapper): | |
def __init__(self, env, num_stack=4, random_start = 30, transform=T.Compose([T.ToPILImage(), | |
T.Resize((84, 84), interpolation=Image.CUBIC), | |
T.ToTensor()])): | |
super().__init__(env) | |
self.env = env | |
self.n = num_stack | |
self.transform = transform | |
self.random_start = random_start | |
self.last_unprocessed_frame = None | |
self.buffer = collections.deque(maxlen=num_stack) | |
@staticmethod | |
def _preprocess_frame(first_frame, second_frame): | |
image_r = np.maximum(first_frame[:, :, 0], second_frame[:, :, 0]) | |
image_g = np.maximum(first_frame[:, :, 1], second_frame[:, :, 1]) | |
image_b = np.maximum(first_frame[:, :, 2], second_frame[:, :, 2]) | |
# openCV uses BGR order of color channels | |
image = np.stack((image_b, image_g, image_r), axis=-1) | |
img_yuv = cv2.cvtColor(image, cv2.COLOR_BGR2YUV) | |
y, _, _ = cv2.split(img_yuv) | |
return torch.from_numpy(y) | |
def reset(self, fill=False, **kwargs): | |
if self.random_start >= 2 and not fill: | |
super(FrameStackingEnv, self).reset(**kwargs) | |
for i in range(self.random_start - 2): | |
self.step(0) | |
first_frame, _, _, _ = super(FrameStackingEnv, self).step(0) | |
second_frame, _, _, _ = super(FrameStackingEnv, self).step(0) | |
frame = self._preprocess_frame(first_frame, second_frame) | |
for i in range(self.n): | |
self.buffer.append(frame) | |
return self.transform(torch.stack(tuple(self.buffer), dim=0)) | |
elif 0 < self.random_start < 2 and not fill: | |
super(FrameStackingEnv, self).reset(**kwargs) | |
first_frame, _, _, _ = super(FrameStackingEnv, self).step(0) | |
second_frame, _, _, _ = super(FrameStackingEnv, self).step(0) | |
frame = self._preprocess_frame(first_frame, second_frame) | |
for i in range(self.n): | |
self.buffer.append(frame) | |
return self.transform(torch.stack(tuple(self.buffer), dim=0)) | |
else: | |
first_frame = super(FrameStackingEnv, self).reset(**kwargs) | |
second_frame = super(FrameStackingEnv, self).reset(**kwargs) | |
frame = self._preprocess_frame(first_frame, second_frame) | |
for i in range(self.n): | |
self.buffer.append(frame) | |
return self.transform(torch.stack(tuple(self.buffer), dim=0)) | |
""" | |
Take the action once and get the return value for the step in the environment | |
Repeat the action frame_skip - 2 times. In the frame_skip - 1 action, save the frame | |
in order to preprocess the current and the last frame of the `render` method | |
""" | |
def step(self, action, frame_skip=4): | |
if frame_skip < 2: | |
obs, reward, _done, info = super(FrameStackingEnv, self).step(action) | |
return obs, reward, _done, info | |
reward_sum = 0 | |
done = False | |
# info = None | |
for i in range(0, frame_skip - 2): | |
obs, reward, _done, info = super(FrameStackingEnv, self).step(action) | |
reward_sum += reward | |
done = done or _done | |
# second-to-last frame | |
obs, reward, _done, info = super(FrameStackingEnv, self).step(action) | |
reward_sum += reward | |
self.last_unprocessed_frame = obs | |
done = done or _done | |
# last frame | |
obs, reward, _done, info = super(FrameStackingEnv, self).step(action) | |
reward_sum += reward | |
done = done or _done | |
frame = self._preprocess_frame(self.last_unprocessed_frame, obs) | |
# 0,1,2 -> 1,2,3 | |
# self.buffer[:, :, 1:self.n] = self.buffer[0:self.n-1] | |
# self.buffer[:, :, 0] = frame | |
self.buffer.append(frame) | |
return self.transform(torch.stack(tuple(self.buffer), dim=0)), reward_sum, done, info | |
def render(self, mode='human', *kwargs): | |
# if save_temp_frame and mode == 'rgb_array': | |
# self.last_unprocessed_frame = super(FrameStackingEnv, self).render('rgb_array') | |
if mode == 'rgb_array': | |
return self.buffer.copy(), super(FrameStackingEnv, self).render('rgb_array') | |
return super(FrameStackingEnv, self).render(mode) | |
def snapshot(self): | |
state = self.env.clone_state() # make snapshot for atari. load with .restoreState() | |
return state, self.buffer.copy() | |
def restore_state_and_buffer(self, state, buffer): | |
self.env.restore_state(state) | |
self.buffer = buffer.copy() | |
return self.transform(torch.stack(tuple(self.buffer), dim=0)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment