Skip to content

Instantly share code, notes, and snippets.

@cscherrer
Last active June 7, 2021 22:22
Show Gist options
  • Save cscherrer/ad80243833fa0a30b448ac58b77b6e5f to your computer and use it in GitHub Desktop.
Save cscherrer/ad80243833fa0a30b448ac58b77b6e5f to your computer and use it in GitHub Desktop.
OnlineStat for log-weighetd Gaussian
using PositiveFactorizations
using LinearAlgebra
using StatsFuns
using Random
using OnlineStatsBase
using Statistics
const OSB = OnlineStatsBase
const bessel = OSB.bessel
mutable struct Gaussian{T} <: OnlineStat{Union{Tuple, NamedTuple, AbstractVector}, } where T<:Number
value::Matrix{T}
A::Matrix{T} # x'x/n
b::Vector{T} # 1'x/n
log∑w::T
log∑w²::T
L::Matrix{T} # Lower Cholesky factor
Lsync::Bool # Indicates whether L is up to date
weight
n::Int
end
# Kish's Effective Sample Size, see https://en.wikipedia.org/wiki/Effective_sample_size
n_eff(o::Gaussian) = exp(2 * o.log∑w - o.log∑w²)
function Gaussian(::Type{T}, p::Int=0; weight = EqualWeight()) where T<:Number
logε = nextfloat(typemin(T))
Gaussian(zeros(T,p,p), zeros(T,p,p), zeros(T,p), logε, logε, zeros(T,p,p), false, weight, 0)
end
Gaussian(p::Int=0; weight = EqualWeight()) = Gaussian(Float64, p; weight=weight)
# The computation for γ,
#
# γ = min(o.weight(o.n += 1) * o.n * logistic(ℓ - o.log∑w), 1.0)
#
# could use come explanation. Let's first look at the first argument inside the
# `min`,
#
# o.weight(o.n += 1) * o.n * logistic(ℓ - o.log∑w)
#
# The first factor,
#
# o.weight(o.n += 1)
#
# is to allow the use of built-in weights from OnlineStats.jl. The last factor,
#
# logistic(ℓ - o.log∑w)
#
# takes into account that each observtation has its own log-weight. But now
# we've counted the weights twice, so there's a question of how to correct for
# this. We could take a square root to get the geometric mean of the two
# effects, but this would leave neither of them working as expected.
#
# With OnlineStats.EqualWeight, the weight for the nth observation would be 1/n.
# So we multiply by n (the second factor).
#
# The `logistic` factor is typically O(1/n), but there are no guarantees about
# this. So the `min` is to account for the (hopefuly rare) case where
#
# o.weight(o.n += 1) * o.n * logistic(ℓ - o.log∑w) > 1
#
# Finally, allowing γ=1 causes problems, because this makes the distribution
# collapse to a point. So as a heuristic, we set this to 0.99 to prevent this
# problem.
function OSB._fit!(o::Gaussian{T}, xℓ) where {T}
o.Lsync = false
(x,ℓ) = xℓ
γ = min(o.weight(o.n += 1) * o.n * logistic(ℓ - o.log∑w), 0.99)
o.log∑w = logaddexp(o.log∑w, ℓ)
o.log∑w² = logaddexp(o.log∑w, 2ℓ)
if isempty(o.A)
p = length(x)
o.b = zeros(T, p)
o.A = zeros(T, p, p)
o.L = zeros(T, p, p)
o.value = zeros(T, p, p)
end
OSB.smooth!(o.b, x, γ)
OSB.smooth_syr!(o.A, x, γ)
end
OSB.nvars(o::Gaussian) = size(o.A, 1)
function OSB.value(o::Gaussian)
o.value[:] = Matrix(Symmetric((o.A - o.b * o.b')))
o.value
end
function OSB._merge!(o::Gaussian, o2::Gaussian)
o.Lsync = false
o.n += o2.n
γ = logistic(o.log∑w - o2.log∑w)
o.log∑w = logaddexp(o.log∑w, o2.log∑w)
o.log∑w² = logaddexp(o.log∑w², o2.log∑w²)
OSB.smooth!(o.A, o2.A, γ)
OSB.smooth!(o.b, o2.b, γ)
end
Statistics.cov(o::Gaussian) = value(o)
Statistics.mean(o::Gaussian) = o.b
Statistics.var(o::Gaussian; kw...) = diag(value(o; kw...))
function Statistics.cor(o::Gaussian; kw...)
value(o; kw...)
v = 1.0 ./ sqrt.(diag(o.value))
rmul!(o.value, Diagonal(v))
lmul!(Diagonal(v), o.value)
o.value
end
function LinearAlgebra.cholesky(o::Gaussian)
o.Lsync && return Cholesky(LowerTriangular(o.L), :L, 0)
copyto!(o.L, value(o))
C = cholesky!(Positive, o.L)
o.Lsync = true
return C
end
function Random.rand!(rng::AbstractRNG, x::AbstractArray, o::Gaussian{T}) where {T}
randn!(rng, x)
L = cholesky(o).L
lmul!(L, x)
x .+= mean(o)
return x
end
function Base.rand(rng::AbstractRNG, o::Gaussian{T}) where {T}
x = Vector{T}(undef, OSB.nvars(o))
rand!(rng, x, o)
end
# Some checks
#
# o = Gaussian()
# for j in 1:1000
# fit!(o, (randn(3), randn()))
# end
# value(o)
# mean(o)
# cov(o)
# cholesky(o)
# rand!(zeros(3), o)
# rand(o)
using MeasureTheory
using TransformVariables
t = as𝕀
p = Pullback(t, Beta(4,2))
q0 = Normal(0,10)
logdensity(p, q0, randn())
o = Gaussian(weight=EqualWeight())
while min(nobs(o), n_eff(o)) < 10
x = rand(q0)
ℓ = logdensity(p, q0, x)
fit!(o, (x, ℓ))
end
while n_eff(o) < 1000
n = n_eff(o)
μ = mean(o)[1]
σ = std(o)[1]
q = Normal(μ,σ)
# Train in minibatches
while n_eff(o) < n + 100
x = rand(q)
ℓ = logdensity(p, q, x)
fit!(o, (x, ℓ))
end
end
using UnicodePlots
xx = 0.01:0.01:0.99
q = Normal(mean(o)[1], std(o)[1])
μμ = [density(Pushforward(t,p), x) for x in xx];
νν = [density(Pushforward(t, q), x) for x in xx];
factor = sum(μμ) / sum(νν)
plt = lineplot(xx, νν*factor);
lineplot!(plt, xx, μμ)
# julia> lineplot!(plt, xx, μμ)
# ┌────────────────────────────────────────┐
# 3 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
# │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
# │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│
# │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠀⠀⠀⠀⠀⠀⠀⠀│
# │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⡤⠒⠛⠥⡉⢢⡀⠀⠀⠀⠀⠀│
# │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⠤⡟⠉⠀⠀⠀⠀⠈⢢⡑⡄⠀⠀⠀⠀│
# │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡠⡗⠊⠀⠀⠀⠀⠀⠀⠀⠀⠑⣵⠀⠀⠀⠀│
# │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡮⠊⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⡇⠀⠀⠀│
# │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡴⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢹⠀⠀⠀│
# │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣔⠕⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⣇⠀⠀│
# │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣠⠞⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢻⡄⠀│
# │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣠⠔⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠸⣧⠀│
# │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣤⠞⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢿⡀│
# │⠀⠀⠀⠀⠀⠀⠀⠀⢀⣠⡶⠛⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⢇│
# 0 │⣀⣀⣀⣀⣀⣤⠴⠞⠋⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇│
# └────────────────────────────────────────┘
# 0 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment