Skip to content

Instantly share code, notes, and snippets.

@takuma-yoneda
Created May 13, 2020 01:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save takuma-yoneda/d9a9e65d7542d779790056b4b45d0b37 to your computer and use it in GitHub Desktop.
Save takuma-yoneda/d9a9e65d7542d779790056b4b45d0b37 to your computer and use it in GitHub Desktop.
Code pieces in planet implementation
def calc_loss(self, trajectory):
'''take a trajectory to calculate losses'''
Loss = namedtuple('loss', ['summary', 'reconst_term', 'kl_term'])
images, best_pix_ind = trajectory[0]
obs_0 = images[0][0].to(self.device)
posterior = self.encoder(obs_0)
prev_state = posterior.rsample()
# loss terms:
tot_reconst_term = 0
tot_kl_term = 0
for i, sample in enumerate(trajectory):
images, best_pix_ind = sample
obs = images[0][1].to(self.device)
best_pix_ind = best_pix_ind.to(self.device)
action = best_pix_ind
# prev_state = sample(posterior)
# calculate KL divergence
prior = self.transition(torch.cat((prev_state, action), 1))
posterior = self.encoder(obs)
kl_term = kl_divergence(posterior, prior)
# calculate cross-entropy --> MSELoss
state = posterior.rsample()
reconst_distr = self.decoder(state)
# reconst_stddev = torch.ones((1, reconst_mean.numel()), device=self.device)
# reconst_stddev = np.eye(reconst_mean.shape[1])
# reconst_stddev = np.ones((1, reconst_mean.shape[0], reconst_mean.shape[1], reconst_mean.shape[2]))
# reconst_distr = reconst_mean.view((1, -1))
reconst_term = reconst_distr.log_prob(obs.view((1, -1)))
# add up loss terms
tot_reconst_term += reconst_term
tot_kl_term += kl_term
prev_state = state.clone().detach() # NOTE: copy tensor without gradients
# prev_state = torch.tensor(state, requires_grad=False) # NOTE: copy tensor without gradients
total_loss = - (tot_reconst_term - tot_kl_term)
return Loss(summary=total_loss, reconst_term=tot_reconst_term, kl_term=tot_kl_term)
class GaussianWrapper:
def __init__(self, model, device):
self.model = model # PyTorch model (nn.Module)
self.device = device
def __call__(self, *args, **kwargs):
out = self.model(*args, **kwargs)
if len(out) == 2:
mean, std_dev = out
else:
mean = out.view((1, -1))
std_dev = torch.ones((1, mean.numel()), device=self.device)
return to_gaussian(mean, std_dev)
def __getattr__(self, name):
if name.startswith('_'):
raise AttributeError("attempted to get missing private attribute '{}'".format(name))
return getattr(self.model, name)
def __str__(self):
return '<{}{}>'.format(type(self).__name__, self.model)
def to_gaussian(mean, std_dev):
"""assume diagonal covariance"""
if len(mean.shape) > 2:
assert len(mean.shape) == 4
std_dev = std_dev
else:
assert len(mean.shape) == 2
# std_dev = torch.diag(std_dev[0]).view((1, mean.shape[1], mean.shape[1]))
std_dev = std_dev
# NOTE: Diagonal Multivariate Normal (https://github.com/pytorch/pytorch/pull/11178)
return Independent(Normal(mean, std_dev), 1)
# return MultivariateNormal(mean, scale_tril=std_dev)
@takuma-yoneda
Copy link
Author

self.transition, self.encoder and self.decoder is a PyTorch model (a class inheritingnn.Module) that is wrapped by GaussianWrapper class.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment