Created
February 19, 2019 18:47
-
-
Save TomLin/36b84c222622bfc01c95dac2b89c7d0b to your computer and use it in GitHub Desktop.
Monitor training progress and save the model of D4PG.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from collections import deque | |
# train the agent | |
def train_d4pg(agent, memory, n_episodes=10, mode='train', | |
actor_pth='./checkpoint/d4pg_actor_checkpoint.pth', | |
critic_pth='./checkpoint/d4pg_critic_checkpoint.pth'): | |
'''Set up training's configuration and print out episodic performance measures, such as avg scores, avg loss. | |
Params | |
====== | |
agent (class object) | |
memory (class attribute): agent's attribute for memory size tracking | |
mode (string): 'train' or 'test', when in test mode, the agent acts in greedy policy only | |
pth (string path): file name for the checkpoint | |
''' | |
scores = [] | |
scores_window = deque(maxlen=100) # last 100 scores | |
c_loss_window = deque(maxlen=100) | |
a_loss_window = deque(maxlen=100) | |
for i_episode in range(1, n_episodes+1): | |
env_info = env.reset(train_mode=True)[brain_name] # reset the environment and activate train_mode | |
state = env_info.vector_observations # get the current state | |
score = 0 | |
agent.running_c_loss = 0 | |
agent.running_a_loss = 0 | |
agent.training_cnt = 0 | |
# agent.reset() # reset OUNoise | |
while True: | |
action = agent.act(state, mode) | |
env_info = env.step(action)[brain_name] # send the action to the environment | |
next_state = env_info.vector_observations # get the next state | |
reward = env_info.rewards[0] # get the reward | |
done = env_info.local_done[0] # see if episode has finished | |
agent.step(state, action, reward, next_state, done) | |
score += reward | |
state = next_state | |
if done: | |
break | |
scores_window.append(score) | |
scores.append(score) | |
c_loss_window.append(agent.running_c_loss/(agent.training_cnt+0.0001)) # avoid zero | |
a_loss_window.append(agent.running_a_loss/(agent.training_cnt+0.0001)) # avoid zero | |
print('\rEpisode {:>4}\tAverage Score:{:>6.3f}\tMemory Size:{:>5}\tCLoss:{:>12.8f}\tALoss:{:>10.6f}'.format( | |
i_episode, np.mean(scores_window), len(memory), np.mean(c_loss_window), np.mean(a_loss_window)), end="") | |
if i_episode % 100 == 0: | |
print('\rEpisode {:>4}\tAverage Score:{:>6.3f}\tMemory Size:{:>5}\tCLoss:{:>12.8f}\tALoss:{:>10.6f}'.format( | |
i_episode, np.mean(scores_window), len(memory), np.mean(c_loss_window), np.mean(a_loss_window))) | |
if np.mean(scores_window) >= 31: | |
break | |
torch.save(agent.actor_local.state_dict(), actor_pth) | |
torch.save(agent.critic_local.state_dict(), critic_pth) | |
return scores |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment