Skip to content

Instantly share code, notes, and snippets.

View evanatyourservice's full-sized avatar

Evan Walters evanatyourservice

  • Louisville, CO
View GitHub Profile
@evanatyourservice
evanatyourservice / Beta-TCVAE in JAX Flax
Created January 1, 2024 20:10
Beta-TCVAE in JAX Flax
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from tensorflow_probability.substrates.jax import distributions as tfd
"""
There's a typo in most B-TCVAE implementations on github, so I thought I'd make a
quick gist of a working B-TCVAE.