Skip to content

Instantly share code, notes, and snippets.

@Geson-anko
Last active June 26, 2024 02:16
Show Gist options
  • Save Geson-anko/c2847695bcb170658aff7ad23ea8876b to your computer and use it in GitHub Desktop.
Save Geson-anko/c2847695bcb170658aff7ad23ea8876b to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor, Size
from torch.distributions import Distribution, Normal
import warnings
class NormalMixture(Distribution):
"""Computes Mixture Density Distribution of Normal distribution."""
SQRT_2_PI = (2 * torch.pi) ** 0.5
arg_constraints = {}
def __init__(
self,
log_pi: Tensor,
mu: Tensor,
sigma: Tensor,
eps: float = 1e-6,
validate_args: bool | None = None
) -> None:
"""Constructor for the NormalMixture class.
This constructor initializes the parameters of the mixture normal distribution and calls the parent class constructor.
log_pi, mu, sigma are must be same shape.
Args:
log_pi (Tensor): Tensor representing the mixture log ratios of each normal distribution.
mu (Tensor): Tensor representing the means of each normal distribution.
sigma (Tensor): Tensor representing the standard deviations of each normal distribution.
eps (float): A small value for numerical stability.
validate_args (bool | None): Whether to validate the arguments.
Shape:
log_pi, mu, sigma: (*, Components)
"""
assert log_pi.shape == mu.shape == sigma.shape
batch_shape = log_pi.shape[:-1]
super().__init__(batch_shape, validate_args)
self.log_pi = log_pi
self.mu = mu
self.sigma = sigma
self.eps = eps
def _extended_shape(self, shape: Size) -> Size:
return *self.log_pi.shape[:-1], *shape
def rsample(self, sample_shape: Size = Size()) -> Tensor:
if len(sample_shape) != 0:
warnings.warn(f"Not implemented for specfied sample shape: {sample_shape}")
pi = self.log_pi.exp()
samples = torch.multinomial(pi.view(-1, pi.size(-1)), 1, ).view(*pi.shape[:-1], 1)
sample_mu = self.mu.gather(-1, samples).squeeze(-1)
sample_sigma = self.sigma.gather(-1, samples).squeeze(-1)
return torch.randn_like(sample_mu) * sample_sigma + sample_mu
def sample(self, sample_shape: Size = Size()) -> Tensor:
return self.rsample(sample_shape)
def log_prob(self, value: Tensor) -> Tensor:
normal_prob = - 0.5 * ((value.unsqueeze(-1) - self.mu) / (self.sigma + self.eps)) ** 2 - torch.log(self.SQRT_2_PI * self.sigma + self.eps)
return torch.logsumexp(self.log_pi + normal_prob, -1)
class MixtureDensityNetwork(nn.Module):
def __init__(self, in_features: int, out_features: int, num_components: int) -> None:
super().__init__()
self.mu_layers = nn.ModuleList(nn.Linear(in_features, out_features) for _ in range(num_components))
self.sigma_layers = nn.ModuleList(nn.Linear(in_features, out_features) for _ in range(num_components))
self.logits_layers = nn.ModuleList(nn.Linear(in_features, out_features) for _ in range(num_components))
def forward(self, x: Tensor) -> NormalMixture:
mu = torch.stack([l(x) for l in self.mu_layers], dim=-1)
sigma = torch.stack([F.softplus(l(x)) for l in self.sigma_layers], dim=-1)
log_pi = torch.stack([l(x) for l in self.logits_layers], dim=-1).log_softmax(-1)
return NormalMixture(log_pi, mu, sigma)
if __name__ == "__main__":
shape = (3, 2, 8)
log_pi = torch.log_softmax(torch.randn(shape), -1)
mu = torch.randn(shape)
std = torch.randn(shape).abs()
dist = NormalMixture(log_pi, mu, std)
out = dist.sample()
assert out.shape == (3,2)
assert dist.log_prob(out).shape == (3,2)
net = MixtureDensityNetwork(8, 4, 10)
out = net(torch.randn(8))
assert isinstance(out, NormalMixture)
assert out.sample().shape == (4, )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment