Skip to content

Instantly share code, notes, and snippets.

@n-taku
Last active August 8, 2020 13:31
Show Gist options
  • Save n-taku/9471e1ba470408e8ec03d9ef4d8e7131 to your computer and use it in GitHub Desktop.
Save n-taku/9471e1ba470408e8ec03d9ef4d8e7131 to your computer and use it in GitHub Desktop.
UnityのBarracudaで推論するサンプル
using Unity.Barracuda;
public class MnistInference
{
readonly IWorker worker;
public MnistInference(NNModel modelAsset)
{
var runtimeModel = ModelLoader.Load(modelAsset);
worker = WorkerFactory.CreateWorker(WorkerFactory.Type.CSharpBurst, runtimeModel);
}
//推論
//Mnistは28x28のfloat値(0~1)のinputで推論できる,左上が原点で右下に向かう座標系
public int Inference(float[] inputFloats)
{
//推論する
var scores = InferenceOnnx(inputFloats);
//最大のIndexを求める.Indexが推論した数字
var maxScore = float.MinValue;
int maxIndex = 0;
for (int i = 0; i < scores.Length; i++)
{
float score = scores[i];
if (maxScore < score)
{
maxScore = score;
maxIndex = i;
}
}
return maxIndex;
}
private float[] InferenceOnnx(float[] input)
{
var inputTensor = new Tensor(1, 28, 28, 1, input);
worker.Execute(inputTensor);
var outputTensor = worker.PeekOutput();
var outputArray = outputTensor.ToReadOnlyArray();
inputTensor.Dispose();
outputTensor.Dispose();
return outputArray;
}
~MnistInference()
{
worker.Dispose();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment