Skip to content

Instantly share code, notes, and snippets.

@lucataco
Created September 11, 2023 19:58
Show Gist options
  • Save lucataco/338ed0efd2041ddf093f2bace84a6aee to your computer and use it in GitHub Desktop.
Save lucataco/338ed0efd2041ddf093f2bace84a6aee to your computer and use it in GitHub Desktop.
Replicate-LoRA-manual-load-weights
import os
import torch
from diffusers import DiffusionPipeline, EulerDiscreteScheduler
from safetensors import safe_open
from dataset_and_utils import TokenEmbeddingsHandler
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
).to("cuda")
# K_EULER Scheduler
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
# Load safetensors
tensors = {}
with safe_open("weights/lora.safetensors", framework="pt", device="cuda") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
pipe.unet.load_state_dict(tensors, strict=False) # should take < 2 seconds
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
embhandler = TokenEmbeddingsHandler(text_encoders, tokenizers)
embhandler.load_embeddings("weights/embeddings.pti")
prompt="A <s0><s1> emoji of a man"
# seed=None
seed=57727
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
generator=torch.Generator("cuda").manual_seed(seed)
common_args = {
"prompt": prompt,
"guidance_scale": 7.5,
"generator": generator,
"num_inference_steps": 50,
}
image = pipe(**common_args).images[0]
image.save("output.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment