Skip to content

Instantly share code, notes, and snippets.

@frangipane
Created May 8, 2020 03:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save frangipane/4adca6481bf55f2260ff215c5686851b to your computer and use it in GitHub Desktop.
Save frangipane/4adca6481bf55f2260ff215c5686851b to your computer and use it in GitHub Desktop.
Visualize actor critic agent in mini gridworld
"""Visualize actor critic agent in mini gridworld
Plot (dynamic):
- agent navigating gridworld env
- bar plot of discrete actions
- line plot of value function over time
"""
import matplotlib.pyplot as plt
class PolicyPlot:
def __init__(self, action_names):
plt.ion()
self.fig, self.ax = plt.subplots(figsize=(6, 4))
self.action_names = action_names
# initialize blank plot
self._rects = self.ax.bar(range(len(self.action_names)),
[0]*len(self.action_names),
tick_label=self.action_names)
self.ax.set_ylabel("Probability")
self.ax.set_title("Policy")
def plot(self, pi):
for rect, h in zip(self._rects, pi):
rect.set_height(h)
self.ax.autoscale_view(True, True, True)
self.ax.relim()
self.fig.canvas.draw()
plt.pause(0.0001)
return self.fig
class ValuePlot:
def __init__(self, trailing_frames):
plt.ion()
self.fig, self.ax = plt.subplots(figsize=(6, 4))
self.ax.set_title("Value")
self.ax.set_xlabel("iteration")
self.ax.set_ylabel("value")
self.trailing_frames = trailing_frames
self.values = []
# initialize blank plot
self.val_plt, = plt.plot([], [], 'r-')
def plot(self):
frame_num = len(self.values)
idxs = slice(max(0, frame_num - self.trailing_frames), frame_num)
self.val_plt.set_data(range(frame_num)[idxs], self.values[idxs])
self.ax.autoscale_view(True, True, True)
self.ax.relim()
self.fig.canvas.draw()
plt.pause(0.0001)
return self.fig
for episode in range(num_episodes):
obs, i = env.reset(), 0
policy_plot = PolicyPlot(action_names=[n.name for n in env.Actions])
value_plot = ValuePlot(trailing_frames=75)
while True:
i += 1
env.render('human')
action, policy, value = agent.get_action(obs)
value_plot.values.append(value)
value_fig = value_plot.plot()
policy_fig = policy_plot.plot(policy)
obs, reward, done, _ = env.step(action)
if done or i == args.max_frames:
break
if env.window is not None and env.window.closed:
break
if env.window.closed:
break
env.render(close=True) # close window after episode is over
plt.close('all')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment