Skip to content

Instantly share code, notes, and snippets.

@Qu3tzal
Created May 18, 2023 08:02
Show Gist options
  • Save Qu3tzal/f59063a2e8b8df2d27a435cabfa500b6 to your computer and use it in GitHub Desktop.
Save Qu3tzal/f59063a2e8b8df2d27a435cabfa500b6 to your computer and use it in GitHub Desktop.
Quickly compare a list of stocks or indices prices.
import yfinance as yf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import argparse
def parse_args():
parser = argparse.ArgumentParser(description='Quickly compare a list of stocks or indices prices.')
parser.add_argument('--history', type=str, default='ytd', help='History period of the price data', choices=["1d", "5d", "1mo", "3mo", "6mo", "1y", "2y", "5y", "10y", "ytd", "max"])
parser.add_argument('--granularity', type=str, default='1d', help='Granularity of the price data', choices=["1m", "2m", "5m", "15m", "30m", "60m", "90m", "1h", "1d", "5d", "1wk", "1mo", "3mo"])
parser.add_argument('--stocks', nargs='+', help='List of stocks to compare')
parser.add_argument('--indices', nargs='+', help='List of indices to compare', choices=["SP500", "NASDAQ", "DOWJONES"])
parser.add_argument('--log', action='store_true', help='Use log scale for the y-axis')
parser.add_argument('--timecolor', action='store_true', help='Use color to indicate time')
parser.add_argument('--pricelines', action='store_true', help='Adds lines at regular price intervals')
parser.add_argument('--theme', type=str, default='dark', help='Theme of the plot', choices=["dark", "light", "seaborn"])
args = parser.parse_args()
return args
def main(args):
print("Hello!")
if args.stocks is None and args.indices is None:
print("You need to specify at least one stock or index to compare.")
return
elif args.stocks is None:
print("You are comparing the following indices: {}".format(args.indices))
elif args.indices is None:
print("You are comparing the following stocks: {}".format(args.stocks))
else:
print("You are comparing the following stocks: {} with the following indices {}".format(args.stocks, args.indices))
if args.indices is not None and args.stocks is not None:
print("WARNING: You are comparing stocks and indices on the same plot. This can make reading the plot difficult. Consider using --stocks and --indices separately.")
if args.stocks is not None and len(args.stocks) > 10:
print("WARNING: You are comparing more than 10 stocks. This can make reading the plot difficult.")
if args.theme is not None:
themes_map = {
"dark": "dark_background",
"light": "default",
"seaborn": "seaborn",
}
if args.theme == "seaborn":
import seaborn as sns
sns.set_theme()
print("WARNING: When using the seaborn theme consider disabling the --timecolor and --pricelines flags.")
else:
plt.style.use(themes_map[args.theme])
# Get the data for the indices.
if args.indices:
indices_ticker_map = {
"SP500": "^GSPC",
"NASDAQ": "^IXIC",
"DOWJONES": "^DJI"
}
indices_tickers = [indices_ticker_map[index] for index in args.indices]
indices_data = yf.download(indices_tickers, period=args.history, interval=args.granularity, progress=False)
if len(indices_tickers) == 1:
indices_data.columns = pd.MultiIndex.from_product([indices_data.columns, indices_tickers])
for ticker in args.indices:
close_price = indices_data["Close", indices_ticker_map[ticker]].to_numpy()
plt.plot(close_price, linestyle='dashed', label=ticker + " ({})".format(indices_ticker_map[ticker]))
# Get the data for the stocks.
data = yf.download(args.stocks, period=args.history, interval=args.granularity, progress=False)
if len(args.stocks) == 1:
data.columns = pd.MultiIndex.from_product([data.columns, args.stocks])
for ticker in args.stocks:
close_price = data["Close", ticker].to_numpy()
high_price = data["High", ticker].to_numpy()
low_price = data["Low", ticker].to_numpy()
plt.plot(close_price, label=ticker)
plt.fill_between(range(len(close_price)), high_price, low_price, alpha=0.3)
# Try to only keep about a dozen labels.
if len(data.index) > 12:
xticks_indices = range(0, len(data.index), int(len(data.index) / 12))
xticks_labels = data.index[xticks_indices]
else:
xticks_indices = range(len(data.index))
xticks_labels = data.index
# Adjust labels depending on granularity.
if args.granularity in ["1m", "2m", "5m", "15m", "30m", "60m", "90m", "1h"]:
xticks_labels = [date.strftime("%H:%M\n%d/%m/%Y") for date in xticks_labels]
elif args.granularity in ["1d", "5d", "1wk"]:
xticks_labels = [date.strftime("%d/%m/%Y") for date in xticks_labels]
elif args.granularity in ["1mo", "3mo"]:
xticks_labels = [date.strftime("%d/%m/%Y") for date in xticks_labels]
else:
xticks_labels = [date.strftime("%m/%Y") for date in xticks_labels]
if args.log:
plt.yscale("log")
if args.timecolor:
for i, xi in enumerate(xticks_indices):
if i % 2 == 0 and i < len(xticks_indices) - 1:
plt.axvspan(xi, xticks_indices[i + 1], facecolor='grey', alpha=0.1)
if args.pricelines:
plt.grid(True)
plt.xticks(xticks_indices, xticks_labels, rotation=45)
plt.xlabel("Date")
plt.ylabel("Price (at close, with high/low)")
plt.legend()
plt.title("Comparison of close prices of stocks ({}/{})".format(args.history, args.granularity))
plt.margins(x=0, tight=True)
plt.show()
if __name__ == "__main__":
args = parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment