Skip to content

Instantly share code, notes, and snippets.

@younes-io
Created May 28, 2024 22:57
Show Gist options
  • Save younes-io/efb4fc3c5cd1c0e054fdd2f4cf8f5f38 to your computer and use it in GitHub Desktop.
Save younes-io/efb4fc3c5cd1c0e054fdd2f4cf8f5f38 to your computer and use it in GitHub Desktop.
### A Pluto.jl notebook ###
# v0.19.42
using Markdown
using InteractiveUtils
# ╔═╡ 4e348f70-7aa2-4d7f-b499-21fd7faaf8be
md"# I - Backpropagation done manually"
# ╔═╡ b45997c5-0c15-4c84-b6b3-3cc4518c968d
md"Define a `Value` struct for automatic differentiation with overloaded arithmetic operations and backward propagation"
# ╔═╡ 653411a3-99cc-468b-aa26-06b04318fea5
begin
mutable struct Value
data::Float64
children::Set{Value}
grad::Float64
_backward::Function
Value(data, children=Set{Value}(), grad=0.0, _backward= () -> nothing) = new(data, children, grad, _backward)
end
function Base.show(io::IO, ::MIME"text/plain", v::Value)
print(io, "Value(data=$(v.data), grad=$(v.grad))")
end
function Base.:+(x::Union{Value, Number}, y::Union{Value, Number})
x = x isa Value ? x : Value(x)
y = y isa Value ? y : Value(y)
out = Value(x.data + y.data, Set([x, y]))
out._backward = () -> begin
x.grad += 1. * out.grad
y.grad += 1. * out.grad
end
out
end
function Base.:*(x::Union{Value, Number}, y::Union{Value, Number})
x = x isa Value ? x : Value(x)
y = y isa Value ? y : Value(y)
out = Value(x.data * y.data, Set([x, y]))
out._backward = () -> begin
x.grad += y.data * out.grad
y.grad += x.data * out.grad
end
out
end
function Base.:-(x::Value)
x * Value(-1)
end
function Base.:-(x::Union{Value, Number}, y::Union{Value, Number})
x + -y
end
function Base.:^(x::Union{Value, Number}, y::Number)
x = x isa Value ? x : Value(x)
out = Value(x.data ^ y, Set([x,]))
out._backward = () -> begin
x.grad += y * (x.data ^ (y - 1)) * out.grad
end
out
end
function Base.inv(x::Value)
1 * Value(1 / x.data)
end
function Base.:/(x::Union{Value, Number}, y::Union{Value, Number})
x * y^(-1)
end
function Base.exp(x::Value)
out = Value(exp(x.data), Set([x, ]))
out._backward = () -> begin
# x.grad += out.data * out.grad
x.grad += x.data * out.grad
end
out
end
function Base.tanh(x::Value)
d = x.data
val = (exp(2 * d) - 1) / (exp(2 * d) + 1)
out = Value(val, Set([x, ]))
out._backward = () -> begin
x.grad += (1 - val^2) * out.grad
end
out
end
end
# ╔═╡ 136d252e-6396-4292-bc4b-879cc0d6042c
let
using Flux
# Define the specific values for inputs, weights, and biases
x1 = 2.0
x2 = 0.0
w1 = -3.0
w2 = 1.0
b = 6.8813735870195432
# Forward pass function
function forward_pass(x1, x2, w1, w2, b)
n = x1 * w1 + x2 * w2 + b
return tanh(n) # output
end
# Custom loss function
function custom_loss(x1, x2, w1, w2, b)
o = forward_pass(x1, x2, w1, w2, b)
return o # The loss is the output of the neuron : dL/do = do/do = 1
end
function backward_propagation(o)
# Perform backpropagation
grads = gradient((x1, x2, w1, w2, b) -> custom_loss(x1, x2, w1, w2, b), x1, x2, w1, w2, b)
x1_g, x2_g, w1_g, w2_g, b_g = grads
@show x1_g
@show w1_g
@show x2_g
@show w2_g
@show b_g
println("-" ^ 20)
@show o
end
o = forward_pass(x1, x2, w1, w2, b)
backward_propagation(o)
end
# ╔═╡ 05d6d547-d7d7-4dca-9463-5c858e1a0503
md"Building an ordered topology of nodes and propagating gradients in reverse"
# ╔═╡ cee32873-9a0a-4a44-8d54-3430c24a586b
begin
function backward(node::Value)
topo, visited = [], Set([])
function build_topo(n::Value)
if !(n in visited)
push!(visited, n)
for child in n.children
build_topo(child)
end
push!(topo, n)
end
end
build_topo(node)
node.grad = 1.
for n in reverse(topo)
n._backward()
# @debug n.grad
end
end
end
# ╔═╡ 53cfc7cc-6429-4eb7-884c-a63c8ed3b56a
md"##### Test `Value` operations"
# ╔═╡ 395c478b-bbc6-43ba-a1cb-c97e5ff6d386
m = Value(9); -m
# ╔═╡ 85d9aff0-5fb9-42a7-9aa0-6a58ce9b98b5
exp(m) / 20^2
# ╔═╡ e1b27db5-8c6e-467f-b2eb-7faa5f2d169b
1 / Value(-8)
# ╔═╡ a718eaae-c08c-4f85-a07b-fd92cb082908
md"**Forward pass** initialize inputs, weights, bias and compute the output"
# ╔═╡ e9e1bc57-776a-4da0-a8c0-313c0dbc28fa
function init_vars()
# inputs x1,x2
x1 = Value(2.0)
x2 = Value(0.0)
# weights w1,w2
w1 = Value(-3.0)
w2 = Value(1.0)
# bias of the neuron
b = Value(6.8813735870195432)
# x1*w1 + x2*w2 + b
x1w1 = x1*w1
x2w2 = x2*w2;
n = x1w1 + x2w2 + b;
x1, x2, w1, w2, b, x1w1, x2w2, n
end
# ╔═╡ d719a301-4c67-44a6-8930-d427dbf3e1e2
md"Backward pass gradients"
# ╔═╡ 05176d6f-9238-4acf-9a6f-12f0a9f014da
function backward_display(o::Value, vars::Array{Value})
backward(o)
x1, x2, w1, w2, b, x1w1, x2w2, n = vars
@show x1.data, x1.grad
@show w1.data, w1.grad
@show x2.data, x2.grad
@show w2.data, w2.grad
@show b.data, b.grad
@show x1w1.data, x1w1.grad
@show x2w2.data, x2w2.grad
@show n.data, n.grad
println("-"^20)
@show o.data, o.grad
end
# ╔═╡ 0d900393-b207-4d71-af65-058780366d2a
md"#### Compute with the activation function $\tanh$"
# ╔═╡ 49a5088c-e9e7-4ab4-9fbb-67f99357ea40
let
arr = init_vars()
n = last(arr) # n represents the neuron
o = tanh(n)
backward_display(o, collect(arr))
end
# ╔═╡ 7afcadc6-73b2-40cf-a36d-44d3a536a717
md"""#### Compute with the exponential definition of $tanh$
The hyperbolic tangent function $tanh(x)$ is defined using exponential functions as follows:
$\tanh(x) = \frac{e^{2x} - 1}{e^{2x} + 1}$
cf. [Wikipedia](https://en.wikipedia.org/wiki/Hyperbolic_functions#Exponential_definitions)
"""
# ╔═╡ b3dd92fe-d7c1-435c-88b7-1572b0e2e546
let
arr = init_vars()
n = last(arr) # n represents the neuron
o = (exp(2n) - 1) / (exp(2n) + 1)
backward_display(o, collect(arr))
end
# ╔═╡ 91b69722-e00f-4682-aa64-2dee5155bb0c
md"---"
# ╔═╡ 6f48bd47-05f2-435a-bd9a-40888c898877
md"# Backpropagation done with `Flux.jl`"
# ╔═╡ c71f6482-28c1-4009-b379-2f40298eb3df
md"Performs a forward pass through a single neuron and computes gradients using backpropagation with the `Flux.gradient()` function"
# ╔═╡ 00000000-0000-0000-0000-000000000001
PLUTO_PROJECT_TOML_CONTENTS = """
[deps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
[compat]
Flux = "~0.14.15"
"""
# ╔═╡ 00000000-0000-0000-0000-000000000002
PLUTO_MANIFEST_TOML_CONTENTS = """
# This file is machine-generated - editing it directly is not advised
julia_version = "1.10.3"
manifest_format = "2.0"
project_hash = "35f3301cf138e85a08e272baae73599ccd6a38b8"
[[deps.AbstractFFTs]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "1.5.0"
weakdeps = ["ChainRulesCore", "Test"]
[deps.AbstractFFTs.extensions]
AbstractFFTsChainRulesCoreExt = "ChainRulesCore"
AbstractFFTsTestExt = "Test"
[[deps.Adapt]]
deps = ["LinearAlgebra", "Requires"]
git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "4.0.4"
weakdeps = ["StaticArrays"]
[deps.Adapt.extensions]
AdaptStaticArraysExt = "StaticArrays"
[[deps.ArgCheck]]
git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4"
uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197"
version = "2.3.0"
[[deps.ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
version = "1.1.1"
[[deps.Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
[[deps.Atomix]]
deps = ["UnsafeAtomics"]
git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be"
uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
version = "0.1.0"
[[deps.BangBang]]
deps = ["Compat", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables"]
git-tree-sha1 = "7aa7ad1682f3d5754e3491bb59b8103cae28e3a3"
uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
version = "0.3.40"
[deps.BangBang.extensions]
BangBangChainRulesCoreExt = "ChainRulesCore"
BangBangDataFramesExt = "DataFrames"
BangBangStaticArraysExt = "StaticArrays"
BangBangStructArraysExt = "StructArrays"
BangBangTypedTablesExt = "TypedTables"
[deps.BangBang.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
[[deps.Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
[[deps.Baselet]]
git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e"
uuid = "9718e550-a3fa-408a-8086-8db961cd8217"
version = "0.1.1"
[[deps.CEnum]]
git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc"
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
version = "0.5.0"
[[deps.ChainRules]]
deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"]
git-tree-sha1 = "291821c1251486504f6bae435227907d734e94d2"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.66.0"
[[deps.ChainRulesCore]]
deps = ["Compat", "LinearAlgebra"]
git-tree-sha1 = "575cd02e080939a33b6df6c5853d14924c08e35b"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.23.0"
weakdeps = ["SparseArrays"]
[deps.ChainRulesCore.extensions]
ChainRulesCoreSparseArraysExt = "SparseArrays"
[[deps.CommonSubexpressions]]
deps = ["MacroTools", "Test"]
git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7"
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
version = "0.3.0"
[[deps.Compat]]
deps = ["TOML", "UUIDs"]
git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "4.15.0"
weakdeps = ["Dates", "LinearAlgebra"]
[deps.Compat.extensions]
CompatLinearAlgebraExt = "LinearAlgebra"
[[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.1.1+0"
[[deps.CompositionsBase]]
git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad"
uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b"
version = "0.1.2"
[deps.CompositionsBase.extensions]
CompositionsBaseInverseFunctionsExt = "InverseFunctions"
[deps.CompositionsBase.weakdeps]
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
[[deps.ConstructionBase]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "260fd2400ed2dab602a7c15cf10c1933c59930a2"
uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
version = "1.5.5"
[deps.ConstructionBase.extensions]
ConstructionBaseIntervalSetsExt = "IntervalSets"
ConstructionBaseStaticArraysExt = "StaticArrays"
[deps.ConstructionBase.weakdeps]
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
[[deps.ContextVariablesX]]
deps = ["Compat", "Logging", "UUIDs"]
git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc"
uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5"
version = "0.1.3"
[[deps.DataAPI]]
git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.16.0"
[[deps.DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.18.20"
[[deps.DataValueInterfaces]]
git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6"
uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464"
version = "1.0.0"
[[deps.Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
[[deps.DefineSingletons]]
git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c"
uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52"
version = "0.1.2"
[[deps.DelimitedFiles]]
deps = ["Mmap"]
git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae"
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
version = "1.9.1"
[[deps.DiffResults]]
deps = ["StaticArraysCore"]
git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "1.1.0"
[[deps.DiffRules]]
deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"]
git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "1.15.1"
[[deps.Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[deps.DocStringExtensions]]
deps = ["LibGit2"]
git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.9.3"
[[deps.Downloads]]
deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
version = "1.6.0"
[[deps.FLoops]]
deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"]
git-tree-sha1 = "ffb97765602e3cbe59a0589d237bf07f245a8576"
uuid = "cc61a311-1640-44b5-9fba-1b764f453329"
version = "0.2.1"
[[deps.FLoopsBase]]
deps = ["ContextVariablesX"]
git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7"
uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6"
version = "0.1.1"
[[deps.FileWatching]]
uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"
[[deps.FillArrays]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "0653c0a2396a6da5bc4766c43041ef5fd3efbe57"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "1.11.0"
[deps.FillArrays.extensions]
FillArraysPDMatsExt = "PDMats"
FillArraysSparseArraysExt = "SparseArrays"
FillArraysStatisticsExt = "Statistics"
[deps.FillArrays.weakdeps]
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[[deps.Flux]]
deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"]
git-tree-sha1 = "a5475163b611812d073171583982c42ea48d22b0"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.14.15"
[deps.Flux.extensions]
FluxAMDGPUExt = "AMDGPU"
FluxCUDAExt = "CUDA"
FluxCUDAcuDNNExt = ["CUDA", "cuDNN"]
FluxMetalExt = "Metal"
[deps.Flux.weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
[[deps.ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"]
git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.36"
weakdeps = ["StaticArrays"]
[deps.ForwardDiff.extensions]
ForwardDiffStaticArraysExt = "StaticArrays"
[[deps.Functors]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "d3e63d9fa13f8eaa2f06f64949e2afc593ff52c2"
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
version = "0.4.10"
[[deps.Future]]
deps = ["Random"]
uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"
[[deps.GPUArrays]]
deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"]
git-tree-sha1 = "38cb19b8a3e600e509dc36a6396ac74266d108c1"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "10.1.1"
[[deps.GPUArraysCore]]
deps = ["Adapt"]
git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950"
uuid = "46192b85-c4d5-4398-a991-12ede77f4527"
version = "0.1.6"
[[deps.IRTools]]
deps = ["InteractiveUtils", "MacroTools"]
git-tree-sha1 = "950c3717af761bc3ff906c2e8e52bd83390b6ec2"
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
version = "0.4.14"
[[deps.InitialValues]]
git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3"
uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c"
version = "0.3.1"
[[deps.InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[deps.IrrationalConstants]]
git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2"
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
version = "0.2.2"
[[deps.IteratorInterfaceExtensions]]
git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
uuid = "82899510-4779-5014-852e-03e436cf321d"
version = "1.0.0"
[[deps.JLLWrappers]]
deps = ["Artifacts", "Preferences"]
git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
version = "1.5.0"
[[deps.JuliaVariables]]
deps = ["MLStyle", "NameResolution"]
git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70"
uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec"
version = "0.2.4"
[[deps.KernelAbstractions]]
deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"]
git-tree-sha1 = "db02395e4c374030c53dc28f3c1d33dec35f7272"
uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
version = "0.9.19"
[deps.KernelAbstractions.extensions]
EnzymeExt = "EnzymeCore"
[deps.KernelAbstractions.weakdeps]
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
[[deps.LLVM]]
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"]
git-tree-sha1 = "065c36f95709dd4a676dc6839a35d6fa6f192f24"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "7.1.0"
[deps.LLVM.extensions]
BFloat16sExt = "BFloat16s"
[deps.LLVM.weakdeps]
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
[[deps.LLVMExtra_jll]]
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
git-tree-sha1 = "88b916503aac4fb7f701bb625cd84ca5dd1677bc"
uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
version = "0.0.29+0"
[[deps.LazyArtifacts]]
deps = ["Artifacts", "Pkg"]
uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
[[deps.LibCURL]]
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
version = "0.6.4"
[[deps.LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
version = "8.4.0+0"
[[deps.LibGit2]]
deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
[[deps.LibGit2_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"]
uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5"
version = "1.6.4+0"
[[deps.LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
version = "1.11.0+1"
[[deps.Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
[[deps.LinearAlgebra]]
deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[deps.LogExpFunctions]]
deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.3.27"
[deps.LogExpFunctions.extensions]
LogExpFunctionsChainRulesCoreExt = "ChainRulesCore"
LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables"
LogExpFunctionsInverseFunctionsExt = "InverseFunctions"
[deps.LogExpFunctions.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
[[deps.MLStyle]]
git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8"
uuid = "d8e11817-5142-5d16-987a-aa16d5891078"
version = "0.4.17"
[[deps.MLUtils]]
deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"]
git-tree-sha1 = "b45738c2e3d0d402dffa32b2c1654759a2ac35a4"
uuid = "f1d291b0-491e-4a28-83b9-f70985020b54"
version = "0.4.4"
[[deps.MacroTools]]
deps = ["Markdown", "Random"]
git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.5.13"
[[deps.Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
[[deps.MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
version = "2.28.2+1"
[[deps.MicroCollections]]
deps = ["BangBang", "InitialValues", "Setfield"]
git-tree-sha1 = "629afd7d10dbc6935ec59b32daeb33bc4460a42e"
uuid = "128add7d-3638-4c79-886c-908ea0c25c34"
version = "0.1.4"
[[deps.Missings]]
deps = ["DataAPI"]
git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "1.2.0"
[[deps.Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[deps.MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
version = "2023.1.10"
[[deps.NNlib]]
deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"]
git-tree-sha1 = "3d4617f943afe6410206a5294a95948c8d1b35bd"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.9.17"
[deps.NNlib.extensions]
NNlibAMDGPUExt = "AMDGPU"
NNlibCUDACUDNNExt = ["CUDA", "cuDNN"]
NNlibCUDAExt = "CUDA"
NNlibEnzymeCoreExt = "EnzymeCore"
[deps.NNlib.weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
[[deps.NaNMath]]
deps = ["OpenLibm_jll"]
git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "1.0.2"
[[deps.NameResolution]]
deps = ["PrettyPrint"]
git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e"
uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391"
version = "0.1.5"
[[deps.NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
version = "1.2.0"
[[deps.OneHotArrays]]
deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"]
git-tree-sha1 = "963a3f28a2e65bb87a68033ea4a616002406037d"
uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
version = "0.2.5"
[[deps.OpenBLAS_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
version = "0.3.23+4"
[[deps.OpenLibm_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "05823500-19ac-5b8b-9628-191a04bc5112"
version = "0.8.1+2"
[[deps.OpenSpecFun_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1"
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.5+0"
[[deps.Optimisers]]
deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "6572fe0c5b74431aaeb0b18a4aa5ef03c84678be"
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
version = "0.3.3"
[[deps.OrderedCollections]]
git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.6.3"
[[deps.Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
version = "1.10.0"
[[deps.PrecompileTools]]
deps = ["Preferences"]
git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f"
uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
version = "1.2.1"
[[deps.Preferences]]
deps = ["TOML"]
git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6"
uuid = "21216c6a-2e73-6563-6e65-726566657250"
version = "1.4.3"
[[deps.PrettyPrint]]
git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4"
uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98"
version = "0.2.0"
[[deps.Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
[[deps.ProgressLogging]]
deps = ["Logging", "SHA", "UUIDs"]
git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539"
uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
version = "0.1.4"
[[deps.REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
[[deps.Random]]
deps = ["SHA"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
[[deps.RealDot]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9"
uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
version = "0.1.0"
[[deps.Reexport]]
git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "1.2.2"
[[deps.Requires]]
deps = ["UUIDs"]
git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.3.0"
[[deps.SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
version = "0.7.0"
[[deps.Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
[[deps.Setfield]]
deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"]
git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac"
uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46"
version = "1.1.1"
[[deps.ShowCases]]
git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5"
uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3"
version = "0.1.0"
[[deps.SimpleTraits]]
deps = ["InteractiveUtils", "MacroTools"]
git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231"
uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
version = "0.9.4"
[[deps.Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
[[deps.SortingAlgorithms]]
deps = ["DataStructures"]
git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085"
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
version = "1.2.1"
[[deps.SparseArrays]]
deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
version = "1.10.0"
[[deps.SparseInverseSubset]]
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
git-tree-sha1 = "52962839426b75b3021296f7df242e40ecfc0852"
uuid = "dc90abb0-5640-4711-901d-7e5b23a2fada"
version = "0.1.2"
[[deps.SpecialFunctions]]
deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "2.4.0"
weakdeps = ["ChainRulesCore"]
[deps.SpecialFunctions.extensions]
SpecialFunctionsChainRulesCoreExt = "ChainRulesCore"
[[deps.SplittablesBase]]
deps = ["Setfield", "Test"]
git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5"
uuid = "171d559e-b47b-412a-8079-5efa626c420e"
version = "0.1.15"
[[deps.StaticArrays]]
deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"]
git-tree-sha1 = "9ae599cd7529cfce7fea36cf00a62cfc56f0f37c"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.9.4"
weakdeps = ["ChainRulesCore", "Statistics"]
[deps.StaticArrays.extensions]
StaticArraysChainRulesCoreExt = "ChainRulesCore"
StaticArraysStatisticsExt = "Statistics"
[[deps.StaticArraysCore]]
git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d"
uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
version = "1.4.2"
[[deps.Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
version = "1.10.0"
[[deps.StatsAPI]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed"
uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
version = "1.7.0"
[[deps.StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.34.3"
[[deps.StructArrays]]
deps = ["ConstructionBase", "DataAPI", "Tables"]
git-tree-sha1 = "f4dc295e983502292c4c3f951dbb4e985e35b3be"
uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
version = "0.6.18"
weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"]
[deps.StructArrays.extensions]
StructArraysAdaptExt = "Adapt"
StructArraysGPUArraysCoreExt = "GPUArraysCore"
StructArraysSparseArraysExt = "SparseArrays"
StructArraysStaticArraysExt = "StaticArrays"
[[deps.SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
[[deps.SuiteSparse_jll]]
deps = ["Artifacts", "Libdl", "libblastrampoline_jll"]
uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c"
version = "7.2.1+1"
[[deps.TOML]]
deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
version = "1.0.3"
[[deps.TableTraits]]
deps = ["IteratorInterfaceExtensions"]
git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39"
uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c"
version = "1.0.1"
[[deps.Tables]]
deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits"]
git-tree-sha1 = "cb76cf677714c095e535e3501ac7954732aeea2d"
uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
version = "1.11.1"
[[deps.Tar]]
deps = ["ArgTools", "SHA"]
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
version = "1.10.0"
[[deps.Test]]
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[deps.Transducers]]
deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"]
git-tree-sha1 = "3064e780dbb8a9296ebb3af8f440f787bb5332af"
uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999"
version = "0.4.80"
[deps.Transducers.extensions]
TransducersBlockArraysExt = "BlockArrays"
TransducersDataFramesExt = "DataFrames"
TransducersLazyArraysExt = "LazyArrays"
TransducersOnlineStatsBaseExt = "OnlineStatsBase"
TransducersReferenceablesExt = "Referenceables"
[deps.Transducers.weakdeps]
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338"
Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e"
[[deps.UUIDs]]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[deps.Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[deps.UnsafeAtomics]]
git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278"
uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f"
version = "0.2.1"
[[deps.UnsafeAtomicsLLVM]]
deps = ["LLVM", "UnsafeAtomics"]
git-tree-sha1 = "d9f5962fecd5ccece07db1ff006fb0b5271bdfdd"
uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
version = "0.1.4"
[[deps.Zlib_jll]]
deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
version = "1.2.13+1"
[[deps.Zygote]]
deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"]
git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.70"
[deps.Zygote.extensions]
ZygoteColorsExt = "Colors"
ZygoteDistancesExt = "Distances"
ZygoteTrackerExt = "Tracker"
[deps.Zygote.weakdeps]
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
[[deps.ZygoteRules]]
deps = ["ChainRulesCore", "MacroTools"]
git-tree-sha1 = "27798139afc0a2afa7b1824c206d5e87ea587a00"
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
version = "0.2.5"
[[deps.libblastrampoline_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
version = "5.8.0+1"
[[deps.nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
version = "1.52.0+1"
[[deps.p7zip_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
version = "17.4.0+2"
"""
# ╔═╡ Cell order:
# ╟─4e348f70-7aa2-4d7f-b499-21fd7faaf8be
# ╟─b45997c5-0c15-4c84-b6b3-3cc4518c968d
# ╠═653411a3-99cc-468b-aa26-06b04318fea5
# ╟─05d6d547-d7d7-4dca-9463-5c858e1a0503
# ╠═cee32873-9a0a-4a44-8d54-3430c24a586b
# ╟─53cfc7cc-6429-4eb7-884c-a63c8ed3b56a
# ╠═395c478b-bbc6-43ba-a1cb-c97e5ff6d386
# ╠═85d9aff0-5fb9-42a7-9aa0-6a58ce9b98b5
# ╠═e1b27db5-8c6e-467f-b2eb-7faa5f2d169b
# ╟─a718eaae-c08c-4f85-a07b-fd92cb082908
# ╠═e9e1bc57-776a-4da0-a8c0-313c0dbc28fa
# ╟─d719a301-4c67-44a6-8930-d427dbf3e1e2
# ╠═05176d6f-9238-4acf-9a6f-12f0a9f014da
# ╟─0d900393-b207-4d71-af65-058780366d2a
# ╠═49a5088c-e9e7-4ab4-9fbb-67f99357ea40
# ╟─7afcadc6-73b2-40cf-a36d-44d3a536a717
# ╠═b3dd92fe-d7c1-435c-88b7-1572b0e2e546
# ╟─91b69722-e00f-4682-aa64-2dee5155bb0c
# ╟─6f48bd47-05f2-435a-bd9a-40888c898877
# ╟─c71f6482-28c1-4009-b379-2f40298eb3df
# ╠═136d252e-6396-4292-bc4b-879cc0d6042c
# ╟─00000000-0000-0000-0000-000000000001
# ╟─00000000-0000-0000-0000-000000000002
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment