-
-
Save tegg89/7090d96b527919e0952db60c1b5750de to your computer and use it in GitHub Desktop.
Implementation of DQN, Double DQN, Bootstrap DQN, and Bootstrap DQN with Randomized Prior in PyTorch on a toy environment
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# extending on code from | |
# https://github.com/58402140/Fruit | |
import os | |
import numpy as np | |
import matplotlib | |
matplotlib.use('TkAgg') | |
from matplotlib import pyplot as plt | |
import copy | |
import time | |
from collections import Counter | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
# Size of game grid e.g. 10 -> 10x10 | |
GRID_SIZE = 10 | |
# How often to check and evaluate | |
EVALUATE_EVERY = 10 | |
# Save images of policy at each evaluation if True, otherwise only at the end if False | |
SAVE_IMAGES = False | |
# How often to print statistics | |
PRINT_EVERY = 1 | |
# Whether to use double DQN or regular DQN | |
USE_DOUBLE_DQN = True | |
# TARGET_UPDATE how often to use replica target | |
TARGET_UPDATE = 10 | |
# Number of evaluation episodes to run | |
N_EVALUATIONS = 100 | |
# Number of heads for ensemble (1 falls back to DQN) | |
N_ENSEMBLE = 5 | |
# Probability of experience to go to each head | |
BERNOULLI_P = 1. | |
# Weight for randomized prior, 0. disables | |
PRIOR_SCALE = 1. | |
# Number of episodes to run | |
N_EPOCHS = 1000 | |
# Batch size to use for learning | |
BATCH_SIZE = 128 | |
# Buffer size for experience replay | |
BUFFER_SIZE = 1000 | |
# Epsilon greedy exploration ~prob of random action, 0. disables | |
EPSILON = .0 | |
# Gamma weight in Q update | |
GAMMA = .8 | |
# Gradient clipping setting | |
CLIP_GRAD = 1 | |
# Learning rate for Adam | |
ADAM_LEARNING_RATE = 1E-3 | |
random_state = np.random.RandomState(11) | |
def seed_everything(seed=1234): | |
#random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
np.random.seed(seed) | |
os.environ['PYTHONHASHSEED'] = str(seed) | |
#torch.backends.cudnn.deterministic = True | |
seed_everything(22) | |
def save_img(epoch): | |
if 'images_{}'.format(epoch) not in os.listdir('.'): | |
os.mkdir('images_{}'.format(epoch)) | |
frame = 0 | |
while True: | |
screen, reward = (yield) | |
plt.imshow(screen[0], interpolation='none') | |
plt.title("reward: {}".format(reward)) | |
plt.savefig('images_{}/{}.png'.format(epoch, frame)) | |
frame += 1 | |
def episode(): | |
""" | |
Coroutine of episode. | |
Action has to be explicitly sent to this coroutine. | |
actions are 0, 1, 2 for left, don't-move, and right | |
""" | |
x, y, z = ( | |
random_state.randint(0, GRID_SIZE), # X of fruit | |
0, # Y of dot | |
random_state.randint(1, GRID_SIZE - 1) # X of basket | |
) | |
while True: | |
X = np.zeros((GRID_SIZE, GRID_SIZE)) # Reset grid | |
X = X.astype("float32") | |
X[y, x] = 1. # Draw fruit | |
bar = range(z - 1, z + 2) | |
X[-1, bar] = 1. # Draw basket | |
# End of game is known when fruit is at penultimate line of grid. | |
# End represents either a win or a loss | |
end = int(y >= GRID_SIZE - 2) | |
reward = 0 | |
# can add this for dense rewards | |
#if x in bar: | |
# reward = 1 | |
if end and x not in bar: | |
reward = -1 | |
if end and x in bar: | |
reward = 1 | |
action = yield X[None], reward #end | |
if end: | |
break | |
# translate actions | |
# 0 is left | |
# 1 is same (0) | |
# 2 is right | |
action = action - 1 | |
z = min(max(z + action, 1), GRID_SIZE - 2) | |
y += 1 | |
def experience_replay(batch_size, max_size): | |
""" | |
Coroutine of experience replay. | |
Provide a new experience by calling send, which in turn yields | |
a random batch of previous replay experiences. | |
""" | |
memory = [] | |
while True: | |
inds = np.arange(len(memory)) | |
experience = yield [memory[i] for i in random_state.choice(inds, size=batch_size, replace=True)] if batch_size <= len(memory) else None | |
# send None to just get random experiences, without changing buffer | |
if experience is not None: | |
memory.append(experience) | |
if len(memory) > max_size: | |
memory.pop(0) | |
class CoreNet(nn.Module): | |
def __init__(self): | |
super(CoreNet, self).__init__() | |
self.conv1 = nn.Conv2d(1, 16, 3, 1, padding=(1, 1)) | |
self.conv2 = nn.Conv2d(16, 16, 3, 1, padding=(1, 1)) | |
#self.conv3 = nn.Conv2d(16, 16, 3, 1, padding=(1, 1)) | |
def forward(self, x): | |
x = F.relu(self.conv1(x)) | |
x = F.relu(self.conv2(x)) | |
#x = F.relu(self.conv3(x)) | |
x = x.view(-1, 10 * 10 * 16) | |
return x | |
class HeadNet(nn.Module): | |
def __init__(self): | |
super(HeadNet, self).__init__() | |
self.fc1 = nn.Linear(10 * 10 * 16, 100) | |
self.fc2 = nn.Linear(100, 3) | |
def forward(self, x): | |
x = F.relu(self.fc1(x)) | |
x = self.fc2(x) | |
return x | |
class EnsembleNet(nn.Module): | |
def __init__(self, n_ensemble): | |
super(EnsembleNet, self).__init__() | |
self.core_net = CoreNet() | |
self.net_list = nn.ModuleList([HeadNet() for k in range(n_ensemble)]) | |
def _core(self, x): | |
return self.core_net(x) | |
def _heads(self, x): | |
return [net(x) for net in self.net_list] | |
def forward(self, x, k): | |
return self.net_list[k](self.core_net(x)) | |
class NetWithPrior(nn.Module): | |
def __init__(self, net, prior, prior_scale=1.): | |
super(NetWithPrior, self).__init__() | |
self.net = net | |
self.prior_scale = prior_scale | |
if self.prior_scale > 0.: | |
self.prior = prior | |
def forward(self, x, k): | |
if hasattr(self.net, "net_list"): | |
if k is not None: | |
if self.prior_scale > 0.: | |
return self.net(x, k) + self.prior_scale * self.prior(x, k).detach() | |
else: | |
return self.net(x, k) | |
else: | |
core_cache = self.net._core(x) | |
net_heads = self.net._heads(core_cache) | |
if self.prior_scale <= 0.: | |
return net_heads | |
else: | |
prior_core_cache = self.prior._core(x) | |
prior_heads = self.prior._heads(prior_core_cache) | |
return [n + self.prior_scale * p.detach() for n, p in zip(net_heads, prior_heads)] | |
else: | |
raise ValueError("Only works with a net_list model") | |
prior_net = EnsembleNet(N_ENSEMBLE) | |
policy_net = EnsembleNet(N_ENSEMBLE) | |
policy_net = NetWithPrior(policy_net, prior_net, PRIOR_SCALE) | |
target_net = EnsembleNet(N_ENSEMBLE) | |
target_net = NetWithPrior(target_net, prior_net, PRIOR_SCALE) | |
opt = optim.Adam(policy_net.parameters(), lr=ADAM_LEARNING_RATE) | |
# change minibatch setup to use masking... | |
exp_replay = experience_replay(BATCH_SIZE, N_ENSEMBLE * BUFFER_SIZE) | |
next(exp_replay) # Start experience-replay coroutines | |
# Stores the total rewards from each evaluation, per head over all epochs | |
accumulation_rewards = [] | |
overall_time = 0. | |
for i in range(N_EPOCHS): | |
start = time.time() | |
ep = episode() | |
S, won = next(ep) # Start coroutine of single episode | |
epoch_losses = [0. for k in range(N_ENSEMBLE)] | |
epoch_steps = [1. for k in range(N_ENSEMBLE)] | |
heads = list(range(N_ENSEMBLE)) | |
random_state.shuffle(heads) | |
active_head = heads[0] | |
try: | |
policy_net.train() | |
while True: | |
if random_state.rand() < EPSILON: | |
action = random_state.randint(0, 3) | |
else: # Get the index of the maximum q-value of the model. | |
# Subtract one because actions are either -1, 0, or 1 | |
with torch.no_grad(): | |
action = np.argmax(policy_net(torch.Tensor(S[None]),active_head).detach().data.numpy(), axis=-1)[0] | |
S_prime, won = ep.send(action) | |
ongoing_flag = 1. | |
exp_mask = random_state.binomial(1, BERNOULLI_P, N_ENSEMBLE) | |
experience = (S, action, won, S_prime, ongoing_flag, exp_mask) | |
S = S_prime | |
batch = exp_replay.send(experience) | |
if batch: | |
inputs = [] | |
actions = [] | |
rewards = [] | |
nexts = [] | |
ongoing_flags = [] | |
masks = [] | |
for b_i in batch: | |
s, a, r, s_prime, ongoing_flag, mask = b_i | |
rewards.append(r) | |
inputs.append(s) | |
actions.append(a) | |
nexts.append(s_prime) | |
ongoing_flags.append(ongoing_flag) | |
masks.append(mask) | |
mask = torch.Tensor(np.array(masks)) | |
# precalculate the core Q values for every head | |
all_target_next_Qs = [n.detach() for n in target_net(torch.Tensor(nexts), None)] | |
all_Qs = policy_net(torch.Tensor(inputs), None) | |
if USE_DOUBLE_DQN: | |
all_policy_next_Qs = [n.detach() for n in policy_net(torch.Tensor(nexts), None)] | |
# set grads to 0 before iterating heads | |
opt.zero_grad() | |
for k in range(N_ENSEMBLE): | |
if USE_DOUBLE_DQN: | |
policy_next_Qs = all_policy_next_Qs[k] | |
next_Qs = all_target_next_Qs[k] | |
policy_actions = policy_next_Qs.max(1)[1][:, None] | |
next_max_Qs = next_Qs.gather(1, policy_actions) | |
next_max_Qs = next_max_Qs.squeeze() | |
else: | |
next_Qs = all_target_next_Qs[k] | |
next_max_Qs = next_Qs.max(1)[0] | |
next_max_Qs = next_max_Qs.squeeze() | |
# mask based on if it is end of episode or not | |
next_max_Qs = torch.Tensor(ongoing_flags) * next_max_Qs | |
target_Qs = torch.Tensor(np.array(rewards).astype("float32")) + GAMMA * next_max_Qs | |
# get current step predictions | |
Qs = all_Qs[k] | |
Qs = Qs.gather(1, torch.LongTensor(np.array(actions)[:, None].astype("int32"))) | |
Qs = Qs.squeeze() | |
# BROADCASTING! NEED TO MAKE SURE DIMS MATCH | |
# need to do updates on each head based on experience mask | |
full_loss = (Qs - target_Qs) ** 2 | |
full_loss = mask[:, k] * full_loss | |
loss = torch.mean(full_loss) | |
#loss = F.smooth_l1_loss(Qs, target_Qs[:, None]) | |
loss.backward(retain_graph=True) | |
for param in policy_net.parameters(): | |
if param.grad is not None: | |
# Multiply grads by 1 / K | |
param.grad.data *= 1. / N_ENSEMBLE | |
epoch_losses[k] += loss.detach().cpu().numpy() | |
epoch_steps[k] += 1. | |
# After iterating all heads, do the update step | |
torch.nn.utils.clip_grad_value_(policy_net.parameters(), CLIP_GRAD) | |
opt.step() | |
except StopIteration: | |
# add the end of episode experience | |
ongoing_flag = 0. | |
# just put in S, since it will get masked anyways | |
exp_mask = random_state.binomial(1, BERNOULLI_P, N_ENSEMBLE) | |
experience = (S, action, won, S, ongoing_flag, exp_mask) | |
exp_replay.send(experience) | |
stop = time.time() | |
overall_time += stop - start | |
if TARGET_UPDATE > 0 and i % TARGET_UPDATE == 0: | |
print("Updating target network at {}".format(i)) | |
target_net.load_state_dict(policy_net.state_dict()) | |
if i % PRINT_EVERY == 0: | |
print("Epoch {}, head {}, loss: {}".format(i + 1, active_head, [epoch_losses[k] / float(epoch_steps[k]) for k in range(N_ENSEMBLE)])) | |
if i % EVALUATE_EVERY == 0 or i == (N_EPOCHS - 1): | |
if i == (N_EPOCHS - 1): | |
# save images at the end for sure | |
SAVE_IMAGES = True | |
ORIG_N_EVALUATIONS = N_EVALUATIONS | |
N_EVALUATIONS = 5 | |
if SAVE_IMAGES: | |
img_saver = save_img(i) | |
next(img_saver) | |
evaluation_rewards = [] | |
for _ in range(N_EVALUATIONS): | |
g = episode() | |
S, reward = next(g) | |
reward_trace = [reward] | |
if SAVE_IMAGES: | |
img_saver.send((S, reward)) | |
try: | |
policy_net.eval() | |
while True: | |
acts = [np.argmax(q.data.numpy(), axis=-1)[0] for q in policy_net(torch.Tensor(S[None]), None)] | |
act_counts = Counter(acts) | |
max_count = max(act_counts.values()) | |
top_actions = [a for a in act_counts.keys() if act_counts[a] == max_count] | |
# break action ties with random choice | |
random_state.shuffle(top_actions) | |
act = top_actions[0] | |
S, reward = g.send(act) | |
reward_trace.append(reward) | |
if SAVE_IMAGES: | |
img_saver.send((S, reward)) | |
except StopIteration: | |
# sum should be either -1 or +1 | |
evaluation_rewards.append(np.sum(reward_trace)) | |
accumulation_rewards.append(np.mean(evaluation_rewards)) | |
print("Evaluation reward {}".format(accumulation_rewards[-1])) | |
if SAVE_IMAGES: | |
img_saver.close() | |
if i == (N_EPOCHS - 1): | |
plt.figure() | |
trace = np.array(accumulation_rewards) | |
xs = np.array([int(n * EVALUATE_EVERY) for n in range(N_EPOCHS // EVALUATE_EVERY + 1)]) | |
plt.plot(xs, trace, label="Reward") | |
plt.legend() | |
plt.ylabel("Average Evaluation Reward ({})".format(ORIG_N_EVALUATIONS)) | |
model = "Double DQN" if USE_DOUBLE_DQN else "DQN" | |
if N_ENSEMBLE > 1: | |
model = "Bootstrap " + model | |
if PRIOR_SCALE > 0.: | |
model = model + " with randomized prior {}".format(PRIOR_SCALE) | |
footnote_text = "Episodes\n" | |
footnote_text += "\n" | |
footnote_text += "\n" | |
footnote_text += "Settings:\n" | |
footnote_text += "{}\n".format(model) | |
footnote_text += "Number of heads {}\n".format(N_ENSEMBLE) | |
footnote_text += "Epsilon-greedy {}\n".format(EPSILON) | |
if N_ENSEMBLE > 1: | |
footnote_text += "Sharing mask probability {}\n".format(BERNOULLI_P) | |
footnote_text += "Gamma decay {}\n".format(GAMMA) | |
footnote_text += "Grad clip {}\n".format(CLIP_GRAD) | |
footnote_text += "Adam, learning rate {}\n".format(ADAM_LEARNING_RATE) | |
footnote_text += "Batch size {}\n".format(BATCH_SIZE) | |
footnote_text += "Experience replay buffer size {}\n".format(BUFFER_SIZE) | |
footnote_text += "Training time {}\n".format(overall_time) | |
plt.xlabel(footnote_text) | |
plt.tight_layout() | |
plt.savefig("reward_traces.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment