Last active
February 7, 2021 14:03
-
-
Save harujoh/a6fa0c9a448524c68b81ef796ff84537 to your computer and use it in GitHub Desktop.
トップダウン型の自動微分
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections.Generic; | |
using System.Linq; | |
namespace ConsoleApp1 | |
{ | |
class Program | |
{ | |
static void Main(string[] args) | |
{ | |
var x = new Variable(3f); | |
var y = x ^ 3f; // = Math.Pow(x, 3f); | |
y.Backward(createGraph: true); | |
var gy = x.Grad; | |
Console.WriteLine("gy:" + gy.Data); | |
x.ClearGrad(); | |
gy.Backward(createGraph: true); | |
var ggy = x.Grad; | |
Console.WriteLine("ggy:" + ggy.Data); | |
x.ClearGrad(); | |
ggy.Backward(createGraph: true); | |
var gggy = x.Grad; | |
Console.WriteLine("gggy" + gggy.Data); | |
Console.Read(); | |
} | |
} | |
class BackpropConfig : IDisposable | |
{ | |
bool saveEnable; | |
public static bool Enable = true; | |
public BackpropConfig(bool enable) | |
{ | |
saveEnable = Enable; | |
Enable = enable; | |
} | |
public void Dispose() | |
{ | |
Enable = saveEnable; | |
} | |
} | |
class Variable | |
{ | |
public float Data; | |
public string Name; | |
public Variable Grad; | |
private Function Creator; | |
public int Generation; | |
public Variable(float data, string name = "") | |
{ | |
this.Data = data; | |
this.Name = name; | |
this.Grad = null; | |
this.Creator = null; | |
this.Generation = 0; | |
} | |
public void SetCreator(Function func) | |
{ | |
this.Creator = func; | |
this.Generation = func.Generation + 1; | |
} | |
public void ClearGrad() | |
{ | |
this.Grad = null; | |
} | |
public void Backward(bool retainGrad = false, bool createGraph = false) | |
{ | |
if (this.Grad == null) | |
{ | |
this.Grad = new Variable(1f); | |
} | |
List<Function> funcs = new List<Function> {this.Creator}; | |
while (funcs.Count != 0) | |
{ | |
Function f = funcs[funcs.Count - 1]; | |
funcs.RemoveAt(funcs.Count - 1); | |
Variable[] y = f.Outputs; | |
Variable[] gys = new Variable[y.Length]; | |
for (int i = 0; i < gys.Length; i++) | |
{ | |
gys[i] = y[i].Grad; | |
} | |
using (new BackpropConfig(createGraph)) | |
{ | |
Variable[] gx = f.Backward(gys); | |
Variable[] x = f.Inputs; | |
for (int i = 0; i < x.Length; i++) | |
{ | |
if (x[i].Grad == null) | |
{ | |
x[i].Grad = gx[i]; | |
} | |
else | |
{ | |
x[i].Grad += gx[i]; | |
} | |
if (x[i].Creator != null) | |
{ | |
if (funcs.Count == 0) | |
{ | |
funcs.Add(x[i].Creator); | |
} | |
else | |
{ | |
Function creator = x[i].Creator; | |
for (int j = 0; j < funcs.Count; j++) | |
{ | |
if (funcs[j] == creator) | |
{ | |
break; | |
} | |
if (funcs[j].Generation >= creator.Generation) | |
{ | |
funcs.Insert(j, creator); | |
break; | |
} | |
} | |
} | |
} | |
} | |
} | |
if (!retainGrad) | |
{ | |
for (int i = 0; i < y.Length; i++) | |
{ | |
y[i].Grad = null; | |
} | |
} | |
} | |
} | |
public static implicit operator Variable(float d) | |
{ | |
return new Variable(d); | |
} | |
public static Variable operator -(Variable a) | |
{ | |
return Function.Neg(a); | |
} | |
public static Variable operator +(Variable a, Variable b) | |
{ | |
return Function.Add(a, b); | |
} | |
public static Variable operator -(Variable a, Variable b) | |
{ | |
return Function.Sub(a, b); | |
} | |
public static Variable operator *(Variable a, Variable b) | |
{ | |
return Function.Mul(a, b); | |
} | |
public static Variable operator /(Variable a, Variable b) | |
{ | |
return Function.Div(a, b); | |
} | |
public static Variable operator ^(Variable a, float b) | |
{ | |
return Function.Pow(a, b); | |
} | |
} | |
abstract class Function | |
{ | |
public Variable[] Inputs; | |
public Variable[] Outputs; | |
public int Generation = 0; | |
private Variable[] Call(params Variable[] inputs) | |
{ | |
float[] xs = new float[inputs.Length]; | |
if (BackpropConfig.Enable) | |
{ | |
for (int i = 0; i < xs.Length; i++) | |
{ | |
xs[i] = inputs[i].Data; | |
} | |
} | |
else | |
{ | |
for (int i = 0; i < xs.Length; i++) | |
{ | |
xs[i] = inputs[i].Data; | |
if (this.Generation < inputs[i].Generation) | |
{ | |
this.Generation = inputs[i].Generation; | |
} | |
} | |
} | |
float[] ys = this.Forward(xs); | |
Variable[] outputs = new Variable[ys.Length]; | |
if (BackpropConfig.Enable) | |
{ | |
for (int i = 0; i < outputs.Length; i++) | |
{ | |
outputs[i] = new Variable(ys[i]); | |
outputs[i].SetCreator(this); | |
} | |
this.Outputs = outputs; | |
this.Inputs = inputs; | |
} | |
else | |
{ | |
for (int i = 0; i < outputs.Length; i++) | |
{ | |
outputs[i] = new Variable(ys[i]); | |
} | |
} | |
return outputs; | |
} | |
protected abstract float[] Forward(float[] x); | |
public abstract Variable[] Backward(Variable[] gy); | |
public static Variable Neg(Variable x) | |
{ | |
return new Neg().Call(x)[0]; | |
} | |
public static Variable Add(Variable x0, Variable x1) | |
{ | |
return new Add().Call(x0, x1)[0]; | |
} | |
public static Variable Sub(Variable x0, Variable x1) | |
{ | |
return new Sub().Call(x0, x1)[0]; | |
} | |
public static Variable Mul(Variable x0, Variable x1) | |
{ | |
return new Mul().Call(x0, x1)[0]; | |
} | |
public static Variable Div(Variable x0, Variable x1) | |
{ | |
return new Div().Call(x0, x1)[0]; | |
} | |
public static Variable Pow(Variable x, float c) | |
{ | |
return new Pow(c).Call(x)[0]; | |
} | |
} | |
class Neg : Function | |
{ | |
protected override float[] Forward(float[] x) | |
{ | |
return new[] { -x[0] }; | |
} | |
public override Variable[] Backward(Variable[] gy) | |
{ | |
return new[] { -gy[0] }; | |
} | |
} | |
class Add : Function | |
{ | |
protected override float[] Forward(float[] x) | |
{ | |
return new[] { x[0] + x[1] }; | |
} | |
public override Variable[] Backward(Variable[] gy) | |
{ | |
return new[] { gy[0], gy[0] }; | |
} | |
} | |
class Sub : Function | |
{ | |
protected override float[] Forward(float[] x) | |
{ | |
return new[] { x[0] - x[1] }; | |
} | |
public override Variable[] Backward(Variable[] gy) | |
{ | |
return new[] { gy[0], -gy[0] }; | |
} | |
} | |
class Mul : Function | |
{ | |
protected override float[] Forward(float[] x) | |
{ | |
return new[] { x[0] * x[1] }; | |
} | |
public override Variable[] Backward(Variable[] gy) | |
{ | |
return new[] { gy[0] * this.Inputs[1], gy[0] * this.Inputs[0] }; | |
} | |
} | |
class Div : Function | |
{ | |
protected override float[] Forward(float[] x) | |
{ | |
return new[] { x[0] / x[1] }; | |
} | |
public override Variable[] Backward(Variable[] gy) | |
{ | |
return new[] { gy[0] * this.Inputs[1], gy[0] * (-this.Inputs[0] / Pow(this.Inputs[1], 2f)) }; | |
} | |
} | |
class Pow : Function | |
{ | |
private float c; | |
public Pow(float c) | |
{ | |
this.c = c; | |
} | |
protected override float[] Forward(float[] x) | |
{ | |
return new[] { MathF.Pow(x[0], c) }; | |
} | |
public override Variable[] Backward(Variable[] gy) | |
{ | |
return new[] { c * Pow(this.Inputs[0], c - 1) * gy[0] }; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment