Skip to content

Instantly share code, notes, and snippets.

@madebyollin
Last active February 10, 2024 02:25
Show Gist options
  • Star 30 You must be signed in to star a gist
  • Fork 7 You must be signed in to fork a gist
  • Save madebyollin/86b9596ffa4ab0fa7674a16ca2aeab3d to your computer and use it in GitHub Desktop.
Save madebyollin/86b9596ffa4ab0fa7674a16ca2aeab3d to your computer and use it in GitHub Desktop.
Stable Diffusion on Apple Silicon GPUs via CoreML; 2s / step on M1 Pro
# ------------------------------------------------------------------
# EDIT: I eventually found a faster way to run SD on macOS, via MPSGraph (~0.8s / step on M1 Pro):
# https://github.com/madebyollin/maple-diffusion
# The original CoreML-related code & discussion is preserved below :)
# ------------------------------------------------------------------
# you too can run stable diffusion on the apple silicon GPU (no ANE sadly)
#
# quick test portraits (each took 50 steps x 2s / step ~= 100s on my M1 Pro):
# * https://i.imgur.com/5ywISvm.png
# * https://i.imgur.com/94fMA22.png
# * https://i.imgur.com/WOpSweZ.png
# * https://i.imgur.com/Hns46Rk.png
# the default pytorch / cpu pipeline took ~4.2s / step and did not use the GPU
#
# how to use the GPU
# 0. https://coremltools.readme.io/docs/installation
# 1. https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb
# (but stop at 'pipe = pipe.to("cuda")'. also I used fp32 checkpoints, but it might work fine with fp16)
# 2. run this gist's code in a new cell instead. it will be very slow the first time, faster on reruns.
# 3. if it worked, proceed to run the image generation cell, which should now use the GPU
import coremltools as ct
from pathlib import Path
import torch as th
def generate_coreml_model_via_awful_hacks(f, out_name):
from coremltools.converters.mil import Builder as mb
from coremltools.converters.mil.frontend.torch.torch_op_registry import register_torch_op, _TORCH_OPS_REGISTRY
import coremltools.converters.mil.frontend.torch.ops as cml_ops
orig_einsum = th.einsum
def fake_einsum(a, b, c):
if a == 'b i d, b j d -> b i j': return th.bmm(b, c.permute(0, 2, 1))
if a == 'b i j, b j d -> b i d': return th.bmm(b, c)
raise ValueError(f"unsupported einsum {a} on {b.shape} {c.shape}")
th.einsum = fake_einsum
if "broadcast_to" in _TORCH_OPS_REGISTRY: del _TORCH_OPS_REGISTRY["broadcast_to"]
@register_torch_op
def broadcast_to(context, node): return cml_ops.expand(context, node)
if "gelu" in _TORCH_OPS_REGISTRY: del _TORCH_OPS_REGISTRY["gelu"]
@register_torch_op
def gelu(context, node): context.add(mb.gelu(x=context[node.inputs[0]], name=node.name))
class Undictifier(th.nn.Module):
def __init__(self, m):
super().__init__()
self.m = m
def forward(self, *args, **kwargs): return self.m(*args, **kwargs)["sample"]
print("tracing")
f_trace = th.jit.trace(Undictifier(f), (th.zeros(2, 4, 64, 64), th.zeros(1), th.zeros(2, 77, 768)), strict=False, check_trace=False)
print("converting")
f_coreml_fp16 = ct.convert(f_trace,
inputs=[ct.TensorType(shape=(2, 4, 64, 64)), ct.TensorType(shape=(1,)), ct.TensorType(shape=(2, 77, 768))],
convert_to="mlprogram", compute_precision=ct.precision.FLOAT16, skip_model_load=True)
f_coreml_fp16.save(f"{out_name}")
th.einsum = orig_einsum
print("the deed is done")
class UNetWrapper:
def __init__(self, f, out_name="unet.mlpackage"):
self.in_channels = f.in_channels
if not Path(out_name).exists():
print("generating coreml model"); generate_coreml_model_via_awful_hacks(f, out_name); print("saved")
# not only does ANE take forever to load because it recompiles each time - it then doesn't work!
# and NSLocalizedDescription = "Error computing NN outputs."; is not helpful... GPU it is
print("loading saved coreml model"); f_coreml_fp16 = ct.models.MLModel(out_name, compute_units=ct.ComputeUnit.CPU_AND_GPU); print("loaded")
self.f = f_coreml_fp16
def __call__(self, sample, timestep, encoder_hidden_states):
args = {"sample_1": sample.numpy(), "timestep": th.tensor([timestep], dtype=th.int32).numpy(), "input_35": encoder_hidden_states.numpy()}
for v in self.f.predict(args).values(): return {"sample": th.tensor(v, dtype=th.float32)}
pipe.unet = UNetWrapper(pipe.unet)
@patrickvonplaten
Copy link

Also cc @pcuenca here just FYI

@madebyollin
Copy link
Author

@tcapelle Glad to hear it worked!

@MatthewWaller Cool! I was able to get that text encoder code through the conversion process with a few minor fixes, and it seems to work (image generated is slightly different but 99% identical):

class CLIPUndictifier(th.nn.Module):
    def __init__(self, m):
        super().__init__()
        self.m = m
    def forward(self, *args, **kwargs): 
        return self.m(*args, **kwargs)[0]

def convert_text_encoder(text_encoder, outname):    
    import transformers
    transformers.models.clip.modeling_clip.CLIPTextTransformer.attention_mask = transformers.models.clip.modeling_clip.CLIPTextTransformer._build_causal_attention_mask(None, 1, 77, th.float)
    def _fake_build_causal_mask(self, *args, **kwargs):
        return self.attention_mask
    transformers.models.clip.modeling_clip.CLIPTextTransformer._build_causal_attention_mask = _fake_build_causal_mask
    f_trace = th.jit.trace(CLIPUndictifier(text_encoder), (th.zeros(1, 77, dtype=th.long)), strict=False, check_trace=False)

    f_coreml = ct.convert(f_trace, 
               inputs=[ct.TensorType(shape=(1, 77))],
               convert_to="mlprogram", compute_precision=ct.precision.FLOAT16, skip_model_load=True)
    f_coreml.save(outname)

class TextEncoderWrapper:
    def __init__(self, f, out_name="text_encoder.mlpackage"):
        if not Path(out_name).exists():
            print("generating coreml model"); convert_text_encoder(f, out_name); print("saved")
        print("loading saved coreml model"); self.f = ct.models.MLModel(out_name, compute_units=ct.ComputeUnit.CPU_AND_GPU); print("loaded")
    
    def __call__(self, input):
        args = args = {"input_ids_1": input.float().numpy()}
        for v in self.f.predict(args).values():
            return (th.tensor(v, dtype=th.float32),)

pipe.text_encoder = TextEncoderWrapper(pipe.text_encoder)

it_works

The schedulers aren't that complicated internally iirc (just some element wise mixing of gaussian noise / existing latents / model prediction), so I expect it is possible to reproduce them via either approach (convert the logic to CoreML, or even rewrite in Swift). I haven't messed with the schedulers at all personally, though :) so I'm not sure which way is easiest.

@MatthewWaller
Copy link

That's awesome @madebyollin ! It's working for me!

Yeah, that scheduler is tricky. Not sure how I would do a CoreML version of the scheduler, and Swift rewrite would take quite a bit of time I think. Even something as simple as tensor multiplication is non-trivial. But who knows? Might be able to dive in.

@MatthewWaller
Copy link

Turns out I got everything done but the tokenizer!

The memory is huge still when I run it, 8 GB. And I've run it on an iPad Pro, the one with 16 GB, with the ANE and everything apparently working!

Now I'm wondering if I can take it further and run it on a phone.

However, I'm not able to run the fp16 version of the model at all. Firstly because on the Mac I'm running on CPU, and CPU doesn't have some of the layers working for half/fp16.

Any thoughts on what I could do to get the unet converted and working on fp16?

@tcapelle
Copy link

tcapelle commented Sep 20, 2022

Do you know how to make this work with diffusers@master? they changed the attention and the Unet. It may be easier to start with the tensorflow implementation...
It appears that commenting the unsliced_attention does the trick.

# def unsliced_attention(self, query, key, value, _sequence_length, _dim):
#     attn = (torch.einsum("b i d, b j d -> b i j", query, key) * self.scale).softmax(dim=-1)
#     attn = torch.einsum("b i j, b j d -> b i d", attn, value)
#     return self.reshape_batch_dim_to_heads(attn)
# diffusers.models.attention.CrossAttention._attention = unsliced_attention

I started doing some benchmarks here: https://wandb.ai/capecape/stable_diffusions/reports/Stable-Diffusion-Inference-Performance--VmlldzoyNjY0ODYz in case you guys are interested.
For apple M1 the tensorflow implementation is twice as fast.

@madebyollin
Copy link
Author

madebyollin commented Sep 21, 2022

@MatthewWaller Awesome! The tokenizer logic is fairly self-contained and I think the HuggingFace folks already transcribed the similar (identical?) GPT2 tokenizer to Swift here. For fp16 - since fp16 seems to work on my M1 Pro GPU, I'd be surprised if it does not work on recent iPhone / iPad chips as well. Is it possible to just ct.convert the fp16 mlpackage on Mac (without ever loading it on Mac) and then load that model directly on the iPad?

@tcapelle Great writeup, thanks for linking! I need to try François' version now :)

@MatthewWaller
Copy link

MatthewWaller commented Sep 26, 2022

Hey I did it! here it is running on a phone

The tokenizer was similar but took a bit of work to fix up.

@madebyollin You're right, however, that it seems like ANE isn't supported, and I think I've narrowed it down.

What I did was break apart the model, and I found that ANE fails to load when using Unet attention blocks.

Here we have a paper on optimizing for ANE.

It looks like the custom attention we're using isn't up to par. I'm just not sure which part. Maybe it wants a more formal translation with the mb? Hmmm

@MatthewWaller
Copy link

Sorry to bother @madebyollin , could you share how you got fpt16 to work?

I tried

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=token)

prompt = "Pineapple on a white table"
pipe.to("mps")

image = pipe(prompt).images[0]
image.save("hmmm.png")

And it is simply not happy with me.

Anyone else following along, let me know how you did it!

@madebyollin
Copy link
Author

@MatthewWaller I've never tested fp16 + mps, so I'm not sure if it's even possible at this point 😅. All I got working (above) was converting in fp32 and storing the exported coreml model in fp16. I think the diffusers repo is the right place to ask about diffusers + torch + mps + fp16 bugs.

(I would bet the specific layernorm issue can be hacked around similar to what I did above, i.e. replace torch.layer_norm with a composition of simpler ops like this... but you would probably run into other issues after that, since the mps backend is so bleeding edge)

@MatthewWaller
Copy link

Oh! I gotcha, I misunderstood. That's a clever idea about hacking around layer_norm. But yeah, I think something else is afoot. I think I'll try asking through more formal Apple Support tickets what might be going on with the ANE. A lot of the logs are marked private on my iPad (although I might also get around that by renting an M1 Mac), but in general, wondering what they would suggest for those attention layers that seem to be tripping up the ANE.

@Lukas1h
Copy link

Lukas1h commented Oct 8, 2022

@MatthewWaller, any chance you'd be willing to make this open source? Sounds like a cool project.

@MatthewWaller
Copy link

I'll definitely be sharing a ton of the steps I've taken, but I also want to put it out as an app for sale when it's ready.

@Lukas1h
Copy link

Lukas1h commented Oct 9, 2022

Thanks. What would really be the use case if it takes up to an hour to generate images? Also any updates on utilizing the ANE? Would that decrease times enough for it to be more usable?

@MatthewWaller
Copy link

MatthewWaller commented Oct 9, 2022

Right now I got it down to 5 min on a new phone (About 15-20 min on a 4-year-old phone)! So not too bad :)

@Lukas1h
Copy link

Lukas1h commented Oct 9, 2022

Impressive! Looking forward to the release!

@tcapelle
Copy link

tcapelle commented Oct 9, 2022

@MatthewWaller have you tried exporting the Tensorflow model to CoreML?

@MatthewWaller
Copy link

I have not! Thinking it might be faster or more efficient?

@madebyollin
Copy link
Author

FWIW, I did eventually get a version of Stable Diffusion working faster-than-TensorFlow by using Metal Performance Shaders Graph, and was able to somewhat reproduce Matt's proof of concept on iOS.

It's annoying to work on because of the hard memory cap & inscrutable MPSGraph memory allocations, though. I'm also not sure if MPSGraph uses the ANE at all (it may be locked behind this flag 🤔).

@MatthewWaller Regarding attention on the ANE via CoreML, if you can narrow down exactly which part of the cross-attention mechanism is ANE-unfriendly, you might be able to rewrite that part in an equivalent form that works:

  • If it's the matmuls, you may be able to express those matmuls as 8 1x1 convs (1 per attention head).
  • If it's the softmax, maybe rewrite softmax as its constituent elementwise / reduce ops
  • If it's just failing because the tensors involved are so huge, you can try to slice the attention graph up across heads

There's also the possibility of writing faster GPU attention in Metal, using a algorithm like https://github.com/HazyResearch/flash-attention; MPSGraph doesn't support custom layers AFAICT, but CoreML might

@tcapelle
Copy link

This is insane; it runs very fast on my M1Pro Macbook Pro. Are you releasing this as an app? What can I do to help you out to make it faster/better on Mac OS?
inference_swift

@MatthewWaller
Copy link

@madebyollin could I include this version in my Mac version of the app I'm building? Seems to be crashing on iOS. Not respecting the additional memory request on some devices.

@MatthewWaller
Copy link

BTW, @madebyollin The ANE trouble is just the sheer number of reshapes:

for instance here:

    def forward(self, x, context=None, mask=None):
        batch_size, sequence_length, dim = x.shape

        q = self.to_q(x)
        context = context if context is not None else x
        k = self.to_k(context)
        v = self.to_v(context)

        q = self.reshape_heads_to_batch_dim(q)
        k = self.reshape_heads_to_batch_dim(k)
        v = self.reshape_heads_to_batch_dim(v)

        # TODO(PVP) - mask is currently never used. Remember to re-implement when used

        # attention, what we cannot get enough of
        hidden_states = self._attention(q, k, v, sequence_length, dim)

        return self.to_out(hidden_states)

If I break apart all of those self.reshape_heads_to_batch_dim to a separate model, then ANE compiles. But altogether there are too many.

I even got the number way way down by using a couple of einsums but it wasn't enough.

@madebyollin
Copy link
Author

@tcapelle Thanks! I'll probably leave Maple Diffusion as-is for now. I might try and speed it up a bit more if I get impatient. The main speedup opportunities I'm aware of are:

  1. Disable more of the MEM-HACK stuff on macOS (I originally had it working around 5% faster than current, so I think there's some regression to fix here)
  2. Try to get the mysterious level1 optimization flag working
  3. Try and implement some custom attention logic in Metal, using the faster flash attention algorithm. This could speed things up by a lot, but it seems like it would require a ton of hacking.

There might also be more speedup opportunities I overlooked (I'm not familiar with Swift or any of Apple's dev tools - just DNNs).

I don't currently plan to make a full-featured, user-friendly app based on Maple Diffusion (too much work!). But since Maple Diffusion is self-contained and free to use, I expect it could be slotted in as a backend for other apps like Diffusion Bee pretty easily.

@MatthewWaller Yup, feel free to incorporate Maple Diffusion in anything you're building :)

Maple Diffusion should hopefully work on iOS if you use the 4GB memory limit capability. iirc memory usage peaks around ~3.8GB, so unfortunately it won't run without that.

I'm not sure why the ANE would care about the number of reshapes - typically a graph compiler will fold all reshapes / permutes into simple indexing logic inside whichever kernel consumes the data. Perhaps the ANE only works with contiguous input? In which case you might be able to force contiguous input by adding some pointless ops before the matmul (add 0.0 or something) - I'm not really sure.

@MatthewWaller
Copy link

Hey thanks! @madebyollin

Yeah, that reshape is silly business.

Is there for instance, a way to do this reshape:

q = q.view(q.shape[0], q.shape[1], self.heads, q.shape[2] // self.heads)
k = k.view(k.shape[0], k.shape[1], self.heads, k.shape[2] // self.heads)
v = v.view(v.shape[0], v.shape[1], self.heads, v.shape[2] // self.heads)

Without doing a reshape? Some sort of matmul that would go from (2, 4096, 320) to (2, 4096, 8, 40)? As a sort of cheat. I don't think the coremltools converter cares that they are contiguous, and in fact I specified contiguous with no difference.

@Lukas1h
Copy link

Lukas1h commented Oct 10, 2022

Awesome work, @madebyollin!

I was hoping @MatthewWaller would release some of the code so I could experiment with it my self, but he hasn't yet so I'm glad you did!

Any chance of running this on iPhones with 3gb of RAM (SE Second Gen), or dose it need the full 4gb of memory?

@MatthewWaller, have you tested to see if your implementation performes on devices under 4gb of memory?

@madebyollin
Copy link
Author

@Lukas1h I just checked and the current code uses around 3.4GB of peak memory. So it's probably quite difficult to get this MPSGraph version working fast on 3GB iOS devices, unfortunately. The CoreML-based approach might work better, assuming it allows you to swap parts of the model to disc without paying a huge graph-recompilation penalty when you load them back in.
image

@julien-c
Copy link

julien-c commented Dec 9, 2022

BTW i assume you've seen https://github.com/apple/ml-stable-diffusion 🔥

@madebyollin
Copy link
Author

BTW i assume you've seen https://github.com/apple/ml-stable-diffusion 🔥

Yes, but it's good to link that repo in this thread!

AFAICT, most (all?) of the UNet speedup Apple got was through optimizations in CoreML itself. On 13.0, in my test device (M1 Pro MBP), Apple's CoreML model graph did not seem any faster than the one I exported back in August. I think this matches what Birchlabs observed in their testing - all of the speedup comes from 13.1 upgrade.

Once I upgrade to 13.1 I'd love to re-check how various SD runners compare (MPSGraph vs. Apple's CoreML model vs. my naive CoreML export of the diffusers model) :)

@x4080
Copy link

x4080 commented Jun 4, 2023

@madebyollin did you have any chance to test the latest stable diffusion speeds, how about compared to latest pytorch in apple silicon

@madebyollin
Copy link
Author

@x4080 I did an informal check of UNet speeds today, 2023-06-04 (on M1 Pro 16GB, macOS Ventura 13.4).

  • PyTorch 2.0.1 / MPS: 1.10 it / s (iirc this is using MPSGraph on GPU, just with a lot of overhead and minimal operator fusion)

image

  • CoreML / CPU_AND_GPU / Apple's SPLIT_EINSUM config: 1.39 it / s

image

  • MPSGraph / GPU (Maple Diffusion): 1.44 it / s (0.69 s / it)

image

  • CoreML / ALL (CPU+GPU+ANE) / Apple's SPLIT_EINSUM config: 1.85 it / s. This is Apple's recommended config for good reason, but I observe a huge on-initial-model-load delay waiting for ANECompilerService, which makes it annoying to use in practice 😞

image

Summarizing: I get 1.68x speedup from using Apple's CoreML+ANE approach vs. just using PyTorch / MPS. It seems like using the ANE is important on the lower-end M{1, 2} and M{1,2} Pro chips. For the higher-end chips, the ANE probably doesn't contribute anything, but I expect running a low-overhead GPU compute graph (CoreML or Maple) still helps.

@x4080
Copy link

x4080 commented Jun 4, 2023

@madebyollin thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment