Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Created December 31, 2023 14:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sayakpaul/404ebe1601d05d7c2bc3dc66fa59dbf3 to your computer and use it in GitHub Desktop.
Save sayakpaul/404ebe1601d05d7c2bc3dc66fa59dbf3 to your computer and use it in GitHub Desktop.
# SDXL: 0.613, 0.5566, 0.54, 0.4162, 0.4042, 0.4596, 0.5374, 0.5286, 0.5038
# SD: 0.5396, 0.5707, 0.477, 0.4665, 0.5419, 0.4594, 0.4857, 0.4741, 0.4804
from diffusers import DiffusionPipeline
from peft import LoraConfig
import argparse
import torch
def load_pipeline(pipeline_id):
pipe = DiffusionPipeline.from_pretrained(pipeline_id)
return pipe
def get_lora_config():
rank = 4
torch.manual_seed(0)
text_lora_config = LoraConfig(
r=rank,
lora_alpha=rank,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
init_lora_weights=False,
)
torch.manual_seed(0)
unet_lora_config = LoraConfig(
r=rank,
lora_alpha=rank,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
)
return text_lora_config, unet_lora_config
def get_dummy_inputs():
pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": "np",
"generator": torch.manual_seed(0),
}
return pipeline_inputs
def run_inference(args):
pipe = load_pipeline(pipeline_id=args.pipeline_id)
text_lora_config, unet_lora_config = get_lora_config()
pipe.text_encoder.add_adapter(text_lora_config)
pipe.unet.add_adapter(unet_lora_config)
if hasattr(pipe, "text_encoder_2"):
pipe.text_encoder_2.add_adapter(text_lora_config)
inputs = get_dummy_inputs()
outputs = pipe(**inputs).images
predicted_slice = outputs[0, -3:, -3:, -1].flatten().tolist()
print(", ".join([str(round(x, 4)) for x in predicted_slice]))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--pipeline_id",
type=str,
default="hf-internal-testing/tiny-sd-pipe",
choices=[
"hf-internal-testing/tiny-sd-pipe",
"hf-internal-testing/tiny-sdxl-pipe",
],
)
args = parser.parse_args()
run_inference(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment