-
-
Save mcabbott/c6cdc73d45ed3e35c3fd8966863993f8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
julia> using Zygote, ForwardDiff, LinearAlgebra | |
julia> mutable struct Mu{T}; x::T; end # Contains mutable state | |
julia> (m::Mu)(y) = (m.x = m.x * y).^2; | |
# Zygote's map reverses in the gradient, sometimes, while its broadcast does not: | |
# Ground truth is ForwardDiff, for which these agree. | |
julia> gradient((x,ys) -> begin m = Mu(x); sum(sum, map(m, ys)) end, [0.1 0.2; 0.3 0.4], [[i i/2; i/3 i/4] for i in 1:3]) | |
([32.98385416666667 11.838802083333334; 85.16987847222222 30.568446180555554], [[24.4215625 8.855572916666668; 34.42333333333333 12.48277777777778], [14.897847222222223 5.66920138888889; 8.639479166666668 3.2876562500000004], [8.69388888888889 4.678750000000001; 4.699027777777779 2.528854166666667]]) | |
julia> gradient((x,ys) -> begin m = Mu(x); sum(sum, broadcast(m, ys)) end, [0.1 0.2; 0.3 0.4], [[i i/2; i/3 i/4] for i in 1:3]) # WRONG answer | |
([21.650000000000006 7.877777777777779; 55.87916666666668 20.33263888888889], [[0.29333333333333333 0.16999999999999998; 0.41333333333333333 0.24000000000000002], [1.5850000000000002 0.7623148148148148; 0.9191666666666667 0.44208333333333333], [18.0787962962963 8.11027777777778; 9.771550925925927 4.383587962962964]]) | |
julia> gradient((x,ys) -> begin m = Mu(x); sum(sum(m, ys)) end, [0.1 0.2; 0.3 0.4], [[i i/2; i/3 i/4] for i in 1:3]) | |
([21.650000000000006 7.877777777777779; 55.87916666666668 20.33263888888889], [[0.29333333333333333 0.16999999999999998; 0.41333333333333333 0.24000000000000002], [1.5850000000000002 0.7623148148148148; 0.9191666666666667 0.44208333333333333], [18.0787962962963 8.11027777777778; 9.771550925925927 4.383587962962964]]) | |
julia> ForwardDiff.gradient(x -> begin ys = [[i i/2; i/3 i/4] for i in 1:3]; m = Mu(x); sum(sum, map(m, ys)) end, [0.1 0.2; 0.3 0.4]) |> tuple # ground truth, matches map | |
([32.983854166666674 11.838802083333334; 85.16987847222224 30.56844618055556],) | |
julia> ForwardDiff.gradient(x -> begin ys = [[i i/2; i/3 i/4] for i in 1:3]; m = Mu(x); sum(sum, broadcast(m, ys)) end, [0.1 0.2; 0.3 0.4]) |> tuple | |
([32.983854166666674 11.838802083333334; 85.16987847222224 30.56844618055556],) | |
# The absolute iteration order matters, not just back relative to forward: | |
julia> rev_map(f,x) = reverse(map(f,reverse(x))); | |
julia> ForwardDiff.gradient(x -> begin ys = [[i i/2; i/3 i/4] for i in 1:3]; m = Mu(x); sum(sum, rev_map(m, ys)) end, [0.1 0.2; 0.3 0.4]) |> tuple | |
([55.717187499999994 20.049913194444443; 143.84765625 51.751779513888884],) | |
julia> gradient((x,ys) -> begin m = Mu(x); sum(sum, rev_map(m, ys)) end, [0.1 0.2; 0.3 0.4], [[i i/2; i/3 i/4] for i in 1:3]) | |
([55.71718750000001 20.049913194444443; 143.84765624999997 51.751779513888884], [[26.08166666666666 14.036249999999997; 14.09708333333333 7.5865624999999985], [23.12895833333333 10.118090277777776; 13.412812499999998 5.86765625], [13.664965277777776 5.154079861111111; 19.26111111111111 7.266111111111111]]) | |
# Zygote's map actually only reverses vectors, it gets the wrong answer for matrices: | |
julia> gradient((x,ys) -> begin m = Mu(x); sum(sum, map(m, ys)) end, [0.1 0.2; 0.3 0.4], [[i j/2; i/3 j/4] for i in 1:2, j in 1:2])[1] |> tuple # WRONG | |
([18.82037037037037 6.842592592592592; 48.58865740740741 17.665625],) | |
julia> gradient((x,ys) -> begin m = Mu(x); sum(sum, broadcast(m, ys)) end, [0.1 0.2; 0.3 0.4], [[i j/2; i/3 j/4] for i in 1:2, j in 1:2])[1] |> tuple # WRONG | |
([18.82037037037037 6.842592592592592; 48.58865740740741 17.665625],) | |
julia> ForwardDiff.gradient(x -> begin ys = [[i j/2; i/3 j/4] for i in 1:2, j in 1:2]; m = Mu(x); sum(sum, rev_map(m, ys)) end, [0.1 0.2; 0.3 0.4]) |> tuple | |
([41.67507233796297 15.022499517746914; 107.57283227237654 38.771191406250004],) | |
# ... and the wrong answer for tuples: | |
julia> gradient((x,ys) -> begin m = Mu(x); sum(sum, map(m, ys)) end, [0.1 0.2; 0.3 0.4], ntuple(i -> [i i/2; i/3 i/4], 3))[1] |> tuple # WRONG | |
([32.98385416666667 11.838802083333334; 85.16987847222222 30.568446180555554],) | |
julia> ForwardDiff.gradient(x -> begin ys = ntuple(i -> [i i/2; i/3 i/4], 3); m = Mu(x); sum(sum, rev_map(m, ys)) end, [0.1 0.2; 0.3 0.4]) |> tuple | |
([55.717187499999994 20.049913194444443; 143.84765625 51.751779513888884],) | |
# Version with an array of numbers, not an array of arrays: | |
julia> gradient((x,ys) -> begin m = Mu(x); sum(map(m, ys)) end, 0.1, [1, 1/2, 1/3]) | |
(0.2555555555555556, [0.02555555555555556, 0.011111111111111113, 0.0016666666666666668]) | |
julia> gradient((x,ys) -> begin m = Mu(x); sum(broadcast(m, ys)) end, 0.1, [1, 1/2, 1/3]) | |
(0.061111111111111116, [0.020000000000000004, 0.030000000000000006, 0.009166666666666668]) | |
julia> gradient((x,ys) -> begin m = Mu(x); sum(m, ys) end, 0.1, [1, 1/2, 1/3]) | |
ERROR: MethodError: no method matching +(::Base.RefValue{Any}, ::Base.RefValue{Any}) | |
julia> ForwardDiff.derivative(x -> begin ys = [1, 1/2, 1/3]; m = Mu(x); sum(map(m, ys)) end, 0.1) | |
0.25555555555555554 | |
julia> gradient((x,ys) -> begin m = Mu(x); sum(map(m, ys)) end, 0.1, (1, 1/2, 1/3)) # tuple | |
(0.2555555555555556, (0.02555555555555556, 0.011111111111111113, 0.0016666666666666668)) | |
julia> gradient((x,ys) -> begin m = Mu(x); sum(map(m, ys)) end, 0.1, [1 1/2 1/3]) # matrix | |
(0.061111111111111116, [0.020000000000000004 0.030000000000000006 0.009166666666666668]) | |
# Compare a struct which contains parameters, but not state: | |
julia> struct Do{T}; x::T; end | |
julia> (d::Do)(y) = dot(d.x, y)^2; | |
julia> gradient((x,ys) -> begin d = Do(x); sum(map(d, ys)) end, [0.1, 0.2], [[1,2], [3,4], [5,6]]) | |
([24.6, 31.200000000000003], [[0.1, 0.2], [0.22000000000000003, 0.44000000000000006], [0.3400000000000001, 0.6800000000000002]]) | |
julia> gradient((x,ys) -> begin d = Do(x); sum(broadcast(d, ys)) end, [0.1, 0.2], [[1,2], [3,4], [5,6]]) | |
([24.6, 31.200000000000003], [[0.1, 0.2], [0.22000000000000003, 0.44000000000000006], [0.3400000000000001, 0.6800000000000002]]) | |
julia> gradient((x,ys) -> begin d = Do(x); sum(d, ys) end, [0.1, 0.2], [[1,2], [3,4], [5,6]]) | |
([24.6, 31.200000000000003], [[0.1, 0.2], [0.22000000000000003, 0.44000000000000006], [0.3400000000000001, 0.6800000000000002]]) | |
julia> ForwardDiff.gradient(x -> begin ys = [[1,2], [3,4], [5,6]]; d = Do(x); sum(map(d, ys)) end, [0.1, 0.2]) |> tuple | |
([24.6, 31.200000000000003],) | |
# Neither the absolute order, nor the reverse in the gradient, matters here: | |
julia> gradient((x,ys) -> begin d = Do(x); sum(rev_map(d, ys)) end, [0.1, 0.2], [[1,2], [3,4], [5,6]]) | |
([24.6, 31.200000000000003], [[0.1, 0.2], [0.22000000000000003, 0.44000000000000006], [0.3400000000000001, 0.6800000000000002]]) | |
# Version without arrays of arrays: | |
julia> gradient((x,ys) -> begin d = Do(x); sum(map(d, ys)) end, 0.1, [1,2,3]) | |
(2.8000000000000003, [0.020000000000000004, 0.04000000000000001, 0.06000000000000001]) | |
julia> gradient((x,ys) -> begin d = Do(x); sum(broadcast(d, ys)) end, 0.1, [1,2,3]) | |
(2.8000000000000003, [0.020000000000000004, 0.04000000000000001, 0.06000000000000001]) | |
julia> gradient((x,ys) -> begin d = Do(x); sum(d, ys) end, 0.1, [1,2,3]) | |
(2.8000000000000003, [0.020000000000000004, 0.04000000000000001, 0.06000000000000001]) | |
julia> ForwardDiff.derivative(x -> begin ys=[1,2,3]; d = Do(x); sum(d, ys) end, 0.1) | |
2.8000000000000003 | |
# Versions | |
(@v1.7) pkg> st Zygote | |
Status `~/.julia/environments/v1.7/Project.toml` | |
[e88e6eb3] Zygote v0.6.14 | |
(@v1.7) pkg> st ChainRules # after https://github.com/JuliaDiff/ChainRules.jl/pull/441 | |
Status `~/.julia/environments/v1.7/Project.toml` | |
[082447d4] ChainRules v0.8.15 | |
# How easy is it to test for the presence of mutable state? | |
d = Do([0.1, 0.2]); | |
Base.issingletontype(typeof(d)) # false -- it contains parameters | |
ismutable(d) # false | |
m = Mu([0.1 0.2; 0.3 0.4]); | |
Base.issingletontype(typeof(m)) # false | |
ismutable(m) # true | |
ismutable(x -> m(x)) # false -- not a good check | |
anymutable(x::AbstractArray) = false | |
anymutable(x) = ismutable(x) ? true : | |
isempty(fieldnames(typeof(x))) ? false : | |
any(anymutable(getfield(x,n)) for n in fieldnames(typeof(x))) | |
anymutable(d) | |
anymutable(x -> d(x)) | |
anymutable(m) | |
anymutable(x -> m(x)) # fails in global scope | |
let m = Mu([0.1 0.2; 0.3 0.4]) | |
anymutable(x -> m(x)) | |
end | |
# Xref https://github.com/FluxML/Zygote.jl/pull/807 | |
# and https://github.com/FluxML/Zygote.jl/pull/1001 | |
# and https://github.com/FluxML/Zygote.jl/pull/1011 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment