Skip to content

Instantly share code, notes, and snippets.

@awni
Created July 17, 2024 16:26
Show Gist options
  • Save awni/b144431c3ea3ccd02cbf8c6710a396e5 to your computer and use it in GitHub Desktop.
Save awni/b144431c3ea3ccd02cbf8c6710a396e5 to your computer and use it in GitHub Desktop.
from typing import Callable, Tuple
import operator
from functools import reduce
from itertools import product
import mlx.core as mx
def _interpolate(
x: mx.array, scale_factor: Tuple, indices_fn: Callable, align_corners: bool = False
):
dims = x.ndim - 2
if dims != len(scale_factor):
raise ValueError("A scale needs to be provided for each spatial dimension")
B, *N, C = x.shape
# Compute the sampling grid
indices = []
for i, (n, s) in enumerate(zip(N, scale_factor)):
indices.append(indices_fn(n, s, align_corners, i, dims))
# Sample and compute the weights
samples = []
weights = []
for idx_weight in product(*indices):
idx, weight = zip(*idx_weight)
samples.append(x[(slice(None),) + idx])
weights.append(reduce(operator.mul, weight))
# Interpolate
return sum(wi * xi for wi, xi in zip(weights, samples))
def _linear_indices(N, scale, align_corners, dim, ndims):
indices = _scaled_indices(N, scale, align_corners, dim, ndims)
indices = mx.clip(indices, a_min=0, a_max=N - 1)
indices_l = mx.floor(indices)
indices_r = mx.ceil(indices)
weight = indices - indices_l
weight = mx.expand_dims(weight, -1)
return (
(indices_l.astype(mx.int32), 1 - weight),
(indices_r.astype(mx.int32), weight),
)
def _scaled_indices(N, scale, align_corners, dim, ndims):
M = int(scale * N)
print(M)
if align_corners:
indices = mx.arange(M, dtype=mx.float32) * ((N - 1) / (M - 1))
else:
step = 1 / scale
start = ((M - 1) * step - N + 1) / 2
indices = mx.arange(M, dtype=mx.float32) * step - start
shape = [1] * ndims
shape[dim] = -1
return indices.reshape(shape)
def interpolate(x, scale_factor, align_corners=False):
return _interpolate(
x=x,
scale_factor=scale_factor,
indices_fn=_linear_indices)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment