Skip to content

Instantly share code, notes, and snippets.

View entrpn's full-sized avatar

Juan Acevedo entrpn

  • Google
  • Lake Forest, CA
View GitHub Profile
@entrpn
entrpn / infer.py
Created September 21, 2023 17:45
Flax SDXL inference
import jax
import jax.numpy as jnp
import numpy as np
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind
assert "TPU" in device_type, "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"
from flax.jax_utils import replicate
from diffusers import FlaxStableDiffusionXLPipeline
@entrpn
entrpn / app.py
Created September 21, 2023 17:42
Flax AOT compile cache sdxl
import gradio as gr
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from diffusers import FlaxStableDiffusionXLPipeline
from flax.training.common_utils import shard