Skip to content

Instantly share code, notes, and snippets.

@harujoh
Created February 16, 2021 05:15
Show Gist options
  • Save harujoh/5265f94dd70af70663c7925caa16b7e5 to your computer and use it in GitHub Desktop.
Save harujoh/5265f94dd70af70663c7925caa16b7e5 to your computer and use it in GitHub Desktop.
たし算をTensorFlow.dll上で実行
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
namespace MinTensorFlow
{
[StructLayout(LayoutKind.Sequential)]
internal struct TF_Status
{
IntPtr status; //tensorflow::Status status;
}
[StructLayout(LayoutKind.Sequential)]
internal struct TF_Graph
{
IntPtr data;
}
[StructLayout(LayoutKind.Sequential)]
public struct TF_Operation
{
IntPtr data;
}
[StructLayout(LayoutKind.Sequential)]
internal struct TF_OperationDescription
{
IntPtr data;
}
[StructLayout(LayoutKind.Sequential)]
internal struct TF_Tensor
{
IntPtr data;
}
[StructLayout(LayoutKind.Sequential)]
internal struct TF_Session
{
IntPtr data;
}
[StructLayout(LayoutKind.Sequential)]
internal struct TF_SessionOptions
{
IntPtr data;
}
[StructLayout(LayoutKind.Sequential)]
public struct TF_Output
{
public TF_Operation oper;
public int index; // The index of the output within oper.
public TF_Output(TF_Operation oper, int index)
{
this.oper = oper;
this.index = index;
}
}
[StructLayout(LayoutKind.Sequential)]
internal struct TF_Buffer
{
IntPtr data;
IntPtr length;
data_deallocator data_deallocator;
}
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
internal delegate void data_deallocator(IntPtr data, IntPtr len, IntPtr arg);
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
internal delegate void deallocator(IntPtr data, IntPtr len, IntPtr arg);
public enum TF_Code
{
TF_OK = 0,
TF_CANCELLED = 1,
TF_UNKNOWN = 2,
TF_INVALID_ARGUMENT = 3,
TF_DEADLINE_EXCEEDED = 4,
TF_NOT_FOUND = 5,
TF_ALREADY_EXISTS = 6,
TF_PERMISSION_DENIED = 7,
TF_UNAUTHENTICATED = 16,
TF_RESOURCE_EXHAUSTED = 8,
TF_FAILED_PRECONDITION = 9,
TF_ABORTED = 10,
TF_OUT_OF_RANGE = 11,
TF_UNIMPLEMENTED = 12,
TF_INTERNAL = 13,
TF_UNAVAILABLE = 14,
TF_DATA_LOSS = 15,
}
enum TF_DataType
{
TF_FLOAT = 1,
TF_DOUBLE = 2,
TF_INT32 = 3, // Int32 tensors are always in 'host' memory.
TF_UINT8 = 4,
TF_INT16 = 5,
TF_INT8 = 6,
TF_STRING = 7,
TF_COMPLEX64 = 8, // Single-precision complex
TF_COMPLEX = 8, // Old identifier kept for API backwards compatibility
TF_INT64 = 9,
TF_BOOL = 10,
TF_QINT8 = 11, // Quantized int8
TF_QUINT8 = 12, // Quantized uint8
TF_QINT32 = 13, // Quantized int32
TF_BFLOAT16 = 14, // Float32 truncated to 16 bits. Only for cast ops.
TF_QINT16 = 15, // Quantized int16
TF_QUINT16 = 16, // Quantized uint16
TF_UINT16 = 17,
TF_COMPLEX128 = 18, // Double-precision complex
TF_HALF = 19,
TF_RESOURCE = 20,
TF_VARIANT = 21,
TF_UINT32 = 22,
TF_UINT64 = 23,
}
class unique_ptr<T>
{
private T v;
private Action<T> Deleteter;
public unique_ptr(T v, Action<T> deleteter)
{
this.v = v;
this.Deleteter = deleteter;
}
public T get()
{
return v;
}
~unique_ptr()
{
Deleteter(v);
}
}
class unique_tensor_ptr : unique_ptr<TF_Tensor>
{
public unique_tensor_ptr(TF_Tensor v, Action<TF_Tensor> deleteter) : base(v, deleteter) { }
}
class Program
{
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)]//linuxだとlibtensorflow.so
public static extern TF_Status TF_NewStatus();
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)]//linuxだとlibtensorflow.so
public static extern TF_Graph TF_NewGraph();
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)]//linuxだとlibtensorflow.so
public static extern void TF_DeleteGraph(TF_Graph g);
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)]//linuxだとlibtensorflow.so
public static extern void TF_DeleteStatus(TF_Status s);
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)]//linuxだとlibtensorflow.so
public static extern TF_OperationDescription TF_NewOperation(TF_Graph graph, string op_type, string oper_name);
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so
public static extern void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value);
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so
public static extern void TF_SetAttrShape(TF_OperationDescription desc, string attr_name, long[] dims, int num_dims);
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so
public static extern TF_Operation TF_FinishOperation(TF_OperationDescription desc, TF_Status status);
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)]//linuxだとlibtensorflow.so
public static extern TF_Code TF_GetCode(TF_Status s);
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so
public static extern string TF_Message(TF_Status s);
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so
public static extern TF_Tensor TF_NewTensor(TF_DataType dtype, IntPtr dims, int num_dims, IntPtr data, IntPtr len, deallocator deallocator, IntPtr deallocator_arg);
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so
public static extern void TF_DeleteTensor(TF_Tensor t);
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so
public static extern void TF_SetAttrTensor(TF_OperationDescription desc, string attr_name, TF_Tensor value, TF_Status status);
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so
public static extern TF_DataType TF_TensorType(TF_Tensor t);
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so
public static extern void TF_AddInputList(TF_OperationDescription desc, TF_Output[] inputs, int num_inputs);
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so
public static extern TF_SessionOptions TF_NewSessionOptions();
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so
public static extern void TF_EnableXLACompilation(TF_SessionOptions options, byte enable);
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so
public static extern TF_Session TF_NewSession(TF_Graph graph, TF_SessionOptions opt, TF_Status status);
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so
public static extern void TF_DeleteSessionOptions(TF_SessionOptions opt);
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so
public static extern IntPtr TF_TensorData(TF_Tensor t);
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so
public static extern void TF_SessionRun(TF_Session session, TF_Buffer[] run_options, TF_Output[] inputs, TF_Tensor[] input_values, int ninputs, TF_Output[] outputs, IntPtr output_values, int noutputs, TF_Operation[] target_opers, int ntargets, TF_Buffer[] run_metadata, TF_Status status);
static TF_Operation Placeholder(TF_Graph graph, TF_Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, List<long> dims = null)
{
if (dims is null)
{
dims = new List<long>();
}
TF_Operation op;
PlaceholderHelper(graph, s, name, dtype, dims, out op);
return op;
}
static void PlaceholderHelper(TF_Graph graph, TF_Status s, string name, TF_DataType dtype, List<long> dims, out TF_Operation op)
{
TF_OperationDescription desc = TF_NewOperation(graph, "Placeholder", name);
TF_SetAttrType(desc, "dtype", dtype);
if (dims is not null)
{
TF_SetAttrShape(desc, "shape", dims.ToArray(), dims.Count);
}
op = TF_FinishOperation(desc, s);
}
static TF_Tensor Int32Tensor(Int32 v)
{
const int num_bytes = sizeof(Int32);
Int32[] values = new Int32[1];
values[0] = v;
GCHandle handle = GCHandle.Alloc(values, GCHandleType.Pinned);
var result = TF_NewTensor(TF_DataType.TF_INT32, IntPtr.Zero, 0, handle.AddrOfPinnedObject(), (IntPtr)num_bytes, Deallocator, GCHandle.ToIntPtr(handle));//第6引数は本来はメモリの開放処理
return result;
}
static void Deallocator(IntPtr data, IntPtr len, IntPtr arg)
{
var gch = GCHandle.FromIntPtr(arg);
gch.Free();
}
static TF_Operation ScalarConst(Int32 v, TF_Graph graph, TF_Status s, string name = "scalar")
{
unique_tensor_ptr tensor = new unique_tensor_ptr(Int32Tensor(v), TF_DeleteTensor);
return Const(tensor.get(), graph, s, name);
}
static TF_Operation Const(TF_Tensor t, TF_Graph graph, TF_Status s, string name)
{
TF_Operation op;
ConstHelper(t, graph, s, name, out op);
return op;
}
static void ConstHelper(TF_Tensor t, TF_Graph graph, TF_Status s, string name, out TF_Operation op)
{
TF_OperationDescription desc = TF_NewOperation(graph, "Const", name);
TF_SetAttrTensor(desc, "value", t, s);
Check(s);
TF_SetAttrType(desc, "dtype", TF_TensorType(t));
op = TF_FinishOperation(desc, s);
Check(s);
}
static void AddOpHelper(TF_Operation l, TF_Operation r, TF_Graph graph, TF_Status s, string name, out TF_Operation op, bool check)
{
TF_OperationDescription desc = TF_NewOperation(graph, "AddN", name);
TF_Output[] add_inputs = new TF_Output[2] { new TF_Output(l, 0), new TF_Output(r, 0) };
TF_AddInputList(desc, add_inputs, 2);
op = TF_FinishOperation(desc, s);
if (check)
{
Check(s);
}
}
static TF_Operation Add(TF_Operation l, TF_Operation r, TF_Graph graph, TF_Status s, string name = "add")
{
TF_Operation op;
AddOpHelper(l, r, graph, s, name, out op, true);
return op;
}
class CSession
{
TF_Session session_;
TF_Output[] inputs_;
TF_Tensor[] input_values_;
TF_Output[] outputs_;
TF_Tensor[] output_values_;
TF_Operation[] targets_ = { };
public CSession(TF_Graph graph, TF_Status s, bool use_XLA = false)
{
TF_SessionOptions opts = TF_NewSessionOptions();
TF_EnableXLACompilation(opts, (byte)(use_XLA ? 1 : 0));
session_ = TF_NewSession(graph, opts, s);
TF_DeleteSessionOptions(opts);
}
public void SetInputs(List<(TF_Operation, TF_Tensor)> inputs)
{
DeleteInputValues();
inputs_ = new TF_Output[inputs.Count];
input_values_ = new TF_Tensor[inputs.Count];
for (int i = 0; i < inputs.Count; i++)
{
inputs_[i] = (new TF_Output(inputs[i].Item1, 0));
input_values_[i] = (inputs[i].Item2);
}
}
public void SetOutputs(List<TF_Operation> outputs)
{
ResetOutputValues();
outputs_ = new TF_Output[outputs.Count];
for (int i = 0; i < outputs.Count; i++)
{
outputs_[i] = (new TF_Output(outputs[i], 0));
}
output_values_ = new TF_Tensor[outputs.Count];
for (int i = 0; i < outputs.Count; i++)
output_values_[i] = new TF_Tensor();
}
public void Run(TF_Status s)
{
if (inputs_.Length != input_values_.Length)
{
throw new Exception("Call SetInputs() before Run()");
}
output_values_ = new TF_Tensor[outputs_.Length];
for (int i = 0; i < outputs_.Length; i++)
output_values_[i] = new TF_Tensor();
GCHandle source = GCHandle.Alloc(output_values_, GCHandleType.Pinned);
TF_SessionRun(
session_, null, inputs_, input_values_, inputs_.Length,
outputs_, source.AddrOfPinnedObject(), outputs_.Length, targets_,
targets_.Length, null, s);
source.Free();
DeleteInputValues();
}
void DeleteInputValues()
{
if (input_values_ != null)
for (int i = 0; i < input_values_.Length; ++i)
{
TF_DeleteTensor(input_values_[i]);
}
input_values_ = null;
}
void ResetOutputValues()
{
if (output_values_ != null)
for (int i = 0; i < output_values_.Length; ++i)
{
TF_DeleteTensor(output_values_[i]);
}
output_values_ = null;
}
public TF_Tensor output_tensor(int i) { return output_values_[i]; }
}
static void Main(string[] args)
{
//EnvironmentVariableTarget.Machineなので実行に管理者権限が必要
//Environment.SetEnvironmentVariable("TF_CPP_MIN_LOG_LEVEL", "2", EnvironmentVariableTarget.Machine);
TF_Status s = TF_NewStatus();
TF_Graph graph = TF_NewGraph();
// Make a placeholder operation.
TF_Operation feed = Placeholder(graph, s);
Check(s);
// Make a constant operation with the scalar "2".
TF_Operation two = ScalarConst(2, graph, s);
Check(s);
// Add operation.
TF_Operation add = Add(feed, two, graph, s);
Check(s);
// Create a session for this graph.
CSession csession = new CSession(graph, s);
Check(s);
// Run the graph.
csession.SetInputs(new List<(TF_Operation, TF_Tensor)> { (feed, Int32Tensor(3)) });
csession.SetOutputs(new List<TF_Operation> { add });
csession.Run(s);
Check(s);
TF_Tensor output = csession.output_tensor(0);
IntPtr result = TF_TensorData(output);
Int32 output_contents = Marshal.ReadInt32(result);
Console.WriteLine(output_contents);
TF_DeleteGraph(graph);
TF_DeleteStatus(s);
Console.WriteLine("Complete.");
Console.Read();
}
static void Check(TF_Status s)
{
TF_Code code = TF_GetCode(s);
if (code != TF_Code.TF_OK)
{
throw new Exception(TF_Message(s));
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment