Skip to content

Instantly share code, notes, and snippets.

@ethancaballero
Last active Apr 24, 2018
Embed
What would you like to do?
Trust region update
# Adapted from https://github.com/pfnet/chainerrl/blob/master/chainerrl/agents/acer.py#L203
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F
def update_avg(shared_avg_model, model, alpha):
for shared_avg_param, param in zip(shared_avg_model.parameters(), model.parameters()):
shared_avg_param.data = shared_avg_param.data*alpha.expand_as(shared_avg_param) + (1-alpha).expand_as(param)*param.data
# Computes a trust region loss based on an existing loss and two distributions
# model/distribution/loss is from the most recent params; ref_distribution is from the average model's params
def _trust_region_loss(model, distribution, ref_distribution, loss, threshold):
# Compute gradients from original loss
model.zero_grad()
loss.backward(retain_graph=True)
g = [Variable(p.grad.data.clone()) for p in model.parameters()]
model.zero_grad()
# KL divergence k = grad theta0*DKL[pi(*|s_i; theta_a) || pi(*|s_i; theta)]
kl = F.kl_div(distribution.log(), ref_distribution, size_average=False)
# Compute gradients from (negative) KL loss (increases KL divergence)
(-kl).backward(retain_graph=True)
k = [Variable(p.grad.data.clone()) for p in model.parameters()]
model.zero_grad()
# Compute dot products of gradients
k_dot_g = sum(torch.sum(k_p * g_p) for k_p, g_p in zip(k, g))
k_dot_k = sum(torch.sum(k_p ** 2) for k_p in k)
# Compute trust region update
if k_dot_k.data.numpy()[0] > 0:
trust_factor = torch.clamp((k_dot_g-threshold)/k_dot_k, min=0)
else:
trust_factor = Variable(torch.zeros(1))
# z* = g - max(0, (k^T*g - delta) / ||k||^2_2)*k
z_star = [g_p - trust_factor.expand_as(k_p) * k_p for g_p, k_p in zip(g, k)]
trust_loss = 0
for param, z_star_p in zip(model.parameters(), z_star):
trust_loss += (param * z_star_p).sum()
return trust_loss
l_r = 1e-4
m = nn.Sequential(nn.Linear(5, 3), nn.Softmax())
r_m = nn.Sequential(nn.Linear(5, 3), nn.Softmax())
x = Variable(torch.ones(1, 5))
y = Variable(torch.ones(1, 3))
d = m(x)
r_d = r_m(x)
l = (d - y).mean(1)
t_r_l = _trust_region_loss(m, d, r_d, l, 1)
t_r_l.backward()
for param in m.parameters():
param.data -= l_r * param.grad.data
update_avg(r_m, m, torch.FloatTensor(1).zero_()+.99)
@keven425
Copy link

keven425 commented Jun 6, 2017

Hi, I'm getting this error

TypeError: backward() got an unexpected keyword argument 'retain_graph'

for line

loss.backward(retain_graph=True)

Is the retain_graph keyword arg required?

*** UPDATE ***
NVM. I see that retain_graph is retain_variables for <= 0.1.12.

@leesunfreshing
Copy link

leesunfreshing commented Apr 24, 2018

Hi, I got a RuntimeError: the derivative for 'target' is not implemented...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment