-
-
Save MichaelWS/e5eb873e32b089a4487e to your computer and use it in GitHub Desktop.
hdf5 for zipline
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
leverage work of briancappello and quantopian team | |
(especcially twiecki, eddie, and fawce) | |
""" | |
import pandas as pd | |
from zipline.gens.utils import hash_args | |
from zipline.sources.data_source import DataSource | |
import tables | |
import datetime | |
import numpy as np | |
def get_time(time_str): | |
time_array = map(int, time_str.split(":")) | |
assert len(time_array) == 2 | |
assert time_array[0] < 24 and time_array[1] < 61 | |
return datetime.time(time_array[0], time_array[1]) | |
class DataSourceTablesOHLC(DataSource): | |
def __init__(self, data, **kwargs): | |
assert isinstance(data, tables.file.File) | |
self.data = data | |
# Unpack config dictionary with default values. | |
if 'symbols' in kwargs: | |
self.sids = kwargs.get('symbols') | |
else: | |
self.sids = None | |
self.tz_in = kwargs.get('tz_in', "US/Eastern") | |
self.start = pd.Timestamp(np.datetime64(kwargs.get('start'))) | |
self.start = self.start.tz_localize('utc') | |
self.end = pd.Timestamp(np.datetime64(kwargs.get('end'))) | |
self.end = self.end.tz_localize('utc') | |
start_time_str = kwargs.get("start_time", "9:30") | |
end_time_str = kwargs.get("end_time", "16:00") | |
self.start_time = get_time(start_time_str) | |
self.end_time = get_time(end_time_str) | |
self._raw_data = None | |
self.arg_string = hash_args(data, **kwargs) | |
self.root_node = "/" + kwargs.get('root', "TD") + "/" | |
@property | |
def instance_hash(self): | |
return self.arg_string | |
def raw_data_gen(self): | |
for date_node in self.data.walkNodes(self.root_node): | |
if isinstance(date_node, tables.group.Group): | |
continue | |
dt64 = np.datetime64(date_node.name) | |
table_dt = pd.Timestamp(dt64).tz_localize("utc") | |
if table_dt < self.start or table_dt > self.end: | |
continue | |
start_ts = pd.Timestamp(datetime.datetime.combine(table_dt.date(), | |
self.start_time)) | |
end_ts = pd.Timestamp(datetime.datetime.combine(table_dt.date(), | |
self.end_time)) | |
volumes = {} | |
price_volumes = {} | |
for row in date_node.iterrows(): | |
sid = row["symbol"] | |
if self.sids is None or sid in self.sids: | |
if sid not in volumes: | |
volumes[sid] = 0 | |
price_volumes[sid] = 0 | |
dt = datetime.datetime.fromtimestamp(row["dt"]) | |
if dt < start_ts or dt > end_ts: | |
continue | |
event = {"sid": sid, "type": "TRADE", "symbol": sid} | |
cols = ["open", "high", "low", "close"] | |
event["dt"] = pd.Timestamp(dt).tz_localize(self.tz_in) | |
event["dt"] = event["dt"].tz_convert('utc') | |
event["price"] = row["close"] | |
event["volume"] = row["volume"] | |
volumes[sid] += event["volume"] | |
price_volumes[sid] += event["price"] * event["volume"] | |
event["vwap"] = price_volumes[sid] / volumes[sid] | |
for field in cols: | |
event[field] = row[field] | |
yield event | |
@property | |
def raw_data(self): | |
if not self._raw_data: | |
self._raw_data = self.raw_data_gen() | |
return self._raw_data | |
@property | |
def mapping(self): | |
return { | |
'sid': (lambda x: x, 'sid'), | |
'dt': (lambda x: x, 'dt'), | |
'open': (float, 'open'), | |
'high': (float, 'high'), | |
'low': (float, 'low'), | |
'close': (float, 'close'), | |
'price': (float, 'price'), | |
'volume': (int, 'volume'), | |
'vwap': (float, 'vwap') | |
} | |
class DataSourceTablesSignal(DataSource): | |
def __init__(self, data, **kwargs): | |
assert isinstance(data, tables.file.File) | |
self.h5file = data | |
# Unpack config dictionary with default values. | |
if 'symbols' in kwargs: | |
self.sids = kwargs.get('symbols') | |
else: | |
self.sids = None | |
self.start = kwargs.get('start') | |
self.end = kwargs.get('end') | |
# signals expects a list | |
self.arg_string = hash_args(data, **kwargs) | |
self._raw_data = None | |
self.root_node = +"/" + kwargs.get('root', "signal") + "/" | |
@property | |
def instance_hash(self): | |
return self.arg_string | |
def raw_data_gen(self): | |
for date_node in self.data.walkNodes(self.root_node): | |
if isinstance(date_node, tables.group.Group): | |
continue | |
# here is the group_node referring to our root | |
dt = pd.Timestamp(date_node.name) | |
if dt < self.start_date or dt > self.end_date: | |
continue | |
table = self.data.getNode(date_node) | |
for row in table.iterrows(): | |
sid = row["symbol"] | |
if self.sids is None or sid in self.sids: | |
event = {"sid": sid, "type": "CUSTOM", "dt": dt, | |
"signal": row["signal"]} | |
yield event | |
@property | |
def raw_data(self): | |
if not self._raw_data: | |
self._raw_data = self.raw_data_gen() | |
return self._raw_data | |
@property | |
def mapping(self): | |
return { | |
'sid': (lambda x: x, 'symbol'), | |
'dt': (lambda x: x, 'dt'), | |
'signal': (lambda x: x, 'signal'), | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment