Skip to content

Instantly share code, notes, and snippets.

@tomron
Last active November 3, 2023 15:14
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tomron/8798256fcee5438edd58c17654adf443 to your computer and use it in GitHub Desktop.
Save tomron/8798256fcee5438edd58c17654adf443 to your computer and use it in GitHub Desktop.
A nicer seasonal decompose chart using plotly.
from statsmodels.tsa.seasonal import seasonal_decompose
import plotly.tools as tls
def plotSeasonalDecompose(
x,
model='additive',
filt=None,
period=None,
two_sided=True,
extrapolate_trend=0,
title="Seasonal Decomposition"):
"""
Plot time series decomposition
:param x: Time series.
See documentation of the remaining models here -
https://www.statsmodels.org/stable/generated/statsmodels.tsa.seasonal.seasonal_decompose.html
Example -
import pandas as pd
from datetime import datetime
import PlotTimeSeries
s = pd.DataFrame(list(range(1, 11))*10,
index=pd.date_range(start=datetime(2010, 1, 1), periods=100))
fig = PlotTimeSeries.plotSeasonalDecompose(s)
fig.show()
"""
result = seasonal_decompose(
x, model=model, filt=filt, period=period,
two_sided=two_sided, extrapolate_trend=extrapolate_trend)
fig = make_subplots(
rows=4, cols=1,
subplot_titles=["Observed", "Trend", "Seasonal", "Residuals"])
fig.add_trace(
go.Scatter(x=result.seasonal.index, y=result.observed, mode='lines'),
row=1, col=1,
)
fig.add_trace(
go.Scatter(x=result.trend.index, y=result.trend, mode='lines'),
row=2, col=1,
)
fig.add_trace(
go.Scatter(x=result.seasonal.index, y=result.seasonal, mode='lines'),
row=3, col=1,
)
fig.add_trace(
go.Scatter(x=result.resid.index, y=result.resid, mode='lines'),
row=4, col=1,
)
return fig
@slmg
Copy link

slmg commented Nov 15, 2021

Thanks for the gist, variation:

import plotly.graph_objects as go
from plotly.subplots import make_subplots
from statsmodels.tsa.seasonal import DecomposeResult, seasonal_decompose


def plot_seasonal_decompose(result: DecomposeResult, title="Seasonal Decomposition"):
    return (
        make_subplots(
            rows=4,
            cols=1,
            subplot_titles=["Observed", "Trend", "Seasonal", "Residuals"],
        )
        .add_trace(
            go.Scatter(x=result.seasonal.index, y=result.observed, mode="lines"),
            row=1,
            col=1,
        )
        .add_trace(
            go.Scatter(x=result.trend.index, y=result.trend, mode="lines"),
            row=2,
            col=1,
        )
        .add_trace(
            go.Scatter(x=result.seasonal.index, y=result.seasonal, mode="lines"),
            row=3,
            col=1,
        )
        .add_trace(
            go.Scatter(x=result.resid.index, y=result.resid, mode="lines"),
            row=4,
            col=1,
        )
        .update_layout(
            height=900, title=title, margin=dict(t=100), title_x=0.5, showlegend=False
        )
    )

@chrimaho
Copy link

Please be careful about using the code above. They both have bugs.

If you run it as you have presented above, it will throw:

  1. Indentation Error:
    IndentationError: unexpected indent
      <[command-1424241324418648]()>, line 33
        fig = make_subplots(
        ^
  2. Attribute Error:
    AttributeError: 'numpy.ndarray' object has no attribute 'index'
    <[command-1424241324418648]()> in plot_seasonal_decompose(result, title)
         12         )
         13         .add_trace(
    ---> 14             go.Scatter(x=result.seasonal.index, y=result.observed, mode="lines"),
         15             row=1,
         16             col=1,

I have now fixed it, please see below.

I've also added another optional parameter to include dates, which can be added to the x-axis for ease of reference.

from plotly.subplots import make_subplots
from statsmodels.tsa.seasonal import DecomposeResult, seasonal_decompose

def plot_seasonal_decompose(result:DecomposeResult, dates:pd.Series=None, title:str="Seasonal Decomposition"):
    x_values = dates if dates is not None else np.arange(len(result.observed))
    return (
        make_subplots(
            rows=4,
            cols=1,
            subplot_titles=["Observed", "Trend", "Seasonal", "Residuals"],
        )
        .add_trace(
            go.Scatter(x=x_values, y=result.observed, mode="lines", name='Observed'),
            row=1,
            col=1,
        )
        .add_trace(
            go.Scatter(x=x_values, y=result.trend, mode="lines", name='Trend'),
            row=2,
            col=1,
        )
        .add_trace(
            go.Scatter(x=x_values, y=result.seasonal, mode="lines", name='Seasonal'),
            row=3,
            col=1,
        )
        .add_trace(
            go.Scatter(x=x_values, y=result.resid, mode="lines", name='Residual'),
            row=4,
            col=1,
        )
        .update_layout(
            height=900, title=f'<b>{title}</b>', margin={'t':100}, title_x=0.5, showlegend=False
        )
    )

Now, it will work like this:

import pandas as pd
from statsmodels.tsa.seasonal import seasonal_decompose
data = pd.read_csv("https://raw.githubusercontent.com/swilsonmfc/pandas/main/AirPassengers.csv")
decomposition = seasonal_decompose(data['#Passengers'], model='additive', period=12)
fig = plot_seasonal_decompose(decomposition, dates=data['Month'])
fig.show()

image

@AdemYoussef
Copy link

@chrimaho Thank you!

@joslinmartinez
Copy link

Nice !!

@till90
Copy link

till90 commented Oct 11, 2022

@chrimaho
i have modified your code that it takes a dataframe to plot multiple columns decompostion side by side



def plot_seasonal_decompose(title:str="Seasonal Decomposition", df:pd.DataFrame=None):
    
    
    fig = make_subplots(rows=4, cols=len(df.columns), subplot_titles=df.columns)
    
    for n, (column_name, Series) in enumerate(df.iteritems(), start=1):
        decomposition = seasonal_decompose(Series, model='additive', period=12)
        fig.add_trace(
            go.Scatter(x=Series.index, y=decomposition.observed, mode="lines", name='Observed'),
            row=1,
            col=n,)
        fig.add_trace(
            go.Scatter(x=Series.index, y=decomposition.trend, mode="lines", name='Observed'),
            row=2,
            col=n,)
        fig.add_trace(
            go.Scatter(x=Series.index, y=decomposition.seasonal, mode="lines", name='Seasonal'),
            row=3,
            col=n,)
        fig.add_trace(
            go.Scatter(x=Series.index, y=decomposition.resid, mode="lines", name='Residual'),
            row=4,
            col=n,
        )
    [fig.update_yaxes(title_text = x, row = n, col=1) for n,x in enumerate(["Observed", "Trend", "Seasonal", "Residuals"], start=1)]
    fig.update_layout(
        height=900, title=f'<b>{title}</b>', margin={'t':100}, title_x=0.5, showlegend=False
    )
    
    return fig
fig = plot_seasonal_decompose(df=df)
fig.show()

newplot(5)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment