Created
September 21, 2023 17:42
-
-
Save entrpn/a8e12289a5a1a9c2c1859d6aa22cfa0d to your computer and use it in GitHub Desktop.
Flax AOT compile cache sdxl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gradio as gr | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
from flax.jax_utils import replicate | |
from diffusers import FlaxStableDiffusionXLPipeline | |
from flax.training.common_utils import shard | |
from jax import pmap | |
from jax.experimental.compilation_cache import compilation_cache as cc | |
cc.initialize_cache("/tmp/sdxl_cache") | |
import time | |
dtype = jnp.bfloat16 | |
model_id = "pcuenq/stable-diffusion-xl-base-1.0-flax" | |
def to_bf16(t): | |
return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype != jnp.bfloat16 else x, t) | |
def create_key(seed=0): | |
return jax.random.PRNGKey(seed) | |
def get_pipeline_params(): | |
rng = create_key(0) | |
rng = jax.random.split(rng, jax.device_count()) | |
prompt = 77 * "a" | |
prompt = [prompt] * jax.device_count() | |
prompt_ids = pipeline.prepare_inputs(prompt) | |
prompt_ids = shard(prompt_ids) | |
neg_prompt = 77 * "a" | |
neg_prompt = [neg_prompt] * jax.device_count() | |
neg_prompt_ids = pipeline.prepare_inputs(neg_prompt) | |
neg_prompt_ids = shard(neg_prompt_ids) | |
num_inference_steps = 40 | |
height = 1024 | |
width = 1024 | |
guidance_scale = 9 | |
g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32) | |
g = g[:, None] | |
return prompt_ids, rng, num_inference_steps, height, width, g, neg_prompt_ids | |
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( | |
model_id, | |
dtype=dtype, | |
) | |
params['vae'] = to_bf16(params['vae']) | |
params['text_encoder'] = to_bf16(params['text_encoder']) | |
params['text_encoder_2'] = to_bf16(params['text_encoder_2']) | |
params['unet'] = to_bf16(params['unet']) | |
p_params = replicate(params) | |
(prompt_ids, rng, num_inference_steps, | |
height, width, g, neg_prompt) = get_pipeline_params() | |
start_time = time.time() | |
p_generate = pmap( | |
pipeline._generate, | |
static_broadcasted_argnums=[3, 4, 5, 9] | |
).lower( | |
prompt_ids, | |
p_params, | |
rng, | |
num_inference_steps, | |
height, | |
width, | |
g, | |
None, | |
neg_prompt, | |
False).compile() | |
print("Compile time:", time.time() - start_time) | |
def generate(): | |
print("Start...") | |
print("Version", jax.__version__) | |
for _ in range(3): | |
(prompt_ids, rng, num_inference_steps, | |
_, _, g, neg_prompt) = get_pipeline_params() | |
start_time = time.time() | |
images = p_generate(prompt_ids, p_params, rng, g, None, neg_prompt) | |
images = images.block_until_ready() | |
end_time = time.time() | |
print(f"For {num_inference_steps} steps", end_time - start_time) | |
print("Avg per step", (end_time - start_time) / num_inference_steps) | |
with gr.Blocks(css="style.css") as demo: | |
batch_size = gr.Slider( | |
label="Batch size", | |
minimum=0, | |
maximum=16, | |
step=1, | |
value=1, | |
) | |
btn = gr.Button("Benchmark!").style( | |
margin=False, | |
rounded=(False, True, True, False), | |
full_width=False, | |
) | |
btn.click(fn=generate, inputs=[]) | |
demo.launch(share=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment