Skip to content

Instantly share code, notes, and snippets.

@fawce
Created November 12, 2012 01:15
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save fawce/4057021 to your computer and use it in GitHub Desktop.
Save fawce/4057021 to your computer and use it in GitHub Desktop.
Rough draft, CSV data source for zipline
"""
Generator-style DataSource that loads from CSV.
"""
import pytz
import csv
import mmap
import os.path
from datetime import datetime
from zipline import ndict
from zipline.gens.utils import hash_args, \
assert_trade_protocol
from zipline.utils.date_utils import UN_EPOCH, EPOCH
import zipline.protocol as zp
class CSVTradeGen(object):
"""A generator that takes a pymongo Collection object, a list of
filters, a start date and an end_date and yields ndicts containing
the results of a query to its collection with the given filter,
start, and end. The output is also packaged with a unique
source_id string for downstream sorting
"""
def __init__(self, file_path, sid_range, start_date, end_date):
assert os.path.exists(file_path)
assert isinstance(sid_range, (list, None))
assert isinstance(start_date, datetime)
assert isinstance(end_date, datetime)
assert start_date.tzinfo == pytz.utc
assert end_date.tzinfo == pytz.utc
self.file_path = file_path
if sid_range:
self.sid_range = frozenset(sid_range)
else:
self.sid_range = None
self.start_date = start_date
self.end_date = end_date
# Create unique identifier string that can be used to break
# sorting ties deterministically.
self.argstring = hash_args(file_path, filter, start_date, end_date)
self.namestring = self.__class__.__name__ + self.argstring
self.iterator = None
def __iter__(self):
return self
def next(self):
if self.iterator:
return self.iterator.next()
else:
self.iterator = self._gen()
return self.iterator.next()
def rewind(self):
self.iterator = self._gen()
def get_hash(self):
return self.namestring
def _gen(self):
# Set up internal iterator. This outputs raw dictionaries.
cursor = self.create_csv_iterator(
self.sid_range,
self.start_date,
self.end_date
)
for event in cursor:
# Construct a new event that fulfills the datasource protocol.
event['type'] = zp.DATASOURCE_TYPE.TRADE
event['source_id'] = self.namestring
payload = ndict(event)
assert_trade_protocol(payload)
yield payload
def create_csv_iterator(self, sid_range, start_date, end_date):
"""
Returns an iterator that spits out raw objects loaded from a
csv file.
"""
# csv file fields are:
# ['datetime','sid','volume','high','low','close','open']
# datetime is the datetime in unix time (ms since the epoch)
# ndict output objects have the same properties, except:
# datetime - python datetime object with tzinfo of pytz.utc
# Open the file in read+binary mode. Create a dictionary
# reader, which will behave like an iterator and produce
# dictionaries. Assumes the file has a header, also
# assumes the file is sorted ASCENDING by day,sid.
with open(self.file_path, 'r') as csv_file:
csv_reader = csv.DictReader(csv_file)
for row in csv_reader:
row['sid'] = int(row['sid'])
# limit the data to the date range [start, end], inclusive
row['datetime'] = UN_EPOCH(row['datetime'])
if row['date_time'] < start_date:
continue
if row['datetime'] > end_date:
raise StopIteration
# limit the data to sids in the range
row['sid'] = int(row['sid'])
if sid_range and row['sid'] not in sid_range:
continue
for key in ['high','low','close','open']:
row[key] = float(row[key])
row['volume'] = int(row['volume'])
# add price alias for closing price of bar
row['price'] = row['close']
yield row
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment