Skip to content

Instantly share code, notes, and snippets.

@yvt
Created August 6, 2017 12:16
Show Gist options
  • Save yvt/fef3622bc97d57af63d3d98ff3d21410 to your computer and use it in GitHub Desktop.
Save yvt/fef3622bc97d57af63d3d98ff3d21410 to your computer and use it in GitHub Desktop.
#version 310 es
precision mediump float;
const uint local_size = 64u;
const uint kernel_size = 4u;
layout(local_size_x = 64 /* local_size */) in;
shared uint in_values[local_size + kernel_size - 1u];
shared uint kernel_values[kernel_size];
layout(set = 0, binding = 0) readonly buffer ConvolutionParameter {
uint kernel_values[kernel_size];
} conv_param;
layout(set = 0, binding = 1) readonly buffer ConvolutionInput {
uint data[];
} conv_in;
layout(set = 0, binding = 2) writeonly buffer ConvolutionOutput {
uint data[];
} conv_out;
void main()
{
uint local_id = gl_LocalInvocationID.x;
// load input data into shared memory
uint global_id = gl_GlobalInvocationID.x;
in_values[local_id] = conv_in.data[global_id];
if (local_id > local_size - kernel_size) {
in_values[local_id + kernel_size - 1u] =
conv_in.data[global_id + kernel_size - 1u];
}
// load kernel into shared memory
if (local_id < kernel_size) {
kernel_values[local_id] = conv_param.kernel_values[local_id];
}
// wait for all inputs to be ready...
groupMemoryBarrier();
barrier();
// perform convolution
uint sum = 0u;
for (uint i = 0u; i < kernel_size; ++i) {
sum += in_values[local_id + i] * kernel_values[i];
}
// store the result
conv_out.data[global_id] = sum;
}
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct ConvolutionInput
{
uint data[1];
};
struct ConvolutionParameter
{
uint kernel_values[4];
};
struct ConvolutionOutput
{
uint data[1];
};
struct main0_workgroup
{
uint in_values[67];
uint kernel_values[4];
};
kernel void main0(uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], device ConvolutionInput& conv_in [[buffer(0)]], device ConvolutionParameter& conv_param [[buffer(1)]], device ConvolutionOutput& conv_out [[buffer(2)]])
{
threadgroup main0_workgroup workgroup = {};
uint local_id = gl_LocalInvocationID.x;
uint global_id = gl_GlobalInvocationID.x;
workgroup.in_values[local_id] = conv_in.data[global_id];
if (local_id > 60u)
{
workgroup.in_values[(local_id + 4u) - 1u] = conv_in.data[(global_id + 4u) - 1u];
}
if (local_id < 4u)
{
workgroup.kernel_values[local_id] = conv_param.kernel_values[local_id];
}
// memory barrier is not supported in MSL
threadgroup_barrier(mem_flags::mem_threadgroup);
uint sum = 0u;
for (uint i = 0u; i < 4u; i++)
{
sum += ((workgroup.in_values[local_id + i]) * workgroup.kernel_values[i]);
}
conv_out.data[global_id] = sum;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment