Skip to content

Instantly share code, notes, and snippets.

@mcabbott
Last active June 27, 2021 21:04
Show Gist options
  • Save mcabbott/c6cdc73d45ed3e35c3fd8966863993f8 to your computer and use it in GitHub Desktop.
Save mcabbott/c6cdc73d45ed3e35c3fd8966863993f8 to your computer and use it in GitHub Desktop.
julia> using Zygote, ForwardDiff, LinearAlgebra
julia> mutable struct Mu{T}; x::T; end # Contains mutable state
julia> (m::Mu)(y) = (m.x = m.x * y).^2;
# Zygote's map reverses in the gradient, sometimes, while its broadcast does not:
# Ground truth is ForwardDiff, for which these agree.
julia> gradient((x,ys) -> begin m = Mu(x); sum(sum, map(m, ys)) end, [0.1 0.2; 0.3 0.4], [[i i/2; i/3 i/4] for i in 1:3])
([32.98385416666667 11.838802083333334; 85.16987847222222 30.568446180555554], [[24.4215625 8.855572916666668; 34.42333333333333 12.48277777777778], [14.897847222222223 5.66920138888889; 8.639479166666668 3.2876562500000004], [8.69388888888889 4.678750000000001; 4.699027777777779 2.528854166666667]])
julia> gradient((x,ys) -> begin m = Mu(x); sum(sum, broadcast(m, ys)) end, [0.1 0.2; 0.3 0.4], [[i i/2; i/3 i/4] for i in 1:3]) # WRONG answer
([21.650000000000006 7.877777777777779; 55.87916666666668 20.33263888888889], [[0.29333333333333333 0.16999999999999998; 0.41333333333333333 0.24000000000000002], [1.5850000000000002 0.7623148148148148; 0.9191666666666667 0.44208333333333333], [18.0787962962963 8.11027777777778; 9.771550925925927 4.383587962962964]])
julia> gradient((x,ys) -> begin m = Mu(x); sum(sum(m, ys)) end, [0.1 0.2; 0.3 0.4], [[i i/2; i/3 i/4] for i in 1:3])
([21.650000000000006 7.877777777777779; 55.87916666666668 20.33263888888889], [[0.29333333333333333 0.16999999999999998; 0.41333333333333333 0.24000000000000002], [1.5850000000000002 0.7623148148148148; 0.9191666666666667 0.44208333333333333], [18.0787962962963 8.11027777777778; 9.771550925925927 4.383587962962964]])
julia> ForwardDiff.gradient(x -> begin ys = [[i i/2; i/3 i/4] for i in 1:3]; m = Mu(x); sum(sum, map(m, ys)) end, [0.1 0.2; 0.3 0.4]) |> tuple # ground truth, matches map
([32.983854166666674 11.838802083333334; 85.16987847222224 30.56844618055556],)
julia> ForwardDiff.gradient(x -> begin ys = [[i i/2; i/3 i/4] for i in 1:3]; m = Mu(x); sum(sum, broadcast(m, ys)) end, [0.1 0.2; 0.3 0.4]) |> tuple
([32.983854166666674 11.838802083333334; 85.16987847222224 30.56844618055556],)
# The absolute iteration order matters, not just back relative to forward:
julia> rev_map(f,x) = reverse(map(f,reverse(x)));
julia> ForwardDiff.gradient(x -> begin ys = [[i i/2; i/3 i/4] for i in 1:3]; m = Mu(x); sum(sum, rev_map(m, ys)) end, [0.1 0.2; 0.3 0.4]) |> tuple
([55.717187499999994 20.049913194444443; 143.84765625 51.751779513888884],)
julia> gradient((x,ys) -> begin m = Mu(x); sum(sum, rev_map(m, ys)) end, [0.1 0.2; 0.3 0.4], [[i i/2; i/3 i/4] for i in 1:3])
([55.71718750000001 20.049913194444443; 143.84765624999997 51.751779513888884], [[26.08166666666666 14.036249999999997; 14.09708333333333 7.5865624999999985], [23.12895833333333 10.118090277777776; 13.412812499999998 5.86765625], [13.664965277777776 5.154079861111111; 19.26111111111111 7.266111111111111]])
# Zygote's map actually only reverses vectors, it gets the wrong answer for matrices:
julia> gradient((x,ys) -> begin m = Mu(x); sum(sum, map(m, ys)) end, [0.1 0.2; 0.3 0.4], [[i j/2; i/3 j/4] for i in 1:2, j in 1:2])[1] |> tuple # WRONG
([18.82037037037037 6.842592592592592; 48.58865740740741 17.665625],)
julia> gradient((x,ys) -> begin m = Mu(x); sum(sum, broadcast(m, ys)) end, [0.1 0.2; 0.3 0.4], [[i j/2; i/3 j/4] for i in 1:2, j in 1:2])[1] |> tuple # WRONG
([18.82037037037037 6.842592592592592; 48.58865740740741 17.665625],)
julia> ForwardDiff.gradient(x -> begin ys = [[i j/2; i/3 j/4] for i in 1:2, j in 1:2]; m = Mu(x); sum(sum, rev_map(m, ys)) end, [0.1 0.2; 0.3 0.4]) |> tuple
([41.67507233796297 15.022499517746914; 107.57283227237654 38.771191406250004],)
# ... and the wrong answer for tuples:
julia> gradient((x,ys) -> begin m = Mu(x); sum(sum, map(m, ys)) end, [0.1 0.2; 0.3 0.4], ntuple(i -> [i i/2; i/3 i/4], 3))[1] |> tuple # WRONG
([32.98385416666667 11.838802083333334; 85.16987847222222 30.568446180555554],)
julia> ForwardDiff.gradient(x -> begin ys = ntuple(i -> [i i/2; i/3 i/4], 3); m = Mu(x); sum(sum, rev_map(m, ys)) end, [0.1 0.2; 0.3 0.4]) |> tuple
([55.717187499999994 20.049913194444443; 143.84765625 51.751779513888884],)
# Version with an array of numbers, not an array of arrays:
julia> gradient((x,ys) -> begin m = Mu(x); sum(map(m, ys)) end, 0.1, [1, 1/2, 1/3])
(0.2555555555555556, [0.02555555555555556, 0.011111111111111113, 0.0016666666666666668])
julia> gradient((x,ys) -> begin m = Mu(x); sum(broadcast(m, ys)) end, 0.1, [1, 1/2, 1/3])
(0.061111111111111116, [0.020000000000000004, 0.030000000000000006, 0.009166666666666668])
julia> gradient((x,ys) -> begin m = Mu(x); sum(m, ys) end, 0.1, [1, 1/2, 1/3])
ERROR: MethodError: no method matching +(::Base.RefValue{Any}, ::Base.RefValue{Any})
julia> ForwardDiff.derivative(x -> begin ys = [1, 1/2, 1/3]; m = Mu(x); sum(map(m, ys)) end, 0.1)
0.25555555555555554
julia> gradient((x,ys) -> begin m = Mu(x); sum(map(m, ys)) end, 0.1, (1, 1/2, 1/3)) # tuple
(0.2555555555555556, (0.02555555555555556, 0.011111111111111113, 0.0016666666666666668))
julia> gradient((x,ys) -> begin m = Mu(x); sum(map(m, ys)) end, 0.1, [1 1/2 1/3]) # matrix
(0.061111111111111116, [0.020000000000000004 0.030000000000000006 0.009166666666666668])
# Compare a struct which contains parameters, but not state:
julia> struct Do{T}; x::T; end
julia> (d::Do)(y) = dot(d.x, y)^2;
julia> gradient((x,ys) -> begin d = Do(x); sum(map(d, ys)) end, [0.1, 0.2], [[1,2], [3,4], [5,6]])
([24.6, 31.200000000000003], [[0.1, 0.2], [0.22000000000000003, 0.44000000000000006], [0.3400000000000001, 0.6800000000000002]])
julia> gradient((x,ys) -> begin d = Do(x); sum(broadcast(d, ys)) end, [0.1, 0.2], [[1,2], [3,4], [5,6]])
([24.6, 31.200000000000003], [[0.1, 0.2], [0.22000000000000003, 0.44000000000000006], [0.3400000000000001, 0.6800000000000002]])
julia> gradient((x,ys) -> begin d = Do(x); sum(d, ys) end, [0.1, 0.2], [[1,2], [3,4], [5,6]])
([24.6, 31.200000000000003], [[0.1, 0.2], [0.22000000000000003, 0.44000000000000006], [0.3400000000000001, 0.6800000000000002]])
julia> ForwardDiff.gradient(x -> begin ys = [[1,2], [3,4], [5,6]]; d = Do(x); sum(map(d, ys)) end, [0.1, 0.2]) |> tuple
([24.6, 31.200000000000003],)
# Neither the absolute order, nor the reverse in the gradient, matters here:
julia> gradient((x,ys) -> begin d = Do(x); sum(rev_map(d, ys)) end, [0.1, 0.2], [[1,2], [3,4], [5,6]])
([24.6, 31.200000000000003], [[0.1, 0.2], [0.22000000000000003, 0.44000000000000006], [0.3400000000000001, 0.6800000000000002]])
# Version without arrays of arrays:
julia> gradient((x,ys) -> begin d = Do(x); sum(map(d, ys)) end, 0.1, [1,2,3])
(2.8000000000000003, [0.020000000000000004, 0.04000000000000001, 0.06000000000000001])
julia> gradient((x,ys) -> begin d = Do(x); sum(broadcast(d, ys)) end, 0.1, [1,2,3])
(2.8000000000000003, [0.020000000000000004, 0.04000000000000001, 0.06000000000000001])
julia> gradient((x,ys) -> begin d = Do(x); sum(d, ys) end, 0.1, [1,2,3])
(2.8000000000000003, [0.020000000000000004, 0.04000000000000001, 0.06000000000000001])
julia> ForwardDiff.derivative(x -> begin ys=[1,2,3]; d = Do(x); sum(d, ys) end, 0.1)
2.8000000000000003
# Versions
(@v1.7) pkg> st Zygote
Status `~/.julia/environments/v1.7/Project.toml`
[e88e6eb3] Zygote v0.6.14
(@v1.7) pkg> st ChainRules # after https://github.com/JuliaDiff/ChainRules.jl/pull/441
Status `~/.julia/environments/v1.7/Project.toml`
[082447d4] ChainRules v0.8.15
# How easy is it to test for the presence of mutable state?
d = Do([0.1, 0.2]);
Base.issingletontype(typeof(d)) # false -- it contains parameters
ismutable(d) # false
m = Mu([0.1 0.2; 0.3 0.4]);
Base.issingletontype(typeof(m)) # false
ismutable(m) # true
ismutable(x -> m(x)) # false -- not a good check
anymutable(x::AbstractArray) = false
anymutable(x) = ismutable(x) ? true :
isempty(fieldnames(typeof(x))) ? false :
any(anymutable(getfield(x,n)) for n in fieldnames(typeof(x)))
anymutable(d)
anymutable(x -> d(x))
anymutable(m)
anymutable(x -> m(x)) # fails in global scope
let m = Mu([0.1 0.2; 0.3 0.4])
anymutable(x -> m(x))
end
# Xref https://github.com/FluxML/Zygote.jl/pull/807
# and https://github.com/FluxML/Zygote.jl/pull/1001
# and https://github.com/FluxML/Zygote.jl/pull/1011
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment