Skip to content

Instantly share code, notes, and snippets.

@duvenaud
Created June 28, 2021 01:33
Show Gist options
  • Save duvenaud/b05b7d7221e12e7a4c6b98397a3e2dc0 to your computer and use it in GitHub Desktop.
Save duvenaud/b05b7d7221e12e7a4c6b98397a3e2dc0 to your computer and use it in GitHub Desktop.
Playing around with linear endomorphisms in Dex
import linalg
instance [Add a] Add (i:n => (..<i) => a) -- Lower triangular tables
add = \xs ys. for i. xs.i + ys.i
sub = \xs ys. for i. xs.i - ys.i
zero = for _. zero
instance [Add a] Add (i:n => (i<..) => a) -- Lower triangular tables
add = \xs ys. for i. xs.i + ys.i
sub = \xs ys. for i. xs.i - ys.i
zero = for _. zero
instance [VSpace a] VSpace (i:n => (..i) => a) -- Lower triangular tables
scaleVec = \s xs. for i. s .* xs.i
instance [VSpace a] VSpace (i:n => (i..) => a) -- Lower triangular tables
scaleVec = \s xs. for i. s .* xs.i
instance [VSpace a] VSpace (i:n => (..<i) => a) -- Lower triangular tables
scaleVec = \s xs. for i. s .* xs.i
instance [VSpace a] VSpace (i:n => (i<..) => a) -- Lower triangular tables
scaleVec = \s xs. for i. s .* xs.i
interface [VSpace m, VSpace v] LinearEndo m v
apply: m -> v -> v -- Since m is a VSpace, don't need separate "compose" func.
determinant': m -> Float
diag: m -> v
solve': m -> v -> v
instance LinearEndo (Float) (Float)
apply = (*)
determinant' = id
diag = id
solve' = \a b. b / a
instance LinearEndo (n=>Float) (n=>Float)
apply = (*)
determinant' = prod
diag = id
solve' = \a b. for i. b.i / a.i
instance LinearEndo (n=>n=>Float) (n=>Float)
apply = \x y. sum for i. view j. x.i.j .* y.j
determinant' = determinant
diag = \x. for i. x.i.i
solve' = solve
instance LinearEndo (LowerTriMat n Float) (n=>Float)
apply = \x y. for i. sum for j. x.i.j .* y.(%inject j)
determinant' = \x. prod $ lowerTriDiag x
diag = lowerTriDiag
solve' = forward_substitute
instance LinearEndo (UpperTriMat n Float) (n=>Float)
apply = \x y. for i. sum for j. x.i.j .* y.(%inject j)
determinant' = \x. prod $ upperTriDiag x
diag = upperTriDiag
solve' = backward_substitute
def SkewSymmetricMat (n:Type) (v:Type) : Type = i:n=>(i<..)=>v
-- Means that --x.i.j = -x.j.i Todo: Use a newtype
def skewSymmetricProd [VSpace v] (x: SkewSymmetricMat n Float) (y: n=>v) : n=>v =
for i. sum for j. x.i.j .* (y.(%inject j) - y.i)
instance LinearEndo (SkewSymmetricMat n Float) (n=>Float)
apply = skewSymmetricProd
determinant' = todo
diag = \x. zero
solve' = todo
-------- Application 1: Gaussians ------------
interface HasStandardNormal a:Type
randNormal : Key -> a
instance HasStandardNormal Float32
randNormal = randn
instance [HasStandardNormal a] HasStandardNormal (n=>a)
randNormal = \key.
for i. randNormal (ixkey key i)
def gaussianSample [VSpace v, VSpace m, LinearEndo m v, HasStandardNormal v]
((mean, covroot) : (v & m)) (key:Key) : v =
noise = randNormal key
mean + apply covroot noise
:p gaussianSample (1.0, 2.0) (newKey 0)
interface [VSpace v] InnerProd v
innerProd : v->v->Float
instance InnerProd Float
innerProd = \x y. x * y
instance [InnerProd a] InnerProd (n=>a)
innerProd = \x y. sum for i. innerProd x.i y.i
def gaussianlogpdf [VSpace m, VSpace v, InnerProd v, LinearEndo m v]
((mean, covroot) : (v & m)) (x:v) : Float =
d = size v
alpha = solve' covroot (solve' covroot x)
covpart = -0.5 * innerProd (x - mean) alpha
normpart = -log (determinant' covroot)
constpart = - 0.5 * (IToF d) * log (2.0 * pi)
covpart + normpart + constpart
--------------- Application 2: SDEs ---------------
Time = Float
def radonNikodym [Mul s, VSpace s, LinearEndo m s]
(drift1: s->Time->s)
(drift2: s->Time->s)
(diffusion: s->Time->m)
(state: s) (t: Time) : s =
-- Dynamics of simple Monte Carlo estimatr of KL divergence between
-- two SDEs that share a diffusion function. (If not, divergence is infinite)
sqd = sq ((drift1 state t) - (drift2 state t))
solve' (diffusion state t) sqd
-- Drift and diffusion product.
-- DiffusionProd takes state and noise and returns dstate/dtime
def Drift (v:Type) -> (_:VSpace v) ?=> : Type = v->Time->v
def Diffusion (m:Type) (v:Type) (_:VSpace v) ?=> (_:VSpace m) ?=> : Type = v->Time->(LinearEndo m v)
def SDE (m:Type) (v:Type) (_:VSpace v) ?=> (_:VSpace m) ?=> : Type = (Drift v & Diffusion m v )
def SkewSymmetricProd (v:Type) : Type = v->Time->v->v -- How to express matrix?
def NegEnergyFunc (v:Type) : Type = v->Time->Float
def StationarySDE (v:Type) : Type = (NegEnergyFunc v & SkewSymmetricProd v & DiffusionProd v)
def stationarySDEPartsToSDE [Mul v, VSpace v] (sta:StationarySDE v) : SDE v =
-- From Section 2.1 of "A Complete Recipe for Stochastic Gradient MCMC"
-- https://arxiv.org/pdf/1506.04696.pdf
(negEnergyFunc, skewSymmetricProd, diffusionProd) = sta
varProd = \state time vec.
0.5 .* (diffusionProd state time $ diffusionProd state time vec)
drift = \state time.
curNE = \state. negEnergyFunc state time
negenergygrad = (grad curNE) state
t1 = (skewSymmetricProd + varProd) state time negenergygrad
gammapart = \state. (skewSymmetricProd + varProd) state time one
t3 = jvp gammapart state one
t1 + t3
(drift, diffusionProd)
-- Unfortunately, we can't convert back to a matrix
-- because we can't say what the size will be.
def jacobian [VSpace a, VSpace b] (f:a->b) (x:a) : ?? =
-- Unfortunately, we can't convert back to a matrix
-- because we can't say what the size will be.
-- We can get around this using Lists.
def linearEndoToMatrix [LinearEndo m v] (x:m) : (List (List Float)) =
d = size v
mat = jacobian (apply m) zero
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment