Skip to content

Instantly share code, notes, and snippets.

@automata
Last active September 6, 2023 02:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save automata/a95828a38b4fee78b77e853e7d5dc2ea to your computer and use it in GitHub Desktop.
Save automata/a95828a38b4fee78b77e853e7d5dc2ea to your computer and use it in GitHub Desktop.
from random import random_float64
from math import tanh
@value
@register_passable("trivial")
struct Value:
var r: Pointer[Int]
var l: Pointer[Int]
var data: Float64
var grad: Float64
var op: StringLiteral
var _id: Float64
fn __init__(data: Float64) -> Value:
return Value(Pointer[Int].get_null(), Pointer[Int].get_null(), data, 0.0, "", random_float64())
fn __eq__(self, other : Value) -> Bool:
# For now using a random_float64 value :-)
if self._id == other._id:
return True
return False
# Add
fn __add__(self, other: Value) -> Value:
return self.new(self.data + other.data, other, "+")
fn __radd__(self, other:Value) -> Value:
return self + other
fn __add__(self, other: Float64) -> Value:
return self + Value(other)
fn __radd__(self, other: Float64) -> Value:
return self + Value(other)
@staticmethod
fn backward_add(inout node: Value):
var l = node.l.bitcast[Value]().load(0)
var r = node.r.bitcast[Value]().load(0)
l.grad += node.grad
r.grad += node.grad
node.l.bitcast[Value]().store(0, l)
node.l.bitcast[Value]().store(0, r)
Value._backward(l)
Value._backward(r)
# Mul
fn __mul__(self, other: Value) -> Value:
return self.new(self.data * other.data, other, "*")
fn __rmul__(self, other: Value) -> Value:
return self * other
fn __mul__(self, other: Float64) -> Value:
return self * Value(other)
fn __rmul__(self, other: Float64) -> Value:
return self * Value(other)
@staticmethod
fn backward_mul(inout node: Value):
var left = node.l.bitcast[Value]().load(0)
var right = node.r.bitcast[Value]().load(0)
left.grad += right.data * node.grad
right.grad += left.data * node.grad
node.l.bitcast[Value]().store(0, left)
node.r.bitcast[Value]().store(0, right)
Value._backward(left)
Value._backward(right)
# Neg
fn __neg__(self) -> Value:
return self * -1
# Sub
fn __sub__(self, other: Value) -> Value:
return self + (-other)
fn __sub__(self, other: Float64) -> Value:
return self + (-Value(other))
# Tanh
fn tanh(self) -> Value:
return self.new(tanh(self.data), "tanh")
fn backward_tanh(inout node: Value):
var left = node.l.bitcast[Value]().load(0)
left.grad += (1 - tanh(left.data)**2) * node.grad
node.l.bitcast[Value]().store(0, left)
Value._backward(left)
# Value alloc
fn new(self, data: Float64, op: StringLiteral) -> Value:
let l = Pointer[Value].alloc(1)
l.store(self)
return Value(l.bitcast[Int](), Pointer[Int].get_null(), data, 0.0, op, random_float64())
fn new(self, data: Float64, right: Value, op: StringLiteral) -> Value:
let l = Pointer[Value].alloc(1)
l.store(self)
let r = Pointer[Value].alloc(1)
r.store(right)
return Value(l.bitcast[Int](), r.bitcast[Int](), data, 0.0, op, random_float64())
# Autograd
@staticmethod
fn _backward(inout node: Value):
if node.op == "":
return
if node.op == "+":
Value.backward_add(node)
if node.op == "*":
Value.backward_mul(node)
if node.op == "tanh":
Value.backward_tanh(node)
fn backward(inout self):
# Topological sort
var topo : DynamicVector[Value] = DynamicVector[Value]()
var visited : DynamicVector[Value] = DynamicVector[Value]()
self.build_topo(self, visited, topo)
self.grad = 1.0
var reversed = Value.reverse(topo)
for i in range(len(reversed)):
self._backward(reversed[i])
fn build_topo(inout self, v : Value, inout visited : DynamicVector[Value], inout topo : DynamicVector[Value]):
var is_in_visited = False
let size = len(visited)
for i in range(size):
if v == visited[i]:
is_in_visited = True
if not is_in_visited:
visited.push_back(v)
# It's pushing back, so visit in reverse, first right then left
if v.r.bitcast[Int]() != Pointer[Int].get_null():
self.build_topo(v.r.bitcast[Value]().load(0), visited, topo)
if v.l.bitcast[Int]() != Pointer[Int].get_null():
self.build_topo(v.l.bitcast[Value]().load(0), visited, topo)
topo.push_back(v)
@staticmethod
fn reverse(vec : DynamicVector[Value]) -> DynamicVector[Value]:
var reversed : DynamicVector[Value] = DynamicVector[Value](len(vec))
for i in range(len(vec)-1, -1, -1):
reversed.push_back(vec[i])
return reversed
fn show(self, label : StringLiteral):
print("<Value", label, "::", "data:", self.data, "grad:", self.grad, "op:", self.op, ">")
@staticmethod
fn print_backward(node: Value):
if node.l and node.r:
let left = node.l.bitcast[Value]().load(0)
let right = node.r.bitcast[Value]().load(0)
print(left.data, "(", left.grad, ")", node.op, right.data, "(", right.grad, ")", "=", node.data)
elif node.l:
let left = node.l.bitcast[Value]().load(0)
print(left.data, "(", left.grad, ")", node.op, "=", node.data)
if node.l:
let left = node.l.bitcast[Value]().load(0)
Value.print_backward(left)
if node.r:
let right = node.r.bitcast[Value]().load(0)
Value.print_backward(right)
var a = Value(1)
var b = Value(2)
var c = Value(7)
var s1 = a + b
var s2 = s1 * c
s2.backward()
Value.print_backward(s2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment