Skip to content

Instantly share code, notes, and snippets.

@mazzzystar
Forked from madebyollin/stable_diffusion_m1.py
Created November 12, 2022 15:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mazzzystar/2201718a9e95da4b7341a2b52f0b35b9 to your computer and use it in GitHub Desktop.
Save mazzzystar/2201718a9e95da4b7341a2b52f0b35b9 to your computer and use it in GitHub Desktop.
Stable Diffusion on Apple Silicon GPUs via CoreML; 2s / step on M1 Pro
# ------------------------------------------------------------------
# EDIT: I eventually found a faster way to run SD on macOS, via MPSGraph (~0.8s / step on M1 Pro):
# https://github.com/madebyollin/maple-diffusion
# The original CoreML-related code & discussion is preserved below :)
# ------------------------------------------------------------------
# you too can run stable diffusion on the apple silicon GPU (no ANE sadly)
#
# quick test portraits (each took 50 steps x 2s / step ~= 100s on my M1 Pro):
# * https://i.imgur.com/5ywISvm.png
# * https://i.imgur.com/94fMA22.png
# * https://i.imgur.com/WOpSweZ.png
# * https://i.imgur.com/Hns46Rk.png
# the default pytorch / cpu pipeline took ~4.2s / step and did not use the GPU
#
# how to use the GPU
# 0. https://coremltools.readme.io/docs/installation
# 1. https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb
# (but stop at 'pipe = pipe.to("cuda")'. also I used fp32 checkpoints, but it might work fine with fp16)
# 2. run this gist's code in a new cell instead. it will be very slow the first time, faster on reruns.
# 3. if it worked, proceed to run the image generation cell, which should now use the GPU
import coremltools as ct
from pathlib import Path
import torch as th
def generate_coreml_model_via_awful_hacks(f, out_name):
from coremltools.converters.mil import Builder as mb
from coremltools.converters.mil.frontend.torch.torch_op_registry import register_torch_op, _TORCH_OPS_REGISTRY
import coremltools.converters.mil.frontend.torch.ops as cml_ops
orig_einsum = th.einsum
def fake_einsum(a, b, c):
if a == 'b i d, b j d -> b i j': return th.bmm(b, c.permute(0, 2, 1))
if a == 'b i j, b j d -> b i d': return th.bmm(b, c)
raise ValueError(f"unsupported einsum {a} on {b.shape} {c.shape}")
th.einsum = fake_einsum
if "broadcast_to" in _TORCH_OPS_REGISTRY: del _TORCH_OPS_REGISTRY["broadcast_to"]
@register_torch_op
def broadcast_to(context, node): return cml_ops.expand(context, node)
if "gelu" in _TORCH_OPS_REGISTRY: del _TORCH_OPS_REGISTRY["gelu"]
@register_torch_op
def gelu(context, node): context.add(mb.gelu(x=context[node.inputs[0]], name=node.name))
class Undictifier(th.nn.Module):
def __init__(self, m):
super().__init__()
self.m = m
def forward(self, *args, **kwargs): return self.m(*args, **kwargs)["sample"]
print("tracing")
f_trace = th.jit.trace(Undictifier(f), (th.zeros(2, 4, 64, 64), th.zeros(1), th.zeros(2, 77, 768)), strict=False, check_trace=False)
print("converting")
f_coreml_fp16 = ct.convert(f_trace,
inputs=[ct.TensorType(shape=(2, 4, 64, 64)), ct.TensorType(shape=(1,)), ct.TensorType(shape=(2, 77, 768))],
convert_to="mlprogram", compute_precision=ct.precision.FLOAT16, skip_model_load=True)
f_coreml_fp16.save(f"{out_name}")
th.einsum = orig_einsum
print("the deed is done")
class UNetWrapper:
def __init__(self, f, out_name="unet.mlpackage"):
self.in_channels = f.in_channels
if not Path(out_name).exists():
print("generating coreml model"); generate_coreml_model_via_awful_hacks(f, out_name); print("saved")
# not only does ANE take forever to load because it recompiles each time - it then doesn't work!
# and NSLocalizedDescription = "Error computing NN outputs."; is not helpful... GPU it is
print("loading saved coreml model"); f_coreml_fp16 = ct.models.MLModel(out_name, compute_units=ct.ComputeUnit.CPU_AND_GPU); print("loaded")
self.f = f_coreml_fp16
def __call__(self, sample, timestep, encoder_hidden_states):
args = {"sample_1": sample.numpy(), "timestep": th.tensor([timestep], dtype=th.int32).numpy(), "input_35": encoder_hidden_states.numpy()}
for v in self.f.predict(args).values(): return {"sample": th.tensor(v, dtype=th.float32)}
pipe.unet = UNetWrapper(pipe.unet)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment