Skip to content

Instantly share code, notes, and snippets.

@harujoh
Last active February 7, 2021 14:03
Show Gist options
  • Save harujoh/a6fa0c9a448524c68b81ef796ff84537 to your computer and use it in GitHub Desktop.
Save harujoh/a6fa0c9a448524c68b81ef796ff84537 to your computer and use it in GitHub Desktop.
トップダウン型の自動微分
using System;
using System.Collections.Generic;
using System.Linq;
namespace ConsoleApp1
{
class Program
{
static void Main(string[] args)
{
var x = new Variable(3f);
var y = x ^ 3f; // = Math.Pow(x, 3f);
y.Backward(createGraph: true);
var gy = x.Grad;
Console.WriteLine("gy:" + gy.Data);
x.ClearGrad();
gy.Backward(createGraph: true);
var ggy = x.Grad;
Console.WriteLine("ggy:" + ggy.Data);
x.ClearGrad();
ggy.Backward(createGraph: true);
var gggy = x.Grad;
Console.WriteLine("gggy" + gggy.Data);
Console.Read();
}
}
class BackpropConfig : IDisposable
{
bool saveEnable;
public static bool Enable = true;
public BackpropConfig(bool enable)
{
saveEnable = Enable;
Enable = enable;
}
public void Dispose()
{
Enable = saveEnable;
}
}
class Variable
{
public float Data;
public string Name;
public Variable Grad;
private Function Creator;
public int Generation;
public Variable(float data, string name = "")
{
this.Data = data;
this.Name = name;
this.Grad = null;
this.Creator = null;
this.Generation = 0;
}
public void SetCreator(Function func)
{
this.Creator = func;
this.Generation = func.Generation + 1;
}
public void ClearGrad()
{
this.Grad = null;
}
public void Backward(bool retainGrad = false, bool createGraph = false)
{
if (this.Grad == null)
{
this.Grad = new Variable(1f);
}
List<Function> funcs = new List<Function> {this.Creator};
while (funcs.Count != 0)
{
Function f = funcs[funcs.Count - 1];
funcs.RemoveAt(funcs.Count - 1);
Variable[] y = f.Outputs;
Variable[] gys = new Variable[y.Length];
for (int i = 0; i < gys.Length; i++)
{
gys[i] = y[i].Grad;
}
using (new BackpropConfig(createGraph))
{
Variable[] gx = f.Backward(gys);
Variable[] x = f.Inputs;
for (int i = 0; i < x.Length; i++)
{
if (x[i].Grad == null)
{
x[i].Grad = gx[i];
}
else
{
x[i].Grad += gx[i];
}
if (x[i].Creator != null)
{
if (funcs.Count == 0)
{
funcs.Add(x[i].Creator);
}
else
{
Function creator = x[i].Creator;
for (int j = 0; j < funcs.Count; j++)
{
if (funcs[j] == creator)
{
break;
}
if (funcs[j].Generation >= creator.Generation)
{
funcs.Insert(j, creator);
break;
}
}
}
}
}
}
if (!retainGrad)
{
for (int i = 0; i < y.Length; i++)
{
y[i].Grad = null;
}
}
}
}
public static implicit operator Variable(float d)
{
return new Variable(d);
}
public static Variable operator -(Variable a)
{
return Function.Neg(a);
}
public static Variable operator +(Variable a, Variable b)
{
return Function.Add(a, b);
}
public static Variable operator -(Variable a, Variable b)
{
return Function.Sub(a, b);
}
public static Variable operator *(Variable a, Variable b)
{
return Function.Mul(a, b);
}
public static Variable operator /(Variable a, Variable b)
{
return Function.Div(a, b);
}
public static Variable operator ^(Variable a, float b)
{
return Function.Pow(a, b);
}
}
abstract class Function
{
public Variable[] Inputs;
public Variable[] Outputs;
public int Generation = 0;
private Variable[] Call(params Variable[] inputs)
{
float[] xs = new float[inputs.Length];
if (BackpropConfig.Enable)
{
for (int i = 0; i < xs.Length; i++)
{
xs[i] = inputs[i].Data;
}
}
else
{
for (int i = 0; i < xs.Length; i++)
{
xs[i] = inputs[i].Data;
if (this.Generation < inputs[i].Generation)
{
this.Generation = inputs[i].Generation;
}
}
}
float[] ys = this.Forward(xs);
Variable[] outputs = new Variable[ys.Length];
if (BackpropConfig.Enable)
{
for (int i = 0; i < outputs.Length; i++)
{
outputs[i] = new Variable(ys[i]);
outputs[i].SetCreator(this);
}
this.Outputs = outputs;
this.Inputs = inputs;
}
else
{
for (int i = 0; i < outputs.Length; i++)
{
outputs[i] = new Variable(ys[i]);
}
}
return outputs;
}
protected abstract float[] Forward(float[] x);
public abstract Variable[] Backward(Variable[] gy);
public static Variable Neg(Variable x)
{
return new Neg().Call(x)[0];
}
public static Variable Add(Variable x0, Variable x1)
{
return new Add().Call(x0, x1)[0];
}
public static Variable Sub(Variable x0, Variable x1)
{
return new Sub().Call(x0, x1)[0];
}
public static Variable Mul(Variable x0, Variable x1)
{
return new Mul().Call(x0, x1)[0];
}
public static Variable Div(Variable x0, Variable x1)
{
return new Div().Call(x0, x1)[0];
}
public static Variable Pow(Variable x, float c)
{
return new Pow(c).Call(x)[0];
}
}
class Neg : Function
{
protected override float[] Forward(float[] x)
{
return new[] { -x[0] };
}
public override Variable[] Backward(Variable[] gy)
{
return new[] { -gy[0] };
}
}
class Add : Function
{
protected override float[] Forward(float[] x)
{
return new[] { x[0] + x[1] };
}
public override Variable[] Backward(Variable[] gy)
{
return new[] { gy[0], gy[0] };
}
}
class Sub : Function
{
protected override float[] Forward(float[] x)
{
return new[] { x[0] - x[1] };
}
public override Variable[] Backward(Variable[] gy)
{
return new[] { gy[0], -gy[0] };
}
}
class Mul : Function
{
protected override float[] Forward(float[] x)
{
return new[] { x[0] * x[1] };
}
public override Variable[] Backward(Variable[] gy)
{
return new[] { gy[0] * this.Inputs[1], gy[0] * this.Inputs[0] };
}
}
class Div : Function
{
protected override float[] Forward(float[] x)
{
return new[] { x[0] / x[1] };
}
public override Variable[] Backward(Variable[] gy)
{
return new[] { gy[0] * this.Inputs[1], gy[0] * (-this.Inputs[0] / Pow(this.Inputs[1], 2f)) };
}
}
class Pow : Function
{
private float c;
public Pow(float c)
{
this.c = c;
}
protected override float[] Forward(float[] x)
{
return new[] { MathF.Pow(x[0], c) };
}
public override Variable[] Backward(Variable[] gy)
{
return new[] { c * Pow(this.Inputs[0], c - 1) * gy[0] };
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment