Last active
June 26, 2024 02:16
-
-
Save Geson-anko/c2847695bcb170658aff7ad23ea8876b to your computer and use it in GitHub Desktop.
混合密度ネットワークを実装したもの。https://github.com/tonyduan/mixture-density-network , https://mikedusenberry.com/mixture-density-networks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import Tensor, Size | |
from torch.distributions import Distribution, Normal | |
import warnings | |
class NormalMixture(Distribution): | |
"""Computes Mixture Density Distribution of Normal distribution.""" | |
SQRT_2_PI = (2 * torch.pi) ** 0.5 | |
arg_constraints = {} | |
def __init__( | |
self, | |
log_pi: Tensor, | |
mu: Tensor, | |
sigma: Tensor, | |
eps: float = 1e-6, | |
validate_args: bool | None = None | |
) -> None: | |
"""Constructor for the NormalMixture class. | |
This constructor initializes the parameters of the mixture normal distribution and calls the parent class constructor. | |
log_pi, mu, sigma are must be same shape. | |
Args: | |
log_pi (Tensor): Tensor representing the mixture log ratios of each normal distribution. | |
mu (Tensor): Tensor representing the means of each normal distribution. | |
sigma (Tensor): Tensor representing the standard deviations of each normal distribution. | |
eps (float): A small value for numerical stability. | |
validate_args (bool | None): Whether to validate the arguments. | |
Shape: | |
log_pi, mu, sigma: (*, Components) | |
""" | |
assert log_pi.shape == mu.shape == sigma.shape | |
batch_shape = log_pi.shape[:-1] | |
super().__init__(batch_shape, validate_args) | |
self.log_pi = log_pi | |
self.mu = mu | |
self.sigma = sigma | |
self.eps = eps | |
def _extended_shape(self, shape: Size) -> Size: | |
return *self.log_pi.shape[:-1], *shape | |
def rsample(self, sample_shape: Size = Size()) -> Tensor: | |
if len(sample_shape) != 0: | |
warnings.warn(f"Not implemented for specfied sample shape: {sample_shape}") | |
pi = self.log_pi.exp() | |
samples = torch.multinomial(pi.view(-1, pi.size(-1)), 1, ).view(*pi.shape[:-1], 1) | |
sample_mu = self.mu.gather(-1, samples).squeeze(-1) | |
sample_sigma = self.sigma.gather(-1, samples).squeeze(-1) | |
return torch.randn_like(sample_mu) * sample_sigma + sample_mu | |
def sample(self, sample_shape: Size = Size()) -> Tensor: | |
return self.rsample(sample_shape) | |
def log_prob(self, value: Tensor) -> Tensor: | |
normal_prob = - 0.5 * ((value.unsqueeze(-1) - self.mu) / (self.sigma + self.eps)) ** 2 - torch.log(self.SQRT_2_PI * self.sigma + self.eps) | |
return torch.logsumexp(self.log_pi + normal_prob, -1) | |
class MixtureDensityNetwork(nn.Module): | |
def __init__(self, in_features: int, out_features: int, num_components: int) -> None: | |
super().__init__() | |
self.mu_layers = nn.ModuleList(nn.Linear(in_features, out_features) for _ in range(num_components)) | |
self.sigma_layers = nn.ModuleList(nn.Linear(in_features, out_features) for _ in range(num_components)) | |
self.logits_layers = nn.ModuleList(nn.Linear(in_features, out_features) for _ in range(num_components)) | |
def forward(self, x: Tensor) -> NormalMixture: | |
mu = torch.stack([l(x) for l in self.mu_layers], dim=-1) | |
sigma = torch.stack([F.softplus(l(x)) for l in self.sigma_layers], dim=-1) | |
log_pi = torch.stack([l(x) for l in self.logits_layers], dim=-1).log_softmax(-1) | |
return NormalMixture(log_pi, mu, sigma) | |
if __name__ == "__main__": | |
shape = (3, 2, 8) | |
log_pi = torch.log_softmax(torch.randn(shape), -1) | |
mu = torch.randn(shape) | |
std = torch.randn(shape).abs() | |
dist = NormalMixture(log_pi, mu, std) | |
out = dist.sample() | |
assert out.shape == (3,2) | |
assert dist.log_prob(out).shape == (3,2) | |
net = MixtureDensityNetwork(8, 4, 10) | |
out = net(torch.randn(8)) | |
assert isinstance(out, NormalMixture) | |
assert out.sample().shape == (4, ) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment