Skip to content

Instantly share code, notes, and snippets.

@darsnack
Created November 18, 2022 15:10
Show Gist options
  • Save darsnack/014a445d16ae2645421ccb8dc4ed0399 to your computer and use it in GitHub Desktop.
Save darsnack/014a445d16ae2645421ccb8dc4ed0399 to your computer and use it in GitHub Desktop.
Forward mode AD for Functors.jl
struct FilteredWalk{F, T, G}
filter_set::F
prune::T
rebuilder::G
end
FilteredWalk(filter_set; prune = identity, rebuilder = identity) =
FilteredWalk(filter_set, prune, rebuilder)
function (walk::FilteredWalk)(recurse, x, ys...)
children, re = Functors.functor(x)
tchildren = walk.filter_set(x)
ychildren = map(y -> Functors.functor(typeof(x), y)[1], ys)
mchildren = map(children, ychildren...) do c, ycs...
c ∈ tchildren ? recurse(c, ycs...) : walk.prune(c)
end
return walk.rebuilder(re)(mchildren)
end
differentiable(x) = _differentiable(Optimisers.trainable(x))
function _differentiable(xs::NamedTuple)
diff_xs = [(k, v) for (k, v) in pairs(xs) if Optimisers.isnumeric(v)]
return (; diff_xs...)
end
_differentiable(xs::Tuple) = tuple([x for x in xs if Optimisers.isnumeric(x)]...)
function normalize_grad(g)
total = Ref(0f0)
fmap(gi -> !isnothing(gi) && (total[] += sum(gi.^2)), g)
gnorm = sqrt(total[])
return fmap(gi -> isnothing(gi) ? gi : gi ./ gnorm, g)
end
function sample_direction(model, sample = Flux.rand32, normalize = true)
walk = FilteredWalk(differentiable;
prune = _ -> nothing,
rebuilder = _ -> identity)
directions = fmap(x -> sample(size(x)...), model;
exclude = Optimisers.isnumeric,
walk = walk)
return normalize_grad(directions)
end
function seed_fwddiff(model::T, direction) where T
walk = FilteredWalk(differentiable)
model_dual = fmap(model, direction;
exclude = Optimisers.isnumeric,
walk = walk) do x, d
S = eltype(x)
partial = ForwardDiff.Partials{1, S}.(tuple.(d))
return ForwardDiff.Dual{T, S, 1}.(x, partial)
end
return model_dual
end
function seed_diffractor(model::T, direction) where T
walk = FilteredWalk(differentiable;
prune = x -> Diffractor.ZeroBundle{1}(x),
rebuilder = _ -> (xs -> (xs...,)))
partials = fmap(model, direction;
exclude = Optimisers.isnumeric,
walk = walk) do x, p
return Diffractor.TangentBundle{1}(x, (p,))
end
return Diffractor.CompositeBundle{1, T}(partials)
end
function fwddiff_dirgradient(f, x, direction = sample_direction(x))
xdual = seed_fwddiff(x, direction)
z = f(xdual)
g = fmap(direction) do d
isnothing(d) ? nothing : ForwardDiff.partials(z) .* d
end
return g
end
function diffractor_dirgradient(f, x, direction = sample_direction(x))
bundle = seed_diffractor(x, direction)
z = Diffractor.∂☆{1}()(Diffractor.ZeroBundle{1}(f), bundle)
g = fmap(direction) do d
isnothing(d) ? nothing : z.partials[1] .* d
end
return g
end
@darsnack
Copy link
Author

darsnack commented Jan 1, 2023

I think that’s just a Float64 vs Float32 issue? It samples perturbations from Float32 normal.

@mcabbott
Copy link

mcabbott commented Jan 1, 2023

The error is the same with diffractor_dirgradient(sum, ones(Float32, 3)). Julia nightly, didn't try others.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment