Skip to content

Instantly share code, notes, and snippets.

@oborchers
Created February 12, 2019 17:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save oborchers/86c095a748e46c3dd6007e43e019b1f0 to your computer and use it in GitHub Desktop.
Save oborchers/86c095a748e46c3dd6007e43e019b1f0 to your computer and use it in GitHub Desktop.
Computation of the MDN Loss Function
from tensorflow_probability import distributions as tfd
def slice_parameter_vectors(parameter_vector):
""" Returns an unpacked list of paramter vectors.
"""
return [parameter_vector[:,i*components:(i+1)*components] for i in range(no_parameters)]
def gnll_loss(y, parameter_vector):
""" Computes the mean negative log-likelihood loss of y given the mixture parameters.
"""
alpha, mu, sigma = slice_parameter_vectors(parameter_vector) # Unpack parameter vectors
gm = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(probs=alpha),
components_distribution=tfd.Normal(
loc=mu,
scale=sigma))
log_likelihood = gm.log_prob(tf.transpose(y)) # Evaluate log-probability of y
return -tf.reduce_mean(log_likelihood, axis=-1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment