Created
October 31, 2021 02:23
-
-
Save alg0trader/9304401c46318aca86750aca4ee69ab1 to your computer and use it in GitHub Desktop.
alpaca stock downloader
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
############################################################################### | |
# module: stock_downloader.py | |
# description: This script will download a given stock from the Alpaca | |
# API. | |
# author: Austin Schaller | |
############################################################################### | |
import pytz | |
import os, time | |
import numpy as np | |
import pandas as pd | |
from enum import Enum | |
import alpaca_trade_api as tradeapi | |
from datetime import datetime, timezone, tzinfo | |
## CLASSES #################################################################### | |
class TimeFrame(Enum): | |
""" | |
Override Alpaca REST TimeFrame to support additional resolutions. | |
""" | |
Day = "1Day" | |
Hour = "1Hour" | |
Minute1 = "1Min" | |
Minute5 = "5Min" | |
Minute15 = "15Min" | |
Sec = "1Sec" | |
## FUNCTIONS ################################################################## | |
def get_today_dt(tz='America/New_York'): | |
''' | |
input: | |
- tz (string): timezone (EDT by default) | |
output: | |
- datetime (ISO string): today's datetime stamp | |
''' | |
utc_now = pytz.utc.localize(datetime.utcnow()) | |
return utc_now.astimezone(pytz.timezone(tz)).replace(microsecond=0).isoformat() | |
def get_days_between(start, end): | |
''' | |
input: | |
- start (ISO string): start datetime stamp | |
- end (ISO string or 'today'): end datetime stamp | |
output: | |
- days (int): number of days between start and end datetime stamps | |
''' | |
start_dt = datetime.strptime(start, '%Y-%m-%dT%H:%M:%S%z') | |
if(end == 'today'): end_dt = get_today_dt() | |
else: end_dt = datetime.strptime(end, '%Y-%m-%dT%H:%M:%S%z') | |
delta = datetime.strptime(end_dt, '%Y-%m-%dT%H:%M:%S%z') - start_dt | |
return delta.days | |
def get_stock_data(api, symbol, start, end, timeframe, limit=1000, tz='America/New_York', adjustment='raw'): | |
''' | |
input: | |
- api (alpaca_trade_api.REST): Alpaca API instance (v2) | |
- symbol (string): stock symbol | |
- start (ISO string): start datetime stamp | |
- end (ISO string): end datetime stamp | |
- timeframe (alpaca_trade_api.rest.TimeFrame): timeframe resolution | |
- limit (int): max rows to capture, restricted to 1-10,000 (endpoints included) | |
output: | |
- df (pd.DataFrame): outputted stock dataframe | |
''' | |
if end == 'today': end = get_today_dt(tz=tz) | |
# Get initial bar data to see if 10,000 or less rows is enough to encompass our range | |
bars_df = api.get_bars(symbol, timeframe, start, end, limit=limit, adjustment=adjustment).df | |
bars_df.columns = ['Open', 'High', 'Low', 'Close', 'Volume'] | |
bars_df.index.name = 'Datetime' | |
accumulated_df = bars_df | |
if len(accumulated_df) < limit: | |
accumulated_df.index = pd.to_datetime(accumulated_df.index) # convert to pd.datetime64 dtype | |
accumulated_df.index = accumulated_df.index.tz_convert(tz) # convert timzone | |
return accumulated_df | |
else: | |
while True: | |
last_dt = bars_df.index[-1].isoformat() | |
bars_df = api.get_bars(symbol, timeframe, start=last_dt, end=end, limit=limit, adjustment=adjustment).df[1:] | |
bars_df.columns = ['Open', 'High', 'Low', 'Close', 'Volume'] | |
bars_df.index.name = 'Datetime' | |
accumulated_df = pd.concat([accumulated_df, bars_df]) | |
if len(bars_df) < (limit - 1): break | |
accumulated_df.index = pd.to_datetime(accumulated_df.index) # convert to pd.datetime64 dtype | |
accumulated_df.index = accumulated_df.index.tz_convert(tz) # convert timzone to EDT | |
return accumulated_df | |
def get_all_stock_data(api, symbols, export_path, start, end, timeframe, limit=1000, tz='America/New_York', fmt='npy', adjustment='raw'): | |
''' | |
input: | |
- api (alpaca_trade_api.REST): Alpaca API instance (v2) | |
- symbol (string): stock symbol | |
- start (ISO string): start datetime stamp | |
- end (ISO string): end datetime stamp | |
- timeframe (alpaca_trade_api.rest.TimeFrame): timeframe resolution | |
- limit (int): max rows to capture, restricted to 1-10,000 (endpoints included) | |
output: | |
- df (pd.DataFrame): outputted stock dataframe | |
''' | |
i = 1 | |
for sym in symbols: | |
bars_df = get_stock_data(api, sym, start, end, timeframe, limit, tz, adjustment) | |
export_data(bars_df, sym, export_path, fmt, remove_tz=False) | |
print('Progress: (%d / %d)\r' % (i, len(symbols)), end='') | |
def get_stock_tickers(api, tradable_only=True): | |
''' | |
input: | |
- api (alpaca_trade_api.REST): Alpaca API instance | |
- tradable_only (bool): include only assets that are tradable | |
by Alpaca | |
output: | |
- list (Asset): list of assets from Alpaca that are tradable | |
''' | |
assets = api.list_assets(status='active') | |
if tradable_only == True: assets = [a for a in assets if a.tradable == True] | |
return assets | |
def export_data(df, symbol, export_path, format='npy', remove_tz=True): | |
''' | |
input: | |
- df (pd.DataFrame): dataframe of stock data | |
- export_path (string): desired export path | |
- format (string): export format ('npy' or 'csv') | |
''' | |
if not os.path.isdir(export_path): os.makedirs(export_path) | |
if remove_tz == True: df.index = df.index.tz_convert(None) | |
if(format == 'csv'): df.to_csv(os.path.join(export_path, symbol + '.csv')) | |
elif(format == 'npy'): df.to_pickle(os.path.join(export_path, symbol + '.npy')) | |
else: raise Exception("Unrecognized export format, only 'npy' and 'csv' is supported.") | |
def print_summary(bars_df, start_time, end_time, fmt): | |
# Print out how long it took to capture data | |
hours, rem = divmod(end_time-start_time, 3600) | |
minutes, seconds = divmod(rem, 60) | |
print("Elapsed time: {:0>2}:{:0>2}:{:05.2f}".format(int(hours),int(minutes),seconds)) | |
## MAIN ####################################################################### | |
if __name__ == "__main__": | |
# Parameters | |
symbol = 'AMC' | |
end = 'today' | |
start = '2021-01-04T07:00:00-04:00' # Jan 1, 2016 is earliest | |
fmt = 'npy' | |
timeframe = TimeFrame.Minute1 | |
export_path = os.path.join(os.getenv('STOCK_DATA_PATH'), 'alpaca') | |
get_all_stocks = False | |
# Connect to the Alpaca API | |
api = tradeapi.REST(api_version='v2') | |
print('Start datetime:\t%s' % start) | |
print('End datetime: \t%s' % get_today_dt()) | |
print('(%d days between)' % get_days_between(start, end)) | |
start_time = time.time() | |
if get_all_stocks == False: | |
# Get barset data of indicated stock in snapshots of 10,000 minutes at a time | |
bars_df = get_stock_data(api, symbol, start, end, timeframe, limit=10000) | |
end_time = time.time() | |
print_summary(bars_df, start_time, end_time, fmt) | |
export_data(bars_df, symbol, export_path, format=fmt, remove_tz=False) | |
else: | |
# Get all stock data, 5-years, from the Alpaca API | |
asset_list = get_stock_tickers(api) | |
start_time = time.time() | |
i = 1 | |
for asset in asset_list: | |
try: | |
bars_df = get_stock_data(api, asset.symbol, start, end, timeframe, limit=10000) | |
export_data(bars_df, asset.symbol, export_path, format=fmt, remove_tz=False) | |
print('Progress: (%d / %d)\r' % (i, len(asset_list)), end='') | |
i += 1 | |
except Exception as e: | |
print('Error: %s' % str(e)) | |
print('Stock: %s' % asset.symbol) | |
print('------------------------') | |
end_time = time.time() | |
print('') | |
print('Done!') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment