Created
March 10, 2017 06:06
-
-
Save arshpreetsingh/b222de8067a70fc49df9b18b9d1a5b5e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Use the previous 10 bars' movements to predict the next movement. | |
# Use a random forest classifier. More here: http://scikit-learn.org/stable/user_guide.html | |
from sklearn.ensemble import RandomForestRegressor | |
import numpy as np | |
def initialize(context): | |
context.security = sid(8554) # Trade SPY | |
context.model = RandomForestRegressor() | |
context.lookback = 3 # Look back 3 days | |
context.history_range = 400 # Only consider the past 400 days' history | |
# Generate a new model every week | |
schedule_function(create_model, date_rules.week_end(), time_rules.market_close(minutes=10)) | |
# Trade at the start of every day | |
schedule_function(trade, date_rules.every_day(), time_rules.market_open(minutes=1)) | |
def create_model(context, data): | |
# Get the relevant daily prices | |
recent_prices = history(context.history_range, '1d', 'price')[context.security].values | |
# Get the price changes | |
price_changes = np.diff(recent_prices).tolist() | |
X = [] # Independent, or input variables | |
Y = [] # Dependent, or output variable | |
# For each day in our history | |
for i in range(context.history_range-context.lookback-1): | |
X.append(price_changes[i:i+context.lookback]) # Store prior price changes | |
Y.append(price_changes[i+context.lookback]) # Store the day's price change | |
context.model.fit(X, Y) # Generate our model | |
def trade(context, data): | |
if context.model: # Check if our model is generated | |
# Get recent prices | |
recent_prices = history(context.lookback+1, '1d', 'price')[context.security].values | |
# Get the price changes | |
price_changes = np.diff(recent_prices).tolist() | |
# Predict using our model and the recent prices | |
prediction = context.model.predict(price_changes) | |
record(prediction = prediction) | |
# Go long if we predict the price will rise, short otherwise | |
if prediction > 0: | |
order_target_percent(context.security, 1.0) | |
else: | |
order_target_percent(context.security, -1.0) | |
def handle_data(context, data): | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment