Skip to content

Instantly share code, notes, and snippets.

@darsnack
Created November 18, 2022 15:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save darsnack/014a445d16ae2645421ccb8dc4ed0399 to your computer and use it in GitHub Desktop.
Save darsnack/014a445d16ae2645421ccb8dc4ed0399 to your computer and use it in GitHub Desktop.
Forward mode AD for Functors.jl
struct FilteredWalk{F, T, G}
filter_set::F
prune::T
rebuilder::G
end
FilteredWalk(filter_set; prune = identity, rebuilder = identity) =
FilteredWalk(filter_set, prune, rebuilder)
function (walk::FilteredWalk)(recurse, x, ys...)
children, re = Functors.functor(x)
tchildren = walk.filter_set(x)
ychildren = map(y -> Functors.functor(typeof(x), y)[1], ys)
mchildren = map(children, ychildren...) do c, ycs...
c ∈ tchildren ? recurse(c, ycs...) : walk.prune(c)
end
return walk.rebuilder(re)(mchildren)
end
differentiable(x) = _differentiable(Optimisers.trainable(x))
function _differentiable(xs::NamedTuple)
diff_xs = [(k, v) for (k, v) in pairs(xs) if Optimisers.isnumeric(v)]
return (; diff_xs...)
end
_differentiable(xs::Tuple) = tuple([x for x in xs if Optimisers.isnumeric(x)]...)
function normalize_grad(g)
total = Ref(0f0)
fmap(gi -> !isnothing(gi) && (total[] += sum(gi.^2)), g)
gnorm = sqrt(total[])
return fmap(gi -> isnothing(gi) ? gi : gi ./ gnorm, g)
end
function sample_direction(model, sample = Flux.rand32, normalize = true)
walk = FilteredWalk(differentiable;
prune = _ -> nothing,
rebuilder = _ -> identity)
directions = fmap(x -> sample(size(x)...), model;
exclude = Optimisers.isnumeric,
walk = walk)
return normalize_grad(directions)
end
function seed_fwddiff(model::T, direction) where T
walk = FilteredWalk(differentiable)
model_dual = fmap(model, direction;
exclude = Optimisers.isnumeric,
walk = walk) do x, d
S = eltype(x)
partial = ForwardDiff.Partials{1, S}.(tuple.(d))
return ForwardDiff.Dual{T, S, 1}.(x, partial)
end
return model_dual
end
function seed_diffractor(model::T, direction) where T
walk = FilteredWalk(differentiable;
prune = x -> Diffractor.ZeroBundle{1}(x),
rebuilder = _ -> (xs -> (xs...,)))
partials = fmap(model, direction;
exclude = Optimisers.isnumeric,
walk = walk) do x, p
return Diffractor.TangentBundle{1}(x, (p,))
end
return Diffractor.CompositeBundle{1, T}(partials)
end
function fwddiff_dirgradient(f, x, direction = sample_direction(x))
xdual = seed_fwddiff(x, direction)
z = f(xdual)
g = fmap(direction) do d
isnothing(d) ? nothing : ForwardDiff.partials(z) .* d
end
return g
end
function diffractor_dirgradient(f, x, direction = sample_direction(x))
bundle = seed_diffractor(x, direction)
z = Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(f), bundle)
g = fmap(direction) do d
isnothing(d) ? nothing : z.partials[1] .* d
end
return g
end
@mcabbott
Copy link

With latest Diffractor I get this error:

julia> diffractor_dirgradient(sum, ones(3))
ERROR: MethodError: no method matching (Diffractor.TangentBundle{1})(::Vector{Float64}, ::Tuple{Vector{Float32}})

Closest candidates are:
  (Diffractor.TangentBundle{N})(::B, ::P) where {N, B, P<:Diffractor.AbstractTangentSpace}
   @ Diffractor ~/.julia/dev/Diffractor/src/tangent.jl:152

Stacktrace:
  [1] (::var"#97#101")(x::Vector{Float64}, p::Vector{Float32})
    @ Main ./REPL[178]:8
  [2] (::Functors.ExcludeWalk{Functors.AnonymousWalk{FilteredWalk{typeof(differentiable), var"#94#98", var"#95#99"}}, var"#97#101", typeof(Optimisers.isnumeric)})(recurse::Function, x::Vector{Float64}, ys::Vector{Float32})
    @ Functors ~/.julia/packages/Functors/orBYx/src/walks.jl:92
  [3] CachedWalk
    @ ~/.julia/packages/Functors/orBYx/src/walks.jl:132 [inlined]
  [4] fmap(walk::Functors.CachedWalk{Functors.ExcludeWalk{Functors.AnonymousWalk{FilteredWalk{typeof(differentiable), var"#94#98", var"#95#99"}}, var"#97#101", typeof(Optimisers.isnumeric)}, Functors.NoKeyword}, f::Function, x::Vector{Float64}, ys::Vector{Float32})
    @ Functors ~/.julia/packages/Functors/orBYx/src/maps.jl:1
  [5] fmap(f::Function, x::Vector{Float64}, ys::Vector{Float32}; exclude::Function, walk::FilteredWalk{typeof(differentiable), var"#94#98", var"#95#99"}, cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
    @ Functors ~/.julia/packages/Functors/orBYx/src/maps.jl:11
  [6] kwcall(::NamedTuple{(:exclude, :walk), Tuple{typeof(Optimisers.isnumeric), FilteredWalk{typeof(differentiable), var"#94#98", var"#95#99"}}}, ::typeof(fmap), f::Function, x::Vector{Float64}, ys::Vector{Float32})
    @ Functors ~/.julia/packages/Functors/orBYx/src/maps.jl:3
  [7] seed_diffractor(model::Vector{Float64}, direction::Vector{Float32})
    @ Main ./REPL[178]:5
  [8] diffractor_dirgradient(f::Function, x::Vector{Float64}, direction::Vector{Float32})
    @ Main ./REPL[180]:2
  [9] diffractor_dirgradient(f::Function, x::Vector{Float64})
    @ Main ./REPL[180]:2

@darsnack
Copy link
Author

darsnack commented Jan 1, 2023

I think that’s just a Float64 vs Float32 issue? It samples perturbations from Float32 normal.

@mcabbott
Copy link

mcabbott commented Jan 1, 2023

The error is the same with diffractor_dirgradient(sum, ones(Float32, 3)). Julia nightly, didn't try others.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment