Last active
June 11, 2022 23:15
-
-
Save mintaow/ab2cc75f6b7b8163c99032f83be3151e to your computer and use it in GitHub Desktop.
Visualization Template: Double-axis Time Series Plot with Auxiliary Lines/Bands
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Matplotlib version: 3.5.2 (My Local Jupyter Notebook) | |
# Seaborn version: 0.11.2 (My Local Jupyter Notebook) | |
# Python version: 3.7.4 (My Local Jupyter Notebook) | |
import numpy as np | |
import pandas as pd | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
import matplotlib.ticker as ticker | |
from datetime import datetime | |
%matplotlib inline | |
sns.set(style="white",context="talk") | |
def double_axis_line_plot(df_1, col_x, col_1_y, line_1_label, df_2, col_2_y, line_2_label, title='TITLE', xlabel = 'XLABEL', y_1_label = 'LEFT_YLABEL',y_2_label = 'RIGHT_YLABEL', figsize=(14, 6)): | |
''' | |
This lineplot function visualizes the (time series) trend of key metrics col_1_y versus col_2_y over col_x. | |
col_1_y corresponds to the left y-axis and col_2_y is ties to the right y-axis. | |
I also keep the alternative to add a shaded area / auxiliary lines to the graph. | |
Input: | |
df_1: pandas dataframe. The dataframe for the first variable. | |
col_x: string. The name of the column in df_1 as the x-axis | |
col_1_y: string. The name of the column in df_1 as the left x-axis | |
line_1_label: string. The label for col_1_y that appears as legends | |
df_2: pandas dataframe. The dataframe for the second variable. | |
col_2_y: string. The name of the column in df_2 as the right x-axis | |
line_2_label: string. The label for col_2_y that appears as legends | |
title: string. The title of the diagram | |
xlabel: string. The name of the x-axis | |
y_1_label: string. The title of the left y-axis | |
y_2_label: string. The title of the right y-axis | |
figsize: tuple. The width and height of the figure. | |
Output: | |
fig: figure object | |
''' | |
fig, ax = plt.subplots(nrows=1, ncols=1, figsize = figsize) | |
# Lineplot for the first variable (left axis) | |
sns.lineplot( | |
x = col_x, y = col_1_y, | |
data = df_1, ax = ax, | |
color = 'royalblue', | |
label = line_1_label, | |
# marker = "o", markersize = 6, | |
) | |
ax.set_xlabel(xlabel,fontsize = 16) | |
ax.set_ylabel(y_1_label,fontsize = 16) | |
# Set the legend location and fontsize | |
ax.legend(loc=2,fontsize=14) | |
# Display 2 times fewer ticks on the x-axis to avoid the overlapping mess | |
ax.xaxis.set_major_locator(ticker.MultipleLocator(2)) | |
# Rotate the x-ticks to save some space | |
plt.setp(ax.get_xticklabels(), fontsize = 14, rotation=45, ha="right", rotation_mode="anchor") | |
# Lineplot for the second variable (right axis) | |
ax_y = ax.twinx() | |
sns.lineplot( | |
x = col_x, | |
y = col_2_y, | |
data = df_2, | |
ax = ax_y, | |
label = line_2_label, | |
color = 'black', | |
alpha = 0.4, | |
# linestyle = 'dashed', | |
# marker='o',markersize= 6, | |
) | |
# Display 2 times fewer ticks on the x-axis to avoid the overlapping mess | |
ax_y.xaxis.set_major_locator(ticker.MultipleLocator(2)) | |
# Rotate the x-ticks to save some space | |
plt.setp(ax_y.get_xticklabels(), fontsize = 14, rotation=45, ha="right", rotation_mode="anchor") | |
ax_y.set_ylabel(y_2_label,fontsize = 16) | |
# Make it a dashed lineplot (https://stackoverflow.com/questions/51963725/how-to-plot-a-dashed-line-on-seaborn-lineplot) | |
ax_y.lines[0].set_linestyle("--") | |
# Set the legend location and fontsize | |
ax_y.legend(loc=1,fontsize = 14) | |
# Add auxiliary lines and auxiliary shaded area (if needed) | |
ax.axvspan( | |
xmin = datetime.strptime('2019/03/20', '%Y/%m/%d'), | |
xmax = datetime.strptime('2019/03/23', '%Y/%m/%d'), | |
color = 'darkgrey', | |
alpha=0.2) | |
# ax.axhline(y = 0.5, ls = "--", c = "black",alpha = 0.6) # | |
# ax.axvline(x = datetime.strptime('2019/09/28', '%Y/%m/%d'), ls = "--", c = "black",alpha = 0.6) | |
# Format the graph-level features | |
ax.set_title(title, fontsize=18, pad=10) # pad paramter increases the space between title and graph (https://stackoverflow.com/questions/16419670/increase-distance-between-title-and-plot-in-matplolib) | |
# Faint the grid lines | |
ax.grid(linestyle="--", alpha=0.4) | |
# ax.tick_params(top=False) | |
plt.show() | |
return fig | |
# ==================================================================================== | |
# Illustration | |
# Load the data | |
# The first dataset: NYC public taxis records from seaborn in March, 2019 | |
df = sns.load_dataset("taxis") | |
df['pickup'] = list(map(lambda x:datetime.strptime(x, '%Y-%m-%d %H:%M:%S'),df.pickup)) | |
df['date'] = list(map(lambda x:x.date(),df.pickup)) | |
taxis = df.groupby(by='date').agg( | |
num_rides = pd.NamedAgg(column='pickup', aggfunc='count'), | |
avg_price = pd.NamedAgg(column='total', aggfunc='mean'), | |
max_price = pd.NamedAgg(column='total', aggfunc='max'), | |
med_price = pd.NamedAgg(column='total', aggfunc=np.median), | |
avg_distance = pd.NamedAgg(column='distance', aggfunc='mean') | |
).sort_values(by='date').reset_index() | |
taxis.date = pd.to_datetime(taxis.date) | |
taxis = taxis[1:] # Remove the abnormal date - Feb 28th | |
# The second dataset: NYC public weather data from National Oceanic and Atmospheric Administration in March, 2019 | |
weather = pd.read_csv("https://raw.githubusercontent.com/mintaow/MyMediumWork/main/data/climate_data_nyc_noaa.csv") | |
weather.date = pd.to_datetime(weather.date) | |
fig = double_axis_line_plot( | |
df_1 = taxis, | |
col_x = 'date', | |
col_1_y = 'num_rides', | |
line_1_label = 'Number of Taxis Rides in NYC', | |
df_2 = weather, | |
col_2_y = 'precip', | |
line_2_label = 'Precipitation (Inch)', | |
title='Understanding Demand Fluctuation: Daily Number of Taxis Rides in NYC', | |
xlabel = 'Date (March, 2019)', | |
y_1_label = '# Rides', | |
y_2_label = 'Inch', | |
figsize=(12, 6) | |
) | |
# fig.savefig( | |
# fname = "../double_axis_time_series_plot_with_band_taxis.png", # path&filename for the output | |
# dsi = 300, # make it a high-resolution graph | |
# bbox_inches='tight' # sometimes default savefig method cuts off the x-axis and x-ticks. This param avoids that | |
# ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment