Skip to content

Instantly share code, notes, and snippets.

@samiede
Created July 16, 2021 12:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save samiede/eaa34b05735c1e4bcc4b397b9e930eec to your computer and use it in GitHub Desktop.
Save samiede/eaa34b05735c1e4bcc4b397b9e930eec to your computer and use it in GitHub Desktop.
DQN with own wrappers
import pickle
import gym
import math
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import click
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
from zennit.composites import COMPOSITES
import utils
from ReplayMemory import ReplayMemory, Transition
from EnvironmentWrappers import FrameStackingEnv
from pickle import dumps, loads
from model import DQN, LinearSchedule
from tqdm import tqdm
import os
import wandb
# matplotlib.use('TkAgg')
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
BATCH_SIZE = 32
GAMMA = 0.99
TARGET_UPDATE_STEPS = 10000
# TARGET_UPDATE_STEPS = 5
REPLAY_MEMORY_SIZE = 1000000
# REPLAY_MEMORY_START_SIZE = 50000
REPLAY_MEMORY_START_SIZE = 32
UPDATE_FREQUENCY = 4
NO_OP_MAX = 30
MAX_FRAME_COUNT = 50000000
LR = 0.00025
GRADIENT_MOMENTUM = 0.95
# SQUARED_GRADIENT_MOMENTUM = 0.95
MIN_SQUARED_GRADIENT = 0.01
SAVE_STATE_SCORE_INTERVAL = 4
@click.command()
@click.option('--seed', type=int)
@click.option('--cpu/--gpu', default=False)
@click.option('--env', type=str, default='BreakoutNoFrameskip-v4')
@click.option('--params', type=click.Path(dir_okay=False))
# @click.option('--wandb', default=True)
def main(seed, cpu, env, params):
if seed is not None:
torch.manual_seed(seed)
steps_done = 0
device = torch.device('cuda:0' if torch.cuda.is_available() and not cpu else 'cpu')
print(device)
test_env = gym.make(env)
env = gym.make(env)
env = FrameStackingEnv(env, random_start=NO_OP_MAX)
test_env = FrameStackingEnv(test_env, random_start=0)
resize = T.Compose([T.ToPILImage(),
T.Resize((84, 84), interpolation=Image.CUBIC),
T.ToTensor()])
def get_screen(env, transform=True):
_, screen = env.render(mode='rgb_array')
screen = screen.transpose((2, 0, 1))
# env.render(mode='human')
screen = torch.from_numpy(screen)
if transform:
return resize(screen)
else:
return screen
policy_net = DQN((4, 84, 84), env.action_space.n, device).to(device)
if params is not None:
policy_dict = torch.load(params, map_location='cuda:0' if torch.cuda.is_available() else 'cpu')
policy_net.load_state_dict(policy_dict)
target_net = DQN((4, 84, 84), env.action_space.n, device).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
optimizer = optim.RMSprop(policy_net.parameters(), lr=LR, momentum=GRADIENT_MOMENTUM, eps=MIN_SQUARED_GRADIENT)
memory = ReplayMemory(100000)
def optimize_model():
if len(memory) < BATCH_SIZE:
return
transitions = memory.sample(BATCH_SIZE)
batch = Transition(*zip(*transitions))
# Compute a mask of non-final states and concatenate the batch elements
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device,
dtype=torch.bool)
non_final_next_states = [s for s in batch.next_state if s is not None]
if len(non_final_next_states) > 0:
non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
state_batch = torch.cat(batch.state)
action_batch = torch.cat(batch.action)
reward_batch = torch.cat(batch.reward)
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
# columns of actions taken. These are the actions which would've been taken
# for each batch state according to policy_net
# policy net computes Q(s_t), the index of the taken action is in action_batch tensor
# now the state_action values are collected by gathering the state values from the result of policy_net
# using the 'gather' function, which selects values from a tensor along the given dimension and the indices
# given in the 'indices' tensor (which in this case is 'action_batch')
state_action_values = policy_net(state_batch).gather(1, action_batch)
# Compute V(s_{t+1}) for all next states.
# Expected values of actions for non_final_next_states are computed based
# on the "older" target_net; selecting their best reward with max(1)[0].
# This is merged based on the mask, such that we'll have either the expected
# state value or 0 in case the state was final.
next_state_values = torch.zeros(BATCH_SIZE, device=device)
# this time, we get the actual values, which is why we use .max(1)[0]
if len(non_final_next_states) > 0:
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
# Compute the expected Q values (reward + gamma q_pi)
expected_state_action_values = reward_batch + (GAMMA * next_state_values)
criterion = nn.SmoothL1Loss()
loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
optimizer.zero_grad()
loss.backward()
# Gradients are clamped to make training more stable:
# "Because the absolute value loss function |x| has a derivative of -1 for all negative values of x and a derivative of 1 for all positive values of x,
# clipping the squared error to be between -1 and 1 corresponds to using an absolute value loss function for errors outside of the (-1,1)"
for param in policy_net.parameters():
param.grad.data.clamp_(-1, 1)
optimizer.step()
def run_test_episode(model, env, max_steps=10000): # -> reward, movie
_state = env.reset().unsqueeze(0)
idx = 0
_done = False
_reward = 0
movie_frames = []
x_vals = []
meanings = env.unwrapped.get_action_meanings()
action_vals = [[] for _ in meanings]
while not _done and idx < max_steps:
movie_frames.append(get_screen(env, False))
_action, all_values = model.select_action(_state, steps_done, eps=0.05, log=True)
_state, r, _done, _ = env.step(_action)
_state = _state.unsqueeze(0)
_reward += max(0, r)
x_vals.append(idx)
for i, l in enumerate(action_vals):
l.append(all_values[0][i].item())
idx += 1
log_state_action_values(x_vals, action_vals, meanings)
return _reward, np.stack(movie_frames, 0)
# prefill Replay Memory with REPLAY_START_SIZE frames of random actions
tq = tqdm()
tq.set_description("Filling initial Replay Memory")
state = env.reset(fill=True).unsqueeze(0)
while steps_done < REPLAY_MEMORY_START_SIZE:
tq.update(1)
steps_done += 1
action = torch.tensor([[env.action_space.sample()]], device=device)
obs, reward, done, info = env.step(action.item())
reward = torch.tensor([reward], device=device)
if not done:
next_state = obs.unsqueeze(0)
else:
next_state = None
# Store the transition in memory
memory.push(state, action, next_state, reward)
if done:
state = env.reset(fill=True).unsqueeze(0)
# else:
# state = next_state
tq.reset()
steps_done = 0
update_steps_done = 0
tq.set_description("Training")
state = env.reset().unsqueeze(0)
score_sum = 0
current_score_interval = 1
while steps_done < MAX_FRAME_COUNT:
tq.update(1)
steps_done += 1
action = policy_net.select_action(state, steps_done)
lives = env.ale.lives() # get lives before action
obs, reward, done, info = env.step(action.item())
reward = torch.tensor([reward], device=device)
reward.data.clamp_(-1, 1)
score_sum += max(0, reward.item())
# if we did not lose a life and the episode is not done
if not done and lives == env.ale.lives():
next_state = obs.unsqueeze(0)
else:
next_state = None
# if steps interval and we haven't already saved this state
if score_sum > current_score_interval and score_sum % SAVE_STATE_SCORE_INTERVAL == 0:
current_score_interval = score_sum
save_snapshots_to_disk(*env.snapshot(), score_sum)
# Store the transition in memory
memory.push(state, action, next_state, reward)
if steps_done % UPDATE_FREQUENCY == 0:
optimize_model()
update_steps_done += 1
if update_steps_done % TARGET_UPDATE_STEPS == 0:
# update target net parameters
target_net.load_state_dict(policy_net.state_dict())
torch.save(policy_net.state_dict(), f'{checkpoint_dir}/policy_net_{steps_done}.pth')
test_reward, frames = run_test_episode(policy_net, test_env)
if done:
wandb.log({'score': score_sum})
score_sum = 0
state = env.reset().unsqueeze(0)
else:
state = obs.unsqueeze(0)
print('Complete')
env.close()
wandb.finish()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment