Skip to content

Instantly share code, notes, and snippets.

View evanatyourservice's full-sized avatar

Evan Walters evanatyourservice

  • Denver, CO
View GitHub Profile
@evanatyourservice
evanatyourservice / hellaswag_jax.py
Last active August 24, 2024 01:25
How to prepare and evaluate on hellaswag in JAX
import json
from typing import Optional
from tqdm import tqdm
import numpy as np
import jax
from jax import jit
import jax.numpy as jnp
from flax.training.train_state import TrainState
import optax
@evanatyourservice
evanatyourservice / Beta-TCVAE in JAX Flax
Created January 1, 2024 20:10
Beta-TCVAE in JAX Flax
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from tensorflow_probability.substrates.jax import distributions as tfd
"""
There's a typo in most B-TCVAE implementations on github, so I thought I'd make a
quick gist of a working B-TCVAE.