-
-
Save ethancaballero/ab19ec0b3e5d8ab2a9f515b6125d6c80 to your computer and use it in GitHub Desktop.
Trust region update
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Adapted from https://github.com/pfnet/chainerrl/blob/master/chainerrl/agents/acer.py#L203 | |
import torch | |
from torch import nn | |
from torch.autograd import Variable | |
from torch.nn import functional as F | |
def update_avg(shared_avg_model, model, alpha): | |
for shared_avg_param, param in zip(shared_avg_model.parameters(), model.parameters()): | |
shared_avg_param.data = shared_avg_param.data*alpha.expand_as(shared_avg_param) + (1-alpha).expand_as(param)*param.data | |
# Computes a trust region loss based on an existing loss and two distributions | |
# model/distribution/loss is from the most recent params; ref_distribution is from the average model's params | |
def _trust_region_loss(model, distribution, ref_distribution, loss, threshold): | |
# Compute gradients from original loss | |
model.zero_grad() | |
loss.backward(retain_graph=True) | |
g = [Variable(p.grad.data.clone()) for p in model.parameters()] | |
model.zero_grad() | |
# KL divergence k = grad theta0*DKL[pi(*|s_i; theta_a) || pi(*|s_i; theta)] | |
kl = F.kl_div(distribution.log(), ref_distribution, size_average=False) | |
# Compute gradients from (negative) KL loss (increases KL divergence) | |
(-kl).backward(retain_graph=True) | |
k = [Variable(p.grad.data.clone()) for p in model.parameters()] | |
model.zero_grad() | |
# Compute dot products of gradients | |
k_dot_g = sum(torch.sum(k_p * g_p) for k_p, g_p in zip(k, g)) | |
k_dot_k = sum(torch.sum(k_p ** 2) for k_p in k) | |
# Compute trust region update | |
if k_dot_k.data.numpy()[0] > 0: | |
trust_factor = torch.clamp((k_dot_g-threshold)/k_dot_k, min=0) | |
else: | |
trust_factor = Variable(torch.zeros(1)) | |
# z* = g - max(0, (k^T*g - delta) / ||k||^2_2)*k | |
z_star = [g_p - trust_factor.expand_as(k_p) * k_p for g_p, k_p in zip(g, k)] | |
trust_loss = 0 | |
for param, z_star_p in zip(model.parameters(), z_star): | |
trust_loss += (param * z_star_p).sum() | |
return trust_loss | |
l_r = 1e-4 | |
m = nn.Sequential(nn.Linear(5, 3), nn.Softmax()) | |
r_m = nn.Sequential(nn.Linear(5, 3), nn.Softmax()) | |
x = Variable(torch.ones(1, 5)) | |
y = Variable(torch.ones(1, 3)) | |
d = m(x) | |
r_d = r_m(x) | |
l = (d - y).mean(1) | |
t_r_l = _trust_region_loss(m, d, r_d, l, 1) | |
t_r_l.backward() | |
for param in m.parameters(): | |
param.data -= l_r * param.grad.data | |
update_avg(r_m, m, torch.FloatTensor(1).zero_()+.99) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, I got a RuntimeError: the derivative for 'target' is not implemented...