Skip to content

Instantly share code, notes, and snippets.

@QuantumFractal
Created September 23, 2023 22:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save QuantumFractal/7056bd0b4720c138d719a596a88617cc to your computer and use it in GitHub Desktop.
Save QuantumFractal/7056bd0b4720c138d719a596a88617cc to your computer and use it in GitHub Desktop.
Micrograd.jl Forward Prop
### A Pluto.jl notebook ###
# v0.19.26
using Markdown
using InteractiveUtils
# ╔═╡ 9d7b2fcb-7155-4636-97f6-4401b6f561af
begin
import Pkg
Pkg.add(path="https://github.com/abelsiqueira/GraphViz.jl#38-add-engine-attribute");
using GraphViz
end
# ╔═╡ 8d04e8a2-5a40-11ee-1678-757abfdac18b
begin
Base.@kwdef struct Value
scalar::Real = 0.0
grad::Real = 0.0
children::Tuple = ()
operation::Union{Nothing,Symbol} = nothing
label::Union{Nothing,Symbol} = nothing
end
end
# ╔═╡ 5d70a89c-7ff3-401d-99cf-85a342d4a0cd
begin
import Base
Base.:+(a::Value, b::Value) = Value(scalar=a.scalar + b.scalar, children=(a,b), operation=:+)
Base.:-(a::Value, b::Value) = Value(scalar=a.scalar - b.scalar, children=(a,b), operation=:-)
Base.:*(a::Value, b::Value) = Value(scalar=a.scalar * b.scalar, children=(a,b), operation=:*)
end
# ╔═╡ db7e8ffc-cb6a-4ec3-bcaa-d580fb6c6bfa
begin
namedEx(v::Value, name::Symbol) = Value(scalar=v.scalar, children=v.children, operation=v.operation, label=name)
namedEx(v::Value, grad::Real, name::Symbol) = Value(scalar=v.scalar, grad=grad, children=v.children, operation=v.operation, label=name)
end
# ╔═╡ 81d6ade0-65d4-4095-8830-6b1b3e149885
function tanh(v::Value)
x = v.scalar
t = (exp(2x) - 1) / (exp(2x) + 1)
return Value(scalar=t, children=(v, ), operation=:tanh, label=:nothing)
end
# ╔═╡ 0be221b7-934e-4071-a926-04fee18ad5ad
Value(scalar=2.0) + Value(scalar=4.0)
# ╔═╡ 086bace9-76e9-472f-ade8-1d707e309ef5
begin
x1 = Value(scalar=2.0, label=:x1)
x2 = Value(scalar=0.0, label=:x2)
w1 = Value(scalar=-3.0, label=:w1)
w2 = Value(scalar=1.0, label=:w2)
b = Value(scalar=8, label=:b)
x1w1 = namedEx(x1 * w1, :x1w1)
x2w2 = namedEx(x2 * w2, :x2w2)
x1w1x2w2 = namedEx(x1w1 + x2w2, :x1w1x2w2)
n = namedEx(x1w1x2w2 + b, :n)
o = namedEx(tanh(n), :output)
end
# ╔═╡ bb0f7b4b-4852-4f7d-b388-b61b01cc3217
function trace(root)
nodes, edges = Set(), Set()
function build(v)
if v ∉ nodes
push!(nodes, v)
for child in v.children
push!(edges, (child, v))
build(child)
end
end
end
build(root)
return nodes, edges
end
# ╔═╡ 7e2d392f-d373-4e88-b81d-72ec2a8b4392
function build_dot(root)
g = """digraph {
rankdir="LR"
"""
nodes, edges = trace(root)
for n in nodes
uid = hash(n)
g = g * """\n"$uid" [label = "$(n.label) | data $(n.scalar) | grad $(n.grad)" shape="record"]"""
if n.operation != nothing
opname = "$uid$(n.operation)"
g = g * """\n"$opname" [label = "$(n.operation)"]"""
g = g * """\n"$opname" -> "$uid" """
end
end
for (n1,n2) in edges
g = g * """\n"$(hash(n1))" -> "$(string(hash(n2)) * string(n2.operation))" """
end
g = g * "\n}"
f = tempname()*".dot"
write(f, g)
print(g)
return GraphViz.load(open(f))
end
# ╔═╡ 4edeed1c-8c51-49d1-ab14-89d94fb57765
build_dot(o)
# ╔═╡ 6f076a8c-8c62-42cc-9de5-e7bcac4085cc
exp(1)
# ╔═╡ Cell order:
# ╠═8d04e8a2-5a40-11ee-1678-757abfdac18b
# ╠═5d70a89c-7ff3-401d-99cf-85a342d4a0cd
# ╠═db7e8ffc-cb6a-4ec3-bcaa-d580fb6c6bfa
# ╠═81d6ade0-65d4-4095-8830-6b1b3e149885
# ╠═0be221b7-934e-4071-a926-04fee18ad5ad
# ╠═086bace9-76e9-472f-ade8-1d707e309ef5
# ╟─9d7b2fcb-7155-4636-97f6-4401b6f561af
# ╟─bb0f7b4b-4852-4f7d-b388-b61b01cc3217
# ╟─7e2d392f-d373-4e88-b81d-72ec2a8b4392
# ╠═4edeed1c-8c51-49d1-ab14-89d94fb57765
# ╠═6f076a8c-8c62-42cc-9de5-e7bcac4085cc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment