Skip to content

Instantly share code, notes, and snippets.

@cgarciae
Created August 28, 2020 03:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cgarciae/184d735141547f12e01d629e8e85ee72 to your computer and use it in GitHub Desktop.
Save cgarciae/184d735141547f12e01d629e8e85ee72 to your computer and use it in GitHub Desktop.
import numpy as np
import jax
import jax.numpy as jnp
import elegy
import optax
class MixtureModel(elegy.Module):
def __init__(self, k: int):
super().__init__()
self.k = k
def call(self, x):
x = elegy.nn.Linear(64, name="backbone")(x)
x = jax.nn.relu(x)
y: np.ndarray = jnp.stack(
[
elegy.nn.Linear(2, name="component")(x)
for _ in range(self.k)
],
axis=1,
)
# equivalent to: y[..., 1] = 1.0 + jax.nn.elu(y[..., 1])
y = jax.ops.index_update(y, jax.ops.index[..., 1], 1.0 + jax.nn.elu(y[..., 1]))
logits = elegy.nn.Linear(self.k, name="gating")(x)
probs = jax.nn.softmax(logits, axis=-1)
return y, probs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment