Created
June 28, 2021 01:33
-
-
Save duvenaud/b05b7d7221e12e7a4c6b98397a3e2dc0 to your computer and use it in GitHub Desktop.
Playing around with linear endomorphisms in Dex
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import linalg | |
instance [Add a] Add (i:n => (..<i) => a) -- Lower triangular tables | |
add = \xs ys. for i. xs.i + ys.i | |
sub = \xs ys. for i. xs.i - ys.i | |
zero = for _. zero | |
instance [Add a] Add (i:n => (i<..) => a) -- Lower triangular tables | |
add = \xs ys. for i. xs.i + ys.i | |
sub = \xs ys. for i. xs.i - ys.i | |
zero = for _. zero | |
instance [VSpace a] VSpace (i:n => (..i) => a) -- Lower triangular tables | |
scaleVec = \s xs. for i. s .* xs.i | |
instance [VSpace a] VSpace (i:n => (i..) => a) -- Lower triangular tables | |
scaleVec = \s xs. for i. s .* xs.i | |
instance [VSpace a] VSpace (i:n => (..<i) => a) -- Lower triangular tables | |
scaleVec = \s xs. for i. s .* xs.i | |
instance [VSpace a] VSpace (i:n => (i<..) => a) -- Lower triangular tables | |
scaleVec = \s xs. for i. s .* xs.i | |
interface [VSpace m, VSpace v] LinearEndo m v | |
apply: m -> v -> v -- Since m is a VSpace, don't need separate "compose" func. | |
determinant': m -> Float | |
diag: m -> v | |
solve': m -> v -> v | |
instance LinearEndo (Float) (Float) | |
apply = (*) | |
determinant' = id | |
diag = id | |
solve' = \a b. b / a | |
instance LinearEndo (n=>Float) (n=>Float) | |
apply = (*) | |
determinant' = prod | |
diag = id | |
solve' = \a b. for i. b.i / a.i | |
instance LinearEndo (n=>n=>Float) (n=>Float) | |
apply = \x y. sum for i. view j. x.i.j .* y.j | |
determinant' = determinant | |
diag = \x. for i. x.i.i | |
solve' = solve | |
instance LinearEndo (LowerTriMat n Float) (n=>Float) | |
apply = \x y. for i. sum for j. x.i.j .* y.(%inject j) | |
determinant' = \x. prod $ lowerTriDiag x | |
diag = lowerTriDiag | |
solve' = forward_substitute | |
instance LinearEndo (UpperTriMat n Float) (n=>Float) | |
apply = \x y. for i. sum for j. x.i.j .* y.(%inject j) | |
determinant' = \x. prod $ upperTriDiag x | |
diag = upperTriDiag | |
solve' = backward_substitute | |
def SkewSymmetricMat (n:Type) (v:Type) : Type = i:n=>(i<..)=>v | |
-- Means that --x.i.j = -x.j.i Todo: Use a newtype | |
def skewSymmetricProd [VSpace v] (x: SkewSymmetricMat n Float) (y: n=>v) : n=>v = | |
for i. sum for j. x.i.j .* (y.(%inject j) - y.i) | |
instance LinearEndo (SkewSymmetricMat n Float) (n=>Float) | |
apply = skewSymmetricProd | |
determinant' = todo | |
diag = \x. zero | |
solve' = todo | |
-------- Application 1: Gaussians ------------ | |
interface HasStandardNormal a:Type | |
randNormal : Key -> a | |
instance HasStandardNormal Float32 | |
randNormal = randn | |
instance [HasStandardNormal a] HasStandardNormal (n=>a) | |
randNormal = \key. | |
for i. randNormal (ixkey key i) | |
def gaussianSample [VSpace v, VSpace m, LinearEndo m v, HasStandardNormal v] | |
((mean, covroot) : (v & m)) (key:Key) : v = | |
noise = randNormal key | |
mean + apply covroot noise | |
:p gaussianSample (1.0, 2.0) (newKey 0) | |
interface [VSpace v] InnerProd v | |
innerProd : v->v->Float | |
instance InnerProd Float | |
innerProd = \x y. x * y | |
instance [InnerProd a] InnerProd (n=>a) | |
innerProd = \x y. sum for i. innerProd x.i y.i | |
def gaussianlogpdf [VSpace m, VSpace v, InnerProd v, LinearEndo m v] | |
((mean, covroot) : (v & m)) (x:v) : Float = | |
d = size v | |
alpha = solve' covroot (solve' covroot x) | |
covpart = -0.5 * innerProd (x - mean) alpha | |
normpart = -log (determinant' covroot) | |
constpart = - 0.5 * (IToF d) * log (2.0 * pi) | |
covpart + normpart + constpart | |
--------------- Application 2: SDEs --------------- | |
Time = Float | |
def radonNikodym [Mul s, VSpace s, LinearEndo m s] | |
(drift1: s->Time->s) | |
(drift2: s->Time->s) | |
(diffusion: s->Time->m) | |
(state: s) (t: Time) : s = | |
-- Dynamics of simple Monte Carlo estimatr of KL divergence between | |
-- two SDEs that share a diffusion function. (If not, divergence is infinite) | |
sqd = sq ((drift1 state t) - (drift2 state t)) | |
solve' (diffusion state t) sqd | |
-- Drift and diffusion product. | |
-- DiffusionProd takes state and noise and returns dstate/dtime | |
def Drift (v:Type) -> (_:VSpace v) ?=> : Type = v->Time->v | |
def Diffusion (m:Type) (v:Type) (_:VSpace v) ?=> (_:VSpace m) ?=> : Type = v->Time->(LinearEndo m v) | |
def SDE (m:Type) (v:Type) (_:VSpace v) ?=> (_:VSpace m) ?=> : Type = (Drift v & Diffusion m v ) | |
def SkewSymmetricProd (v:Type) : Type = v->Time->v->v -- How to express matrix? | |
def NegEnergyFunc (v:Type) : Type = v->Time->Float | |
def StationarySDE (v:Type) : Type = (NegEnergyFunc v & SkewSymmetricProd v & DiffusionProd v) | |
def stationarySDEPartsToSDE [Mul v, VSpace v] (sta:StationarySDE v) : SDE v = | |
-- From Section 2.1 of "A Complete Recipe for Stochastic Gradient MCMC" | |
-- https://arxiv.org/pdf/1506.04696.pdf | |
(negEnergyFunc, skewSymmetricProd, diffusionProd) = sta | |
varProd = \state time vec. | |
0.5 .* (diffusionProd state time $ diffusionProd state time vec) | |
drift = \state time. | |
curNE = \state. negEnergyFunc state time | |
negenergygrad = (grad curNE) state | |
t1 = (skewSymmetricProd + varProd) state time negenergygrad | |
gammapart = \state. (skewSymmetricProd + varProd) state time one | |
t3 = jvp gammapart state one | |
t1 + t3 | |
(drift, diffusionProd) | |
-- Unfortunately, we can't convert back to a matrix | |
-- because we can't say what the size will be. | |
def jacobian [VSpace a, VSpace b] (f:a->b) (x:a) : ?? = | |
-- Unfortunately, we can't convert back to a matrix | |
-- because we can't say what the size will be. | |
-- We can get around this using Lists. | |
def linearEndoToMatrix [LinearEndo m v] (x:m) : (List (List Float)) = | |
d = size v | |
mat = jacobian (apply m) zero | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment