Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
pytorch stable normal using log_scale
import math
import torch
from torch.distributions import Normal
from torch.distributions.utils import broadcast_all, _standard_normal
from torch.distributions.kl import register_kl
class StableNormal(Normal):
"""Modified version that uses log_scale for stability of grad."""
def __init__(self, loc, log_scale):
# super().__init__()
self.loc, self.log_scale = broadcast_all(loc, log_scale)
batch_shape = self.loc.size()
super(Normal, self).__init__(batch_shape)
@property
def scale(self):
return torch.exp(self.log_scale)
def log_prob(self, value):
"""compute the variance, modified to use log_scale for stability."""
var = (self.scale ** 2)
return - ((value - self.loc)** 2) / (2 * var) - self.log_scale - math.log(math.sqrt(2 * math.pi))
@register_kl(StableNormal, StableNormal)
def _kl_normal_normal(p, q):
"""
Modified to use log scale for stability.
Modified from https://github.com/pytorch/pytorch/blob/317b78d56ed434bb52030c3472affbd0feed8344/torch/distributions/kl.py#L407
"""
var_ratio_log = (p.log_scale - q.log_scale) * 2
t1 = ((p.loc - q.loc) / q.scale).pow(2)
return 0.5 * (var_ratio_log.exp() + t1 - 1 - var_ratio_log)
if __name__ == '__main__':
# Test
mean = torch.rand((10, 2, 4))
loc = torch.rand((10, 2, 4))
x = torch.rand((10, 2, 4))
a=Normal(mean, loc)
b=StableNormal(mean, loc.log())
mean2 = torch.rand((10, 2, 4))+1
loc2 = torch.rand((10, 2, 4))+1
x2 = torch.rand((10, 2, 4))
a2=Normal(mean2, loc2)
b2=StableNormal(mean2, loc2.log())
assert a.log_prob(x).mean()== b.log_prob(x).mean()
assert torch.distributions.kl_divergence(a, a).mean()==0
assert torch.distributions.kl_divergence(b, b).mean()==0
kld_a=torch.distributions.kl_divergence(b, b2)
kld_b = torch.distributions.kl_divergence(a, a2)
torch.testing.assert_allclose(kld_a, kld_b)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment