Skip to content

Instantly share code, notes, and snippets.

@manujosephv

manujosephv/base_mdn.py Secret

Last active Mar 15, 2021
Embed
What would you like to do?
#Sample Implementation for educational purposes
#For full implementation check out https://github.com/manujosephv/pytorch_tabular
class BaseMDN(BaseModel):
def __init__(self, config: DictConfig, **kwargs):
super().__init__(config, **kwargs)
@abstractmethod
def unpack_input(self, x: Dict):
pass
def forward(self, x: Dict):
x = self.unpack_input(x)
x = self.backbone(x)
pi, sigma, mu = self.mdn(x)
return {"pi": pi, "sigma": sigma, "mu": mu, "backbone_features": x}
def sample(self, x: Dict, n_samples: Optional[int] = None, ret_model_output = False):
ret_value = self.forward(x)
samples= self.mdn.generate_samples(
ret_value["pi"], ret_value["sigma"], ret_value["mu"], n_samples
)
if ret_model_output:
return samples, ret_value
else:
return samples
def calculate_loss(self, y, pi, sigma, mu, tag="train"):
# NLL Loss
log_prob = self.mdn.log_prob(pi, sigma, mu, y)
loss = torch.mean(-log_prob)
if self.hparams.mdn_config.weight_regularization is not None:
sigma_l1_reg = 0
pi_l1_reg = 0
mu_l1_reg = 0
if self.hparams.mdn_config.lambda_sigma > 0:
# Weight Regularization Sigma
sigma_params = torch.cat(
[x.view(-1) for x in self.mdn.sigma.parameters()]
)
sigma_l1_reg = self.hparams.mdn_config.lambda_sigma * torch.norm(
sigma_params, self.hparams.mdn_config.weight_regularization
)
if self.hparams.mdn_config.lambda_pi > 0:
pi_params = torch.cat([x.view(-1) for x in self.mdn.pi.parameters()])
pi_l1_reg = self.hparams.mdn_config.lambda_sigma * torch.norm(
pi_params, self.hparams.mdn_config.weight_regularization
)
if self.hparams.mdn_config.lambda_mu > 0:
mu_params = torch.cat([x.view(-1) for x in self.mdn.mu.parameters()])
mu_l1_reg = self.hparams.mdn_config.lambda_mu * torch.norm(
mu_params, self.hparams.mdn_config.weight_regularization
)
loss = loss + sigma_l1_reg + pi_l1_reg + mu_l1_reg
self.log(
f"{tag}_loss",
loss,
on_epoch=(tag == "valid"),
on_step=(tag == "train"),
# on_step=False,
logger=True,
prog_bar=True,
)
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment