Created
July 17, 2024 16:26
-
-
Save awni/b144431c3ea3ccd02cbf8c6710a396e5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Callable, Tuple | |
import operator | |
from functools import reduce | |
from itertools import product | |
import mlx.core as mx | |
def _interpolate( | |
x: mx.array, scale_factor: Tuple, indices_fn: Callable, align_corners: bool = False | |
): | |
dims = x.ndim - 2 | |
if dims != len(scale_factor): | |
raise ValueError("A scale needs to be provided for each spatial dimension") | |
B, *N, C = x.shape | |
# Compute the sampling grid | |
indices = [] | |
for i, (n, s) in enumerate(zip(N, scale_factor)): | |
indices.append(indices_fn(n, s, align_corners, i, dims)) | |
# Sample and compute the weights | |
samples = [] | |
weights = [] | |
for idx_weight in product(*indices): | |
idx, weight = zip(*idx_weight) | |
samples.append(x[(slice(None),) + idx]) | |
weights.append(reduce(operator.mul, weight)) | |
# Interpolate | |
return sum(wi * xi for wi, xi in zip(weights, samples)) | |
def _linear_indices(N, scale, align_corners, dim, ndims): | |
indices = _scaled_indices(N, scale, align_corners, dim, ndims) | |
indices = mx.clip(indices, a_min=0, a_max=N - 1) | |
indices_l = mx.floor(indices) | |
indices_r = mx.ceil(indices) | |
weight = indices - indices_l | |
weight = mx.expand_dims(weight, -1) | |
return ( | |
(indices_l.astype(mx.int32), 1 - weight), | |
(indices_r.astype(mx.int32), weight), | |
) | |
def _scaled_indices(N, scale, align_corners, dim, ndims): | |
M = int(scale * N) | |
print(M) | |
if align_corners: | |
indices = mx.arange(M, dtype=mx.float32) * ((N - 1) / (M - 1)) | |
else: | |
step = 1 / scale | |
start = ((M - 1) * step - N + 1) / 2 | |
indices = mx.arange(M, dtype=mx.float32) * step - start | |
shape = [1] * ndims | |
shape[dim] = -1 | |
return indices.reshape(shape) | |
def interpolate(x, scale_factor, align_corners=False): | |
return _interpolate( | |
x=x, | |
scale_factor=scale_factor, | |
indices_fn=_linear_indices) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment