Skip to content

Instantly share code, notes, and snippets.

@adonath
Created March 12, 2024 12:49
Show Gist options
  • Save adonath/3f16b30498c60f25cf1349792c15283c to your computer and use it in GitHub Desktop.
Save adonath/3f16b30498c60f25cf1349792c15283c to your computer and use it in GitHub Desktop.
A short comparison of CPU, GPU native and FFT based convolution in MLX
import logging
import timeit
from dataclasses import dataclass, field
from typing import Optional
import matplotlib.pyplot as plt
import mlx.core as mx
from scipy.signal import fftconvolve as scipy_fftconvolve
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)
def _centered(arr, newshape):
newshape = mx.array(newshape)
currshape = mx.array(arr.shape)
startind = (currshape - newshape) // 2
endind = startind + newshape
myslice = [slice(startind[k].item(), endind[k].item()) for k in range(len(endind))]
return arr[tuple(myslice)]
def convolve_fft(image, kernel, stream):
"""Convolve FFT for torch tensors"""
image_2d, kernel_2d = image[0, 0], kernel[0, 0]
shape = [image_2d.shape[i] + kernel_2d.shape[i] - 1 for i in range(image_2d.ndim)]
image_ft = mx.fft.rfft2(image, s=shape, stream=stream)
kernel_ft = mx.fft.rfft2(kernel, s=shape, stream=stream)
result = mx.fft.irfft2(image_ft * kernel_ft, s=shape, stream=stream)
return _centered(result, image.shape)
@dataclass
class BenchmarkSpec:
method: str
shape: tuple
stream: Optional[str] = None
results: list = field(default_factory=list)
kernel_shape: callable = lambda x: (1, x, x, 1)
specs = {}
image_size = 1024
specs["cpu"] = BenchmarkSpec(
method="mx.conv2d",
shape=(1, image_size, image_size, 1),
stream="mx.cpu",
)
specs["gpu"] = BenchmarkSpec(
method="mx.conv2d",
shape=(1, image_size, image_size, 1),
stream="mx.gpu",
)
specs["cpu-fft"] = BenchmarkSpec(
method="convolve_fft",
shape=(1, 1, image_size, image_size),
stream="mx.cpu",
kernel_shape=lambda x: (1, 1, x, x),
)
specs["cpu-fft-scipy"] = BenchmarkSpec(
method="scipy_fftconvolve",
shape=(image_size, image_size),
stream=None,
kernel_shape=lambda x: (x, x),
)
kernel_sizes = [2**i for i in range(1, 10)]
for name, spec in specs.items():
image = mx.random.normal(loc=0, scale=1, shape=spec.shape)
for size in kernel_sizes:
log.info(f"Running {name} with kernel size {size}")
shape = spec.kernel_shape(size)
kernel = mx.random.normal(loc=0, scale=1, shape=shape)
if "scipy" in name:
expr = f"{spec.method}(image, kernel, mode='same')"
else:
expr = f"mx.eval({spec.method}(image, kernel, stream={spec.stream}))"
timer = timeit.Timer(expr, globals=globals())
value = timer.timeit(1)
spec.results.append(value)
for name, spec in specs.items():
plt.plot(kernel_sizes, spec.results, label=name)
plt.xlabel("Kernel size")
plt.ylabel("Time (s)")
plt.legend()
plt.loglog()
filename = "mlx-conv-mini-benchmark.png"
log.info(f"Writing {filename}")
plt.savefig(filename, dpi=150)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment