-
-
Save madebyollin/86b9596ffa4ab0fa7674a16ca2aeab3d to your computer and use it in GitHub Desktop.
# ------------------------------------------------------------------ | |
# EDIT: I eventually found a faster way to run SD on macOS, via MPSGraph (~0.8s / step on M1 Pro): | |
# https://github.com/madebyollin/maple-diffusion | |
# The original CoreML-related code & discussion is preserved below :) | |
# ------------------------------------------------------------------ | |
# you too can run stable diffusion on the apple silicon GPU (no ANE sadly) | |
# | |
# quick test portraits (each took 50 steps x 2s / step ~= 100s on my M1 Pro): | |
# * https://i.imgur.com/5ywISvm.png | |
# * https://i.imgur.com/94fMA22.png | |
# * https://i.imgur.com/WOpSweZ.png | |
# * https://i.imgur.com/Hns46Rk.png | |
# the default pytorch / cpu pipeline took ~4.2s / step and did not use the GPU | |
# | |
# how to use the GPU | |
# 0. https://coremltools.readme.io/docs/installation | |
# 1. https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb | |
# (but stop at 'pipe = pipe.to("cuda")'. also I used fp32 checkpoints, but it might work fine with fp16) | |
# 2. run this gist's code in a new cell instead. it will be very slow the first time, faster on reruns. | |
# 3. if it worked, proceed to run the image generation cell, which should now use the GPU | |
import coremltools as ct | |
from pathlib import Path | |
import torch as th | |
def generate_coreml_model_via_awful_hacks(f, out_name): | |
from coremltools.converters.mil import Builder as mb | |
from coremltools.converters.mil.frontend.torch.torch_op_registry import register_torch_op, _TORCH_OPS_REGISTRY | |
import coremltools.converters.mil.frontend.torch.ops as cml_ops | |
orig_einsum = th.einsum | |
def fake_einsum(a, b, c): | |
if a == 'b i d, b j d -> b i j': return th.bmm(b, c.permute(0, 2, 1)) | |
if a == 'b i j, b j d -> b i d': return th.bmm(b, c) | |
raise ValueError(f"unsupported einsum {a} on {b.shape} {c.shape}") | |
th.einsum = fake_einsum | |
if "broadcast_to" in _TORCH_OPS_REGISTRY: del _TORCH_OPS_REGISTRY["broadcast_to"] | |
@register_torch_op | |
def broadcast_to(context, node): return cml_ops.expand(context, node) | |
if "gelu" in _TORCH_OPS_REGISTRY: del _TORCH_OPS_REGISTRY["gelu"] | |
@register_torch_op | |
def gelu(context, node): context.add(mb.gelu(x=context[node.inputs[0]], name=node.name)) | |
class Undictifier(th.nn.Module): | |
def __init__(self, m): | |
super().__init__() | |
self.m = m | |
def forward(self, *args, **kwargs): return self.m(*args, **kwargs)["sample"] | |
print("tracing") | |
f_trace = th.jit.trace(Undictifier(f), (th.zeros(2, 4, 64, 64), th.zeros(1), th.zeros(2, 77, 768)), strict=False, check_trace=False) | |
print("converting") | |
f_coreml_fp16 = ct.convert(f_trace, | |
inputs=[ct.TensorType(shape=(2, 4, 64, 64)), ct.TensorType(shape=(1,)), ct.TensorType(shape=(2, 77, 768))], | |
convert_to="mlprogram", compute_precision=ct.precision.FLOAT16, skip_model_load=True) | |
f_coreml_fp16.save(f"{out_name}") | |
th.einsum = orig_einsum | |
print("the deed is done") | |
class UNetWrapper: | |
def __init__(self, f, out_name="unet.mlpackage"): | |
self.in_channels = f.in_channels | |
if not Path(out_name).exists(): | |
print("generating coreml model"); generate_coreml_model_via_awful_hacks(f, out_name); print("saved") | |
# not only does ANE take forever to load because it recompiles each time - it then doesn't work! | |
# and NSLocalizedDescription = "Error computing NN outputs."; is not helpful... GPU it is | |
print("loading saved coreml model"); f_coreml_fp16 = ct.models.MLModel(out_name, compute_units=ct.ComputeUnit.CPU_AND_GPU); print("loaded") | |
self.f = f_coreml_fp16 | |
def __call__(self, sample, timestep, encoder_hidden_states): | |
args = {"sample_1": sample.numpy(), "timestep": th.tensor([timestep], dtype=th.int32).numpy(), "input_35": encoder_hidden_states.numpy()} | |
for v in self.f.predict(args).values(): return {"sample": th.tensor(v, dtype=th.float32)} | |
pipe.unet = UNetWrapper(pipe.unet) |
@tcapelle Thanks! I'll probably leave Maple Diffusion as-is for now. I might try and speed it up a bit more if I get impatient. The main speedup opportunities I'm aware of are:
- Disable more of the
MEM-HACK
stuff on macOS (I originally had it working around 5% faster than current, so I think there's some regression to fix here) - Try to get the mysterious
level1
optimization flag working - Try and implement some custom attention logic in Metal, using the faster flash attention algorithm. This could speed things up by a lot, but it seems like it would require a ton of hacking.
There might also be more speedup opportunities I overlooked (I'm not familiar with Swift or any of Apple's dev tools - just DNNs).
I don't currently plan to make a full-featured, user-friendly app based on Maple Diffusion (too much work!). But since Maple Diffusion is self-contained and free to use, I expect it could be slotted in as a backend for other apps like Diffusion Bee pretty easily.
@MatthewWaller Yup, feel free to incorporate Maple Diffusion in anything you're building :)
Maple Diffusion should hopefully work on iOS if you use the 4GB memory limit capability. iirc memory usage peaks around ~3.8GB, so unfortunately it won't run without that.
I'm not sure why the ANE would care about the number of reshapes - typically a graph compiler will fold all reshapes / permutes into simple indexing logic inside whichever kernel consumes the data. Perhaps the ANE only works with contiguous input? In which case you might be able to force contiguous input by adding some pointless ops before the matmul (add 0.0 or something) - I'm not really sure.
Hey thanks! @madebyollin
Yeah, that reshape is silly business.
Is there for instance, a way to do this reshape:
q = q.view(q.shape[0], q.shape[1], self.heads, q.shape[2] // self.heads)
k = k.view(k.shape[0], k.shape[1], self.heads, k.shape[2] // self.heads)
v = v.view(v.shape[0], v.shape[1], self.heads, v.shape[2] // self.heads)
Without doing a reshape? Some sort of matmul that would go from (2, 4096, 320) to (2, 4096, 8, 40)? As a sort of cheat. I don't think the coremltools converter cares that they are contiguous, and in fact I specified contiguous with no difference.
Awesome work, @madebyollin!
I was hoping @MatthewWaller would release some of the code so I could experiment with it my self, but he hasn't yet so I'm glad you did!
Any chance of running this on iPhones with 3gb of RAM (SE Second Gen), or dose it need the full 4gb of memory?
@MatthewWaller, have you tested to see if your implementation performes on devices under 4gb of memory?
@Lukas1h I just checked and the current code uses around 3.4GB of peak memory. So it's probably quite difficult to get this MPSGraph version working fast on 3GB iOS devices, unfortunately. The CoreML-based approach might work better, assuming it allows you to swap parts of the model to disc without paying a huge graph-recompilation penalty when you load them back in.
BTW i assume you've seen https://github.com/apple/ml-stable-diffusion 🔥
BTW i assume you've seen https://github.com/apple/ml-stable-diffusion 🔥
Yes, but it's good to link that repo in this thread!
AFAICT, most (all?) of the UNet speedup Apple got was through optimizations in CoreML itself. On 13.0, in my test device (M1 Pro MBP), Apple's CoreML model graph did not seem any faster than the one I exported back in August. I think this matches what Birchlabs observed in their testing - all of the speedup comes from 13.1 upgrade.
Once I upgrade to 13.1 I'd love to re-check how various SD runners compare (MPSGraph vs. Apple's CoreML model vs. my naive CoreML export of the diffusers model) :)
@madebyollin did you have any chance to test the latest stable diffusion speeds, how about compared to latest pytorch in apple silicon
@x4080 I did an informal check of UNet speeds today, 2023-06-04 (on M1 Pro 16GB, macOS Ventura 13.4).
- PyTorch 2.0.1 / MPS: 1.10 it / s (iirc this is using MPSGraph on GPU, just with a lot of overhead and minimal operator fusion)
- CoreML / CPU_AND_GPU / Apple's SPLIT_EINSUM config: 1.39 it / s
- MPSGraph / GPU (Maple Diffusion): 1.44 it / s (0.69 s / it)
- CoreML / ALL (CPU+GPU+ANE) / Apple's SPLIT_EINSUM config: 1.85 it / s. This is Apple's recommended config for good reason, but I observe a huge on-initial-model-load delay waiting for ANECompilerService, which makes it annoying to use in practice 😞
Summarizing: I get 1.68x speedup from using Apple's CoreML+ANE approach vs. just using PyTorch / MPS. It seems like using the ANE is important on the lower-end M{1, 2} and M{1,2} Pro chips. For the higher-end chips, the ANE probably doesn't contribute anything, but I expect running a low-overhead GPU compute graph (CoreML or Maple) still helps.
@madebyollin thanks
BTW, @madebyollin The ANE trouble is just the sheer number of reshapes:
for instance here:
If I break apart all of those self.reshape_heads_to_batch_dim to a separate model, then ANE compiles. But altogether there are too many.
I even got the number way way down by using a couple of einsums but it wasn't enough.