Skip to content

Instantly share code, notes, and snippets.

Created March 7, 2020 22:28
Show Gist options
  • Save rebcabin/e1636fee1ec8fbab73038d0352cbcbf1 to your computer and use it in GitHub Desktop.
Save rebcabin/e1636fee1ec8fbab73038d0352cbcbf1 to your computer and use it in GitHub Desktop.
Kalman parameter estimation in PyTorch
import torch
import toolz
def kalman (b, # # rows, cols, in Z; # rows in z
n, # # rows, cols, in P; # rows in x
Z, # b x b observation covariance
x, # n x 1, current state
P, # n x n, current covariance
A, # b x n, current observation partials
z # b x 1, current observation vector
) :
"""Recurrent Kalman filter for parameter estimation (no dynamics)."""
# Transcribe the following sketch of Wolfram code (the intermediate
# matrices are not necessary in Wolfram, but we need them in Python).
# noInverseKalman[Z_][{x_, P_}, {A_, z_}] :=
# Module[{PAT, D, Res, DiRes, KRes, AP, DiAP, KAP},
# 1. PAT = P.Transpose[A] (* n x b *)
# 2. D = Z + A.PAT (* b x b *)
# 3. Res = z - A.x (* b x 1 *)
# 4. DiRes = LinearSolve[D, Res] (* b x 1 *)
# 5. KRes = PAT.DiRes (* n x 1 *)
# 6. AP = A.P (* n x 1 *)
# 7. DiAP = LinearSolve[D, AP] (* b x n *)
# 8. KAP = PAT.DiAP (* n x n *)
# 9. x <- x + KRes
# 10. P <- P - KAP
# b n b
# / * * \ / * * * * \ / * * \
# n | * * | <-- n | * * * * | n | * * |
# | * * | | * * * * | | * * |
# \ * * / \ * * * * / \ * * /
pat = torch.matmul(P, torch.t(A))
# b n b b
# b / * * \ <-- b / * * * * \ n / * * \ + b / * * \
# \ * * / \ * * * * / | * * | \ * * /
# | * * |
# \ * * /
d = torch.add(torch.matmul(A, pat), Z)
# |
# Res | A x z
# 1 v n 1 1
# b / * \ <-- - b / * * * * \ n / * \ + b / * \
# \ * / \ * * * * / | * | \ * /
# | * |
# \ * /
res = torch.sub(z, torch.matmul(A, x))
# DiRes Di = D^-1 Res
# 1 b 1
# b / * \ <-- b / * * \ b / * \
# \ * / \ * * / \ * /
di = torch.inverse(d)
dires = torch.matmul(di, res)
# KRes PAT DiRes
# 1 b 1
# n / * \ n / * * \ b / * \
# | * | <-- | * * | \ * /
# | * | | * * |
# \ * / \ * * /
kres = torch.matmul(pat, dires)
# AP A P
# n n n
# b / * * * * \ <-- b / * * * * \ n / * * * * \
# \ * * * * / \ * * * * / | * * * * |
# | * * * * |
# \ * * * * /
ap = torch.matmul(A, P)
# DiAP Di = D^-1 AP
# n b n
# b / * * * * \ <-- b / * * \ b / * * * * \
# \ * * * * / \ * * / \ * * * * /
diap = torch.matmul(di, ap)
# n b n
# n / * * * * \ <-- / * * \ b / * * * * \
# | * * * * | n | * * | \ * * * * /
# | * * * * | | * * |
# \ * * * * / \ * * /
kap = torch.matmul(pat, diap)
# x x KRes
# 1 1 1
# n / * \ <-- n / * \ + n / * \
# | * | | * | | * |
# | * | | * | | * |
# \ * / \ * / \ * /
x = torch.add(x, kres)
# |
# P | KAP P
# n v n n
# n / * * * * \ <-- - n / * * * * \ + n / * * * * \
# | * * * * | | * * * * | | * * * * |
# | * * * * | | * * * * | | * * * * |
# \ * * * * / \ * * * * / \ * * * * /
p = torch.sub(P, kap)
return (x, p)
def normal_equations():
"""Produces the estimate by linear regression without covariance
print ("----------------------------------------------------------------")
print ("The Normal Equations for Linear Regression")
x0 = torch.zeros(4)
print ({'x0': x0})
a = torch.tensor([[1., 0., 0., 0.],
[1., 1., 1., 1.],
[1., -1., 1., -1.],
[1., -2., 4., -8.],
[1., 2., 4., 8.]])
print ({'A': a})
zs = torch.tensor([-2.28442, -4.83168, -10.4601, 1.40488, -40.8079])
print ({'zs': zs})
at = torch.t(a)
print ({'at': at})
ata = torch.matmul(at, a)
print ({'ata': ata})
atai = torch.inverse(torch.matmul(at, a))
print ({'atai': atai})
atai_at = torch.matmul(atai, at)
print ({'atai_at': atai_at})
atai_at_zs = torch.matmul(atai_at, zs)
print ({'expect': torch.tensor([-2.9751, 7.2700, -4.2104, -4.4558])})
print ({'atai_at_zs': atai_at_zs})
def kalman_sample_by_hand():
"""Verify against equation 1 in"""
print ("----------------------------------------------------------------")
print ("Explicit intermediate variables in a recurrence over five data.")
x0 = torch.tensor ([[x] for x in torch.zeros(4)])
zs = torch.tensor ([[z] for z in [-2.28442, -4.83168, -10.4601, 1.40488, -40.8079]])
aas = torch.tensor ([[a] for a in [[1., 0., 0., 0.],
[1., 1., 1., 1.],
[1., -1., 1., -1.],
[1., -2., 4., -8.],
[1., 2., 4., 8.]]])
p0 = 1000. * torch.eye(4)
Z = torch.tensor([[1.0]])
x1, p1 = kalman(1, 4, Z, x0, p0, aas[0], zs[0])
print ({'x1': x1, 'p1': p1})
x2, p2 = kalman(1, 4, Z, x1, p1, aas[1], zs[1])
print ({'x2': x2, 'p2': p2})
x3, p3 = kalman(1, 4, Z, x2, p2, aas[2], zs[2])
print ({'x3': x3, 'p3': p3})
x4, p4 = kalman(1, 4, Z, x3, p3, aas[3], zs[3])
print ({'x4': x4, 'p4': p4})
x5, p5 = kalman(1, 4, Z, x4, p4, aas[4], zs[4])
print ({'x5': x5, 'p5': p5})
def kalman_with_random_data():
"""Verify against ground truth [-3, 9, -4, -5]."""
print ("----------------------------------------------------------------")
print ("Recurrence over large-ish data set.")
ground_truth = torch.tensor([[-3.0], [9.0], [-4.0], [-5.0]])
x0 = torch.tensor ([[x] for x in torch.zeros(4)])
p0 = 1000. * torch.eye(4)
Z = torch.tensor([[1.0]])
# foldable; lifted over b, n, Z
fk = lambda xp, az: kalman(1, 4, Z, xp[0], xp[1], az[0], az[1])
seed = torch.random.initial_seed()
print ({'seed': seed})
trials = 10000
trs = [torch.rand(1) * 4.0 - 2.0 for _ in range(trials)]
aars = [torch.tensor([[1.0, t, t ** 2, t**3]]) for t in trs]
zrs = [torch.add(torch.matmul(a, ground_truth), torch.randn(1)) for a in aars]
xtrials, ptrials = toolz.reduce(fk, list(zip(aars, zrs)), [x0, p0])
print ({'xtrials': xtrials, 'ptrials': ptrials})
if __name__ == "__main__":
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment