Skip to content

Instantly share code, notes, and snippets.

@slinderman
Created January 10, 2017 15:28
Show Gist options
  • Save slinderman/31fdf5acdc0affdf4dc60670069d314a to your computer and use it in GitHub Desktop.
Save slinderman/31fdf5acdc0affdf4dc60670069d314a to your computer and use it in GitHub Desktop.
Sample predictions test
import numpy as np
import matplotlib.pyplot as plt
from pylds.models import DefaultLDS
inputs = \
np.array([[ 0. , 0. ],
[ 2.72785283, 7.53608657],
[ 0. , 7.23201033],
[ 2.0261219 , 7.07866193],
[ 2.42989525, 7.67644583],
[ 2.80870737, 7.83475645],
[ 0. , 7.9209767 ],
[ 0. , 7.79186987],
[ 2.78714571, 8.37162161],
[ 2.49732917, 8.14119562],
[ 2.71444 , 7.94740989],
[ 2.54583288, 8.05844122],
[ 2.51490101, 8.18854467],
[ 1.61097519, 7.88220925],
[ 2.58138165, 7.85400085],
[ 2.45695553, 7.71646971],
[ 2.86483477, 7.96989029]])
data = \
np.array([[ 7.53608657],
[ 7.23201033],
[ 7.07866193],
[ 7.67644583],
[ 7.83475645],
[ 7.9209767 ],
[ 7.79186987],
[ 8.37162161],
[ 8.14119562],
[ 7.94740989],
[ 8.05844122],
[ 8.18854467],
[ 7.88220925],
[ 7.85400085],
[ 7.71646971],
[ 7.96989029],
[ 7.68803383]])
offset = data.mean(axis=0)
data -= offset
model = DefaultLDS(D_obs=1, D_latent=2, D_input=2)
model.add_data(data, inputs=inputs)
for _ in range(50):
model.EM_step()
T_given = 13
T_predict = 4
given_data = data[:T_given]
given_inputs = inputs[:T_given]
smooth_data = model.smooth(given_data, given_inputs)
preds = \
model.sample_predictions(
given_data, inputs=given_inputs,
Tpred=T_predict,
inputs_pred=inputs[T_given:T_given + T_predict],
states_noise=False, obs_noise=False)
smooth_plus_pred = np.concatenate((smooth_data, preds))
plt.plot(data, label="true")
plt.plot(smooth_plus_pred[:T_given], label="smoothed")
plt.plot(np.arange(T_given-1, T_given+T_predict), smooth_plus_pred[T_given-1:], label="pred")
plt.xlabel("Time")
plt.ylabel("Data")
plt.legend(loc="upper left")
ylim = plt.ylim()
plt.plot([T_given-1, T_given-1], ylim, '-k')
plt.ylim(ylim)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment