Skip to content

Instantly share code, notes, and snippets.

@pcuenca
Created July 29, 2022 12:29
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pcuenca/3112687cf0d169a047c51825ed7d0eff to your computer and use it in GitHub Desktop.
Save pcuenca/3112687cf0d169a047c51825ed7d0eff to your computer and use it in GitHub Desktop.
Metal Kernel to make a contiguous copy of MLMultiArray storage
//
// MakeContiguousKernel.metal
//
// Created by Pedro Cuenca on 20220307.
// Copyright © 2022 LateNiteSoft S.L. All rights reserved.
//
#include <metal_stdlib>
using namespace metal;
/// The textures contain 1 color per position.
kernel void MakeContiguousKernel(texture2d<float, access::read> source_array [[ texture(0) ]],
texture2d<float, access::write> contiguous_array [[ texture(1) ]],
const device uint &ndim [[ buffer(0) ]],
const device uint *shape_ptr [[ buffer(1) ]],
const device uint *stride_ptr [[ buffer(2) ]],
uint2 gid [[thread_position_in_grid]])
{
uint width = source_array.get_width();
// 1d index from 2d position, in the contiguous array
uint elem_index = gid.y * width + gid.x;
uint non_contiguous_index = 0;
// Iterate through shapes in reverse order
for (int d=ndim-1; d>=0; d--) {
uint dim = shape_ptr[d];
uint dim_index = elem_index % dim;
uint stride = stride_ptr[d];
non_contiguous_index += dim_index * stride;
elem_index = elem_index / dim;
}
uint nc_x = non_contiguous_index % width;
uint nc_y = (non_contiguous_index - nc_x) / width;
uint2 non_contiguous_position = uint2(nc_x, nc_y);
float sample = source_array.read(non_contiguous_position).r;
contiguous_array.write(sample, gid);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment