Skip to content

plot_calibration

yohou.plotting.evaluation.plot_calibration(y_pred, y_truth, coverage_rates=None, *, columns=None, target=None, n_bins=10, groups=None, facet_by='member', facet_n_cols=2, color_palette=None, show_legend=True, title=None, x_label=None, y_label=None, width=None, height=None, line_width=2.0, line_opacity=1.0, reference_color='#1e293b', reference_width=3.0, reference_dash='dash')

Plot calibration for interval or class-probability forecasts.

Automatically detects the prediction type from column names:

  • Interval predictions: columns named "{target}_upper_{rate}" / "{target}_lower_{rate}" are compared against nominal coverage_rates (empirical vs nominal coverage).
  • Class-probability predictions: columns containing "_proba_" are binned and compared against empirical frequencies (reliability diagram).

A well-calibrated model has points close to the diagonal.

Parameters

Name Type Description Default
y_pred DataFrame

Predicted values. Either prediction intervals with "{target}_upper_{rate}" / "{target}_lower_{rate}" columns, or class-probability predictions with "{target}_proba_{class}" columns.

required
y_truth DataFrame

Ground truth values.

required
coverage_rates list of float or None

Nominal coverage rates for interval calibration (e.g., [0.9, 0.95]). Required when y_pred contains interval columns; ignored for class-probability predictions.

None
columns str | list[str] | None

Target column name(s) for interval calibration. When groups is set this acts as a member postfix filter. Ignored for class-probability predictions (use target instead).

None
target str or None

Target column name for class-probability calibration. Required when multiple targets are present in the probability columns. Ignored for interval calibration.

None
n_bins int

Number of bins for class-probability calibration. Ignored for interval calibration.

10
groups list[str] | None

Panel group prefixes for faceted subplots. When provided, each resolved panel column gets its own subplot.

None
facet_by Literal['group', 'member'] | None

Faceting axis for panel data. "group" creates one subplot per group, "member" one per member. None disables faceting. Ignored for non-panel data.

"member"
facet_n_cols int

Number of columns in the facet grid when groups is used.

2
color_palette list[str] | None

Custom color palette as hex codes. If None, uses yohou palette.

None
show_legend bool

Whether to show the legend.

True
title str | None

Plot title. Defaults to "Calibration plot" for intervals or "Reliability Diagram" for class probabilities.

None
x_label str | None

X-axis label.

None
y_label str | None

Y-axis label.

None
width int | None

Plot width in pixels.

None
height int | None

Plot height in pixels.

None
line_width float

Width of the calibration line.

2.0
line_opacity float

Opacity of the calibration line.

1.0
reference_color str

Colour of the perfect-calibration reference line.

"#1e293b"
reference_width float

Width of the reference line.

3.0
reference_dash str

Dash style of the reference line.

"dash"

Returns

Type Description
Figure

Plotly figure object.

Raises

Type Description
ValueError

If interval columns are missing, coverage_rates is not provided for interval predictions, or class-probability columns are ambiguous.

Examples

Interval calibration:

>>> import polars as pl
>>> import numpy as np
>>> from yohou.plotting import plot_calibration
>>> # Create sample data
>>> n = 100
>>> y_truth = pl.DataFrame({"y": np.random.randn(n)})
>>> y_pred_int = pl.DataFrame({
...     "y_upper_0.9": np.random.randn(n) + 1.65,
...     "y_lower_0.9": np.random.randn(n) - 1.65,
...     "y_upper_0.95": np.random.randn(n) + 1.96,
...     "y_lower_0.95": np.random.randn(n) - 1.96,
... })
>>> # Plot calibration
>>> fig = plot_calibration(y_pred_int, y_truth, coverage_rates=[0.9, 0.95])
>>> len(fig.data)
2

Class-probability calibration (reliability diagram):

>>> from datetime import datetime
>>> y_pred_proba = pl.DataFrame({
...     "time": [datetime(2020, 1, i) for i in range(1, 11)],
...     "w_proba_sunny": [0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.9, 0.85],
...     "w_proba_rainy": [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.1, 0.15],
... })
>>> y_truth_cat = pl.DataFrame({
...     "time": [datetime(2020, 1, i) for i in range(1, 11)],
...     "w": ["sunny", "sunny", "sunny", "rainy", "rainy", "rainy", "rainy", "rainy", "sunny", "sunny"],
... })
>>> fig = plot_calibration(y_pred_proba, y_truth_cat)
>>> isinstance(fig, go.Figure)
True

See Also

plot_forecast : Plot forecast with optional prediction intervals. plot_residuals : Residual diagnostics with panel facets.

Source Code

Show/Hide source
def plot_calibration(
    y_pred: pl.DataFrame,
    y_truth: pl.DataFrame,
    coverage_rates: list[StrictFloat] | None = None,
    *,
    columns: str | list[str] | None = None,
    target: str | None = None,
    n_bins: int = 10,
    groups: list[str] | None = None,
    facet_by: Literal["group", "member"] | None = "member",
    facet_n_cols: int = 2,
    color_palette: list[str] | None = None,
    show_legend: bool = True,
    title: str | None = None,
    x_label: str | None = None,
    y_label: str | None = None,
    width: int | None = None,
    height: int | None = None,
    line_width: float = 2.0,
    line_opacity: float = 1.0,
    reference_color: str = "#1e293b",
    reference_width: float = 3.0,
    reference_dash: str = "dash",
) -> go.Figure:
    """Plot calibration for interval or class-probability forecasts.

    Automatically detects the prediction type from column names:

    - **Interval predictions**: columns named
      ``"{target}_upper_{rate}"`` / ``"{target}_lower_{rate}"`` are
      compared against nominal *coverage_rates* (empirical vs nominal
      coverage).
    - **Class-probability predictions**: columns containing
      ``"_proba_"`` are binned and compared against empirical
      frequencies (reliability diagram).

    A well-calibrated model has points close to the diagonal.

    Parameters
    ----------
    y_pred : pl.DataFrame
        Predicted values.  Either prediction intervals with
        ``"{target}_upper_{rate}"`` / ``"{target}_lower_{rate}"``
        columns, or class-probability predictions with
        ``"{target}_proba_{class}"`` columns.
    y_truth : pl.DataFrame
        Ground truth values.
    coverage_rates : list of float or None, default=None
        Nominal coverage rates for interval calibration (e.g.,
        ``[0.9, 0.95]``).  Required when *y_pred* contains interval
        columns; ignored for class-probability predictions.
    columns : str | list[str] | None, default=None
        Target column name(s) for interval calibration.  When
        *groups* is set this acts as a member postfix
        filter.  Ignored for class-probability predictions (use
        *target* instead).
    target : str or None, default=None
        Target column name for class-probability calibration.  Required
        when multiple targets are present in the probability columns.
        Ignored for interval calibration.
    n_bins : int, default=10
        Number of bins for class-probability calibration.  Ignored for
        interval calibration.
    groups : list[str] | None, default=None
        Panel group prefixes for faceted subplots.  When provided, each
        resolved panel column gets its own subplot.
    facet_by : Literal["group", "member"] | None, default="member"
        Faceting axis for panel data. ``"group"`` creates one subplot per
        group, ``"member"`` one per member. ``None`` disables faceting.
        Ignored for non-panel data.
    facet_n_cols : int, default=2
        Number of columns in the facet grid when *groups* is
        used.
    color_palette : list[str] | None, default=None
        Custom color palette as hex codes. If None, uses yohou palette.
    show_legend : bool, default=True
        Whether to show the legend.
    title : str | None, default=None
        Plot title. Defaults to ``"Calibration plot"`` for intervals or
        ``"Reliability Diagram"`` for class probabilities.
    x_label : str | None, default=None
        X-axis label.
    y_label : str | None, default=None
        Y-axis label.
    width : int | None, default=None
        Plot width in pixels.
    height : int | None, default=None
        Plot height in pixels.
    line_width : float, default=2.0
        Width of the calibration line.
    line_opacity : float, default=1.0
        Opacity of the calibration line.
    reference_color : str, default="#1e293b"
        Colour of the perfect-calibration reference line.
    reference_width : float, default=3.0
        Width of the reference line.
    reference_dash : str, default="dash"
        Dash style of the reference line.

    Returns
    -------
    go.Figure
        Plotly figure object.

    Raises
    ------
    ValueError
        If interval columns are missing, *coverage_rates* is not
        provided for interval predictions, or class-probability columns
        are ambiguous.

    Examples
    --------
    Interval calibration:

    >>> import polars as pl
    >>> import numpy as np
    >>> from yohou.plotting import plot_calibration

    >>> # Create sample data
    >>> n = 100
    >>> y_truth = pl.DataFrame({"y": np.random.randn(n)})
    >>> y_pred_int = pl.DataFrame({
    ...     "y_upper_0.9": np.random.randn(n) + 1.65,
    ...     "y_lower_0.9": np.random.randn(n) - 1.65,
    ...     "y_upper_0.95": np.random.randn(n) + 1.96,
    ...     "y_lower_0.95": np.random.randn(n) - 1.96,
    ... })

    >>> # Plot calibration
    >>> fig = plot_calibration(y_pred_int, y_truth, coverage_rates=[0.9, 0.95])
    >>> len(fig.data)
    2

    Class-probability calibration (reliability diagram):

    >>> from datetime import datetime
    >>> y_pred_proba = pl.DataFrame({
    ...     "time": [datetime(2020, 1, i) for i in range(1, 11)],
    ...     "w_proba_sunny": [0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.9, 0.85],
    ...     "w_proba_rainy": [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.1, 0.15],
    ... })
    >>> y_truth_cat = pl.DataFrame({
    ...     "time": [datetime(2020, 1, i) for i in range(1, 11)],
    ...     "w": ["sunny", "sunny", "sunny", "rainy", "rainy", "rainy", "rainy", "rainy", "sunny", "sunny"],
    ... })
    >>> fig = plot_calibration(y_pred_proba, y_truth_cat)
    >>> isinstance(fig, go.Figure)
    True

    See Also
    --------
    [`plot_forecast`][yohou.plotting.plot_forecast] : Plot forecast with optional prediction intervals.
    [`plot_residuals`][yohou.plotting.plot_residuals] : Residual diagnostics with panel facets.
    """
    # Validate inputs
    if not isinstance(y_truth, pl.DataFrame):
        msg = f"Expected pl.DataFrame for y_truth, got {type(y_truth).__name__}"
        raise TypeError(msg)
    if not isinstance(y_pred, pl.DataFrame):
        msg = f"Expected pl.DataFrame for y_pred, got {type(y_pred).__name__}"
        raise TypeError(msg)
    validate_plotting_params(width=width, height=height)

    # Detect class-probability columns
    proba_cols = [c for c in y_pred.columns if "_proba_" in c]
    if proba_cols:
        return _plot_calibration_class_proba(
            y_pred=y_pred,
            y_truth=y_truth,
            target=target,
            n_bins=n_bins,
            color_palette=color_palette,
            title=title,
            x_label=x_label,
            y_label=y_label,
            width=width,
            height=height,
        )

    # Interval calibration path
    if coverage_rates is None:
        msg = "coverage_rates is required for interval calibration. Pass a list of coverage rates (e.g., [0.9, 0.95])."
        raise ValueError(msg)

    # Auto-detect panel data
    _, _panel_groups = inspect_panel(y_truth)
    if groups is None and columns is None and _panel_groups:
        groups = []

    if groups is not None:
        panel_cols = resolve_panel_columns(y_truth, groups, columns)

        # Group columns by panel prefix, collect unique members
        grouped, all_members = _group_panel_columns(panel_cols)

        n_groups = len(grouped)
        n_cols_grid = min(n_groups, facet_n_cols)
        n_rows = (n_groups + n_cols_grid - 1) // n_cols_grid

        fig = make_subplots(
            rows=n_rows,
            cols=n_cols_grid,
            subplot_titles=list(grouped.keys()),
            shared_xaxes=True,
            shared_yaxes=True,
            vertical_spacing=_subplot_spacing(n_rows),
            horizontal_spacing=0.08,
        )

        palette = resolve_color_palette(color_palette, len(all_members))
        legend_tracker = LegendTracker(show_legend=show_legend)
        for group_idx, (_, group_cols) in enumerate(grouped.items()):
            row = group_idx // n_cols_grid + 1
            col_idx = group_idx % n_cols_grid + 1

            for panel_col in group_cols:
                member_name = _member_name(panel_col)
                member_idx = all_members.index(member_name)

                truth_vals = y_truth[panel_col].to_numpy().flatten()
                emp_cov = _compute_empirical_coverages(truth_vals, y_pred, panel_col, coverage_rates)

                fig.add_trace(
                    go.Scatter(
                        x=list(coverage_rates),
                        y=emp_cov,
                        mode="lines+markers",
                        name=member_name,
                        line={"color": palette[member_idx], "width": line_width},
                        opacity=line_opacity,
                        showlegend=legend_tracker.should_show(member_name),
                        legendgroup=member_name,
                        hovertemplate="<b>%{fullData.name}</b><br>Nominal: %{x:.2f}<br>Coverage: %{y:.3f}<extra></extra>",
                    ),
                    row=row,
                    col=col_idx,
                )

            fig.add_trace(
                go.Scatter(
                    x=list(coverage_rates),
                    y=list(coverage_rates),
                    mode="lines",
                    name="Perfect",
                    line={"color": reference_color, "width": reference_width, "dash": reference_dash},
                    showlegend=legend_tracker.should_show("Perfect"),
                    legendgroup="perfect",
                ),
                row=row,
                col=col_idx,
            )

        row_height = 300
        default_height = max(row_height * n_rows, 400)
        fig = apply_default_layout(
            fig,
            title=title or "Calibration plot",
            x_label=x_label or "Nominal coverage",
            y_label=y_label or "Empirical coverage",
            width=width,
            height=height or default_height,
        )
        fig.update_layout(showlegend=show_legend)
        return fig

    if columns is not None:
        target_columns = [columns] if isinstance(columns, str) else list(columns)
        for col in target_columns:
            if col not in y_truth.columns:
                msg = f"Target column '{col}' not found in y_truth"
                raise ValueError(msg)
    else:
        target_columns = [c for c in y_truth.columns if c not in ("time", "vintage_time")]
        if not target_columns:
            msg = "y_truth has no non-time columns"
            raise ValueError(msg)

    palette = resolve_color_palette(color_palette, len(target_columns))

    fig = go.Figure()

    for col_idx, target_column in enumerate(target_columns):
        truth_vals = y_truth[target_column].to_numpy().flatten()
        emp_cov = _compute_empirical_coverages(truth_vals, y_pred, target_column, coverage_rates)

        trace_name = target_column if len(target_columns) > 1 else "Empirical coverage"
        fig.add_trace(
            go.Scatter(
                x=list(coverage_rates),
                y=emp_cov,
                mode="lines+markers",
                name=trace_name,
                line={"color": palette[col_idx % len(palette)], "width": line_width},
                opacity=line_opacity,
                hovertemplate=f"<b>{trace_name}</b><br>Nominal: %{{x:.2f}}<br>Coverage: %{{y:.3f}}<extra></extra>",
            )
        )

    # Reference diagonal (always exactly one)
    fig.add_trace(
        go.Scatter(
            x=list(coverage_rates),
            y=list(coverage_rates),
            mode="lines",
            name="Perfect calibration",
            line={"color": reference_color, "width": reference_width, "dash": reference_dash},
            hovertemplate="<b>Perfect</b><br>Coverage: %{x:.2f}<extra></extra>",
        )
    )

    fig = apply_default_layout(
        fig,
        title=title or "Calibration plot",
        x_label=x_label or "Nominal coverage",
        y_label=y_label or "Empirical coverage",
        width=width,
        height=height,
    )
    fig.update_layout(showlegend=show_legend)

    return fig

Tutorials

The following example notebooks use this component:

  • How to Score Class-Probability Forecasts


    Evaluation-Search

    Evaluate categorical forecasts with LogLoss, BrierScore, and Accuracy. Covers per-timestep scoring, aggregation modes, and reliability diagrams.

    View · Open in marimo

  • How to Evaluate Interval Forecasts


    Evaluation-Search

    Evaluate prediction intervals with EmpiricalCoverage, IntervalScore, MeanIntervalWidth, PinballLoss, and CalibrationError across coverage levels.

    View · Open in marimo

  • How to Forecast Class Probabilities


    Forecasting-Models

    Use ClassProbaReductionForecaster to produce calibrated probability forecasts and evaluate them with Brier score, log loss, and accuracy.

    View · Open in marimo

  • How to Use Distance-Based Similarity for Intervals


    Forecasting-Models

    Adaptive prediction intervals via similarity-weighted conformal prediction using DistanceSimilarity with configurable distance metrics and bandwidths.

    View · Open in marimo

  • Quickstart


    Quickstart

    Comprehensive end-to-end tour of yohou beyond the Getting Started tutorials, covering data loading, baseline forecasting, preprocessing pipelines, decomposition, cross-validation search, and interval prediction.

    View · Open in marimo

  • How to Visualize Forecast Evaluation Results


    Visualization

    Use plot_calibration, plot_score_per_step, and plot_forecast to diagnose forecast accuracy and interval calibration visually.

    View · Open in marimo

  • Forecast Visualization


    Visualization

    Visualise point forecasts from single and multiple models, decomposition pipeline components, and time weight decay functions with interactive Plotly.

    View · Open in marimo

  • How to Visualize Forecasts


    Visualization

    Plot point forecasts, compare multiple models, render prediction interval bands, inspect residual diagnostics, and check interval calibration.

    View · Open in marimo