Skip to content

Instantly share code, notes, and snippets.

@karchie
Created June 27, 2013 12:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save karchie/5876071 to your computer and use it in GitHub Desktop.
Save karchie/5876071 to your computer and use it in GitHub Desktop.
Scott Purdy's hotgym example driver, modified to use ipython and pylab to graph prediction and errors. A good deal of library installation and reconfiguration is needed to make this work; read the comment at the top of the file and contact me if you can't get it working.
#!/usr/bin/env python
# ----------------------------------------------------------------------
# Numenta Platform for Intelligent Computing (NuPIC)
# Copyright (C) 2013, Numenta, Inc. Unless you have purchased from
# Numenta, Inc. a separate commercial license for this software code, the
# following terms and conditions apply:
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see http://www.gnu.org/licenses.
#
# http://numenta.org/licenses/
# ----------------------------------------------------------------------
"""A simple client to create a CLA model for hotgym, originally by
Scott Purdy. Modified by Kevin Archie to graph predicted and actual
values."""
# NOTE: this code requires ipython with pylab; I've been using the
# Tornado-hosted notebook but presumably the Qt console could also
# work. It also requires a more modern matplotlib than the one
# embedded in nupic; I found it was sufficient to delete the nupic
# site_packages/matplotlib thus exposing the system Python version
# of the library.
import csv
import datetime
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import clear_output, display
from nupic.data.datasethelpers import findDataset
from nupic.frameworks.opf.modelfactory import ModelFactory
from nupic.data.inference_shifter import InferenceShifter
import model_params
DATA_PATH = "extra/hotgym/hotgym.csv"
def createModel():
return ModelFactory.create(model_params.MODEL_PARAMS)
def should_consolidate(interval, offset, current):
if interval and interval * 2 <= current - offset:
if interval * 2 < current - offset:
print 'WARNING: missed consolidation; current interval', offset, 'to', current
return True
else:
return False
def consolidate(interval, offset, obs, preds):
distance = np.linalg.norm(np.array(preds[0:interval-1]) -
np.array(obs[0:interval-1]))
return offset+interval, obs[interval:], preds[interval:], distance
def runHotgym(**opts):
model = createModel()
model.enableInference({'predictionSteps': [1, 5],
'predictedField': 'consumption',
'numRecords': 4000})
shifter = InferenceShifter()
print 'Using dataset', findDataset(DATA_PATH)
with open (findDataset(DATA_PATH)) as fin:
reader = csv.reader(fin)
headers = reader.next()
# print headers
# print reader.next()
# print reader.next()
# skip the additional header lines
reader.next()
reader.next()
ys = []
zs = []
errs = [[],[]]
consolidated_to = 0
if 'consolidate_by' in opts:
consolidate_by = opts['consolidate_by']
else:
consolidate_by = None
for record in reader:
# print record
modelInput = dict(zip(headers, record))
modelInput["consumption"] = float(modelInput["consumption"])
modelInput["timestamp"] = datetime.datetime.strptime(
modelInput["timestamp"], "%Y-%m-%d %H:%M:%S.%f")
# TODO: make this work so that predictions appear aligned with the
# timestep that they are predicting (rather than the time step where
# the prediction is generated). program hangs when I try it
# result = shifter.shift(model.run(modelInput))
result = model.run(modelInput)
if (should_consolidate(consolidate_by, consolidated_to,
result.predictionNumber)):
consolidated_to, ys, zs, distance = consolidate(consolidate_by,
consolidated_to,
ys, zs)
errs[0].append(consolidated_to)
errs[1].append(distance)
# print result
if 'plot' in opts:
ys.append(modelInput['consumption'])
zs.append(result.inferences['prediction'][0] or 0)
if 'max_predictions' in opts and opts['max_predictions'] < result.predictionNumber:
break
if 'plot' in opts and (not 'plot_interval' in opts or
result.predictionNumber and
(10 > result.predictionNumber or
0 == result.predictionNumber % opts['plot_interval'])):
xs = range(consolidated_to,result.predictionNumber+1)
opts['plot']([xs,xs,errs[0]], [ys,zs,errs[1]])
def update_plot(fig, axes, xs, ys):
clear_output()
axes[0].clear()
axes[0].plot(xs[0], ys[0], label='consumption')
axes[0].plot(xs[1], ys[1], 'm', label='prediction')
axes[1].clear()
axes[1].plot(xs[1], np.subtract(ys[1],ys[0]), label='error')
axes[2].vlines(xs[2],0,ys[2], label='RMS error')
display(fig)
def doplot(fig, axes):
return lambda xs,ys: update_plot(fig, axes, xs, ys)
if __name__ == "__main__":
predictions=80000
fig,axes = plt.subplots(3, 1, figsize=(15,10))
runHotgym(plot=doplot(fig,axes),
max_predictions=predictions,
plot_interval=100,
consolidate_by=500)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment