Skip to content

Instantly share code, notes, and snippets.

@MichaelWS
Last active November 15, 2016 14:23
Show Gist options
  • Save MichaelWS/e5eb873e32b089a4487e to your computer and use it in GitHub Desktop.
Save MichaelWS/e5eb873e32b089a4487e to your computer and use it in GitHub Desktop.
hdf5 for zipline
"""
leverage work of briancappello and quantopian team
(especcially twiecki, eddie, and fawce)
"""
import pandas as pd
from zipline.gens.utils import hash_args
from zipline.sources.data_source import DataSource
import tables
import datetime
import numpy as np
def get_time(time_str):
time_array = map(int, time_str.split(":"))
assert len(time_array) == 2
assert time_array[0] < 24 and time_array[1] < 61
return datetime.time(time_array[0], time_array[1])
class DataSourceTablesOHLC(DataSource):
def __init__(self, data, **kwargs):
assert isinstance(data, tables.file.File)
self.data = data
# Unpack config dictionary with default values.
if 'symbols' in kwargs:
self.sids = kwargs.get('symbols')
else:
self.sids = None
self.tz_in = kwargs.get('tz_in', "US/Eastern")
self.start = pd.Timestamp(np.datetime64(kwargs.get('start')))
self.start = self.start.tz_localize('utc')
self.end = pd.Timestamp(np.datetime64(kwargs.get('end')))
self.end = self.end.tz_localize('utc')
start_time_str = kwargs.get("start_time", "9:30")
end_time_str = kwargs.get("end_time", "16:00")
self.start_time = get_time(start_time_str)
self.end_time = get_time(end_time_str)
self._raw_data = None
self.arg_string = hash_args(data, **kwargs)
self.root_node = "/" + kwargs.get('root', "TD") + "/"
@property
def instance_hash(self):
return self.arg_string
def raw_data_gen(self):
for date_node in self.data.walkNodes(self.root_node):
if isinstance(date_node, tables.group.Group):
continue
dt64 = np.datetime64(date_node.name)
table_dt = pd.Timestamp(dt64).tz_localize("utc")
if table_dt < self.start or table_dt > self.end:
continue
start_ts = pd.Timestamp(datetime.datetime.combine(table_dt.date(),
self.start_time))
end_ts = pd.Timestamp(datetime.datetime.combine(table_dt.date(),
self.end_time))
volumes = {}
price_volumes = {}
for row in date_node.iterrows():
sid = row["symbol"]
if self.sids is None or sid in self.sids:
if sid not in volumes:
volumes[sid] = 0
price_volumes[sid] = 0
dt = datetime.datetime.fromtimestamp(row["dt"])
if dt < start_ts or dt > end_ts:
continue
event = {"sid": sid, "type": "TRADE", "symbol": sid}
cols = ["open", "high", "low", "close"]
event["dt"] = pd.Timestamp(dt).tz_localize(self.tz_in)
event["dt"] = event["dt"].tz_convert('utc')
event["price"] = row["close"]
event["volume"] = row["volume"]
volumes[sid] += event["volume"]
price_volumes[sid] += event["price"] * event["volume"]
event["vwap"] = price_volumes[sid] / volumes[sid]
for field in cols:
event[field] = row[field]
yield event
@property
def raw_data(self):
if not self._raw_data:
self._raw_data = self.raw_data_gen()
return self._raw_data
@property
def mapping(self):
return {
'sid': (lambda x: x, 'sid'),
'dt': (lambda x: x, 'dt'),
'open': (float, 'open'),
'high': (float, 'high'),
'low': (float, 'low'),
'close': (float, 'close'),
'price': (float, 'price'),
'volume': (int, 'volume'),
'vwap': (float, 'vwap')
}
class DataSourceTablesSignal(DataSource):
def __init__(self, data, **kwargs):
assert isinstance(data, tables.file.File)
self.h5file = data
# Unpack config dictionary with default values.
if 'symbols' in kwargs:
self.sids = kwargs.get('symbols')
else:
self.sids = None
self.start = kwargs.get('start')
self.end = kwargs.get('end')
# signals expects a list
self.arg_string = hash_args(data, **kwargs)
self._raw_data = None
self.root_node = +"/" + kwargs.get('root', "signal") + "/"
@property
def instance_hash(self):
return self.arg_string
def raw_data_gen(self):
for date_node in self.data.walkNodes(self.root_node):
if isinstance(date_node, tables.group.Group):
continue
# here is the group_node referring to our root
dt = pd.Timestamp(date_node.name)
if dt < self.start_date or dt > self.end_date:
continue
table = self.data.getNode(date_node)
for row in table.iterrows():
sid = row["symbol"]
if self.sids is None or sid in self.sids:
event = {"sid": sid, "type": "CUSTOM", "dt": dt,
"signal": row["signal"]}
yield event
@property
def raw_data(self):
if not self._raw_data:
self._raw_data = self.raw_data_gen()
return self._raw_data
@property
def mapping(self):
return {
'sid': (lambda x: x, 'symbol'),
'dt': (lambda x: x, 'dt'),
'signal': (lambda x: x, 'signal'),
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment