Skip to content

Instantly share code, notes, and snippets.

@jcoffi
Forked from alg0trader/alpaca_downloader.py
Created November 26, 2021 23:31
Show Gist options
  • Save jcoffi/dec2f5b704d5c69d340b3e68997e5fab to your computer and use it in GitHub Desktop.
Save jcoffi/dec2f5b704d5c69d340b3e68997e5fab to your computer and use it in GitHub Desktop.
alpaca stock downloader
###############################################################################
# module: stock_downloader.py
# description: This script will download a given stock from the Alpaca
# API.
# author: Austin Schaller
###############################################################################
import pytz
import os, time
import numpy as np
import pandas as pd
from enum import Enum
import alpaca_trade_api as tradeapi
from datetime import datetime, timezone, tzinfo
## CLASSES ####################################################################
class TimeFrame(Enum):
"""
Override Alpaca REST TimeFrame to support additional resolutions.
"""
Day = "1Day"
Hour = "1Hour"
Minute1 = "1Min"
Minute5 = "5Min"
Minute15 = "15Min"
Sec = "1Sec"
## FUNCTIONS ##################################################################
def get_today_dt(tz='America/New_York'):
'''
input:
- tz (string): timezone (EDT by default)
output:
- datetime (ISO string): today's datetime stamp
'''
utc_now = pytz.utc.localize(datetime.utcnow())
return utc_now.astimezone(pytz.timezone(tz)).replace(microsecond=0).isoformat()
def get_days_between(start, end):
'''
input:
- start (ISO string): start datetime stamp
- end (ISO string or 'today'): end datetime stamp
output:
- days (int): number of days between start and end datetime stamps
'''
start_dt = datetime.strptime(start, '%Y-%m-%dT%H:%M:%S%z')
if(end == 'today'): end_dt = get_today_dt()
else: end_dt = datetime.strptime(end, '%Y-%m-%dT%H:%M:%S%z')
delta = datetime.strptime(end_dt, '%Y-%m-%dT%H:%M:%S%z') - start_dt
return delta.days
def get_stock_data(api, symbol, start, end, timeframe, limit=1000, tz='America/New_York', adjustment='raw'):
'''
input:
- api (alpaca_trade_api.REST): Alpaca API instance (v2)
- symbol (string): stock symbol
- start (ISO string): start datetime stamp
- end (ISO string): end datetime stamp
- timeframe (alpaca_trade_api.rest.TimeFrame): timeframe resolution
- limit (int): max rows to capture, restricted to 1-10,000 (endpoints included)
output:
- df (pd.DataFrame): outputted stock dataframe
'''
if end == 'today': end = get_today_dt(tz=tz)
# Get initial bar data to see if 10,000 or less rows is enough to encompass our range
bars_df = api.get_bars(symbol, timeframe, start, end, limit=limit, adjustment=adjustment).df
bars_df.columns = ['Open', 'High', 'Low', 'Close', 'Volume']
bars_df.index.name = 'Datetime'
accumulated_df = bars_df
if len(accumulated_df) < limit:
accumulated_df.index = pd.to_datetime(accumulated_df.index) # convert to pd.datetime64 dtype
accumulated_df.index = accumulated_df.index.tz_convert(tz) # convert timzone
return accumulated_df
else:
while True:
last_dt = bars_df.index[-1].isoformat()
bars_df = api.get_bars(symbol, timeframe, start=last_dt, end=end, limit=limit, adjustment=adjustment).df[1:]
bars_df.columns = ['Open', 'High', 'Low', 'Close', 'Volume']
bars_df.index.name = 'Datetime'
accumulated_df = pd.concat([accumulated_df, bars_df])
if len(bars_df) < (limit - 1): break
accumulated_df.index = pd.to_datetime(accumulated_df.index) # convert to pd.datetime64 dtype
accumulated_df.index = accumulated_df.index.tz_convert(tz) # convert timzone to EDT
return accumulated_df
def get_all_stock_data(api, symbols, export_path, start, end, timeframe, limit=1000, tz='America/New_York', fmt='npy', adjustment='raw'):
'''
input:
- api (alpaca_trade_api.REST): Alpaca API instance (v2)
- symbol (string): stock symbol
- start (ISO string): start datetime stamp
- end (ISO string): end datetime stamp
- timeframe (alpaca_trade_api.rest.TimeFrame): timeframe resolution
- limit (int): max rows to capture, restricted to 1-10,000 (endpoints included)
output:
- df (pd.DataFrame): outputted stock dataframe
'''
i = 1
for sym in symbols:
bars_df = get_stock_data(api, sym, start, end, timeframe, limit, tz, adjustment)
export_data(bars_df, sym, export_path, fmt, remove_tz=False)
print('Progress: (%d / %d)\r' % (i, len(symbols)), end='')
def get_stock_tickers(api, tradable_only=True):
'''
input:
- api (alpaca_trade_api.REST): Alpaca API instance
- tradable_only (bool): include only assets that are tradable
by Alpaca
output:
- list (Asset): list of assets from Alpaca that are tradable
'''
assets = api.list_assets(status='active')
if tradable_only == True: assets = [a for a in assets if a.tradable == True]
return assets
def export_data(df, symbol, export_path, format='npy', remove_tz=True):
'''
input:
- df (pd.DataFrame): dataframe of stock data
- export_path (string): desired export path
- format (string): export format ('npy' or 'csv')
'''
if not os.path.isdir(export_path): os.makedirs(export_path)
if remove_tz == True: df.index = df.index.tz_convert(None)
if(format == 'csv'): df.to_csv(os.path.join(export_path, symbol + '.csv'))
elif(format == 'npy'): df.to_pickle(os.path.join(export_path, symbol + '.npy'))
else: raise Exception("Unrecognized export format, only 'npy' and 'csv' is supported.")
def print_summary(bars_df, start_time, end_time, fmt):
# Print out how long it took to capture data
hours, rem = divmod(end_time-start_time, 3600)
minutes, seconds = divmod(rem, 60)
print("Elapsed time: {:0>2}:{:0>2}:{:05.2f}".format(int(hours),int(minutes),seconds))
## MAIN #######################################################################
if __name__ == "__main__":
# Parameters
symbol = 'AMC'
end = 'today'
start = '2021-01-04T07:00:00-04:00' # Jan 1, 2016 is earliest
fmt = 'npy'
timeframe = TimeFrame.Minute1
export_path = os.path.join(os.getenv('STOCK_DATA_PATH'), 'alpaca')
get_all_stocks = False
# Connect to the Alpaca API
api = tradeapi.REST(api_version='v2')
print('Start datetime:\t%s' % start)
print('End datetime: \t%s' % get_today_dt())
print('(%d days between)' % get_days_between(start, end))
start_time = time.time()
if get_all_stocks == False:
# Get barset data of indicated stock in snapshots of 10,000 minutes at a time
bars_df = get_stock_data(api, symbol, start, end, timeframe, limit=10000)
end_time = time.time()
print_summary(bars_df, start_time, end_time, fmt)
export_data(bars_df, symbol, export_path, format=fmt, remove_tz=False)
else:
# Get all stock data, 5-years, from the Alpaca API
asset_list = get_stock_tickers(api)
start_time = time.time()
i = 1
for asset in asset_list:
try:
bars_df = get_stock_data(api, asset.symbol, start, end, timeframe, limit=10000)
export_data(bars_df, asset.symbol, export_path, format=fmt, remove_tz=False)
print('Progress: (%d / %d)\r' % (i, len(asset_list)), end='')
i += 1
except Exception as e:
print('Error: %s' % str(e))
print('Stock: %s' % asset.symbol)
print('------------------------')
end_time = time.time()
print('')
print('Done!')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment