Skip to content

Instantly share code, notes, and snippets.

Created October 20, 2021 19:25
Show Gist options
  • Save batu/d5bc9d90d2cff346e158e06dd1a23665 to your computer and use it in GitHub Desktop.
Save batu/d5bc9d90d2cff346e158e06dd1a23665 to your computer and use it in GitHub Desktop.
Inference on Barracuda Loaded ONNX model
using System;
using System.Collections;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using Unity.Barracuda;
using UnityEngine;
using Random = UnityEngine.Random;
public class Inferer : MonoBehaviour, InputHandler
public int decisionFrequency = 5;
public NNModel modelAsset;
private NavigationAgent _navigationAgent;
private Model _runtimeModel;
private IWorker _worker;
private VectorObservation _vectorObservation;
private DepthMaskObservation _depthMaskObservation;
private OccupancyGridObservation _occupancyGridObservation;
private WhiskerObservation _whiskerObservation;
private int _inputShape;
void Start()
Debug.LogWarning("The Quaternion turn adjustment is active!");
_navigationAgent = GetComponent<NavigationAgent>();
_vectorObservation = GetComponent<VectorObservation>();
_depthMaskObservation = GetComponent<DepthMaskObservation>();
_whiskerObservation = GetComponent<WhiskerObservation>();
_occupancyGridObservation = GetComponent<OccupancyGridObservation>();
_runtimeModel = ModelLoader.Load(modelAsset);
_inputShape = _runtimeModel.inputs[0].shape[_runtimeModel.inputs[0].shape.Length - 1];
_inputShape = 505;
print($"Input shape for the model: {_runtimeModel.outputs[0]}");
// _worker = WorkerFactory.CreateWorker(WorkerFactory.Type.CSharpRef, _runtimeModel);
_worker = WorkerFactory.CreateWorker(WorkerFactory.Type.CSharpBurst, _runtimeModel);
private Vector3 _movement;
private int _counter = 0;
private float _jump;
IEnumerator AskForDecision()
while (true)
if (_counter == decisionFrequency)
List<float> obsList = new List<float>();
if (_navigationAgent.useLocalRaycasts)
if (_navigationAgent.useDepthMask)
if (_navigationAgent.useOccupancyGrid)
Tensor input = new Tensor(1, _inputShape, obsList.ToArray());
Tensor output = _worker.PeekOutput();
float[] movementArray = output.ToReadOnlyArray();
_movement = new Vector3(movementArray[0], 0, movementArray[1]);
_movement = Quaternion.AngleAxis(-transform.rotation.eulerAngles.y + 90, Vector3.up) * _movement;
_jump = movementArray[2];
_counter = 0;
yield return null;
public Vector3 GetMoveInput()
return _movement;
public float GetLookInputsHorizontal()
return 0;
public float GetLookInputsVertical()
return 0;
public bool GetJumpInputDown()
return _jump > .5f;
public bool GetSprintInputHeld()
return false;
public bool GetCrouchInputDown()
return false;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment