Skip to content

Instantly share code, notes, and snippets.

@mintaow
Last active June 11, 2022 23:15
Show Gist options
  • Save mintaow/ab2cc75f6b7b8163c99032f83be3151e to your computer and use it in GitHub Desktop.
Save mintaow/ab2cc75f6b7b8163c99032f83be3151e to your computer and use it in GitHub Desktop.
Visualization Template: Double-axis Time Series Plot with Auxiliary Lines/Bands
# Matplotlib version: 3.5.2 (My Local Jupyter Notebook)
# Seaborn version: 0.11.2 (My Local Jupyter Notebook)
# Python version: 3.7.4 (My Local Jupyter Notebook)
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from datetime import datetime
%matplotlib inline
sns.set(style="white",context="talk")
def double_axis_line_plot(df_1, col_x, col_1_y, line_1_label, df_2, col_2_y, line_2_label, title='TITLE', xlabel = 'XLABEL', y_1_label = 'LEFT_YLABEL',y_2_label = 'RIGHT_YLABEL', figsize=(14, 6)):
'''
This lineplot function visualizes the (time series) trend of key metrics col_1_y versus col_2_y over col_x.
col_1_y corresponds to the left y-axis and col_2_y is ties to the right y-axis.
I also keep the alternative to add a shaded area / auxiliary lines to the graph.
Input:
df_1: pandas dataframe. The dataframe for the first variable.
col_x: string. The name of the column in df_1 as the x-axis
col_1_y: string. The name of the column in df_1 as the left x-axis
line_1_label: string. The label for col_1_y that appears as legends
df_2: pandas dataframe. The dataframe for the second variable.
col_2_y: string. The name of the column in df_2 as the right x-axis
line_2_label: string. The label for col_2_y that appears as legends
title: string. The title of the diagram
xlabel: string. The name of the x-axis
y_1_label: string. The title of the left y-axis
y_2_label: string. The title of the right y-axis
figsize: tuple. The width and height of the figure.
Output:
fig: figure object
'''
fig, ax = plt.subplots(nrows=1, ncols=1, figsize = figsize)
# Lineplot for the first variable (left axis)
sns.lineplot(
x = col_x, y = col_1_y,
data = df_1, ax = ax,
color = 'royalblue',
label = line_1_label,
# marker = "o", markersize = 6,
)
ax.set_xlabel(xlabel,fontsize = 16)
ax.set_ylabel(y_1_label,fontsize = 16)
# Set the legend location and fontsize
ax.legend(loc=2,fontsize=14)
# Display 2 times fewer ticks on the x-axis to avoid the overlapping mess
ax.xaxis.set_major_locator(ticker.MultipleLocator(2))
# Rotate the x-ticks to save some space
plt.setp(ax.get_xticklabels(), fontsize = 14, rotation=45, ha="right", rotation_mode="anchor")
# Lineplot for the second variable (right axis)
ax_y = ax.twinx()
sns.lineplot(
x = col_x,
y = col_2_y,
data = df_2,
ax = ax_y,
label = line_2_label,
color = 'black',
alpha = 0.4,
# linestyle = 'dashed',
# marker='o',markersize= 6,
)
# Display 2 times fewer ticks on the x-axis to avoid the overlapping mess
ax_y.xaxis.set_major_locator(ticker.MultipleLocator(2))
# Rotate the x-ticks to save some space
plt.setp(ax_y.get_xticklabels(), fontsize = 14, rotation=45, ha="right", rotation_mode="anchor")
ax_y.set_ylabel(y_2_label,fontsize = 16)
# Make it a dashed lineplot (https://stackoverflow.com/questions/51963725/how-to-plot-a-dashed-line-on-seaborn-lineplot)
ax_y.lines[0].set_linestyle("--")
# Set the legend location and fontsize
ax_y.legend(loc=1,fontsize = 14)
# Add auxiliary lines and auxiliary shaded area (if needed)
ax.axvspan(
xmin = datetime.strptime('2019/03/20', '%Y/%m/%d'),
xmax = datetime.strptime('2019/03/23', '%Y/%m/%d'),
color = 'darkgrey',
alpha=0.2)
# ax.axhline(y = 0.5, ls = "--", c = "black",alpha = 0.6) #
# ax.axvline(x = datetime.strptime('2019/09/28', '%Y/%m/%d'), ls = "--", c = "black",alpha = 0.6)
# Format the graph-level features
ax.set_title(title, fontsize=18, pad=10) # pad paramter increases the space between title and graph (https://stackoverflow.com/questions/16419670/increase-distance-between-title-and-plot-in-matplolib)
# Faint the grid lines
ax.grid(linestyle="--", alpha=0.4)
# ax.tick_params(top=False)
plt.show()
return fig
# ====================================================================================
# Illustration
# Load the data
# The first dataset: NYC public taxis records from seaborn in March, 2019
df = sns.load_dataset("taxis")
df['pickup'] = list(map(lambda x:datetime.strptime(x, '%Y-%m-%d %H:%M:%S'),df.pickup))
df['date'] = list(map(lambda x:x.date(),df.pickup))
taxis = df.groupby(by='date').agg(
num_rides = pd.NamedAgg(column='pickup', aggfunc='count'),
avg_price = pd.NamedAgg(column='total', aggfunc='mean'),
max_price = pd.NamedAgg(column='total', aggfunc='max'),
med_price = pd.NamedAgg(column='total', aggfunc=np.median),
avg_distance = pd.NamedAgg(column='distance', aggfunc='mean')
).sort_values(by='date').reset_index()
taxis.date = pd.to_datetime(taxis.date)
taxis = taxis[1:] # Remove the abnormal date - Feb 28th
# The second dataset: NYC public weather data from National Oceanic and Atmospheric Administration in March, 2019
weather = pd.read_csv("https://raw.githubusercontent.com/mintaow/MyMediumWork/main/data/climate_data_nyc_noaa.csv")
weather.date = pd.to_datetime(weather.date)
fig = double_axis_line_plot(
df_1 = taxis,
col_x = 'date',
col_1_y = 'num_rides',
line_1_label = 'Number of Taxis Rides in NYC',
df_2 = weather,
col_2_y = 'precip',
line_2_label = 'Precipitation (Inch)',
title='Understanding Demand Fluctuation: Daily Number of Taxis Rides in NYC',
xlabel = 'Date (March, 2019)',
y_1_label = '# Rides',
y_2_label = 'Inch',
figsize=(12, 6)
)
# fig.savefig(
# fname = "../double_axis_time_series_plot_with_band_taxis.png", # path&filename for the output
# dsi = 300, # make it a high-resolution graph
# bbox_inches='tight' # sometimes default savefig method cuts off the x-axis and x-ticks. This param avoids that
# )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment