Skip to content

Instantly share code, notes, and snippets.

@praeclarum
Created March 3, 2022 22:22
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save praeclarum/b8077771fb341a1f9c28240113e00425 to your computer and use it in GitHub Desktop.
Save praeclarum/b8077771fb341a1f9c28240113e00425 to your computer and use it in GitHub Desktop.
Translation of Apple's MPSGraph MNIST sample to Xamarin.iOS
#nullable enable
using System;
using System.Collections.Generic;
using System.IO;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Foundation;
using Metal;
using MetalPerformanceShaders;
using MetalPerformanceShadersGraph;
namespace TestMPSGraph
{
public class MnistTest
{
const int batchSize = 16;
readonly int numTrainingIterations = 300;
readonly IMTLCommandQueue commandQueue = MTLDevice.SystemDefault!.CreateCommandQueue()!;
readonly MnistGraph graph = new (batchSize);
readonly MnistData data = new ();
public event EventHandler? BatchTrained;
public void Run()
{
MPSCommandBuffer? latestCommandBuffer = null;
for (var i = 0; i < numTrainingIterations; i++)
{
latestCommandBuffer = RunTrainingIterationBatch((i + 1) / (float)numTrainingIterations);
}
latestCommandBuffer?.WaitUntilCompleted();
}
MPSCommandBuffer RunTrainingIterationBatch(float progress)
{
var commandBuffer = MPSCommandBuffer.Create(commandQueue);
var xInput = data.GetRandomTrainingBatch(commandQueue.Device, batchSize, out var yLabels);
graph.EncodeTrainingBatch(commandBuffer, xInput, yLabels, loss =>
{
commandBuffer.Dispose();
Console.WriteLine($"Progress: {progress*100:000.0}% Loss: {loss}");
BatchTrained?.Invoke(this, EventArgs.Empty);
});
commandBuffer.Commit();
return commandBuffer;
}
}
public class MnistData
{
public const int ImageSize = 28;
public const int NumClasses = 10;
readonly Random random = new Random();
const int ImageMetadataPrefixSize = 16;
readonly int totalNumberOfTrainImages;
const string trainImagesUrl = "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz";
const string trainLabelsUrl = "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz";
readonly byte[] dataTrainLabel;
readonly byte[] dataTrainImage;
static readonly HttpClient http = new HttpClient();
public MnistData()
{
var ti = LoadUrlAsync(trainImagesUrl);
var tl = LoadUrlAsync(trainLabelsUrl);
dataTrainImage = ti.Result;
dataTrainLabel = tl.Result;
totalNumberOfTrainImages = dataTrainLabel.Length - ImageMetadataPrefixSize;
}
static async Task<byte[]> LoadUrlAsync(string url)
{
var cacheDir =
Path.Combine(
NSFileManager.DefaultManager.GetUrl(NSSearchPathDirectory.CachesDirectory, NSSearchPathDomain.User, null, false, out var error).Path,
"com.kruegersystems.TestMPSGraph");
Directory.CreateDirectory(cacheDir);
var name = Path.GetFileNameWithoutExtension(url);
var path = Path.Combine(cacheDir, name);
try
{
var data = await File.ReadAllBytesAsync(path).ConfigureAwait(false);
Console.WriteLine($"Loaded {url}");
return data;
}
catch
{
Console.WriteLine(cacheDir);
Console.WriteLine($"Downloading {url}");
using var stream = await http.GetStreamAsync(url).ConfigureAwait(false);
using var gz = new System.IO.Compression.GZipStream(stream, System.IO.Compression.CompressionMode.Decompress);
using var ostream = new MemoryStream();
await gz.CopyToAsync(ostream).ConfigureAwait(false);
var data = ostream.ToArray();
await File.WriteAllBytesAsync(path, data).ConfigureAwait(false);
Console.WriteLine($"Downloaded {url}");
return data;
}
}
public MPSGraphTensorData GetRandomTrainingBatch(IMTLDevice device, int batchSize, out MPSGraphTensorData labels)
{
var inputVals = new float[batchSize * ImageSize * ImageSize];
var labelVals = new float[batchSize * NumClasses];
for (var batchInd = 0; batchInd < batchSize; batchInd++)
{
var randomImageIdx = random.Next(totalNumberOfTrainImages);
var valueOffset = ImageMetadataPrefixSize + randomImageIdx * ImageSize * ImageSize;
for (var ind = 0; ind < ImageSize * ImageSize; ind++) {
inputVals[batchInd * ImageSize * ImageSize + ind] = dataTrainImage[valueOffset + ind] / 255.0f;
}
var labelOffset = ImageMetadataPrefixSize + randomImageIdx;
for (int classIdx = 0; classIdx < NumClasses; classIdx++)
{
labelVals[batchInd * NumClasses + classIdx] = classIdx == dataTrainLabel[labelOffset] ? 1.0f : 0.0f;
}
}
labels = MPSGraphTensorData.Create (device, labelVals, batchSize, NumClasses);
return MPSGraphTensorData.Create (device, inputVals, batchSize, ImageSize * ImageSize);
}
}
public class MnistGraph : MPSGraph
{
const float lambda = 0.01f;
readonly int imageSize;
readonly int numClasses;
readonly int batchSize;
readonly MPSGraphConvolution2DOpDescriptor convDesc = MPSGraphConvolution2DOpDescriptor.Create(
strideInX: 1,
strideInY: 1,
dilationRateInX: 1,
dilationRateInY: 1,
groups: 1,
paddingStyle: MPSGraphPaddingStyle.Same,
dataLayout: MPSGraphTensorNamedDataLayout.Nhwc,
weightsLayout: MPSGraphTensorNamedDataLayout.Hwio)!;
readonly MPSGraphPooling2DOpDescriptor poolDesc = MPSGraphPooling2DOpDescriptor.Create(
kernelWidth: 2,
kernelHeight: 2,
strideInX: 2,
strideInY: 2,
paddingStyle: MPSGraphPaddingStyle.Same,
dataLayout: MPSGraphTensorNamedDataLayout.Nhwc)!;
readonly MPSGraphTensor[] inferenceTensors;
readonly MPSGraphOperation[] inferenceOps;
readonly MPSGraphTensor[] trainingTensors;
readonly MPSGraphOperation[] trainingOps;
readonly MPSGraphTensor sourcePlaceholder;
readonly MPSGraphTensor labelsPlaceholder;
public MnistGraph(int batchSize)
{
this.imageSize = MnistData.ImageSize;
this.numClasses = MnistData.NumClasses;
this.batchSize = batchSize;
Options = MPSGraphOptions.SynchronizeResults;// | MPSGraphOptions.Verbose;
Console.WriteLine(Options);
sourcePlaceholder = this.Placeholder(new[] { batchSize, imageSize * imageSize }, null);
labelsPlaceholder = this.Placeholder(new[] { batchSize, numClasses }, null);
var variables = new List<MPSGraphTensor>();
var reshapedInput = this.Reshape(sourcePlaceholder, shape: new[] { batchSize, imageSize, imageSize, 1 }, null);
var conv0 = AddConvLayer(reshapedInput, weightsShape: new int[4] { 5, 5, 1, 32 }, convDesc, variables);
var pool0 = this.MaxPooling2D(conv0, poolDesc, null);
var conv1Tensor = AddConvLayer(pool0, weightsShape: new int[4] { 5, 5, 32, 64 }, convDesc, variables);
var pool1Tensor = this.MaxPooling2D(conv1Tensor, poolDesc, null);
var reshape = this.Reshape(pool1Tensor, new[] { -1, 64 * 7 * 7 }, null);
var fc0 = AddFullyConnectedLayer(reshape, weightsShape: new int[2] { 7 * 7 * 64, 1024 }, hasActivation: true, variables);
var fc1 = AddFullyConnectedLayer(fc0, weightsShape: new int[2] { 1024, numClasses }, hasActivation: false, variables);
var softmax = this.SoftMax(fc1, axis: -1, null);
var loss = this.SoftMaxCrossEntropy(fc1, labels: labelsPlaceholder, axis: -1, MPSGraphLossReductionType.Sum, null);
var batchSizeT = this.Constant((float)batchSize);
var lossMean = this.Division(loss, batchSizeT, null);
inferenceTensors = new[] { softmax };
inferenceOps = Array.Empty<MPSGraphOperation>();
trainingTensors = new[] { lossMean };
trainingOps = GetAssignOperations(lossMean, variables);
}
MPSGraphOperation[] GetAssignOperations(MPSGraphTensor loss, List<MPSGraphTensor> variables)
{
var grads = this.Gradients(loss, variables.ToArray(), null);
var lambdaT = this.Constant(lambda);
var updateOps = new List<MPSGraphOperation>();
foreach (var (k, value) in grads)
{
var key = (MPSGraphTensor)k;
var update = this.StochasticGradientDescent(lambdaT, key, (MPSGraphTensor)value, null);
var assign = this.Assign(key, update, null);
updateOps.Add(assign);
}
return updateOps.ToArray();
}
MPSGraphTensor AddFullyConnectedLayer(MPSGraphTensor source, int[] weightsShape, bool hasActivation, List<MPSGraphTensor> variables)
{
var weightCount = 1;
foreach (var length in weightsShape)
{
weightCount *= length;
}
var biasCount = weightsShape[1];
var weightsValues = GetRandomData(weightCount, -0.2f, 0.2f);
var biasesValues = new float[biasCount];
Array.Fill(biasesValues, 0.1f);
var weights = this.Variable(weightsValues, weightsShape);
var biases = this.Variable(biasesValues, new[] { biasCount });
var fc = this.MatrixMultiplication(source, weights, null);
var fcBias = this.Addition(fc, biases, null);
variables.Add(weights);
variables.Add(biases);
if (!hasActivation)
return fcBias;
var activation = this.ReLU(fcBias, null);
return activation;
}
MPSGraphTensor AddConvLayer(MPSGraphTensor source, int[] weightsShape, MPSGraphConvolution2DOpDescriptor desc, List<MPSGraphTensor> variables)
{
var weightCount = 1;
foreach (var length in weightsShape)
{
weightCount *= length;
}
var biasCount = weightsShape[3];
var convWeightsValues = GetRandomData(weightCount, -0.2f, 0.2f);
var weights = this.Variable(convWeightsValues, weightsShape);
var biases = this.Variable(0.1f, new[] { biasCount });
var conv = this.Convolution2D(source, weights, desc, null);
var convBias = this.Addition(conv, biases, null);
var activation = this.ReLU(convBias, null);
variables.Add(weights);
variables.Add(biases);
return activation;
}
readonly Random random = new Random();
readonly Semaphore doubleBufferSemaphore = new Semaphore(2, 2);
float[] GetRandomData(int length, float min, float max)
{
var d = max - min;
var r = new float[length];
for (var i = 0; i < length; i++)
{
r[i] = ((float)random.NextDouble() * d) + min;
}
return r;
}
public MPSGraphTensorData EncodeTrainingBatch(MPSCommandBuffer commandBuffer, MPSGraphTensorData sourceTensorData, MPSGraphTensorData labelsTensorData, Action<float>? completion)
{
doubleBufferSemaphore.WaitOne();
var executionDesc = new MPSGraphExecutionDescriptor
{
CompletionHandler = (results, error) =>
{
// This is necessary because there's a weird synchronization issue with
// this callback. I have requested support from Apple about it.
// Same things happens in Swift, so just some bug or mistake in the sample.
Thread.Sleep(5);
var lossTensorData = results[trainingTensors[0]];
var loss = new[] { 0.0f };
lossTensorData.Read(loss);
doubleBufferSemaphore.Release();
if (completion is { } c)
{
BeginInvokeOnMainThread(() => c(loss[0]));
}
}
};
var feed = NSDictionary<MPSGraphTensor, MPSGraphTensorData>.FromObjectsAndKeys(
new[] { sourceTensorData, labelsTensorData },
new[] { sourcePlaceholder, labelsPlaceholder },
2);
var fetch = this.Encode(commandBuffer, feed, trainingTensors, trainingOps, executionDesc);
return fetch[trainingTensors[0]];
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment