Skip to content

plot_forecast

yohou.plotting.forecasting.plot_forecast(y_test=None, y_pred=None, *, y_train=None, columns=None, coverage_rates=None, n_history=None, groups=None, facet_by='member', facet_n_cols=2, color_palette=None, show_legend=True, title=None, x_label=None, y_label=None, width=None, height=None, connect_gaps=False, resampler=None, line_width=2.0, band_opacity=0.25, show_transition=True)

Plot forecasts with historical data and optional prediction intervals.

Accepts separate DataFrames for actuals and predictions following an sklearn-like API. Automatically detects interval columns from y_pred when coverage_rates is provided.

When y_pred is a dict[str, pl.DataFrame], each entry is treated as a separate model and plotted with a distinct color for side-by-side comparison.

Categorical support: When y_pred contains class-probability columns ({target}_proba_{class} pattern), renders a stacked area chart of predicted probabilities with ground-truth class markers from y_test. When y_pred contains categorical (string) columns, renders a step chart comparing predicted and actual class labels over time.

Parameters

Name Type Description Default
y_test DataFrame | None

Actual test values with 'time' column. When None, only the forecast and interval traces are rendered (no "Actual" line).

None
y_pred DataFrame | dict[str, DataFrame]

Forecast values with 'time' column. May also contain interval columns named {col}_lower_{rate} and {col}_upper_{rate}. If a dict, keys are model names and values are prediction DataFrames.

None
y_train DataFrame | None

Historical training data with 'time' column. If provided, shown before the forecast period.

None
columns str | list[str] | None

Target column(s) to plot from y_test. When None, all non-time columns are used. Associated interval columns ({col}_lower_{rate} / {col}_upper_{rate}) are kept automatically.

None
coverage_rates list[float] | None

Coverage rates to display intervals for (e.g., [0.9, 0.95]). Looks for {col}_lower_{rate} / {col}_upper_{rate} in y_pred.

None
n_history int | None

Number of historical observations to show from y_train. If None, shows all.

None
groups list[str] | None

Panel group prefixes to plot. If None and panel data is detected, plots all groups. Creates faceted subplots.

None
facet_by Literal['group', 'member'] | None

Faceting axis for panel data. "group" creates one subplot per group, "member" one per member. None disables faceting. Ignored for non-panel data.

"member"
facet_n_cols int

Number of columns in facet grid for panel data.

2
color_palette list[str] | None

Custom color palette.

None
show_legend bool

Whether to show the legend.

True
title str | None

Plot title.

None
x_label str | None

X-axis label.

None
y_label str | None

Y-axis label.

None
width int | None

Plot width in pixels.

None
height int | None

Plot height in pixels.

None
connect_gaps bool

Whether to connect gaps in the data with lines.

False
resampler bool | Literal['widget'] | None

Enable plotly-resampler for large datasets. True or "widget" creates a FigureWidgetResampler; False or None uses a plain go.Figure.

None
line_width float

Width of line traces.

2.0
band_opacity float

Opacity of prediction interval bands.

0.25
show_transition bool

Whether to show a dashed connector between the last training point and the first forecast point.

True

Returns

Type Description
Figure

Plotly figure object.

Raises

Type Description
TypeError

If inputs are not Polars DataFrames.

ValueError

If DataFrames are empty or missing 'time' column.

Examples

>>> import polars as pl
>>> from yohou.plotting import plot_forecast
>>> # Create sample data
>>> y_train = pl.DataFrame({
...     "time": pl.date_range(pl.date(2020, 1, 1), pl.date(2020, 3, 31), "1d", eager=True),
...     "y": [100 + i for i in range(91)],
... })
>>> y_test = pl.DataFrame({
...     "time": pl.date_range(pl.date(2020, 4, 1), pl.date(2020, 4, 30), "1d", eager=True),
...     "y": [191 + i for i in range(30)],
... })
>>> y_pred = pl.DataFrame({
...     "time": pl.date_range(pl.date(2020, 4, 1), pl.date(2020, 4, 30), "1d", eager=True),
...     "y": [190 + i + (i % 3) for i in range(30)],
... })
>>> fig = plot_forecast(y_test, y_pred, y_train=y_train)
>>> len(fig.data) >= 2
True

Multi-model comparison:

>>> y_pred_b = pl.DataFrame({
...     "time": pl.date_range(pl.date(2020, 4, 1), pl.date(2020, 4, 30), "1d", eager=True),
...     "y": [192 + i for i in range(30)],
... })
>>> fig = plot_forecast(y_test, {"Model A": y_pred, "Model B": y_pred_b})
>>> len(fig.data) >= 3
True

See Also

plot_residuals : Plot residual diagnostics. plot_score_per_step : Score by horizon step.

Source Code

Show/Hide source
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
def plot_forecast(
    y_test: pl.DataFrame | None = None,
    y_pred: pl.DataFrame | dict[str, pl.DataFrame] | None = None,
    *,
    y_train: pl.DataFrame | None = None,
    columns: str | list[str] | None = None,
    coverage_rates: list[float] | None = None,
    n_history: int | None = None,
    groups: list[str] | None = None,
    facet_by: Literal["group", "member"] | None = "member",
    facet_n_cols: int = 2,
    color_palette: list[str] | None = None,
    show_legend: bool = True,
    title: str | None = None,
    x_label: str | None = None,
    y_label: str | None = None,
    width: int | None = None,
    height: int | None = None,
    connect_gaps: bool = False,
    resampler: bool | Literal["widget"] | None = None,
    line_width: float = 2.0,
    band_opacity: float = 0.25,
    show_transition: bool = True,
) -> go.Figure:
    """Plot forecasts with historical data and optional prediction intervals.

    Accepts separate DataFrames for actuals and predictions following an
    sklearn-like API. Automatically detects interval columns from y_pred
    when coverage_rates is provided.

    When *y_pred* is a ``dict[str, pl.DataFrame]``, each entry is treated as
    a separate model and plotted with a distinct color for side-by-side
    comparison.

    **Categorical support**: When *y_pred* contains class-probability
    columns (``{target}_proba_{class}`` pattern), renders a stacked area
    chart of predicted probabilities with ground-truth class markers from
    *y_test*. When *y_pred* contains categorical (string) columns, renders
    a step chart comparing predicted and actual class labels over time.

    Parameters
    ----------
    y_test : pl.DataFrame | None, default=None
        Actual test values with 'time' column.  When ``None``, only
        the forecast and interval traces are rendered (no "Actual" line).
    y_pred : pl.DataFrame | dict[str, pl.DataFrame]
        Forecast values with 'time' column. May also contain interval columns
        named ``{col}_lower_{rate}`` and ``{col}_upper_{rate}``.
        If a dict, keys are model names and values are prediction DataFrames.
    y_train : pl.DataFrame | None, default=None
        Historical training data with 'time' column. If provided, shown
        before the forecast period.
    columns : str | list[str] | None, default=None
        Target column(s) to plot from *y_test*.  When ``None``, all
        non-time columns are used.  Associated interval columns
        (``{col}_lower_{rate}`` / ``{col}_upper_{rate}``) are kept
        automatically.
    coverage_rates : list[float] | None, default=None
        Coverage rates to display intervals for (e.g., [0.9, 0.95]).
        Looks for ``{col}_lower_{rate}`` / ``{col}_upper_{rate}`` in y_pred.
    n_history : int | None, default=None
        Number of historical observations to show from y_train. If None, shows all.
    groups : list[str] | None, default=None
        Panel group prefixes to plot. If None and panel data is detected,
        plots all groups. Creates faceted subplots.
    facet_by : Literal["group", "member"] | None, default="member"
        Faceting axis for panel data.  ``"group"`` creates one subplot per
        group, ``"member"`` one per member.  ``None`` disables faceting.
        Ignored for non-panel data.
    facet_n_cols : int, default=2
        Number of columns in facet grid for panel data.
    color_palette : list[str] | None, default=None
        Custom color palette.
    show_legend : bool, default=True
        Whether to show the legend.
    title : str | None, default=None
        Plot title.
    x_label : str | None, default=None
        X-axis label.
    y_label : str | None, default=None
        Y-axis label.
    width : int | None, default=None
        Plot width in pixels.
    height : int | None, default=None
        Plot height in pixels.
    connect_gaps : bool, default=False
        Whether to connect gaps in the data with lines.
    resampler : bool | Literal["widget"] | None, default=None
        Enable plotly-resampler for large datasets.  ``True`` or
        ``"widget"`` creates a ``FigureWidgetResampler``; ``False`` or
        ``None`` uses a plain ``go.Figure``.
    line_width : float, default=2.0
        Width of line traces.
    band_opacity : float, default=0.25
        Opacity of prediction interval bands.
    show_transition : bool, default=True
        Whether to show a dashed connector between the last training
        point and the first forecast point.

    Returns
    -------
    go.Figure
        Plotly figure object.

    Raises
    ------
    TypeError
        If inputs are not Polars DataFrames.
    ValueError
        If DataFrames are empty or missing 'time' column.

    Examples
    --------
    >>> import polars as pl
    >>> from yohou.plotting import plot_forecast

    >>> # Create sample data
    >>> y_train = pl.DataFrame({
    ...     "time": pl.date_range(pl.date(2020, 1, 1), pl.date(2020, 3, 31), "1d", eager=True),
    ...     "y": [100 + i for i in range(91)],
    ... })
    >>> y_test = pl.DataFrame({
    ...     "time": pl.date_range(pl.date(2020, 4, 1), pl.date(2020, 4, 30), "1d", eager=True),
    ...     "y": [191 + i for i in range(30)],
    ... })
    >>> y_pred = pl.DataFrame({
    ...     "time": pl.date_range(pl.date(2020, 4, 1), pl.date(2020, 4, 30), "1d", eager=True),
    ...     "y": [190 + i + (i % 3) for i in range(30)],
    ... })

    >>> fig = plot_forecast(y_test, y_pred, y_train=y_train)
    >>> len(fig.data) >= 2
    True

    Multi-model comparison:

    >>> y_pred_b = pl.DataFrame({
    ...     "time": pl.date_range(pl.date(2020, 4, 1), pl.date(2020, 4, 30), "1d", eager=True),
    ...     "y": [192 + i for i in range(30)],
    ... })
    >>> fig = plot_forecast(y_test, {"Model A": y_pred, "Model B": y_pred_b})
    >>> len(fig.data) >= 3
    True

    See Also
    --------
    [`plot_residuals`][yohou.plotting.plot_residuals] : Plot residual diagnostics.
    [`plot_score_per_step`][yohou.plotting.plot_score_per_step] : Score by horizon step.
    """
    # Validate inputs
    if y_pred is None:
        raise TypeError("y_pred is required")
    if y_test is not None:
        validate_plotting_data(y_test)
    if isinstance(y_pred, dict):
        for _name, pred_df in y_pred.items():
            validate_plotting_data(pred_df)  # ty: ignore[invalid-argument-type]
    else:
        validate_plotting_data(y_pred)
    if y_train is not None:
        validate_plotting_data(y_train)
    validate_plotting_params(width=width, height=height)

    # Semantic colors always come from the effective palette: slot 0 = history,
    # slot 1 = single-model forecast, slot 2 = actual, slot 3+ = model comparison.
    eff_palette = color_palette if color_palette is not None else _PALETTE
    forecast_color = eff_palette[1 % len(eff_palette)]
    actual_color = eff_palette[2 % len(eff_palette)]

    # Auto-detect prediction type from the first prediction DataFrame
    _first_pred = next(iter(y_pred.values())) if isinstance(y_pred, dict) else y_pred
    prediction_mode = _detect_prediction_mode(_first_pred)  # ty: ignore[invalid-argument-type]

    # Detect panel data (fall back to y_pred when y_test is not provided)
    _panel_source: pl.DataFrame = y_test if y_test is not None else _first_pred  # ty: ignore[invalid-assignment]
    _, _panels = inspect_panel(_panel_source)
    is_panel = bool(_panels)

    # For panel data, delegate to faceted handler (single or multi-model)
    if is_panel:
        return _plot_forecast_panel(
            y_test=y_test,
            y_pred=y_pred,
            y_train=y_train,
            coverage_rates=coverage_rates,
            n_history=n_history,
            groups=groups,
            facet_by=facet_by,
            facet_n_cols=facet_n_cols,
            color_palette=color_palette,
            title=title,
            x_label=x_label,
            y_label=y_label,
            width=width,
            height=height,
            line_width=line_width,
            band_opacity=band_opacity,
            show_transition=show_transition,
            connect_gaps=connect_gaps,
            resampler=resampler,
            show_legend=show_legend,
            prediction_mode=prediction_mode,
        )

    # Non-panel class-probability predictions
    if prediction_mode == "class_proba":
        if y_test is None:
            raise ValueError("y_test is required for class-probability predictions")
        return _plot_forecast_class_proba(
            y_test=y_test,
            y_pred=y_pred,
            columns=columns,
            color_palette=color_palette,
            title=title,
            x_label=x_label,
            y_label=y_label,
            width=width,
            height=height,
            facet_n_cols=facet_n_cols,
        )

    # Non-panel categorical predictions
    if prediction_mode == "categorical":
        if y_test is None:
            raise ValueError("y_test is required for categorical predictions")
        return _plot_forecast_categorical(
            y_test=y_test,
            y_pred=y_pred,
            y_train=y_train,
            n_history=n_history,
            columns=columns,
            color_palette=color_palette,
            title=title,
            x_label=x_label,
            y_label=y_label,
            width=width,
            height=height,
            facet_n_cols=facet_n_cols,
        )

    # Multi-model dict: delegate to dedicated helper
    if isinstance(y_pred, dict):
        return _plot_forecast_multi_model(
            y_test=y_test,
            y_preds=y_pred,  # ty: ignore[invalid-argument-type]
            y_train=y_train,
            coverage_rates=coverage_rates,
            n_history=n_history,
            columns=columns,
            color_palette=color_palette,
            title=title,
            x_label=x_label,
            y_label=y_label,
            width=width,
            height=height,
            facet_n_cols=facet_n_cols,
            line_width=line_width,
            band_opacity=band_opacity,
            show_transition=show_transition,
            connect_gaps=connect_gaps,
            resampler=resampler,
            show_legend=show_legend,
        )

    # Non-panel, single-model case
    interval_pattern = re.compile(r"^.+_(lower|upper)_[\d.]+$")
    pred_value_cols = [c for c in y_pred.columns if c not in ("time", "vintage_time") and not interval_pattern.match(c)]
    test_value_cols = [c for c in y_test.columns if c != "time"] if y_test is not None else list(pred_value_cols)

    # Apply columns filter
    if columns is not None:
        col_list = [columns] if isinstance(columns, str) else list(columns)
        test_value_cols = [c for c in col_list if c in test_value_cols]
    plot_columns = test_value_cols

    def _render_forecast(ctx: RenderContext) -> None:
        """Render train/interval/actual/forecast traces for one target column."""
        col = ctx.display_name
        pred_col = col if col in pred_value_cols else (pred_value_cols[0] if pred_value_cols else None)

        # Training data
        if y_train is not None and col in y_train.columns:
            train_df = y_train.tail(n_history) if n_history is not None else y_train
            _hex = actual_color.lstrip("#")
            _rgb = tuple(int(_hex[i : i + 2], 16) for i in (0, 2, 4))
            _train_color = f"rgba({_rgb[0]}, {_rgb[1]}, {_rgb[2]}, 0.4)"
            ctx.fig.add_trace(
                go.Scatter(
                    x=train_df["time"],
                    y=train_df[col],
                    mode="lines",
                    line={"color": _train_color, "width": line_width},
                    connectgaps=connect_gaps,
                    name=f"{col} (Train)",
                    legendrank=0,
                    hovertemplate=_make_hovertemplate(f"{col} Train", "Time", "Value"),
                ),
                row=ctx.row,
                col=ctx.col,
            )

        # Prediction intervals (rendered before Actual so actual sits on top)
        if coverage_rates:
            interval_base = pred_col if pred_col is not None else col
            _hex = forecast_color.lstrip("#")
            rgb = tuple(int(_hex[i : i + 2], 16) for i in (0, 2, 4))
            sorted_rates = sorted(coverage_rates)
            n_rates = len(sorted_rates)
            for sort_idx, rate in enumerate(sorted_rates):
                lower_col = f"{interval_base}_lower_{rate}"
                upper_col = f"{interval_base}_upper_{rate}"
                if lower_col in y_pred.columns and upper_col in y_pred.columns:
                    t = y_pred["time"].to_list()
                    y_upper = y_pred[upper_col].to_list()
                    y_lower = y_pred[lower_col].to_list()
                    if rate == 0:
                        ctx.fig.add_trace(
                            go.Scatter(
                                x=t,
                                y=y_upper,
                                mode="lines",
                                line={"dash": "dash", "width": line_width * 0.75, "color": forecast_color},
                                name=f"{col} (Median)",
                                legendrank=11 + sort_idx,
                                hoverinfo="skip",
                            ),
                            row=ctx.row,
                            col=ctx.col,
                        )
                    else:
                        rate_opacity = band_opacity * (1.0 - 0.45 * sort_idx / max(1, n_rates - 1))
                        rgba = f"rgba({rgb[0]}, {rgb[1]}, {rgb[2]}, {rate_opacity:.3f})"
                        x_band = t + t[::-1]
                        y_band = y_upper + y_lower[::-1]
                        ctx.fig.add_trace(
                            go.Scatter(
                                x=x_band,
                                y=y_band,
                                fill="toself",
                                fillcolor=rgba,
                                mode="lines",
                                line={"width": 0, "color": rgba},
                                name=f"{col} ({rate:.0%} PI)",
                                legendrank=11 + sort_idx,
                                hoverinfo="skip",
                            ),
                            row=ctx.row,
                            col=ctx.col,
                        )

        # Actual test data (prepend last train point to close the gap)
        if y_test is not None and col in y_test.columns:
            _x_actual = y_test["time"]
            _y_actual = y_test[col]
            if y_train is not None and col in y_train.columns:
                _x_actual = pl.concat([pl.Series("time", [y_train["time"][-1]]), _x_actual])
                _y_actual = pl.concat([pl.Series([y_train[col][-1]], dtype=_y_actual.dtype), _y_actual])
            ctx.fig.add_trace(
                go.Scatter(
                    x=_x_actual,
                    y=_y_actual,
                    mode="lines",
                    line={"color": actual_color, "width": line_width},
                    connectgaps=connect_gaps,
                    name=f"{col} (Actual)",
                    legendrank=1,
                    hovertemplate=_make_hovertemplate(f"{col} Actual", "Time", "Value"),
                ),
                row=ctx.row,
                col=ctx.col,
            )

        # Forecast
        if pred_col is not None and pred_col in y_pred.columns:
            x_forecast = y_pred["time"]
            forecast_y = y_pred[pred_col]
            if show_transition and y_train is not None and col in y_train.columns:
                last_train_time = y_train["time"][-1]
                last_train_val = y_train[col][-1]
                x_forecast = pl.concat([pl.Series("time", [last_train_time]), y_pred["time"]])
                forecast_y = pl.concat([pl.Series([last_train_val], dtype=forecast_y.dtype), forecast_y])
            ctx.fig.add_trace(
                go.Scatter(
                    x=x_forecast,
                    y=forecast_y,
                    mode="lines",
                    line={"color": forecast_color, "width": line_width},
                    connectgaps=connect_gaps,
                    name=f"{col} (Forecast)",
                    legendrank=10,
                    hovertemplate=_make_hovertemplate(f"{col} Forecast", "Time", "Value"),
                ),
                row=ctx.row,
                col=ctx.col,
            )

    fig = facet_figure(
        y_test if y_test is not None else y_pred,
        _render_forecast,
        columns=plot_columns,
        facet_n_cols=facet_n_cols,
        title=title or "Forecast",
        x_label=x_label or "Time",
        y_label=y_label or "Value",
        width=width,
        height=height,
        resampler=resampler,
    )
    fig.update_layout(showlegend=show_legend)

    return fig

Tutorials

The following example notebooks use this component:

  • How to Tune Fourier Seasonality Terms


    Data-Features

    Explore how Fourier harmonic count affects seasonal fit quality, compare Fourier vs Pattern seasonality, and tune harmonics jointly with GridSearchCV.

    View · Open in marimo

  • How to Forecast with CatBoost


    Forecasting-Models

    Plug CatBoostRegressor into PointReductionForecaster as a drop-in sklearn estimator, compare gradient-boosted versus Ridge linear baseline, and demonstrate the direct reduction strategy with tree-based models.

    View · Open in marimo

  • How to Choose a Decomposition Strategy


    Forecasting-Models

    Build 2- and 3-component DecompositionPipeline forecasters chaining trend, seasonality, and residual models with target pre-transformation.

    View · Open in marimo

  • How to Use Lagged Forecasts as Features


    Forecasting-Models

    Compare ForecastedFeatureForecaster strategies (actual, predicted, rewind) and split ratio tuning for chaining feature and target forecasters.

    View · Open in marimo

  • Observe-Predict Workflow


    Getting-Started

    Walk through a test set in batches, updating forecasts as new data arrives with observe_predict.

    View · Open in marimo

  • Panel Data Forecasting


    Getting-Started

    Forecast multiple related time series simultaneously using the __ naming convention, LocalPanelForecaster, and per-group scoring.

    View · Open in marimo

  • How to Configure LocalPanelForecaster


    Panel-Data

    Wrap any forecaster with LocalPanelForecaster for fully independent per-group clones, parallel fitting via n_jobs, and selective group operations.

    View · Open in marimo

  • How to Run Panel Cross-Validation


    Panel-Data

    Time series cross-validation on panel data with GridSearchCV, selective group observation, rewind operations, and groupwise performance comparison.

    View · Open in marimo

  • How to Forecast Panel Prediction Intervals


    Panel-Data

    Combine conformal and quantile regression intervals on panel data with per-group coverage analysis, calibration plots, and groupwise interval scoring.

    View · Open in marimo

  • How to Visualize Forecasts


    Visualization

    Plot point forecasts, compare multiple models, render prediction interval bands, inspect residual diagnostics, and check interval calibration.

    View · Open in marimo