Skip to content

Instantly share code, notes, and snippets.

@johnnychen94
Last active July 30, 2021 11:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save johnnychen94/b0bf2c336bc6991cf31d81dbb2f86f85 to your computer and use it in GitHub Desktop.
Save johnnychen94/b0bf2c336bc6991cf31d81dbb2f86f85 to your computer and use it in GitHub Desktop.
diffwarp: proof of concept on rotation
using ImageTransformations
using StaticArrays
using Interpolations
using ImageCore
using ImageShow
using TestImages
using ChainRules
using ChainRules: NoTangent, ZeroTangent, @not_implemented
using ChainRulesTestUtils
using Zygote
img = Float64.(imresize(testimage("cameraman"), (32, 32)))
function ϕ(p, θ)
sinθ, cosθ = sincos(θ)
[p[1]*cosθ - p[2]*sinθ,
p[1]*sinθ + p[2]*cosθ]
end
function ChainRules.rrule(::typeof(ϕ), p, θ)
sinθ, cosθ = sincos(θ)
q = [p[1]*cosθ - p[2]*sinθ,
p[1]*sinθ + p[2]*cosθ]
function dϕ(dLdq)
dϕdθ = [-p[1]*sinθ - p[2]*cosθ,
p[1]*cosθ - p[2]*sinθ]
dLdθ = sum(dLdq .* dϕdθ)
dLdp = [dLdq[1] * cosθ + dLdq[2] * sinθ,
-dLdq[1] * sinθ + dLdq[2] * cosθ]
return NoTangent(), dLdp, dLdθ
end
return q, dϕ
end
function τ(X, q)
etp = extrapolate(interpolate(X, BSpline(Linear())), zero(eltype(X)))
return etp(q...)
end
function ChainRules.rrule(::typeof(τ), X, q)
etp = extrapolate(interpolate(X, BSpline(Linear())), zero(eltype(X)))
Yp = etp(q...)
function dτ(dLdYp)
dLdq = dLdYp .* Interpolations.gradient(etp, q...)
dLdX = @not_implemented(
"Interpolations doesn't yet support gradient to coefficients"
)
return NoTangent(), dLdX, dLdq
end
return Yp, dτ
end
# Do some pixel level tests
θ = 0.2
p = [10, 10]
q = ϕ(p, θ)
# test our rrule
test_rrule(ϕ, p, θ; check_inferred=false) # FIXME
test_rrule(τ, img, q; check_inferred=false)
function f_single_pixel(θ)
τ(img, ϕ(p, θ))
end
img_p = img[p...]
Zygote.gradient(θ) do θ
f_single_pixel(θ) - img_p
end # (-0.08925456466517724,)
# Now let's put the warp together
function simple_rotate(X, θ)
out = similar(X)
for p in CartesianIndices(out)
q = ϕ(collect(p.I), θ)
out[p] = τ(X, q)
end
out
end
function ChainRules.rrule(::typeof(simple_rotate), X, θ)
Y = simple_rotate(X, θ)
function gradient_simple_rotate(dLdY)
dLdθ = zero(eltype(dLdY))
lk = ReentrantLock()
Threads.@threads for p in CartesianIndices(Y)
tmp = let p = collect(p.I)
_, dτ = rrule(τ, X, p)
_, _, dLdq = dτ(dLdY[p])
_, dϕ = rrule(ϕ, p, θ)
_, _, dLdθ = dϕ(dLdq)
return dLdθ
end
lock(lk) do
dLdθ += tmp
end
end
dLdX = @not_implemented(
"Interpolations doesn't yet support gradient to coefficients"
)
return NoTangent(), dLdX, dLdθ
end
return Y, gradient_simple_rotate
end
simple_rotate(img, θ) .|> Gray
imgr, gradient_simple_rotate = rrule(simple_rotate, img, θ)
gradient_simple_rotate(ones(eltype(img), size(img)))
test_rrule(simple_rotate, img, θ; check_inferred=false) # FIXME
# do some simple experiment
g(θ) = simple_rotate(img, θ)
g(θ) .|> Gray
θ = 0.5
lr = 1e-1
outs = []
for _ in 1:100
dθ, = Zygote.gradient(θ) do θ
sum(abs2, g(θ) - img)
end
θ = θ - lr * dθ
out = g(θ)
println("loss: ", sum(abs2, out - img), ", dθ=", dθ)
push!(outs, Gray.(out))
end
ImageShow.gif([outs...])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment