Skip to content

Instantly share code, notes, and snippets.

@BlGene
Created June 27, 2018 13:38
Show Gist options
  • Save BlGene/7a2585ed3726cd08ae536aea43493db4 to your computer and use it in GitHub Desktop.
Save BlGene/7a2585ed3726cd08ae536aea43493db4 to your computer and use it in GitHub Desktop.
import os
import glob
import multiprocessing
from collections import deque
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import pygame
from pdb import set_trace
import numpy as np
#env is in a different process
env_data_queue = multiprocessing.Queue()
class Viewer:
def __init__(self, transpose=True, zoom=None, video=False):
self.env_initialized = False
self.run_initialized = False
self.transpose = transpose
self.zoom = zoom
if video:
os.makedirs('./video', exist_ok=True)
files = glob.glob('./video/*.png')
for f in files:
os.remove(f)
self.frame_count = 0
self.video = video
self.col_data = None
self.col_data = None
plt.ion()
self.fig = plt.figure(figsize = (2,2))
gs = gridspec.GridSpec(2, 2)
gs.update(wspace=0.001, hspace=0.001) # set the spacing between axes.
self.col_ax = plt.subplot(gs[0,0])
self.net_ax = plt.subplot(gs[0,1])
self.plt_ax = plt.subplot(gs[1,:])
self.col_ax.set_axis_off()
self.net_ax.set_axis_off()
self.plt_ax.set_axis_off()
plt.subplots_adjust(wspace=0.5, hspace=0, left=0, bottom=0, right=1, top=1)
# time series
num_plots = 2
self.horizon_timesteps = 30 * 5
self.t = 0
self.cur_plot = [None for _ in range(num_plots)]
self.data = [deque(maxlen=self.horizon_timesteps) for _ in range(num_plots)]
def env_callback(self, env, draw=True):
obs = env._observation
env_data_queue.put(obs)
def run_callback(self, prev_obs, obs, actions, rew, masks, values, draw=True):
self.run_data = obs.copy()
print("Run callback: ",rew, np.max(obs))
# time series
def data_callback(prev_obs, obs, actions, rew, masks, values):
return [rew[0], values[0],]
points = data_callback(prev_obs, obs, actions, rew, masks, values)
for point, data_series in zip(points, self.data):
data_series.append(point)
self.t += 1
xmin, xmax = max(0, self.t - self.horizon_timesteps), self.t
for i, (plot,c,l) in enumerate(zip(self.cur_plot,['C1','C2'],['rew','val'])):
if plot is not None:
plot.remove()
#self.cur_plot[i] = self.plt_ax.scatter(range(xmin, xmax), list(self.data[i]),color='k')
self.cur_plot[i], = self.plt_ax.plot(range(xmin, xmax), list(self.data[i]),color=c,label=l)
self.plt_ax.set_xlim(xmin, xmax)
self.plt_ax.legend(loc='lower left')
self.draw()
def get_video_size(self, obs):
# helper function for draw
assert len(obs.shape) == 2 or (len(obs.shape) == 3 and obs.shape[2] in [1,3]) , "shape was {}".format(obs.shape)
if self.transpose:
video_size = obs.shape[1], obs.shape[0]
else:
video_size = obs.shape[0], obs.shape[1]
if self.zoom is not None:
video_size = int(video_size[0] * zoom), int(video_size[1] * zoom)
return video_size
def draw(self):
col_data = env_data_queue.get()
if self.env_initialized == True:
self.col_screen.set_data(col_data)
elif col_data is not None:
obs = col_data
video_size = self.get_video_size(obs)
self.col_screen = self.col_ax.imshow(obs, aspect='auto')
self.env_initialized = True
if self.run_initialized == True:
self.net_screen.set_data(self.run_data)
elif self.run_data is not None and not np.all(self.run_data == 0):
obs = self.run_data
video_size = self.get_video_size(obs)
self.net_screen = self.net_ax.imshow(obs, cmap='gray', aspect='auto')
self.run_initialized = True
self.fig.tight_layout()
self.fig.canvas.draw()
if self.video:
fn = './video/{0:03d}.png'.format(self.frame_count)
self.fig.savefig(fn, bbox_inches='tight', pad_inches=0)
self.frame_count += 1
@staticmethod
def display_arr(screen, arr, transpose=True, video_size=(84,84)):
arr_min, arr_max = arr.min(), arr.max()
arr = 255.0 * (arr - arr_min) / (arr_max - arr_min)
pyg_img = pygame.surfarray.make_surface(arr.swapaxes(0, 1) if transpose else arr)
pyg_img = pygame.transform.scale(pyg_img, video_size)
screen.blit(pyg_img, (0,0))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment