Created
February 16, 2021 05:15
-
-
Save harujoh/5265f94dd70af70663c7925caa16b7e5 to your computer and use it in GitHub Desktop.
たし算をTensorFlow.dll上で実行
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Runtime.InteropServices; | |
namespace MinTensorFlow | |
{ | |
[StructLayout(LayoutKind.Sequential)] | |
internal struct TF_Status | |
{ | |
IntPtr status; //tensorflow::Status status; | |
} | |
[StructLayout(LayoutKind.Sequential)] | |
internal struct TF_Graph | |
{ | |
IntPtr data; | |
} | |
[StructLayout(LayoutKind.Sequential)] | |
public struct TF_Operation | |
{ | |
IntPtr data; | |
} | |
[StructLayout(LayoutKind.Sequential)] | |
internal struct TF_OperationDescription | |
{ | |
IntPtr data; | |
} | |
[StructLayout(LayoutKind.Sequential)] | |
internal struct TF_Tensor | |
{ | |
IntPtr data; | |
} | |
[StructLayout(LayoutKind.Sequential)] | |
internal struct TF_Session | |
{ | |
IntPtr data; | |
} | |
[StructLayout(LayoutKind.Sequential)] | |
internal struct TF_SessionOptions | |
{ | |
IntPtr data; | |
} | |
[StructLayout(LayoutKind.Sequential)] | |
public struct TF_Output | |
{ | |
public TF_Operation oper; | |
public int index; // The index of the output within oper. | |
public TF_Output(TF_Operation oper, int index) | |
{ | |
this.oper = oper; | |
this.index = index; | |
} | |
} | |
[StructLayout(LayoutKind.Sequential)] | |
internal struct TF_Buffer | |
{ | |
IntPtr data; | |
IntPtr length; | |
data_deallocator data_deallocator; | |
} | |
[UnmanagedFunctionPointer(CallingConvention.Cdecl)] | |
internal delegate void data_deallocator(IntPtr data, IntPtr len, IntPtr arg); | |
[UnmanagedFunctionPointer(CallingConvention.Cdecl)] | |
internal delegate void deallocator(IntPtr data, IntPtr len, IntPtr arg); | |
public enum TF_Code | |
{ | |
TF_OK = 0, | |
TF_CANCELLED = 1, | |
TF_UNKNOWN = 2, | |
TF_INVALID_ARGUMENT = 3, | |
TF_DEADLINE_EXCEEDED = 4, | |
TF_NOT_FOUND = 5, | |
TF_ALREADY_EXISTS = 6, | |
TF_PERMISSION_DENIED = 7, | |
TF_UNAUTHENTICATED = 16, | |
TF_RESOURCE_EXHAUSTED = 8, | |
TF_FAILED_PRECONDITION = 9, | |
TF_ABORTED = 10, | |
TF_OUT_OF_RANGE = 11, | |
TF_UNIMPLEMENTED = 12, | |
TF_INTERNAL = 13, | |
TF_UNAVAILABLE = 14, | |
TF_DATA_LOSS = 15, | |
} | |
enum TF_DataType | |
{ | |
TF_FLOAT = 1, | |
TF_DOUBLE = 2, | |
TF_INT32 = 3, // Int32 tensors are always in 'host' memory. | |
TF_UINT8 = 4, | |
TF_INT16 = 5, | |
TF_INT8 = 6, | |
TF_STRING = 7, | |
TF_COMPLEX64 = 8, // Single-precision complex | |
TF_COMPLEX = 8, // Old identifier kept for API backwards compatibility | |
TF_INT64 = 9, | |
TF_BOOL = 10, | |
TF_QINT8 = 11, // Quantized int8 | |
TF_QUINT8 = 12, // Quantized uint8 | |
TF_QINT32 = 13, // Quantized int32 | |
TF_BFLOAT16 = 14, // Float32 truncated to 16 bits. Only for cast ops. | |
TF_QINT16 = 15, // Quantized int16 | |
TF_QUINT16 = 16, // Quantized uint16 | |
TF_UINT16 = 17, | |
TF_COMPLEX128 = 18, // Double-precision complex | |
TF_HALF = 19, | |
TF_RESOURCE = 20, | |
TF_VARIANT = 21, | |
TF_UINT32 = 22, | |
TF_UINT64 = 23, | |
} | |
class unique_ptr<T> | |
{ | |
private T v; | |
private Action<T> Deleteter; | |
public unique_ptr(T v, Action<T> deleteter) | |
{ | |
this.v = v; | |
this.Deleteter = deleteter; | |
} | |
public T get() | |
{ | |
return v; | |
} | |
~unique_ptr() | |
{ | |
Deleteter(v); | |
} | |
} | |
class unique_tensor_ptr : unique_ptr<TF_Tensor> | |
{ | |
public unique_tensor_ptr(TF_Tensor v, Action<TF_Tensor> deleteter) : base(v, deleteter) { } | |
} | |
class Program | |
{ | |
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)]//linuxだとlibtensorflow.so | |
public static extern TF_Status TF_NewStatus(); | |
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)]//linuxだとlibtensorflow.so | |
public static extern TF_Graph TF_NewGraph(); | |
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)]//linuxだとlibtensorflow.so | |
public static extern void TF_DeleteGraph(TF_Graph g); | |
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)]//linuxだとlibtensorflow.so | |
public static extern void TF_DeleteStatus(TF_Status s); | |
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)]//linuxだとlibtensorflow.so | |
public static extern TF_OperationDescription TF_NewOperation(TF_Graph graph, string op_type, string oper_name); | |
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so | |
public static extern void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value); | |
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so | |
public static extern void TF_SetAttrShape(TF_OperationDescription desc, string attr_name, long[] dims, int num_dims); | |
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so | |
public static extern TF_Operation TF_FinishOperation(TF_OperationDescription desc, TF_Status status); | |
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)]//linuxだとlibtensorflow.so | |
public static extern TF_Code TF_GetCode(TF_Status s); | |
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so | |
public static extern string TF_Message(TF_Status s); | |
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so | |
public static extern TF_Tensor TF_NewTensor(TF_DataType dtype, IntPtr dims, int num_dims, IntPtr data, IntPtr len, deallocator deallocator, IntPtr deallocator_arg); | |
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so | |
public static extern void TF_DeleteTensor(TF_Tensor t); | |
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so | |
public static extern void TF_SetAttrTensor(TF_OperationDescription desc, string attr_name, TF_Tensor value, TF_Status status); | |
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so | |
public static extern TF_DataType TF_TensorType(TF_Tensor t); | |
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so | |
public static extern void TF_AddInputList(TF_OperationDescription desc, TF_Output[] inputs, int num_inputs); | |
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so | |
public static extern TF_SessionOptions TF_NewSessionOptions(); | |
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so | |
public static extern void TF_EnableXLACompilation(TF_SessionOptions options, byte enable); | |
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so | |
public static extern TF_Session TF_NewSession(TF_Graph graph, TF_SessionOptions opt, TF_Status status); | |
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so | |
public static extern void TF_DeleteSessionOptions(TF_SessionOptions opt); | |
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so | |
public static extern IntPtr TF_TensorData(TF_Tensor t); | |
[DllImport("tensorflow.dll", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Cdecl)] //linuxだとlibtensorflow.so | |
public static extern void TF_SessionRun(TF_Session session, TF_Buffer[] run_options, TF_Output[] inputs, TF_Tensor[] input_values, int ninputs, TF_Output[] outputs, IntPtr output_values, int noutputs, TF_Operation[] target_opers, int ntargets, TF_Buffer[] run_metadata, TF_Status status); | |
static TF_Operation Placeholder(TF_Graph graph, TF_Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, List<long> dims = null) | |
{ | |
if (dims is null) | |
{ | |
dims = new List<long>(); | |
} | |
TF_Operation op; | |
PlaceholderHelper(graph, s, name, dtype, dims, out op); | |
return op; | |
} | |
static void PlaceholderHelper(TF_Graph graph, TF_Status s, string name, TF_DataType dtype, List<long> dims, out TF_Operation op) | |
{ | |
TF_OperationDescription desc = TF_NewOperation(graph, "Placeholder", name); | |
TF_SetAttrType(desc, "dtype", dtype); | |
if (dims is not null) | |
{ | |
TF_SetAttrShape(desc, "shape", dims.ToArray(), dims.Count); | |
} | |
op = TF_FinishOperation(desc, s); | |
} | |
static TF_Tensor Int32Tensor(Int32 v) | |
{ | |
const int num_bytes = sizeof(Int32); | |
Int32[] values = new Int32[1]; | |
values[0] = v; | |
GCHandle handle = GCHandle.Alloc(values, GCHandleType.Pinned); | |
var result = TF_NewTensor(TF_DataType.TF_INT32, IntPtr.Zero, 0, handle.AddrOfPinnedObject(), (IntPtr)num_bytes, Deallocator, GCHandle.ToIntPtr(handle));//第6引数は本来はメモリの開放処理 | |
return result; | |
} | |
static void Deallocator(IntPtr data, IntPtr len, IntPtr arg) | |
{ | |
var gch = GCHandle.FromIntPtr(arg); | |
gch.Free(); | |
} | |
static TF_Operation ScalarConst(Int32 v, TF_Graph graph, TF_Status s, string name = "scalar") | |
{ | |
unique_tensor_ptr tensor = new unique_tensor_ptr(Int32Tensor(v), TF_DeleteTensor); | |
return Const(tensor.get(), graph, s, name); | |
} | |
static TF_Operation Const(TF_Tensor t, TF_Graph graph, TF_Status s, string name) | |
{ | |
TF_Operation op; | |
ConstHelper(t, graph, s, name, out op); | |
return op; | |
} | |
static void ConstHelper(TF_Tensor t, TF_Graph graph, TF_Status s, string name, out TF_Operation op) | |
{ | |
TF_OperationDescription desc = TF_NewOperation(graph, "Const", name); | |
TF_SetAttrTensor(desc, "value", t, s); | |
Check(s); | |
TF_SetAttrType(desc, "dtype", TF_TensorType(t)); | |
op = TF_FinishOperation(desc, s); | |
Check(s); | |
} | |
static void AddOpHelper(TF_Operation l, TF_Operation r, TF_Graph graph, TF_Status s, string name, out TF_Operation op, bool check) | |
{ | |
TF_OperationDescription desc = TF_NewOperation(graph, "AddN", name); | |
TF_Output[] add_inputs = new TF_Output[2] { new TF_Output(l, 0), new TF_Output(r, 0) }; | |
TF_AddInputList(desc, add_inputs, 2); | |
op = TF_FinishOperation(desc, s); | |
if (check) | |
{ | |
Check(s); | |
} | |
} | |
static TF_Operation Add(TF_Operation l, TF_Operation r, TF_Graph graph, TF_Status s, string name = "add") | |
{ | |
TF_Operation op; | |
AddOpHelper(l, r, graph, s, name, out op, true); | |
return op; | |
} | |
class CSession | |
{ | |
TF_Session session_; | |
TF_Output[] inputs_; | |
TF_Tensor[] input_values_; | |
TF_Output[] outputs_; | |
TF_Tensor[] output_values_; | |
TF_Operation[] targets_ = { }; | |
public CSession(TF_Graph graph, TF_Status s, bool use_XLA = false) | |
{ | |
TF_SessionOptions opts = TF_NewSessionOptions(); | |
TF_EnableXLACompilation(opts, (byte)(use_XLA ? 1 : 0)); | |
session_ = TF_NewSession(graph, opts, s); | |
TF_DeleteSessionOptions(opts); | |
} | |
public void SetInputs(List<(TF_Operation, TF_Tensor)> inputs) | |
{ | |
DeleteInputValues(); | |
inputs_ = new TF_Output[inputs.Count]; | |
input_values_ = new TF_Tensor[inputs.Count]; | |
for (int i = 0; i < inputs.Count; i++) | |
{ | |
inputs_[i] = (new TF_Output(inputs[i].Item1, 0)); | |
input_values_[i] = (inputs[i].Item2); | |
} | |
} | |
public void SetOutputs(List<TF_Operation> outputs) | |
{ | |
ResetOutputValues(); | |
outputs_ = new TF_Output[outputs.Count]; | |
for (int i = 0; i < outputs.Count; i++) | |
{ | |
outputs_[i] = (new TF_Output(outputs[i], 0)); | |
} | |
output_values_ = new TF_Tensor[outputs.Count]; | |
for (int i = 0; i < outputs.Count; i++) | |
output_values_[i] = new TF_Tensor(); | |
} | |
public void Run(TF_Status s) | |
{ | |
if (inputs_.Length != input_values_.Length) | |
{ | |
throw new Exception("Call SetInputs() before Run()"); | |
} | |
output_values_ = new TF_Tensor[outputs_.Length]; | |
for (int i = 0; i < outputs_.Length; i++) | |
output_values_[i] = new TF_Tensor(); | |
GCHandle source = GCHandle.Alloc(output_values_, GCHandleType.Pinned); | |
TF_SessionRun( | |
session_, null, inputs_, input_values_, inputs_.Length, | |
outputs_, source.AddrOfPinnedObject(), outputs_.Length, targets_, | |
targets_.Length, null, s); | |
source.Free(); | |
DeleteInputValues(); | |
} | |
void DeleteInputValues() | |
{ | |
if (input_values_ != null) | |
for (int i = 0; i < input_values_.Length; ++i) | |
{ | |
TF_DeleteTensor(input_values_[i]); | |
} | |
input_values_ = null; | |
} | |
void ResetOutputValues() | |
{ | |
if (output_values_ != null) | |
for (int i = 0; i < output_values_.Length; ++i) | |
{ | |
TF_DeleteTensor(output_values_[i]); | |
} | |
output_values_ = null; | |
} | |
public TF_Tensor output_tensor(int i) { return output_values_[i]; } | |
} | |
static void Main(string[] args) | |
{ | |
//EnvironmentVariableTarget.Machineなので実行に管理者権限が必要 | |
//Environment.SetEnvironmentVariable("TF_CPP_MIN_LOG_LEVEL", "2", EnvironmentVariableTarget.Machine); | |
TF_Status s = TF_NewStatus(); | |
TF_Graph graph = TF_NewGraph(); | |
// Make a placeholder operation. | |
TF_Operation feed = Placeholder(graph, s); | |
Check(s); | |
// Make a constant operation with the scalar "2". | |
TF_Operation two = ScalarConst(2, graph, s); | |
Check(s); | |
// Add operation. | |
TF_Operation add = Add(feed, two, graph, s); | |
Check(s); | |
// Create a session for this graph. | |
CSession csession = new CSession(graph, s); | |
Check(s); | |
// Run the graph. | |
csession.SetInputs(new List<(TF_Operation, TF_Tensor)> { (feed, Int32Tensor(3)) }); | |
csession.SetOutputs(new List<TF_Operation> { add }); | |
csession.Run(s); | |
Check(s); | |
TF_Tensor output = csession.output_tensor(0); | |
IntPtr result = TF_TensorData(output); | |
Int32 output_contents = Marshal.ReadInt32(result); | |
Console.WriteLine(output_contents); | |
TF_DeleteGraph(graph); | |
TF_DeleteStatus(s); | |
Console.WriteLine("Complete."); | |
Console.Read(); | |
} | |
static void Check(TF_Status s) | |
{ | |
TF_Code code = TF_GetCode(s); | |
if (code != TF_Code.TF_OK) | |
{ | |
throw new Exception(TF_Message(s)); | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment