Created
August 12, 2020 09:36
-
-
Save twiecki/e4199e05fc01bfb24b17e7498062ce6c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import click | |
import os | |
import sys | |
import warnings | |
try: | |
from pygments import highlight | |
from pygments.lexers import PythonLexer | |
from pygments.formatters import TerminalFormatter | |
PYGMENTS = True | |
except ImportError: | |
PYGMENTS = False | |
import logbook | |
import pandas as pd | |
import six | |
from toolz import concatv | |
from trading_calendars import get_calendar | |
from zipline.data import bundles | |
from zipline.data.benchmarks import get_benchmark_returns_from_file | |
from zipline.data.data_portal import DataPortal | |
from zipline.finance import metrics | |
from zipline.finance.trading import SimulationParameters | |
from zipline.pipeline.data import USEquityPricing | |
from zipline.pipeline.loaders import USEquityPricingLoader | |
import zipline.utils.paths as pth | |
from zipline.extensions import load | |
from zipline.errors import SymbolNotFound | |
from zipline.algorithm import TradingAlgorithm, NoBenchmark | |
from zipline.finance.blotter import Blotter | |
log = logbook.Logger(__name__) | |
class _RunAlgoError(click.ClickException, ValueError): | |
"""Signal an error that should have a different message if invoked from | |
the cli. | |
Parameters | |
---------- | |
pyfunc_msg : str | |
The message that will be shown when called as a python function. | |
cmdline_msg : str, optional | |
The message that will be shown on the command line. If not provided, | |
this will be the same as ``pyfunc_msg` | |
""" | |
exit_code = 1 | |
def __init__(self, pyfunc_msg, cmdline_msg=None): | |
if cmdline_msg is None: | |
cmdline_msg = pyfunc_msg | |
super(_RunAlgoError, self).__init__(cmdline_msg) | |
self.pyfunc_msg = pyfunc_msg | |
def __str__(self): | |
return self.pyfunc_msg | |
def _run(handle_data, | |
initialize, | |
before_trading_start, | |
analyze, | |
algofile, | |
algotext, | |
defines, | |
data_frequency, | |
capital_base, | |
bundle, | |
bundle_timestamp, | |
start, | |
end, | |
output, | |
trading_calendar, | |
print_algo, | |
metrics_set, | |
local_namespace, | |
environ, | |
blotter, | |
benchmark_spec, | |
choose_loader, | |
): | |
"""Run a backtest for the given algorithm. | |
This is shared between the cli and :func:`zipline.run_algo`. | |
""" | |
bundle_data = bundles.load( | |
bundle, | |
environ, | |
bundle_timestamp, | |
) | |
if trading_calendar is None: | |
trading_calendar = get_calendar('XNYS') | |
# date parameter validation | |
if trading_calendar.session_distance(start, end) < 1: | |
raise _RunAlgoError( | |
'There are no trading days between %s and %s' % ( | |
start.date(), | |
end.date(), | |
), | |
) | |
benchmark_sid, benchmark_returns = benchmark_spec.resolve( | |
asset_finder=bundle_data.asset_finder, | |
start_date=start, | |
end_date=end, | |
) | |
if algotext is not None: | |
if local_namespace: | |
ip = get_ipython() # noqa | |
namespace = ip.user_ns | |
else: | |
namespace = {} | |
for assign in defines: | |
try: | |
name, value = assign.split('=', 2) | |
except ValueError: | |
raise ValueError( | |
'invalid define %r, should be of the form name=value' % | |
assign, | |
) | |
try: | |
# evaluate in the same namespace so names may refer to | |
# eachother | |
namespace[name] = eval(value, namespace) | |
except Exception as e: | |
raise ValueError( | |
'failed to execute definition for name %r: %s' % (name, e), | |
) | |
elif defines: | |
raise _RunAlgoError( | |
'cannot pass define without `algotext`', | |
"cannot pass '-D' / '--define' without '-t' / '--algotext'", | |
) | |
else: | |
namespace = {} | |
if algofile is not None: | |
algotext = algofile.read() | |
if print_algo: | |
if PYGMENTS: | |
highlight( | |
algotext, | |
PythonLexer(), | |
TerminalFormatter(), | |
outfile=sys.stdout, | |
) | |
else: | |
click.echo(algotext) | |
first_trading_day = \ | |
bundle_data.equity_minute_bar_reader.first_trading_day | |
data = DataPortal( | |
bundle_data.asset_finder, | |
trading_calendar=trading_calendar, | |
first_trading_day=first_trading_day, | |
equity_minute_reader=bundle_data.equity_minute_bar_reader, | |
equity_daily_reader=bundle_data.equity_daily_bar_reader, | |
adjustment_reader=bundle_data.adjustment_reader, | |
) | |
pipeline_loader = USEquityPricingLoader.without_fx( | |
bundle_data.equity_daily_bar_reader, | |
bundle_data.adjustment_reader, | |
) | |
if choose_loader is None: | |
def choose_loader(column): | |
if column in USEquityPricing.columns: | |
return pipeline_loader | |
raise ValueError( | |
"No PipelineLoader registered for column %s." % column | |
) | |
if isinstance(metrics_set, six.string_types): | |
try: | |
metrics_set = metrics.load(metrics_set) | |
except ValueError as e: | |
raise _RunAlgoError(str(e)) | |
if isinstance(blotter, six.string_types): | |
try: | |
blotter = load(Blotter, blotter) | |
except ValueError as e: | |
raise _RunAlgoError(str(e)) | |
try: | |
perf = TradingAlgorithm( | |
namespace=namespace, | |
data_portal=data, | |
get_pipeline_loader=choose_loader, | |
trading_calendar=trading_calendar, | |
sim_params=SimulationParameters( | |
start_session=start, | |
end_session=end, | |
trading_calendar=trading_calendar, | |
capital_base=capital_base, | |
data_frequency=data_frequency, | |
), | |
metrics_set=metrics_set, | |
blotter=blotter, | |
benchmark_returns=benchmark_returns, | |
benchmark_sid=benchmark_sid, | |
**{ | |
'initialize': initialize, | |
'handle_data': handle_data, | |
'before_trading_start': before_trading_start, | |
'analyze': analyze, | |
} if algotext is None else { | |
'algo_filename': getattr(algofile, 'name', '<algorithm>'), | |
'script': algotext, | |
} | |
).run() | |
except NoBenchmark: | |
raise _RunAlgoError( | |
( | |
'No ``benchmark_spec`` was provided, and' | |
' ``zipline.api.set_benchmark`` was not called in' | |
' ``initialize``.' | |
), | |
( | |
"Neither '--benchmark-symbol' nor '--benchmark-sid' was" | |
" provided, and ``zipline.api.set_benchmark`` was not called" | |
" in ``initialize``. Did you mean to pass '--no-benchmark'?" | |
), | |
) | |
if output == '-': | |
click.echo(str(perf)) | |
elif output != os.devnull: # make the zipline magic not write any data | |
perf.to_pickle(output) | |
return perf | |
# All of the loaded extensions. We don't want to load an extension twice. | |
_loaded_extensions = set() | |
def load_extensions(default, extensions, strict, environ, reload=False): | |
"""Load all of the given extensions. This should be called by run_algo | |
or the cli. | |
Parameters | |
---------- | |
default : bool | |
Load the default exension (~/.zipline/extension.py)? | |
extension : iterable[str] | |
The paths to the extensions to load. If the path ends in ``.py`` it is | |
treated as a script and executed. If it does not end in ``.py`` it is | |
treated as a module to be imported. | |
strict : bool | |
Should failure to load an extension raise. If this is false it will | |
still warn. | |
environ : mapping | |
The environment to use to find the default extension path. | |
reload : bool, optional | |
Reload any extensions that have already been loaded. | |
""" | |
if default: | |
default_extension_path = pth.default_extension(environ=environ) | |
pth.ensure_file(default_extension_path) | |
# put the default extension first so other extensions can depend on | |
# the order they are loaded | |
extensions = concatv([default_extension_path], extensions) | |
for ext in extensions: | |
if ext in _loaded_extensions and not reload: | |
continue | |
try: | |
# load all of the zipline extensionss | |
if ext.endswith('.py'): | |
with open(ext) as f: | |
ns = {} | |
six.exec_(compile(f.read(), ext, 'exec'), ns, ns) | |
else: | |
__import__(ext) | |
except Exception as e: | |
if strict: | |
# if `strict` we should raise the actual exception and fail | |
raise | |
# without `strict` we should just log the failure | |
warnings.warn( | |
'Failed to load extension: %r\n%s' % (ext, e), | |
stacklevel=2 | |
) | |
else: | |
_loaded_extensions.add(ext) | |
def run_algorithm(start, | |
end, | |
initialize, | |
capital_base, | |
handle_data=None, | |
before_trading_start=None, | |
analyze=None, | |
data_frequency='daily', | |
bundle='quantopian-quandl', | |
bundle_timestamp=None, | |
trading_calendar=None, | |
metrics_set='default', | |
benchmark_returns=None, | |
default_extension=True, | |
extensions=(), | |
strict_extensions=True, | |
environ=os.environ, | |
blotter='default', | |
choose_loader=None, | |
): | |
""" | |
Run a trading algorithm. | |
Parameters | |
---------- | |
start : datetime | |
The start date of the backtest. | |
end : datetime | |
The end date of the backtest.. | |
initialize : callable[context -> None] | |
The initialize function to use for the algorithm. This is called once | |
at the very begining of the backtest and should be used to set up | |
any state needed by the algorithm. | |
capital_base : float | |
The starting capital for the backtest. | |
handle_data : callable[(context, BarData) -> None], optional | |
The handle_data function to use for the algorithm. This is called | |
every minute when ``data_frequency == 'minute'`` or every day | |
when ``data_frequency == 'daily'``. | |
before_trading_start : callable[(context, BarData) -> None], optional | |
The before_trading_start function for the algorithm. This is called | |
once before each trading day (after initialize on the first day). | |
analyze : callable[(context, pd.DataFrame) -> None], optional | |
The analyze function to use for the algorithm. This function is called | |
once at the end of the backtest and is passed the context and the | |
performance data. | |
data_frequency : {'daily', 'minute'}, optional | |
The data frequency to run the algorithm at. | |
bundle : str, optional | |
The name of the data bundle to use to load the data to run the backtest | |
with. This defaults to 'quantopian-quandl'. | |
bundle_timestamp : datetime, optional | |
The datetime to lookup the bundle data for. This defaults to the | |
current time. | |
trading_calendar : TradingCalendar, optional | |
The trading calendar to use for your backtest. | |
metrics_set : iterable[Metric] or str, optional | |
The set of metrics to compute in the simulation. If a string is passed, | |
resolve the set with :func:`zipline.finance.metrics.load`. | |
default_extension : bool, optional | |
Should the default zipline extension be loaded. This is found at | |
``$ZIPLINE_ROOT/extension.py`` | |
extensions : iterable[str], optional | |
The names of any other extensions to load. Each element may either be | |
a dotted module path like ``a.b.c`` or a path to a python file ending | |
in ``.py`` like ``a/b/c.py``. | |
strict_extensions : bool, optional | |
Should the run fail if any extensions fail to load. If this is false, | |
a warning will be raised instead. | |
environ : mapping[str -> str], optional | |
The os environment to use. Many extensions use this to get parameters. | |
This defaults to ``os.environ``. | |
blotter : str or zipline.finance.blotter.Blotter, optional | |
Blotter to use with this algorithm. If passed as a string, we look for | |
a blotter construction function registered with | |
``zipline.extensions.register`` and call it with no parameters. | |
Default is a :class:`zipline.finance.blotter.SimulationBlotter` that | |
never cancels orders. | |
Returns | |
------- | |
perf : pd.DataFrame | |
The daily performance of the algorithm. | |
See Also | |
-------- | |
zipline.data.bundles.bundles : The available data bundles. | |
""" | |
load_extensions(default_extension, extensions, strict_extensions, environ) | |
benchmark_spec = BenchmarkSpec.from_returns(benchmark_returns) | |
return _run( | |
handle_data=handle_data, | |
initialize=initialize, | |
before_trading_start=before_trading_start, | |
analyze=analyze, | |
algofile=None, | |
algotext=None, | |
defines=(), | |
data_frequency=data_frequency, | |
capital_base=capital_base, | |
bundle=bundle, | |
bundle_timestamp=bundle_timestamp, | |
start=start, | |
end=end, | |
output=os.devnull, | |
trading_calendar=trading_calendar, | |
print_algo=False, | |
metrics_set=metrics_set, | |
local_namespace=False, | |
environ=environ, | |
blotter=blotter, | |
benchmark_spec=benchmark_spec, | |
choose_loader=choose_loader, | |
) | |
class BenchmarkSpec(object): | |
""" | |
Helper for different ways we can get benchmark data for the Zipline CLI and | |
zipline.utils.run_algo.run_algorithm. | |
Parameters | |
---------- | |
benchmark_returns : pd.Series, optional | |
Series of returns to use as the benchmark. | |
benchmark_file : str or file | |
File containing a csv with `date` and `return` columns, to be read as | |
the benchmark. | |
benchmark_sid : int, optional | |
Sid of the asset to use as a benchmark. | |
benchmark_symbol : str, optional | |
Symbol of the asset to use as a benchmark. Symbol will be looked up as | |
of the end date of the backtest. | |
no_benchmark : bool | |
Flag indicating that no benchmark is configured. Benchmark-dependent | |
metrics will be calculated using a dummy benchmark of all-zero returns. | |
""" | |
def __init__(self, | |
benchmark_returns, | |
benchmark_file, | |
benchmark_sid, | |
benchmark_symbol, | |
no_benchmark): | |
self.benchmark_returns = benchmark_returns | |
self.benchmark_file = benchmark_file | |
self.benchmark_sid = benchmark_sid | |
self.benchmark_symbol = benchmark_symbol | |
self.no_benchmark = no_benchmark | |
@classmethod | |
def from_cli_params(cls, | |
benchmark_sid, | |
benchmark_symbol, | |
benchmark_file, | |
no_benchmark): | |
return cls( | |
benchmark_returns=None, | |
benchmark_sid=benchmark_sid, | |
benchmark_symbol=benchmark_symbol, | |
benchmark_file=benchmark_file, | |
no_benchmark=no_benchmark, | |
) | |
@classmethod | |
def from_returns(cls, benchmark_returns): | |
return cls( | |
benchmark_returns=benchmark_returns, | |
benchmark_file=None, | |
benchmark_sid=None, | |
benchmark_symbol=None, | |
no_benchmark=benchmark_returns is not None, | |
) | |
def resolve(self, asset_finder, start_date, end_date): | |
""" | |
Resolve inputs into values to be passed to TradingAlgorithm. | |
Returns a pair of ``(benchmark_sid, benchmark_returns)`` with at most | |
one non-None value. Both values may be None if no benchmark source has | |
been configured. | |
Parameters | |
---------- | |
asset_finder : zipline.assets.AssetFinder | |
Asset finder for the algorithm to be run. | |
start_date : pd.Timestamp | |
Start date of the algorithm to be run. | |
end_date : pd.Timestamp | |
End date of the algorithm to be run. | |
Returns | |
------- | |
benchmark_sid : int | |
Sid to use as benchmark. | |
benchmark_returns : pd.Series | |
Series of returns to use as benchmark. | |
""" | |
if self.benchmark_returns is not None: | |
benchmark_sid = None | |
benchmark_returns = self.benchmark_returns | |
elif self.benchmark_file is not None: | |
benchmark_sid = None | |
benchmark_returns = get_benchmark_returns_from_file( | |
self.benchmark_file, | |
) | |
elif self.benchmark_sid is not None: | |
benchmark_sid = self.benchmark_sid | |
benchmark_returns = None | |
elif self.benchmark_symbol is not None: | |
try: | |
asset = asset_finder.lookup_symbol( | |
self.benchmark_symbol, | |
as_of_date=end_date, | |
) | |
benchmark_sid = asset.sid | |
benchmark_returns = None | |
except SymbolNotFound: | |
raise _RunAlgoError( | |
"Symbol %s as a benchmark not found in this bundle." | |
) | |
elif self.no_benchmark: | |
benchmark_sid = None | |
benchmark_returns = self._zero_benchmark_returns( | |
start_date=start_date, | |
end_date=end_date, | |
) | |
else: | |
log.warn( | |
"No benchmark configured. " | |
"Assuming algorithm calls set_benchmark." | |
) | |
log.warn( | |
"Pass --benchmark-sid, --benchmark-symbol, or" | |
" --benchmark-file to set a source of benchmark returns." | |
) | |
log.warn( | |
"Pass --no-benchmark to use a dummy benchmark " | |
"of zero returns.", | |
) | |
benchmark_sid = None | |
benchmark_returns = None | |
return benchmark_sid, benchmark_returns | |
@staticmethod | |
def _zero_benchmark_returns(start_date, end_date): | |
return pd.Series( | |
index=pd.date_range(start_date, end_date, tz='utc'), | |
data=0.0, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment