Skip to content

plot_residuals

yohou.plotting.evaluation.plot_residuals(y_pred, y_truth, *, columns=None, groups=None, facet_by='member', facet_n_cols=2, color_palette=None, show_legend=True, title=None, x_label=None, y_label=None, width=None, height=None, resampler=None, marker_size=4, marker_opacity=0.6, n_bins=30)

Plot diagnostic plots for model residuals.

When a single column is selected, creates a 4-panel layout with residuals over time, residuals vs fitted values, histogram of residuals, and Q-Q plot for checking normality assumptions. When multiple columns are resolved (through columns or groups), produces a faceted layout showing residuals over time for each column.

Residuals are computed internally as y_truth - y_pred for matching non-time columns.

Parameters

Name Type Description Default
y_pred DataFrame

Predicted values with "time" column.

required
y_truth DataFrame

Ground-truth values with "time" column.

required
columns str | list[str] | None

Column(s) to compute residuals for. When groups is set, acts as a member postfix filter (e.g. ["a"] selects y__a). When None, uses all common non-time columns. A single match triggers 4-panel diagnostics, multiple produce facets.

None
groups list[str] | None

Panel group prefixes to facet by.

None
facet_by Literal['group', 'member'] | None

Faceting axis for panel data. "group" creates one subplot per group, "member" one per member. None disables faceting. Ignored for non-panel data.

"member"
facet_n_cols int

Number of columns in the faceted grid when multiple target columns are resolved.

2
color_palette list[str] | None

Custom color palette. If None, uses yohou palette.

None
show_legend bool

Whether to show the legend.

True
title str | None

Plot title.

None
x_label str | None

X-axis label.

None
y_label str | None

Y-axis label.

None
width int | None

Plot width in pixels.

None
height int | None

Plot height in pixels.

None
resampler bool | Literal['widget'] | None

Enable plotly-resampler for large datasets. "figure" creates a FigureResampler, "widget" a FigureWidgetResampler.

None
marker_size float

Marker size for scatter plots.

4
marker_opacity float

Marker opacity.

0.6
n_bins int

Number of bins for histogram (single-column diagnostics).

30

Returns

Type Description
Figure

Plotly figure object.

Examples

>>> import polars as pl
>>> from yohou.plotting import plot_residuals
>>> dates = pl.date_range(pl.date(2020, 1, 1), pl.date(2020, 3, 31), "1d", eager=True)
>>> y_truth = pl.DataFrame({"time": dates, "y": [100 + i for i in range(91)]})
>>> y_pred = pl.DataFrame({"time": dates, "y": [100 + i + (i % 3) for i in range(91)]})
>>> fig = plot_residuals(y_pred, y_truth)
>>> len(fig.data) > 0
True

See Also

plot_forecast : Plot forecasts with historical data.

Source Code

Show/Hide source
def plot_residuals(
    y_pred: pl.DataFrame,
    y_truth: pl.DataFrame,
    *,
    columns: str | list[str] | None = None,
    groups: list[str] | None = None,
    facet_by: Literal["group", "member"] | None = "member",
    facet_n_cols: int = 2,
    color_palette: list[str] | None = None,
    show_legend: bool = True,
    title: str | None = None,
    x_label: str | None = None,
    y_label: str | None = None,
    width: int | None = None,
    height: int | None = None,
    resampler: bool | Literal["widget"] | None = None,
    marker_size: float = 4,
    marker_opacity: float = 0.6,
    n_bins: int = 30,
) -> go.Figure:
    """Plot diagnostic plots for model residuals.

    When a single column is selected, creates a 4-panel layout with
    residuals over time, residuals vs fitted values, histogram of
    residuals, and Q-Q plot for checking normality assumptions.  When
    multiple columns are resolved (through *columns* or
    *groups*), produces a faceted layout showing residuals
    over time for each column.

    Residuals are computed internally as ``y_truth - y_pred`` for matching
    non-time columns.

    Parameters
    ----------
    y_pred : pl.DataFrame
        Predicted values with ``"time"`` column.
    y_truth : pl.DataFrame
        Ground-truth values with ``"time"`` column.
    columns : str | list[str] | None, default=None
        Column(s) to compute residuals for.  When *groups* is
        set, acts as a member postfix filter (e.g. ``["a"]`` selects
        ``y__a``).  When *None*, uses all common non-time columns. A
        single match triggers 4-panel diagnostics, multiple produce facets.
    groups : list[str] | None, default=None
        Panel group prefixes to facet by.
    facet_by : Literal["group", "member"] | None, default="member"
        Faceting axis for panel data. ``"group"`` creates one subplot per
        group, ``"member"`` one per member. ``None`` disables faceting.
        Ignored for non-panel data.
    facet_n_cols : int, default=2
        Number of columns in the faceted grid when multiple target columns
        are resolved.
    color_palette : list[str] | None, default=None
        Custom color palette. If None, uses yohou palette.
    show_legend : bool, default=True
        Whether to show the legend.
    title : str | None, default=None
        Plot title.
    x_label : str | None, default=None
        X-axis label.
    y_label : str | None, default=None
        Y-axis label.
    width : int | None, default=None
        Plot width in pixels.
    height : int | None, default=None
        Plot height in pixels.
    resampler : bool | Literal["widget"] | None, default=None
        Enable plotly-resampler for large datasets. ``"figure"`` creates a
        ``FigureResampler``, ``"widget"`` a ``FigureWidgetResampler``.
    marker_size : float, default=4
        Marker size for scatter plots.
    marker_opacity : float, default=0.6
        Marker opacity.
    n_bins : int, default=30
        Number of bins for histogram (single-column diagnostics).

    Returns
    -------
    go.Figure
        Plotly figure object.

    Examples
    --------
    >>> import polars as pl
    >>> from yohou.plotting import plot_residuals

    >>> dates = pl.date_range(pl.date(2020, 1, 1), pl.date(2020, 3, 31), "1d", eager=True)
    >>> y_truth = pl.DataFrame({"time": dates, "y": [100 + i for i in range(91)]})
    >>> y_pred = pl.DataFrame({"time": dates, "y": [100 + i + (i % 3) for i in range(91)]})

    >>> fig = plot_residuals(y_pred, y_truth)
    >>> len(fig.data) > 0
    True

    See Also
    --------
    [`plot_forecast`][yohou.plotting.plot_forecast] : Plot forecasts with historical data.
    """
    validate_plotting_data(y_pred)
    validate_plotting_data(y_truth)
    validate_plotting_params(width=width, height=height)

    # Auto-detect panel data
    _, _panel_groups = inspect_panel(y_truth)
    if groups is None and columns is None and _panel_groups:
        groups = []

    # Resolve target columns
    if groups is not None:
        target_cols = validate_plotting_data(y_pred, columns=columns, groups=groups)
        missing = [c for c in target_cols if c not in y_truth.columns]
        if missing:
            msg = f"Columns {missing} not found in y_truth"
            raise ValueError(msg)
    elif columns is not None:
        target_cols = [columns] if isinstance(columns, str) else list(columns)
        for col in target_cols:
            if col not in y_pred.columns:
                msg = f"Column '{col}' not found in y_pred"
                raise ValueError(msg)
            if col not in y_truth.columns:
                msg = f"Column '{col}' not found in y_truth"
                raise ValueError(msg)
    else:
        pred_cols = [c for c in y_pred.columns if c != "time"]
        truth_cols = [c for c in y_truth.columns if c != "time"]
        common = [c for c in pred_cols if c in truth_cols]
        if not common:
            msg = "No common non-time columns found between y_pred and y_truth"
            raise ValueError(msg)
        target_cols = common

    # Compute residuals
    residual_exprs = [(pl.col(c) - y_pred[c]).alias(c) for c in target_cols]
    residuals_df = y_truth.select("time", *residual_exprs)

    # Single column: full 4-panel diagnostics
    if len(target_cols) == 1:
        return _render_residual_diagnostics(
            residuals_df,
            y_pred,
            target_cols[0],
            color_palette=color_palette,
            title=title,
            width=width,
            height=height,
            marker_size=marker_size,
            marker_opacity=marker_opacity,
            n_bins=n_bins,
        )

    # Multiple columns: faceted residuals over time

    if groups is not None:
        pn_cols = resolve_panel_columns(residuals_df, groups, columns)
        _, all_members = _group_panel_columns(pn_cols)
        member_palette = resolve_color_palette(color_palette, len(all_members))
        legend_tracker = LegendTracker()

        def _render_residual_scatter(ctx: RenderContext) -> None:
            """Render residuals over time for a single panel column."""
            base = [c for c in ctx.sub_df.columns if c != "time"][0]
            color = member_palette[ctx.entity_idx % len(member_palette)]
            ctx.fig.add_trace(
                go.Scatter(
                    x=ctx.sub_df["time"],
                    y=ctx.sub_df[base],
                    mode="markers",
                    name=ctx.display_name,
                    legendgroup=ctx.display_name,
                    marker={
                        "size": marker_size,
                        "color": color,
                        "opacity": marker_opacity,
                    },
                    showlegend=legend_tracker.should_show(ctx.display_name),
                    hovertemplate=_make_hovertemplate(ctx.display_name, "Time", "Residual", decimals=3),
                ),
                row=ctx.row,
                col=ctx.col,
            )
            ctx.fig.add_hline(
                y=0,
                line={"dash": "dash", "color": "#DC2626", "width": 1},
                row=ctx.row,
                col=ctx.col,
            )

        effective_facet_by = facet_by or "member"
        fig = facet_figure(
            residuals_df,
            _render_residual_scatter,
            groups=groups,
            columns=columns,
            facet_by=effective_facet_by,
            facet_n_cols=facet_n_cols,
            title=title or "Residual Diagnostics",
            x_label=x_label or "Time",
            y_label=y_label or "Residuals",
            width=width,
            height=height,
            resampler=resampler,
        )
        fig.update_layout(showlegend=show_legend)
        return fig

    # Non-panel multi-column: column-mode facet_figure
    _colors = resolve_color_palette(color_palette, len(target_cols))
    _col_colors = dict(zip(target_cols, _colors, strict=False))

    def _render_residual(ctx: RenderContext) -> None:
        """Render residual scatter for one column into a subplot."""
        base = ctx.display_name
        col_color = _col_colors[base]
        ctx.fig.add_trace(
            go.Scatter(
                x=residuals_df["time"],
                y=residuals_df[base],
                mode="markers",
                marker={"size": marker_size, "color": col_color, "opacity": marker_opacity},
                showlegend=False,
            ),
            row=ctx.row,
            col=ctx.col,
        )
        ctx.fig.add_hline(
            y=0,
            line={"dash": "dash", "color": "#DC2626", "width": 1},
            row=ctx.row,
            col=ctx.col,
        )

    fig = facet_figure(
        residuals_df,
        _render_residual,
        columns=target_cols,
        facet_n_cols=facet_n_cols,
        title=title or "Residual Diagnostics",
        x_label=x_label or "Time",
        y_label=y_label or "Residuals",
        width=width,
        height=height,
        shared_xaxes=True,
        resampler=resampler,
    )
    fig.update_layout(showlegend=show_legend)

    return fig

Tutorials

The following example notebooks use this component:

  • Forecasting Workflow


    Getting-Started

    Evaluate forecasters with cross-validation, search hyperparameters with GridSearchCV, and inspect residuals to diagnose model weaknesses.

    View · Open in marimo

  • How to Visualize Forecast Evaluation Results


    Visualization

    Use plot_calibration, plot_score_per_step, and plot_forecast to diagnose forecast accuracy and interval calibration visually.

    View · Open in marimo

  • How to Visualize Forecasts


    Visualization

    Plot point forecasts, compare multiple models, render prediction interval bands, inspect residual diagnostics, and check interval calibration.

    View · Open in marimo