Skip to content

Instantly share code, notes, and snippets.

@KristofferC
Last active March 11, 2016 10:05
Show Gist options
  • Save KristofferC/56594a9c57d6f27df769 to your computer and use it in GitHub Desktop.
Save KristofferC/56594a9c57d6f27df769 to your computer and use it in GitHub Desktop.
immutable Partials{T,C}
data::C
end
Partials(data) = Partials{eltype(data),typeof(data)}(data)
data(p::Partials) = p.data
immutable GradientNumber{T,C}
value::T
partials::Partials{T,C}
end
partials(g::GradientNumber) = g.partials
value(g::GradientNumber) = g.value
# Gradient number
@inline +(g1::GradientNumber, g2::GradientNumber) = g1(value(g1)+value(g2), partials(g1)+partials(g2))
@inline +{N}(g1::GradientNumber{N}, g2::GradientNumber{N}) = GradientNumber(value(g1)+value(g2), partials(g1)+partials(g2))
@inline +(g::GradientNumber, x::Real) = GradientNumber(value(g)+x, partials(g))
@inline +(x::Real, g::GradientNumber) = g+x
@inline -(g::GradientNumber) = GradientNumber(-value(g), -partials(g))
@inline -(g1::GradientNumber, g2::GradientNumber) = GradientNumber(value(g1)-value(g2), partials(g1)-partials(g2))
@inline -(g::GradientNumber, x::Real) = GradientNumber(value(g)-x, partials(g))
@inline -(x::Real, g::GradientNumber) = GradientNumber(x-value(g), -partials(g))
@inline *(g::GradientNumber, x::Bool) = x ? g : (signbit(value(g))==0 ? zero(g) : -zero(g))
@inline *(x::Bool, g::GradientNumber) = g*x
@inline function *(g1::GradientNumber, g2::GradientNumber)
a1, a2 = value(g1), value(g2)
return GradientNumber(a1*a2, _mul_partials(partials(g1), partials(g2), a2, a1))
end
@inline *(g::GradientNumber, x::Real) = GradientNumber(value(g)*x, partials(g)*x)
@inline *(x::Real, g::GradientNumber) = g*x
@inline function /(g1::GradientNumber, g2::GradientNumber)
a1, a2 = value(g1), value(g2)
div_a = a1/a2
return GradientNumber(div_a, _div_partials(partials(g1), partials(g2), a1, a2))
end
@inline function /(x::Real, g::GradientNumber)
a = value(g)
div_a = x/a
deriv = -(div_a/a)
return gradnum_from_deriv(g, div_a, deriv)
end
@inline function /(g::GradientNumber, x::Real)
div_a = value(g)/x
return GradientNumber(div_a, partials(g)/x)
end
# Partials
@inline +(a::Partials, b::Partials) = Partials(data(a) + data(b))
@inline -(a::Partials, b::Partials) = Partials(data(a) - data(b))
@inline -(partials::Partials) = Partials(-data(partials))
@inline *(partials::Partials, x::Number) = Partials(data(partials) * x)
@inline /(partials::Partials, x::Number) = Partials(data(partials) / x)
@inline *(x::Number, partials::Partials) = partials*x
@inline _mul_partials(a::Partials, b::Partials, afactor, bfactor) = Partials(data(a) * bfactor + data(b) * afactor)
@inline function _div_partials(a::Partials, b::Partials, aval, bval)
afactor = inv(bval)
bfactor = -aval/(bval*bval)
return _mul_partials(a, b, afactor, bfactor)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment