Skip to content

Instantly share code, notes, and snippets.

@cgarciae
Created August 28, 2020 03:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cgarciae/27fa515f792c33192c49c921423faa05 to your computer and use it in GitHub Desktop.
Save cgarciae/27fa515f792c33192c49c921423faa05 to your computer and use it in GitHub Desktop.
Loss
class MixtureNLL(elegy.Loss):
def call(self, y_true, y_pred):
y, probs = y_pred
y_true = jnp.broadcast_to(y_true, (y_true.shape[0], y.shape[1]))
return -safe_log(
jnp.sum(
probs
* jax.scipy.stats.norm.pdf(y_true, loc=y[..., 0], scale=y[..., 1]),
axis=1,
),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment