Skip to content

Instantly share code, notes, and snippets.

@entrpn
Created September 21, 2023 17:45
Show Gist options
  • Save entrpn/687774b554ed29956f43f05a204adb65 to your computer and use it in GitHub Desktop.
Save entrpn/687774b554ed29956f43f05a204adb65 to your computer and use it in GitHub Desktop.
Flax SDXL inference
import jax
import jax.numpy as jnp
import numpy as np
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind
assert "TPU" in device_type, "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"
from flax.jax_utils import replicate
from diffusers import FlaxStableDiffusionXLPipeline
from flax.training.common_utils import shard
dtype = jnp.bfloat16
model_id = "pcuenq/stable-diffusion-xl-base-1.0-flax"
def to_bf16(t):
return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype != jnp.bfloat16 else x, t)
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
model_id,
use_safetensors=True,
dtype=dtype,
)
params['vae'] = to_bf16(params['vae'])
params['text_encoder'] = to_bf16(params['text_encoder'])
params['text_encoder_2'] = to_bf16(params['text_encoder_2'])
params['unet'] = to_bf16(params['unet'])
imgs_per_device = 1
prompt = "a colorful photo of a castle in the middle of a forest with trees and bushes, by Ismail Inceoglu, shadows, high contrast, dynamic shading, hdr, detailed vegetation, digital painting, digital drawing, detailed painting, a detailed digital painting, gothic art, featured on deviantart"
prompt = [prompt] * jax.device_count() * imgs_per_device
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids = shard(prompt_ids)
neg_prompt = "fog, grainy, purple"
neg_prompt = [neg_prompt] * jax.device_count() * imgs_per_device
neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
neg_prompt_ids = shard(neg_prompt_ids)
p_params = replicate(params)
def create_key(seed=0):
return jax.random.PRNGKey(seed)
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
do_jit = True
def generate(prompt_ids, neg_prompt_ids):
return pipeline(
prompt_ids if do_jit else prompt_ids[0],
p_params if do_jit else params,
rng if do_jit else rng[0],
num_inference_steps=40,
neg_prompt_ids=neg_prompt_ids if do_jit else neg_prompt_ids[0],
guidance_scale = 9.,
jit=do_jit,
).images
import time
start = time.time()
_ = generate(prompt_ids, neg_prompt_ids)
print(f"Compiled in {time.time() - start}")
start = time.time()
for _ in range(5):
images = generate(prompt_ids, neg_prompt_ids)
print(f"Inference in {(time.time() - start)/5}")
print("images.shape:",images.shape)
trace_path = "/tmp/tensorboard"
with jax.profiler.trace(trace_path):
images = generate(prompt_ids, neg_prompt_ids)
print("images.shape:",images.shape)
print("images.dtype:",images.dtype)
images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
images = pipeline.numpy_to_pil(np.array(images))
for i, image in enumerate(images):
image.save(f"castle_{i}.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment