Skip to content

plot_splits

yohou.plotting.model_selection.plot_splits(y, splitter, *, X_actual=None, train_color=None, test_color=None, show_legend=True, title=None, x_label=None, y_label=None, width=None, height=None, resampler=None, line_width=10.0)

Plot cross-validation splits as a timeline visualization.

Creates a horizontal bar chart showing train/test splits for each fold, useful for understanding temporal CV strategies like expanding or sliding windows.

Parameters

Name Type Description Default
y DataFrame

Target time series with mandatory "time" column.

required
splitter BaseSplitter

A yohou splitter instance (e.g., ExpandingWindowSplitter, SlidingWindowSplitter).

required
X_actual DataFrame or None

Actual features passed to splitter.split(). Not used for splitting but accepted for API consistency.

None
train_color str | None

Color for train segments. If None, uses first color from yohou palette.

None
test_color str | None

Color for test segments. If None, uses second color from yohou palette.

None
show_legend bool

Whether to show the legend.

True
title str | None

Plot title. Defaults to "Cross-Validation Splits".

None
x_label str | None

X-axis label. Defaults to "Time".

None
y_label str | None

Y-axis label. Defaults to "Fold".

None
width int | None

Plot width in pixels.

None
height int | None

Plot height in pixels. Defaults to 300 + n_splits * 30.

None
resampler bool | Literal['widget'] | None

Enable plotly-resampler for large datasets. True or "widget" creates a FigureWidgetResampler; False or None uses a plain go.Figure.

None
line_width float

Width of the train/test bars.

10.0

Returns

Type Description
Figure

Plotly figure object.

Raises

Type Description
TypeError

If y is not a Polars DataFrame or splitter is not a BaseSplitter.

ValueError

If DataFrame is empty or missing 'time' column.

Examples

>>> import polars as pl
>>> from yohou.plotting import plot_splits
>>> from yohou.model_selection import ExpandingWindowSplitter
>>> # Create sample data
>>> y = pl.DataFrame({
...     "time": pl.date_range(pl.date(2020, 1, 1), pl.date(2020, 12, 31), "1d", eager=True),
...     "value": list(range(366)),
... })
>>> # Create splitter and plot
>>> splitter = ExpandingWindowSplitter(n_splits=3, test_size=30)
>>> fig = plot_splits(y, splitter)
>>> len(fig.data) > 0
True

See Also

plot_cv_results_scatter : Plot hyperparameter search results. ExpandingWindowSplitter : Expanding window cross-validation. SlidingWindowSplitter : Sliding window cross-validation.

Source Code

Show/Hide source
def plot_splits(
    y: pl.DataFrame,
    splitter: BaseSplitter,
    *,
    X_actual: pl.DataFrame | None = None,
    train_color: str | None = None,
    test_color: str | None = None,
    show_legend: bool = True,
    title: str | None = None,
    x_label: str | None = None,
    y_label: str | None = None,
    width: int | None = None,
    height: int | None = None,
    resampler: bool | Literal["widget"] | None = None,
    line_width: float = 10.0,
) -> go.Figure:
    """
    Plot cross-validation splits as a timeline visualization.

    Creates a horizontal bar chart showing train/test splits for each fold,
    useful for understanding temporal CV strategies like expanding or sliding windows.

    Parameters
    ----------
    y : pl.DataFrame
        Target time series with mandatory "time" column.
    splitter : BaseSplitter
        A yohou splitter instance (e.g., ExpandingWindowSplitter, SlidingWindowSplitter).
    X_actual : pl.DataFrame or None, default=None
        Actual features passed to ``splitter.split()``. Not used for
        splitting but accepted for API consistency.
    train_color : str | None, default=None
        Color for train segments. If None, uses first color from yohou palette.
    test_color : str | None, default=None
        Color for test segments. If None, uses second color from yohou palette.
    show_legend : bool, default=True
        Whether to show the legend.
    title : str | None, default=None
        Plot title. Defaults to "Cross-Validation Splits".
    x_label : str | None, default=None
        X-axis label. Defaults to "Time".
    y_label : str | None, default=None
        Y-axis label. Defaults to "Fold".
    width : int | None, default=None
        Plot width in pixels.
    height : int | None, default=None
        Plot height in pixels. Defaults to 300 + n_splits * 30.
    resampler : bool | Literal["widget"] | None, default=None
        Enable plotly-resampler for large datasets.  ``True`` or
        ``"widget"`` creates a ``FigureWidgetResampler``; ``False`` or
        ``None`` uses a plain ``go.Figure``.
    line_width : float, default=10.0
        Width of the train/test bars.

    Returns
    -------
    go.Figure
        Plotly figure object.

    Raises
    ------
    TypeError
        If y is not a Polars DataFrame or splitter is not a BaseSplitter.
    ValueError
        If DataFrame is empty or missing 'time' column.

    Examples
    --------
    >>> import polars as pl
    >>> from yohou.plotting import plot_splits
    >>> from yohou.model_selection import ExpandingWindowSplitter

    >>> # Create sample data
    >>> y = pl.DataFrame({
    ...     "time": pl.date_range(pl.date(2020, 1, 1), pl.date(2020, 12, 31), "1d", eager=True),
    ...     "value": list(range(366)),
    ... })

    >>> # Create splitter and plot
    >>> splitter = ExpandingWindowSplitter(n_splits=3, test_size=30)
    >>> fig = plot_splits(y, splitter)
    >>> len(fig.data) > 0
    True

    See Also
    --------
    [`plot_cv_results_scatter`][yohou.plotting.plot_cv_results_scatter] : Plot hyperparameter search results.
    `ExpandingWindowSplitter` : Expanding window cross-validation.
    `SlidingWindowSplitter` : Sliding window cross-validation.
    """
    # Validate inputs
    validate_plotting_data(y)
    validate_plotting_params(width=width, height=height)

    if not isinstance(splitter, BaseSplitter):
        msg = f"Expected BaseSplitter, got {type(splitter).__name__}"
        raise TypeError(msg)

    # Get colors
    colors = resolve_color_palette(None, 2)
    train_color = train_color or colors[0]
    test_color = test_color or colors[1]

    # Get splits
    splits = list(splitter.split(y, X_actual))
    n_splits = len(splits)

    # Create figure
    fig = _create_figure(resampler)

    # Get time column
    times = y["time"]

    # Plot each split
    for i, (train_idx, test_idx) in enumerate(splits):
        fold_label = f"Fold {i + 1}"

        # Get train time range
        t_train_start = times[int(train_idx[0])]
        t_train_end = times[int(train_idx[-1])]

        # Get test time range
        t_test_start = times[int(test_idx[0])]
        t_test_end = times[int(test_idx[-1])]

        # Plot train segment
        fig.add_trace(
            go.Scatter(
                x=[t_train_start, t_train_end],
                y=[fold_label, fold_label],
                mode="lines",
                line={"color": train_color, "width": line_width},
                name="Train" if i == 0 else None,
                showlegend=(i == 0),
                legendgroup="train",
                hovertemplate=f"Train<br>Start: %{{x}}<br>Fold: {fold_label}<extra></extra>",
            )
        )

        # Plot test segment
        fig.add_trace(
            go.Scatter(
                x=[t_test_start, t_test_end],
                y=[fold_label, fold_label],
                mode="lines",
                line={"color": test_color, "width": line_width},
                name="Test" if i == 0 else None,
                showlegend=(i == 0),
                legendgroup="test",
                hovertemplate=f"Test<br>Start: %{{x}}<br>Fold: {fold_label}<extra></extra>",
            )
        )

    # Set default labels
    title_default = title or "Cross-Validation Splits"
    x_label_default = x_label or "Time"
    y_label_default = y_label or "Fold"

    # Calculate height based on number of splits
    height_default = height or (300 + n_splits * 30)

    fig = apply_default_layout(
        fig,
        title=title_default,
        x_label=x_label_default,
        y_label=y_label_default,
        width=width,
        height=height_default,
    )
    fig.update_layout(showlegend=show_legend)

    return fig

Tutorials

The following example notebooks use this component:

  • CV Splitters


    Getting-Started

    Demonstrate ExpandingWindowSplitter and SlidingWindowSplitter for temporal cross-validation with configurable test_size, stride, and fold visualisation.

    View · Open in marimo

  • Quickstart


    Quickstart

    Comprehensive end-to-end tour of yohou beyond the Getting Started tutorials, covering data loading, baseline forecasting, preprocessing pipelines, decomposition, cross-validation search, and interval prediction.

    View · Open in marimo

  • How to Visualize Model Selection Results


    Visualization

    Visualise CV fold geometry with expanding and sliding window splitters and hyperparameter search results with plot_splits and plot_cv_results_scatter.

    View · Open in marimo