Skip to content

Instantly share code, notes, and snippets.

@piEsposito
Created January 20, 2020 14:19
Show Gist options
  • Save piEsposito/433027d863ee39a03b1e0a67e0609aa0 to your computer and use it in GitHub Desktop.
Save piEsposito/433027d863ee39a03b1e0a67e0609aa0 to your computer and use it in GitHub Desktop.
class FrameStacker:
def __init__(self):
"""
We can set the memory size here.
Our memory is a deque and, on each stack, it concatenates the frames in memory along the axis 0
We also have a transformer from torch that handles the resizing.
"""
self.memory_size = 4
self.memory = deque(maxlen=self.memory_size)
self.reset()
self.transformer = T.Compose([T.ToPILImage(),
T.Resize((84,84)),
T.ToTensor()])
def reset(self):
"""
by feeding the deque with zero-tensors we restart the memory.
"""
for i in range(4):
self.memory.append(torch.zeros(1, 84, 84).to(device))
def preprocess_frame(self, frame):
"""
here we handle the cutting and flowing the frame through the transformer
"""
frame = frame[80:,:]
frame = self.transformer(frame)
return frame.to(device)/255
def stack(self, frame):
"""
our stack method preprocesses the state and returns it stacked.
"""
frame = self.preprocess_frame(frame)
self.memory.append(frame)
return torch.cat(tuple(self.memory), dim=0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment