Skip to content

Instantly share code, notes, and snippets.

View zhangqiaorjc's full-sized avatar

Qiao Zhang zhangqiaorjc

View GitHub Profile
import itertools as it
import jax
import jax.numpy as jnp
jax.config.update('jax_enable_x64', True)
jax.config.update('jax_platform_name', 'cpu')
L = num_stages = 5
N = batch_size = 6
# zhangqiaorjc@google.com
import atexit
import functools
from absl import app
from absl import flags
from absl import logging
import jax
from jax.lib import xla_extension as xc