Skip to content

plot_score_heatmap

yohou.plotting.evaluation.plot_score_heatmap(scorer, y_truth, y_pred, *, x_dim='step', y_dim='vintage', color_palette=None, text_format='.2f', show_text=True, title=None, x_label=None, y_label=None, width=None, height=None)

Plot a 2D heatmap of scores across two forecast dimensions.

Creates a heatmap where each cell shows the score for a specific combination of forecast dimensions (e.g., horizon step vs vintage).

Parameters

Name Type Description Default
scorer BaseScorer

Yohou scorer instance. Used to compute per-cell scores.

required
y_truth DataFrame

Ground truth with "time" column.

required
y_pred DataFrame

Predictions with "vintage_time" and "time" columns. Only single-model input is supported.

required
x_dim str

Dimension for the x-axis: "step" (forecast horizon) or "vintage" (observed time).

"step"
y_dim str

Dimension for the y-axis: "step" or "vintage". Must differ from x_dim.

"vintage"
color_palette str or None

Plotly colorscale name. If None, auto-selects based on scorer._lower_is_better: "Blues" when lower is better, "Blues_r" otherwise.

None
text_format str

Format string for cell text annotations.

".2f"
show_text bool

Whether to display numeric annotations in cells.

True
title str | None

Plot title.

None
x_label str | None

X-axis label.

None
y_label str | None

Y-axis label.

None
width int | None

Plot width in pixels.

None
height int | None

Plot height in pixels.

None

Returns

Type Description
Figure

Plotly figure with heatmap.

Raises

Type Description
ValueError

If x_dim and y_dim are the same, or if y_pred has a single vintage.

Examples

>>> import polars as pl
>>> from datetime import datetime
>>> from yohou.metrics import MeanAbsoluteError
>>> from yohou.plotting import plot_score_heatmap
>>> y_truth = pl.DataFrame({
...     "time": [datetime(2020, 1, i) for i in range(1, 4)],
...     "value": [10.0, 20.0, 30.0],
... })
>>> y_pred = pl.DataFrame({
...     "vintage_time": [datetime(2019, 12, 30)] * 3 + [datetime(2019, 12, 31)] * 3,
...     "time": [datetime(2020, 1, i) for i in range(1, 4)] * 2,
...     "value": [12.0, 19.0, 28.0, 11.0, 21.0, 29.0],
... }).sort("time")
>>> fig = plot_score_heatmap(MeanAbsoluteError(), y_truth, y_pred)
>>> len(fig.data)
1

See Also

plot_score_per_step : Score by horizon step. plot_score_per_vintage : Score by vintage.

Source Code

Show/Hide source
def plot_score_heatmap(
    scorer: BaseScorer,
    y_truth: pl.DataFrame,
    y_pred: pl.DataFrame,
    *,
    x_dim: Literal["step", "vintage"] = "step",
    y_dim: Literal["step", "vintage"] = "vintage",
    color_palette: str | None = None,
    text_format: str = ".2f",
    show_text: bool = True,
    title: str | None = None,
    x_label: str | None = None,
    y_label: str | None = None,
    width: int | None = None,
    height: int | None = None,
) -> go.Figure:
    """Plot a 2D heatmap of scores across two forecast dimensions.

    Creates a heatmap where each cell shows the score for a specific
    combination of forecast dimensions (e.g., horizon step vs vintage).

    Parameters
    ----------
    scorer : BaseScorer
        Yohou scorer instance. Used to compute per-cell scores.
    y_truth : pl.DataFrame
        Ground truth with ``"time"`` column.
    y_pred : pl.DataFrame
        Predictions with ``"vintage_time"`` and ``"time"`` columns.
        Only single-model input is supported.
    x_dim : str, default="step"
        Dimension for the x-axis: ``"step"`` (forecast horizon) or
        ``"vintage"`` (observed time).
    y_dim : str, default="vintage"
        Dimension for the y-axis: ``"step"`` or ``"vintage"``.
        Must differ from ``x_dim``.
    color_palette : str or None, default=None
        Plotly colorscale name. If None, auto-selects based on
        ``scorer._lower_is_better``: ``"Blues"`` when lower is better,
        ``"Blues_r"`` otherwise.
    text_format : str, default=".2f"
        Format string for cell text annotations.
    show_text : bool, default=True
        Whether to display numeric annotations in cells.
    title : str | None, default=None
        Plot title.
    x_label : str | None, default=None
        X-axis label.
    y_label : str | None, default=None
        Y-axis label.
    width : int | None, default=None
        Plot width in pixels.
    height : int | None, default=None
        Plot height in pixels.

    Returns
    -------
    go.Figure
        Plotly figure with heatmap.

    Raises
    ------
    ValueError
        If ``x_dim`` and ``y_dim`` are the same, or if ``y_pred`` has
        a single vintage.

    Examples
    --------
    >>> import polars as pl
    >>> from datetime import datetime
    >>> from yohou.metrics import MeanAbsoluteError
    >>> from yohou.plotting import plot_score_heatmap

    >>> y_truth = pl.DataFrame({
    ...     "time": [datetime(2020, 1, i) for i in range(1, 4)],
    ...     "value": [10.0, 20.0, 30.0],
    ... })
    >>> y_pred = pl.DataFrame({
    ...     "vintage_time": [datetime(2019, 12, 30)] * 3 + [datetime(2019, 12, 31)] * 3,
    ...     "time": [datetime(2020, 1, i) for i in range(1, 4)] * 2,
    ...     "value": [12.0, 19.0, 28.0, 11.0, 21.0, 29.0],
    ... }).sort("time")

    >>> fig = plot_score_heatmap(MeanAbsoluteError(), y_truth, y_pred)
    >>> len(fig.data)
    1

    See Also
    --------
    [`plot_score_per_step`][yohou.plotting.plot_score_per_step] : Score by horizon step.
    [`plot_score_per_vintage`][yohou.plotting.plot_score_per_vintage] : Score by vintage.
    """
    validate_plotting_data(y_truth)
    validate_plotting_data(y_pred)
    validate_plotting_params(width=width, height=height)

    if x_dim == y_dim:
        msg = f"x_dim and y_dim must differ, got x_dim={x_dim!r} and y_dim={y_dim!r}"
        raise ValueError(msg)

    valid_dims = {"step", "vintage"}
    for dim_name, dim_val in [("x_dim", x_dim), ("y_dim", y_dim)]:
        if dim_val not in valid_dims:
            msg = f"{dim_name} must be one of {valid_dims}, got {dim_val!r}"
            raise ValueError(msg)

    if "vintage_time" not in y_pred.columns:
        msg = "y_pred must have an 'vintage_time' column for heatmap plotting"
        raise ValueError(msg)

    vintages = y_pred["vintage_time"].unique().sort()
    if len(vintages) <= 1:
        msg = "y_pred has only a single vintage (vintage_time). plot_score_heatmap requires at least 2 vintages."
        raise ValueError(msg)

    # Compute per-vintage, per-step scores
    scorer_cw = _prepare_scorer_for_componentwise(copy.deepcopy(scorer))
    scorer_cw.fit(y_truth)

    score_matrix: list[list[float]] = []
    vintage_labels: list[str] = []
    n_steps: int | None = None

    for vintage_val in vintages:
        y_pred_v = y_pred.filter(pl.col("vintage_time") == vintage_val)
        scores_df = scorer_cw.score(y_truth, y_pred_v)

        if not isinstance(scores_df, pl.DataFrame):
            msg = f"Scorer must return DataFrame for componentwise aggregation, got {type(scores_df).__name__}"
            raise TypeError(msg)

        score_cols = [c for c in scores_df.columns if c not in _SCORER_META_COLS]
        if len(score_cols) == 1:
            row_scores = scores_df[score_cols[0]].to_list()
        else:
            row_scores = scores_df.select(score_cols).mean_horizontal().to_list()

        if n_steps is None:
            n_steps = len(row_scores)

        # Pad or truncate to consistent length
        row_scores = row_scores[:n_steps] + [float("nan")] * max(0, n_steps - len(row_scores))

        score_matrix.append(row_scores)
        vintage_labels.append(str(vintage_val))

    step_labels = [str(i) for i in range(1, (n_steps or 0) + 1)]
    z = np.array(score_matrix)

    # Arrange dimensions
    if x_dim == "step" and y_dim == "vintage":
        x_labels = step_labels
        y_labels = vintage_labels
        z_plot = z
    else:
        # x_dim == "vintage" and y_dim == "step"
        x_labels = vintage_labels
        y_labels = step_labels
        z_plot = z.T

    # Auto-select colorscale
    if color_palette is None:
        lower_is_better = getattr(scorer, "_lower_is_better", True)
        colorscale = "Blues" if lower_is_better else "Blues_r"
    else:
        colorscale = color_palette

    # Build text annotations
    text_vals = [[f"{v:{text_format}}" for v in row] for row in z_plot] if show_text else None

    fig = go.Figure(
        data=go.Heatmap(
            z=z_plot,
            x=x_labels,
            y=y_labels,
            colorscale=colorscale,
            text=text_vals,
            texttemplate="%{text}" if show_text else None,
            hovertemplate="x: %{x}<br>y: %{y}<br>Score: %{z:.3f}<extra></extra>",
        )
    )

    scorer_name = scorer.__class__.__name__
    default_x = x_label or ("Horizon Step" if x_dim == "step" else "Vintage (Observed Time)")
    default_y = y_label or ("Horizon Step" if y_dim == "step" else "Vintage (Observed Time)")
    default_title = title or f"{scorer_name} Heatmap"

    fig = apply_default_layout(
        fig,
        title=default_title,
        x_label=default_x,
        y_label=default_y,
        width=width,
        height=height,
    )

    return fig

Tutorials

The following example notebooks use this component:

  • How to Evaluate Interval Forecasts


    Evaluation-Search

    Evaluate prediction intervals with EmpiricalCoverage, IntervalScore, MeanIntervalWidth, PinballLoss, and CalibrationError across coverage levels.

    View · Open in marimo

  • How to Score Multi-Vintage Forecasts


    Evaluation-Search

    Generate multi-vintage predictions with observe_predict, score per step and per vintage, and visualize with heatmap, per-step, and per-vintage plots.

    View · Open in marimo

  • How to Use Distance-Based Similarity for Intervals


    Forecasting-Models

    Adaptive prediction intervals via similarity-weighted conformal prediction using DistanceSimilarity with configurable distance metrics and bandwidths.

    View · Open in marimo

  • How to Build Interval Forecasts with Reduction


    Forecasting-Models

    Wrap any quantile-capable sklearn estimator with IntervalReductionForecaster to produce calibrated prediction intervals across multiple horizons.

    View · Open in marimo

  • How to Combine Forecasters with VotingPointForecaster


    Forecasting-Models

    Build point ensembles with VotingPointForecaster using mean, weighted, and median aggregation strategies.

    View · Open in marimo

  • Naive Forecasters


    Getting-Started

    Baseline forecasting (the first portion of the First Forecast tutorial) with SeasonalNaive using different seasonality periods, the observe/predict streaming workflow, and rolling evaluation patterns.

    View · Open in marimo

  • How to Visualize Forecast Evaluation Results


    Visualization

    Use plot_calibration, plot_score_per_step, and plot_forecast to diagnose forecast accuracy and interval calibration visually.

    View · Open in marimo