-
-
Save ssfrr/6dcb548c06e18e54c35fc89874fad553 to your computer and use it in GitHub Desktop.
Complex Dual numbers in Julia
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using ForwardDiff: Dual, partials | |
"""return a float-valued N-length one-hot Tuple""" | |
onehot(n, N) = tuple((float(i == n) for i in 1:N)...) | |
"""return the number of seeds needed for the given array(s)""" | |
nseeds(x::AbstractArray{<:Real}) = length(x) | |
nseeds(x::AbstractArray{<:Complex}) = 2*length(x) | |
nseeds(xs::AbstractArray...) = sum(nseeds.(xs)) | |
"""Give the offsets into the seeds array for the given data arrays""" | |
# note we can't dot-broadcast nseeds because that would give a tuple and cumsum doesn't work on tuples in 0.6.2 | |
seedoffsets(xs::AbstractArray...) = (0, cumsum([nseeds(x) for x in xs[1:end-1]])...) | |
"""return a tuple of Vectors of seed tuples corresponding to each element of the given arrays""" | |
genseeds(x::AbstractArray) = _genseeds(x, 0, nseeds(x)) | |
function genseeds(xs::AbstractArray...) | |
_genseeds.(xs, seedoffsets(xs...), nseeds(xs...)) | |
end | |
"""return a vector of seed tuples for the given array""" | |
_genseeds(x::AbstractArray, offset::Integer, N::Integer) = onehot.((1:nseeds(x))+offset, N) | |
""" | |
return an Array or tuple of Arrays of the same shapes as the given arrays, with | |
each element dual-ized. Note it's important for all the Arrays in a given dual | |
computation to be dualized at the same time, so they have orthogonal | |
perturbations. | |
""" | |
dualize(x::AbstractArray) = _dualize(x, genseeds(x)) | |
dualize(xs::AbstractArray...) = _dualize.(xs, genseeds(xs...)) | |
""" | |
Given an Array `x` and peturbations `seeds` (of the same length), return the | |
dualized array with N-dimensional peterbations. | |
""" | |
function _dualize(x::AbstractArray{<:Real}, seeds::AbstractArray{<:NTuple{N, T}}) where {N, T} | |
dual = similar(x, Dual{Void, eltype(x), N}) | |
for i in 1:length(x) | |
dual[i] = Dual(x[i], seeds[i]...) | |
end | |
dual | |
end | |
function _dualize(x::AbstractArray{<:Complex{T}}, seeds::AbstractArray{<:NTuple{N, T}}) where {N, T} | |
dual = similar(x, Complex{Dual{Void, T, N}}) | |
for i in 1:length(x) | |
dual[i] = Dual(real(x[i]), seeds[2i-1]...) + im * Dual(imag(x[i]), seeds[2i]) | |
end | |
dual | |
end | |
""" | |
Apply gradient descent with step size `alpha` and the gradient given by the | |
perturbations of `ydual`. | |
""" | |
function grad_step!(ydual, alpha, xs...) | |
grad_step_single.(ydual, alpha, xs, seedoffsets(xs...)) | |
nothing | |
end | |
function grad_step_single(ydual, alpha, x::AbstractArray{<:Real}, offset) | |
# TODO: use eachindex(x) and somehow convert to linear index for more generality | |
for i in 1:length(x) | |
x[i] -= alpha * partials(ydual)[i+offset] | |
end | |
nothing | |
end | |
function grad_step_single(ydual, alpha, x::AbstractArray{<:Complex}, offset) | |
# TODO: use eachindex(x) and somehow convert to linear index for more generality | |
for i in 1:length(x) | |
dreal = partials(ydual)[2i-1+offset] | |
dimag = partials(ydual)[2i+offset] | |
x[i] -= alpha * (dreal + im*dimag) | |
end | |
nothing | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Run as cells in Juno | |
## data | |
N = 100 | |
x = randn(Complex, 5, N) | |
Wt = randn(Complex, 2, 5) | |
bt = randn(Complex, 2) | |
y = Wt*x .+ bt | |
# params | |
W, b = dualize(randn(Complex, 2, 5), randn(Complex, 2)) | |
predict(x) = W*x .+b | |
# loss WRT X and Y data | |
loss(x, y) = sum(abs2.(predict(x) .- y))/N | |
## | |
res = loss(x, y) | |
res.value | |
## | |
grad_step!(res, 0.1, W, b) | |
res = loss(x, y) | |
res.value | |
## | |
for _ in 1:100 | |
grad_step!(res, 0.1, W, b) | |
res = loss(x, y) | |
end | |
res.value | |
## |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Base.Test | |
# not super generic, but works for this (shouldn't be needed in 0.7) | |
Base.Random.randn(::Type{Complex}, A::Integer...) = randn(A...) + im * randn(A...) | |
Base.Random.rand(::Type{Complex}, A::Integer...) = rand(A...) + im * rand(A...) | |
@test onehot(1, 4) == (1, 0, 0, 0) | |
@test onehot(3, 4) == (0, 0, 1, 0) | |
@test nseeds(rand(2, 3)) == 6 | |
@test nseeds(rand(2, 3), rand(4)) == 10 | |
@test nseeds(rand(Complex, 2, 3)) == 12 | |
@test nseeds(rand(Complex, 2, 3), rand(Complex, 4)) == 20 | |
@test nseeds(rand(Complex, 2, 3), rand(4)) == 16 | |
@test seedoffsets(rand(2, 3)) == (0, ) | |
@test seedoffsets(rand(2, 3), rand(4), rand(2, 2)) == (0, 6, 10) | |
@test seedoffsets(rand(Complex, 2, 3), rand(Complex, 4), rand(Complex, 2, 2)) == (0, 12, 20) | |
@test seedoffsets(rand(Complex, 2, 3), rand(4), rand(Complex, 2, 2)) == (0, 12, 16) | |
@test genseeds(rand(2, 2)) == [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)] | |
@test genseeds(rand(2, 2), rand(2)) == ([(1, 0, 0, 0, 0, 0), (0, 1, 0, 0, 0, 0), (0, 0, 1, 0, 0, 0), (0, 0, 0, 1, 0, 0)], | |
[(0, 0, 0, 0, 1, 0), (0, 0, 0, 0, 0, 1)]) | |
@test genseeds(rand(Complex, 2)) == [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)] | |
@test genseeds(rand(Complex, 2, 2), rand(Complex, 2)) == ( | |
[(1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), | |
(0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), | |
(0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0), | |
(0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0), | |
(0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0), | |
(0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0), | |
(0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0), | |
(0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0)], | |
[(0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0), | |
(0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0), | |
(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0), | |
(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1)]) | |
x1 = rand(3, 2) | |
x2 = rand(2, 2) | |
@test dualize(x1) == [Dual(x1[r,c], onehot(3(c-1)+r, 6)...) for r in 1:3, c in 1:2] | |
@test dualize(x1, x2) == ( | |
[Dual(x1[r,c], onehot(3(c-1)+r, 10)...) for r in 1:3, c in 1:2], | |
[Dual(x2[r,c], onehot(2(c-1)+r+6, 10)...) for r in 1:2, c in 1:2]) | |
c1 = rand(Complex, 3, 2) | |
c2 = rand(Complex, 2, 2) | |
@test dualize(c1) == [Complex(Dual(real(c1[r,c]), onehot(6(c-1)+2r-1, 12)...), | |
Dual(imag(c1[r,c]), onehot(6(c-1)+2r, 12)...)) for r in 1:3, c in 1:2] | |
@test dualize(c1, c2) == ([Complex(Dual(real(c1[r,c]), onehot(6(c-1)+2r-1, 20)...), | |
Dual(imag(c1[r,c]), onehot(6(c-1)+2r, 20)...)) for r in 1:3, c in 1:2], | |
[Complex(Dual(real(c2[r,c]), onehot(6(c-1)+2r-1+12, 20)...), | |
Dual(imag(c2[r,c]), onehot(6(c-1)+2r+12, 20)...)) for r in 1:2, c in 1:2]) | |
x1 = dualize(randn(3, 2)) | |
p1 = copy(x1) | |
grad = Dual(randn(7)...) | |
grad_step!(grad, 0.1, p1) | |
@test p1[:] == x1[:] - partials(grad) * 0.1 | |
x1, x2 = dualize(randn(3, 2), randn(2, 2)) | |
p1 = copy(x1) | |
p2 = copy(x2) | |
grad = Dual(randn(11)...) | |
grad_step!(grad, 0.1, p1, p2) | |
@test p1[:] == x1[:] - partials(grad)[1:6] * 0.1 | |
@test p2[:] == x2[:] - partials(grad)[7:10] * 0.1 | |
c1 = dualize(randn(Complex, 3, 2)) | |
p1 = copy(c1) | |
grad = Dual(randn(13)...) | |
grad_step!(grad, 0.1, p1) | |
@test real(p1[:]) == real(c1[:]) - partials(grad)[1:2:end] * 0.1 | |
@test imag(p1[:]) == imag(c1[:]) - partials(grad)[2:2:end] * 0.1 | |
z = Dual(2.0, 1, 0) + im*Dual(3.0, 0, 1) | |
t = (5.0-3.0im) | |
res = abs2(z-t) | |
z -= 0.01(partials(res)[1] + im*partials(res)[2]) | |
res = abs2(z-t) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment