Skip to content

Instantly share code, notes, and snippets.

@entrpn
Created September 21, 2023 17:42
Show Gist options
  • Save entrpn/a8e12289a5a1a9c2c1859d6aa22cfa0d to your computer and use it in GitHub Desktop.
Save entrpn/a8e12289a5a1a9c2c1859d6aa22cfa0d to your computer and use it in GitHub Desktop.
Flax AOT compile cache sdxl
import gradio as gr
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from diffusers import FlaxStableDiffusionXLPipeline
from flax.training.common_utils import shard
from jax import pmap
from jax.experimental.compilation_cache import compilation_cache as cc
cc.initialize_cache("/tmp/sdxl_cache")
import time
dtype = jnp.bfloat16
model_id = "pcuenq/stable-diffusion-xl-base-1.0-flax"
def to_bf16(t):
return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype != jnp.bfloat16 else x, t)
def create_key(seed=0):
return jax.random.PRNGKey(seed)
def get_pipeline_params():
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
prompt = 77 * "a"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids = shard(prompt_ids)
neg_prompt = 77 * "a"
neg_prompt = [neg_prompt] * jax.device_count()
neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
neg_prompt_ids = shard(neg_prompt_ids)
num_inference_steps = 40
height = 1024
width = 1024
guidance_scale = 9
g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)
g = g[:, None]
return prompt_ids, rng, num_inference_steps, height, width, g, neg_prompt_ids
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
model_id,
dtype=dtype,
)
params['vae'] = to_bf16(params['vae'])
params['text_encoder'] = to_bf16(params['text_encoder'])
params['text_encoder_2'] = to_bf16(params['text_encoder_2'])
params['unet'] = to_bf16(params['unet'])
p_params = replicate(params)
(prompt_ids, rng, num_inference_steps,
height, width, g, neg_prompt) = get_pipeline_params()
start_time = time.time()
p_generate = pmap(
pipeline._generate,
static_broadcasted_argnums=[3, 4, 5, 9]
).lower(
prompt_ids,
p_params,
rng,
num_inference_steps,
height,
width,
g,
None,
neg_prompt,
False).compile()
print("Compile time:", time.time() - start_time)
def generate():
print("Start...")
print("Version", jax.__version__)
for _ in range(3):
(prompt_ids, rng, num_inference_steps,
_, _, g, neg_prompt) = get_pipeline_params()
start_time = time.time()
images = p_generate(prompt_ids, p_params, rng, g, None, neg_prompt)
images = images.block_until_ready()
end_time = time.time()
print(f"For {num_inference_steps} steps", end_time - start_time)
print("Avg per step", (end_time - start_time) / num_inference_steps)
with gr.Blocks(css="style.css") as demo:
batch_size = gr.Slider(
label="Batch size",
minimum=0,
maximum=16,
step=1,
value=1,
)
btn = gr.Button("Benchmark!").style(
margin=False,
rounded=(False, True, True, False),
full_width=False,
)
btn.click(fn=generate, inputs=[])
demo.launch(share=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment