Last active
November 3, 2023 15:14
-
-
Save tomron/8798256fcee5438edd58c17654adf443 to your computer and use it in GitHub Desktop.
A nicer seasonal decompose chart using plotly.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from statsmodels.tsa.seasonal import seasonal_decompose | |
import plotly.tools as tls | |
def plotSeasonalDecompose( | |
x, | |
model='additive', | |
filt=None, | |
period=None, | |
two_sided=True, | |
extrapolate_trend=0, | |
title="Seasonal Decomposition"): | |
""" | |
Plot time series decomposition | |
:param x: Time series. | |
See documentation of the remaining models here - | |
https://www.statsmodels.org/stable/generated/statsmodels.tsa.seasonal.seasonal_decompose.html | |
Example - | |
import pandas as pd | |
from datetime import datetime | |
import PlotTimeSeries | |
s = pd.DataFrame(list(range(1, 11))*10, | |
index=pd.date_range(start=datetime(2010, 1, 1), periods=100)) | |
fig = PlotTimeSeries.plotSeasonalDecompose(s) | |
fig.show() | |
""" | |
result = seasonal_decompose( | |
x, model=model, filt=filt, period=period, | |
two_sided=two_sided, extrapolate_trend=extrapolate_trend) | |
fig = make_subplots( | |
rows=4, cols=1, | |
subplot_titles=["Observed", "Trend", "Seasonal", "Residuals"]) | |
fig.add_trace( | |
go.Scatter(x=result.seasonal.index, y=result.observed, mode='lines'), | |
row=1, col=1, | |
) | |
fig.add_trace( | |
go.Scatter(x=result.trend.index, y=result.trend, mode='lines'), | |
row=2, col=1, | |
) | |
fig.add_trace( | |
go.Scatter(x=result.seasonal.index, y=result.seasonal, mode='lines'), | |
row=3, col=1, | |
) | |
fig.add_trace( | |
go.Scatter(x=result.resid.index, y=result.resid, mode='lines'), | |
row=4, col=1, | |
) | |
return fig |
Please be careful about using the code above. They both have bugs.
If you run it as you have presented above, it will throw:
- Indentation Error:
IndentationError: unexpected indent <[command-1424241324418648]()>, line 33 fig = make_subplots( ^
- Attribute Error:
AttributeError: 'numpy.ndarray' object has no attribute 'index' <[command-1424241324418648]()> in plot_seasonal_decompose(result, title) 12 ) 13 .add_trace( ---> 14 go.Scatter(x=result.seasonal.index, y=result.observed, mode="lines"), 15 row=1, 16 col=1,
I have now fixed it, please see below.
I've also added another optional parameter to include dates, which can be added to the x-axis for ease of reference.
from plotly.subplots import make_subplots
from statsmodels.tsa.seasonal import DecomposeResult, seasonal_decompose
def plot_seasonal_decompose(result:DecomposeResult, dates:pd.Series=None, title:str="Seasonal Decomposition"):
x_values = dates if dates is not None else np.arange(len(result.observed))
return (
make_subplots(
rows=4,
cols=1,
subplot_titles=["Observed", "Trend", "Seasonal", "Residuals"],
)
.add_trace(
go.Scatter(x=x_values, y=result.observed, mode="lines", name='Observed'),
row=1,
col=1,
)
.add_trace(
go.Scatter(x=x_values, y=result.trend, mode="lines", name='Trend'),
row=2,
col=1,
)
.add_trace(
go.Scatter(x=x_values, y=result.seasonal, mode="lines", name='Seasonal'),
row=3,
col=1,
)
.add_trace(
go.Scatter(x=x_values, y=result.resid, mode="lines", name='Residual'),
row=4,
col=1,
)
.update_layout(
height=900, title=f'<b>{title}</b>', margin={'t':100}, title_x=0.5, showlegend=False
)
)
Now, it will work like this:
import pandas as pd
from statsmodels.tsa.seasonal import seasonal_decompose
data = pd.read_csv("https://raw.githubusercontent.com/swilsonmfc/pandas/main/AirPassengers.csv")
decomposition = seasonal_decompose(data['#Passengers'], model='additive', period=12)
fig = plot_seasonal_decompose(decomposition, dates=data['Month'])
fig.show()
@chrimaho Thank you!
Nice !!
@chrimaho
i have modified your code that it takes a dataframe to plot multiple columns decompostion side by side
def plot_seasonal_decompose(title:str="Seasonal Decomposition", df:pd.DataFrame=None):
fig = make_subplots(rows=4, cols=len(df.columns), subplot_titles=df.columns)
for n, (column_name, Series) in enumerate(df.iteritems(), start=1):
decomposition = seasonal_decompose(Series, model='additive', period=12)
fig.add_trace(
go.Scatter(x=Series.index, y=decomposition.observed, mode="lines", name='Observed'),
row=1,
col=n,)
fig.add_trace(
go.Scatter(x=Series.index, y=decomposition.trend, mode="lines", name='Observed'),
row=2,
col=n,)
fig.add_trace(
go.Scatter(x=Series.index, y=decomposition.seasonal, mode="lines", name='Seasonal'),
row=3,
col=n,)
fig.add_trace(
go.Scatter(x=Series.index, y=decomposition.resid, mode="lines", name='Residual'),
row=4,
col=n,
)
[fig.update_yaxes(title_text = x, row = n, col=1) for n,x in enumerate(["Observed", "Trend", "Seasonal", "Residuals"], start=1)]
fig.update_layout(
height=900, title=f'<b>{title}</b>', margin={'t':100}, title_x=0.5, showlegend=False
)
return fig
fig = plot_seasonal_decompose(df=df)
fig.show()
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks for the gist, variation: