Skip to content

plot_decomposition

yohou.plotting.forecasting.plot_decomposition(y, components, *, method=None, columns=None, groups=None, show_original=True, period='auto', periods=None, model='additive', robust=True, trend_window=None, seasonal_window=None, low_pass_window=None, two_sided=True, extrapolate_trend=0, color_palette=None, show_legend=True, title=None, x_label=None, y_label=None, width=None, height=None, connect_gaps=False, resampler=None, line_width=2.0, line_dash='solid')

Plot time series decomposition as vertically stacked subplots.

Displays the original series and its decomposed components (e.g. trend, seasonal, residual) in separate panels sharing the same time axis.

There are two modes of operation:

Pre-computed mode (default) - pass components as a dict mapping component names to DataFrames produced by a Yohou decomposition pipeline.

Decomposition mode - pass components as a list or tuple of component names and set method to "stl", "mstl", or "classical". The function runs the decomposition internally and renders the requested components.

Parameters

Name Type Description Default
y DataFrame

Original time series with "time" column.

required
components dict[str, DataFrame] | list[str] | tuple[str, ...]

dict - mapping of component names to DataFrames (pre-computed mode). Each DataFrame must have a "time" column plus value columns matching y.

list/tuple of str - component names to compute and display. Requires method to be set. Valid names: "observed", "trend", "seasonal", "residual", "seasonal_adjusted".

required
method (stl, mstl, classical)

Decomposition backend. Required when components is a list. None means pre-computed dict mode.

"stl"
columns str | list[str] | None

Value columns to plot. None uses all numeric non-time columns.

None
groups list[str] | None

Panel group prefixes to include. For panel data, returns one figure per member with groups overlaid by colour.

None
show_original bool

Include the original series as the first subplot.

True
period int | str

Seasonal period (STL, MSTL, classical). "auto" infers from the sampling interval.

'auto'
periods list[int] | str | None

Seasonal periods for MSTL. Required when method="mstl".

None
model (additive, multiplicative)

Decomposition model. STL/MSTL use a log-transform approximation for multiplicative; classical uses native statsmodels support.

"additive"
robust bool

Use robust fitting (STL/MSTL only, down-weights outliers).

True
trend_window int | None

Trend smoother window (STL only).

None
seasonal_window int | None

Seasonal smoother window (STL only).

None
low_pass_window int | None

Low-pass filter window (STL only).

None
two_sided bool

Two-sided (centered) moving average for trend (classical only).

True
extrapolate_trend int | str

Extrapolate trend at edges (classical only). 0 leaves NaN.

0
color_palette list[str] | None

Custom color palette.

None
show_legend bool

Whether to show the legend.

True
title str | None

Plot title.

None
x_label str | None

X-axis label on bottom subplot. Defaults to "Time".

None
y_label str | None

Y-axis label.

None
width int | None

Plot width in pixels.

None
height int | None

Plot height in pixels.

None
connect_gaps bool

Whether to connect gaps with lines.

False
resampler bool | Literal['widget'] | None

Enable plotly-resampler for large datasets.

None
line_width float

Width of component line traces.

2.0
line_dash str

Dash style for component lines.

'solid'

Returns

Type Description
Figure | dict[str, Figure]

Plotly figure (or dict of figures for panel data).

Raises

Type Description
TypeError

If y is not a Polars DataFrame.

ValueError

If components is a list without method, DataFrames are empty, unknown component names, or method="mstl" without periods.

ImportError

When statsmodels is not installed.

Examples

Pre-computed mode:

>>> import polars as pl
>>> from yohou.plotting import plot_decomposition
>>> dates = pl.date_range(pl.date(2020, 1, 1), pl.date(2020, 12, 31), "1d", eager=True)
>>> y = pl.DataFrame({"time": dates, "y": list(range(len(dates)))})
>>> comps = {
...     "trend": pl.DataFrame({"time": dates, "y": [i * 0.5 for i in range(len(dates))]}),
...     "residual": pl.DataFrame({"time": dates, "y": [i * 0.5 for i in range(len(dates))]}),
... }
>>> fig = plot_decomposition(y, comps)
>>> len(fig.data) >= 3
True

STL mode:

>>> df = pl.DataFrame({
...     "time": pl.date_range(pl.date(2018, 1, 1), pl.date(2022, 12, 31), "1mo", eager=True),
...     "y": [100 + 10 * (i % 12) + i * 0.5 for i in range(60)],
... })
>>> fig = plot_decomposition(df, ["trend", "seasonal"], method="stl")

Classical mode:

>>> fig = plot_decomposition(df, ["trend", "seasonal"], method="classical")

See Also

plot_forecast : Forecast visualization. plot_seasonality : Seasonal pattern analysis.

Source Code

Show/Hide source
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
2936
2937
2938
2939
2940
2941
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021
3022
3023
3024
3025
3026
3027
3028
3029
3030
3031
3032
3033
3034
3035
3036
3037
3038
3039
3040
3041
3042
3043
3044
3045
3046
3047
3048
3049
3050
3051
3052
3053
3054
3055
3056
3057
3058
3059
3060
3061
3062
3063
3064
3065
3066
3067
3068
3069
3070
3071
3072
3073
3074
3075
3076
3077
3078
3079
3080
3081
3082
3083
3084
3085
3086
3087
3088
3089
3090
3091
3092
3093
3094
3095
3096
3097
3098
3099
3100
3101
3102
3103
3104
3105
3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
3121
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145
3146
3147
3148
3149
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
3163
3164
3165
3166
3167
3168
3169
3170
3171
3172
def plot_decomposition(
    y: pl.DataFrame,
    components: dict[str, pl.DataFrame] | list[str] | tuple[str, ...],
    *,
    method: Literal["stl", "mstl", "classical"] | None = None,
    columns: str | list[str] | None = None,
    groups: list[str] | None = None,
    show_original: bool = True,
    period: int | str = "auto",
    periods: list[int] | str | None = None,
    model: Literal["additive", "multiplicative"] = "additive",
    robust: bool = True,
    trend_window: int | None = None,
    seasonal_window: int | None = None,
    low_pass_window: int | None = None,
    two_sided: bool = True,
    extrapolate_trend: int | str = 0,
    color_palette: list[str] | None = None,
    show_legend: bool = True,
    title: str | None = None,
    x_label: str | None = None,
    y_label: str | None = None,
    width: int | None = None,
    height: int | None = None,
    connect_gaps: bool = False,
    resampler: bool | Literal["widget"] | None = None,
    line_width: float = 2.0,
    line_dash: str = "solid",
) -> go.Figure | dict[str, go.Figure]:
    """Plot time series decomposition as vertically stacked subplots.

    Displays the original series and its decomposed components (e.g. trend,
    seasonal, residual) in separate panels sharing the same time axis.

    There are two modes of operation:

    **Pre-computed mode** (default) - pass *components* as a ``dict`` mapping
    component names to DataFrames produced by a Yohou decomposition pipeline.

    **Decomposition mode** - pass *components* as a ``list`` or ``tuple`` of
    component names and set *method* to ``"stl"``, ``"mstl"``, or
    ``"classical"``.  The function runs the decomposition internally and
    renders the requested components.

    Parameters
    ----------
    y : pl.DataFrame
        Original time series with ``"time"`` column.
    components : dict[str, pl.DataFrame] | list[str] | tuple[str, ...]
        **dict** - mapping of component names to DataFrames (pre-computed
        mode).  Each DataFrame must have a ``"time"`` column plus value
        columns matching *y*.

        **list/tuple of str** - component names to compute and display.
        Requires *method* to be set.
        Valid names: ``"observed"``, ``"trend"``, ``"seasonal"``,
        ``"residual"``, ``"seasonal_adjusted"``.
    method : {"stl", "mstl", "classical"} or None
        Decomposition backend.  Required when *components* is a list.
        ``None`` means pre-computed dict mode.
    columns : str | list[str] | None
        Value columns to plot.  ``None`` uses all numeric non-time columns.
    groups : list[str] | None
        Panel group prefixes to include.  For panel data, returns one
        figure per member with groups overlaid by colour.
    show_original : bool
        Include the original series as the first subplot.
    period : int | str
        Seasonal period (STL, MSTL, classical).  ``"auto"`` infers from
        the sampling interval.
    periods : list[int] | str | None
        Seasonal periods for MSTL.  **Required** when ``method="mstl"``.
    model : {"additive", "multiplicative"}
        Decomposition model.  STL/MSTL use a log-transform approximation
        for multiplicative; classical uses native statsmodels support.
    robust : bool
        Use robust fitting (STL/MSTL only, down-weights outliers).
    trend_window : int | None
        Trend smoother window (STL only).
    seasonal_window : int | None
        Seasonal smoother window (STL only).
    low_pass_window : int | None
        Low-pass filter window (STL only).
    two_sided : bool
        Two-sided (centered) moving average for trend (classical only).
    extrapolate_trend : int | str
        Extrapolate trend at edges (classical only).  ``0`` leaves NaN.
    color_palette : list[str] | None
        Custom color palette.
    show_legend : bool
        Whether to show the legend.
    title : str | None
        Plot title.
    x_label : str | None
        X-axis label on bottom subplot.  Defaults to ``"Time"``.
    y_label : str | None
        Y-axis label.
    width : int | None
        Plot width in pixels.
    height : int | None
        Plot height in pixels.
    connect_gaps : bool
        Whether to connect gaps with lines.
    resampler : bool | Literal["widget"] | None
        Enable plotly-resampler for large datasets.
    line_width : float
        Width of component line traces.
    line_dash : str
        Dash style for component lines.

    Returns
    -------
    go.Figure | dict[str, go.Figure]
        Plotly figure (or dict of figures for panel data).

    Raises
    ------
    TypeError
        If *y* is not a Polars DataFrame.
    ValueError
        If *components* is a list without *method*, DataFrames are empty,
        unknown component names, or ``method="mstl"`` without *periods*.
    ImportError
        When ``statsmodels`` is not installed.

    Examples
    --------
    Pre-computed mode:

    >>> import polars as pl
    >>> from yohou.plotting import plot_decomposition

    >>> dates = pl.date_range(pl.date(2020, 1, 1), pl.date(2020, 12, 31), "1d", eager=True)
    >>> y = pl.DataFrame({"time": dates, "y": list(range(len(dates)))})
    >>> comps = {
    ...     "trend": pl.DataFrame({"time": dates, "y": [i * 0.5 for i in range(len(dates))]}),
    ...     "residual": pl.DataFrame({"time": dates, "y": [i * 0.5 for i in range(len(dates))]}),
    ... }
    >>> fig = plot_decomposition(y, comps)
    >>> len(fig.data) >= 3
    True

    STL mode:

    >>> df = pl.DataFrame({
    ...     "time": pl.date_range(pl.date(2018, 1, 1), pl.date(2022, 12, 31), "1mo", eager=True),
    ...     "y": [100 + 10 * (i % 12) + i * 0.5 for i in range(60)],
    ... })
    >>> fig = plot_decomposition(df, ["trend", "seasonal"], method="stl")  # doctest: +SKIP

    Classical mode:

    >>> fig = plot_decomposition(df, ["trend", "seasonal"], method="classical")  # doctest: +SKIP

    See Also
    --------
    [`plot_forecast`][yohou.plotting.plot_forecast] : Forecast visualization.
    [`plot_seasonality`][yohou.plotting.plot_seasonality] : Seasonal pattern analysis.
    """
    import warnings  # noqa: PLC0415

    validate_plotting_data(y)
    validate_plotting_params(width=width, height=height)

    decomp_mode = isinstance(components, list | tuple) and (not components or isinstance(components[0], str))

    # Decomposition mode
    if decomp_mode:
        if method is None:
            msg = (
                "method is required when components is a list/tuple. "
                "Use method='stl', method='mstl', or method='classical'."
            )
            raise ValueError(msg)

        if method == "mstl" and periods is None:
            msg = "periods is required when method='mstl'."
            raise ValueError(msg)

        # Warn on mismatched params
        _stl_only = {
            "trend_window": trend_window,
            "seasonal_window": seasonal_window,
            "low_pass_window": low_pass_window,
        }
        _classical_only = {"two_sided": two_sided, "extrapolate_trend": extrapolate_trend}

        if method != "stl":
            for pname, pval in _stl_only.items():
                if pval is not None:
                    warnings.warn(
                        f"'{pname}' is only used with method='stl' (got method='{method}'); ignored.",
                        UserWarning,
                        stacklevel=2,
                    )
        if method != "classical":
            if two_sided is not True:
                warnings.warn(
                    f"'two_sided' is only used with method='classical' (got method='{method}'); ignored.",
                    UserWarning,
                    stacklevel=2,
                )
            if extrapolate_trend != 0:
                warnings.warn(
                    f"'extrapolate_trend' is only used with method='classical' (got method='{method}'); ignored.",
                    UserWarning,
                    stacklevel=2,
                )

        components_list = list(components)

        # "observed" in the list maps to show_original
        if "observed" in components_list:
            show_original = True
            components_list.remove("observed")

        # Validate component names early
        valid_names = {"trend", "seasonal", "residual", "seasonal_adjusted"}
        unknown = {c for c in components_list if c not in valid_names and not re.match(r"^seasonal_\w+$", c)}
        if unknown:
            all_valid = sorted(valid_names | {"observed"})
            msg = f"Unknown components: {unknown}. Valid: {all_valid} (also seasonal_<period>)"
            raise ValueError(msg)

        if not components_list and not show_original:
            msg = "components must contain at least one displayable component"
            raise ValueError(msg)

        value_cols = validate_plotting_data(y, columns=columns, exclude=["time"])

        if method == "mstl":
            components = _mstl_to_component_dict(
                y,
                components_list,
                value_cols,
                periods,  # ty: ignore[invalid-argument-type]
                robust,
                model=model,
            )
            title = title or "MSTL Decomposition"
        elif method == "classical":
            components = _classical_to_component_dict(
                y,
                components_list,
                value_cols,
                period=period,
                model=model,
                two_sided=two_sided,
                extrapolate_trend=extrapolate_trend,
            )
            title = title or "Classical Decomposition"
        else:
            # method == "stl"
            components = _stl_to_component_dict(
                y,
                components_list,
                value_cols,
                period=period,
                model=model,
                robust=robust,
                trend_window=trend_window,
                seasonal_window=seasonal_window,
                low_pass_window=low_pass_window,
            )
            title = title or "STL Decomposition"

        # Fall through to the shared dict plotting below

    # Dict validation
    if not isinstance(components, dict):
        msg = (
            "components must be a dict[str, pl.DataFrame] (pre-computed mode) "
            "or a list[str]/tuple[str, ...] with method set (decomposition mode)"
        )
        raise TypeError(msg)

    if not components:
        msg = "components dict must be non-empty"
        raise ValueError(msg)
    for comp_df in components.values():
        validate_plotting_data(comp_df)

    value_cols = validate_plotting_data(y, columns=columns)

    # Detect panel data
    _, panel_groups = inspect_panel(y)
    is_panel = bool(panel_groups)

    # Auto-enter panel mode when panel columns are detected
    if is_panel and groups is None and columns is None:
        groups = []

    if is_panel and groups is not None:
        return _plot_decomposition_panel(
            y=y,
            components=components,
            value_cols=value_cols,
            decomp_mode=decomp_mode,
            show_original=show_original,
            groups=groups,
            color_palette=color_palette,
            title=title,
            x_label=x_label,
            y_label=y_label,
            width=width,
            height=height,
            connect_gaps=connect_gaps,
            line_width=line_width,
            line_dash=line_dash,
            resampler=resampler,
            show_legend=show_legend,
        )

    # Build subplot structure (shared by both modes)
    panel_names: list[str] = []
    if show_original:
        panel_names.append("Original")
    panel_names.extend(_format_component_label(name) if decomp_mode else name for name in components)

    n_rows = len(panel_names)
    fig = _create_subplots(
        resampler,
        rows=n_rows,
        cols=1,
        shared_xaxes=True,
        vertical_spacing=0.04,
    )

    # Set component names as y-axis titles
    for idx, pname in enumerate(panel_names):
        yaxis_key = f"yaxis{idx + 1}" if idx > 0 else "yaxis"
        fig.layout[yaxis_key].title = {"text": pname, "font": {"size": 12}}

    colors = resolve_color_palette(color_palette, len(value_cols))
    row_offset = 0

    # Original series
    if show_original:
        row_offset = 1
        for i, col in enumerate(value_cols):
            if col in y.columns:
                fig.add_trace(
                    go.Scatter(
                        x=y["time"],
                        y=y[col],
                        mode="lines",
                        line={
                            "color": colors[i % len(colors)],
                            "width": line_width,
                            "dash": line_dash,
                        },
                        name=col,
                        legendgroup=col,
                        showlegend=True,
                        connectgaps=connect_gaps,
                    ),
                    row=1,
                    col=1,
                )

    # Component panels
    _meta_cols = {"time", "vintage_time"}
    for comp_idx, (comp_name, comp_df) in enumerate(components.items()):
        row = comp_idx + 1 + row_offset
        # Resolve value columns for this component: prefer names matching y,
        # otherwise fall back to the component's own numeric columns (handles
        # renamed columns from transformers, e.g. "log_off_0.0_tourists").
        comp_value_cols = [c for c in value_cols if c in comp_df.columns]
        if not comp_value_cols:
            comp_value_cols = [c for c in comp_df.columns if c not in _meta_cols]
        for i, col in enumerate(comp_value_cols):
            # Use the original value_cols name for legend when available,
            # otherwise use the component column name.
            legend_col = value_cols[i] if i < len(value_cols) else col
            display_name = _format_component_label(comp_name) if decomp_mode else legend_col
            fig.add_trace(
                go.Scatter(
                    x=comp_df["time"],
                    y=comp_df[col],
                    mode="lines",
                    line={
                        "color": colors[i % len(colors)],
                        "width": line_width,
                        "dash": line_dash,
                    },
                    name=display_name,
                    legendgroup=display_name if decomp_mode else legend_col,
                    showlegend=(comp_idx == 0 and not show_original),
                    connectgaps=connect_gaps,
                ),
                row=row,
                col=1,
            )

    title_default = title or "Time Series Decomposition"
    default_height = max(300 * n_rows, 400)

    fig = apply_default_layout(
        fig,
        title=title_default,
        x_label=None,
        y_label=None,
        width=width,
        height=height or default_height,
    )

    # When the caller supplies a y-label, apply it to every subplot row
    # so it appears on the left.  Otherwise keep the per-row titles that
    # make_subplots already set (e.g. Trend, Seasonal, Residual).
    if y_label is not None:
        fig.update_yaxes(title_text=y_label)

    # Show x-axis label on bottom subplot only
    x_label_text = x_label if x_label is not None else "Time"
    bottom_xaxis = f"xaxis{n_rows}" if n_rows > 1 else "xaxis"
    fig.layout[bottom_xaxis].title = {"text": x_label_text}

    return fig

Tutorials

The following example notebooks use this component:

  • Decomposition


    Data-Features

    Chain PolynomialTrendForecaster, PatternSeasonalityForecaster, and FourierSeasonalityForecaster inside DecompositionPipeline with component visualisation.

    View · Open in marimo

  • Forecast Visualization


    Visualization

    Visualise point forecasts from single and multiple models, decomposition pipeline components, and time weight decay functions with interactive Plotly.

    View · Open in marimo

  • Seasonal Analysis


    Visualization

    Seasonal overlays, subseasonal structure, ACF/PACF correlation patterns, and STL decomposition for monthly, quarterly, and long-cycle datasets.

    View · Open in marimo