Skip to content

Instantly share code, notes, and snippets.

@renatolfc
Created September 1, 2018 19:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save renatolfc/07b173c93f515e6ddd1225cbb43e5522 to your computer and use it in GitHub Desktop.
Save renatolfc/07b173c93f515e6ddd1225cbb43e5522 to your computer and use it in GitHub Desktop.
Code to plot a stack of four RGB frames as used in the Visual Banana navigation challenge
import matplotlib
matplotlib.use('Agg')
from matplotlib import pylab as plt
import matplotlib.gridspec as gridspec
import matplotlib.backends.backend_agg as agg
from matplotlib.ticker import FuncFormatter, MaxNLocator
# You need pygame and matplotlib: !pip install pygame matplotlib
# You need to create a pygame screen as well
# Something like this:
# pygame.init()
# screen = pygame.display.set_mode(VIEW_RESOLUTION, pygame.DOUBLEBUF)
# Then you pass the screen as an argument to `show_agent`
#
# If you want to wath every frame, just add the call to the code
# after performing an action
STACK_SIZE = 4
FRAME_SKIP = 1
VIEW_RESOLUTION = 1280, 720
ACTIONS = {
0: '↑',
1: '↓',
2: '←',
3: '→',
}
def show_agent(state, next_state, action, screen):
fig = plt.figure(0, figsize=(VIEW_RESOLUTION[0]/96, VIEW_RESOLUTION[1]/96), dpi=96)
for i in range(4):
ax = plt.subplot2grid((9, 2), ((i // 2) * 2, i % 2), rowspan=2)
ax.imshow(state[:, i, :, :].transpose(1, 2, 0))
ax.set_title('State - %d' % (3 - i))
for i in range(4):
ax = plt.subplot2grid((9, 2), (4 + (i // 2) * 2, i % 2), rowspan=2)
ax.imshow(next_state[:, i, :, :].transpose(1, 2, 0))
ax.set_title('Next State - %d' % (3 - i))
a = np.zeros((1, 4))
a[0, action] = 1
ax = plt.subplot2grid((9, 2), (8, 0), colspan=2)
ax.imshow(a, cmap='gray')
ax.xaxis.set_major_formatter(FuncFormatter(tick_formatter))
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
fig.tight_layout()
canvas = agg.FigureCanvasAgg(fig)
canvas.draw()
renderer = canvas.get_renderer()
raw_data = renderer.tostring_rgb()
size = canvas.get_width_height()
surf = pygame.image.fromstring(raw_data, size, "RGB")
surf_pos = surf.get_rect()
screen.blit(surf, surf_pos)
pygame.display.update()
plt.close(fig)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment