Skip to content

Instantly share code, notes, and snippets.

@cmdr2
Last active July 29, 2023 20:07
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save cmdr2/2e7ca7f9630e553bae4a7eb599c3c526 to your computer and use it in GitHub Desktop.
Save cmdr2/2e7ca7f9630e553bae4a7eb599c3c526 to your computer and use it in GitHub Desktop.
import time
MODEL_PATH = "F:/models/stable-diffusion/sd-v1-4.ckpt"
CONFIG_PATH = "F:/models/stable-diffusion/v1-inference.yaml"
def diff():
print('diffusers')
import torch
from transformers import logging as tr_logging
tr_logging.set_verbosity_error() # suppress unnecessary logging
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import load_pipeline_from_original_stable_diffusion_ckpt
extract_ema = False
print("Loading pipeline from original stable diffusion checkpoint")
t = time.time()
generator = torch.Generator("cuda").manual_seed(42)
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path=MODEL_PATH,
original_config_file=CONFIG_PATH,
extract_ema=extract_ema,
from_safetensors=False,
)
pipe = pipe.to('cuda:0').to(torch.float16)
pipe.enable_attention_slicing()
# pipe.enable_vae_slicing()
# pipe.enable_sequential_cpu_offload()
# pipe.save_pretrained('diff')
print("Loading complete in ", time.time() - t, 'sec')
t = time.time()
image = pipe('photo of an astronaut riding a horse', generator=generator).images[0]
print("made image in ", time.time() - t, 'sec')
image.save('diffusers.jpg')
def sd():
print('sdkit')
from sdkit import Context
from sdkit.models import load_model
from sdkit.generate import generate_images
c = Context()
c.model_paths['stable-diffusion'] = MODEL_PATH
t = time.time()
load_model(c, 'stable-diffusion')
print('loaded model in ', time.time() - t, 'sec')
t = time.time()
images = generate_images(c, prompt='photo of an astronaut riding a horse', seed=42, sampler_name='plms', num_inference_steps=50)
print('generated image in ', time.time() - t, 'sec')
images[0].save('sdkit.jpg')
import sys
if sys.argv[1] == 'sdkit':
sd()
else:
diff()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment