Skip to content

train_test_split

yohou.model_selection.split.train_test_split(*arrays, test_size, X_forecast=None)

Split time series data into temporal train and test sets.

A time series counterpart to :func:sklearn.model_selection.train_test_split. Data is always split in temporal order (no shuffling): the earliest rows form the training set and the most recent rows form the test set.

Row-indexed arrays (y, X_actual) are split by position. X_forecast, when provided, is split by vintage_time range using the cutoff time inferred from the first positional array.

Parameters

Name Type Description Default
*arrays DataFrame

One or more Polars DataFrames to split by row index. All must have the same number of rows. The first array must contain a "time" column (used to derive the vintage cutoff when X_forecast is provided).

()
test_size int or float

If int, the number of rows to allocate to the test set. If float, the fraction of total rows for testing (must be in (0.0, 1.0)).

required
X_forecast DataFrame or None

External forecasts with "vintage_time" and "time" columns. Split by vintage_time range: training receives vintages where vintage_time <= cutoff_time, testing receives vintages where cutoff_time < vintage_time <= test_end_time. The cutoff and test end times are derived from the "time" column of the first positional array.

None

Returns

Type Description
list of pl.DataFrame

Alternating train/test pairs for each positional array, followed by the X_forecast train/test pair if X_forecast is provided.

With one array: [arr_train, arr_test].

With two arrays: [arr1_train, arr1_test, arr2_train, arr2_test].

With X_forecast: [..., X_forecast_train, X_forecast_test] appended.

Raises

Type Description
ValueError

If no arrays are provided, arrays have different lengths, test_size is invalid, or the first array is missing a "time" column when X_forecast is provided.

Examples

>>> import polars as pl
>>> from yohou.model_selection import train_test_split

Split y and X_actual (80/20):

>>> y = pl.DataFrame({
...     "time": pl.date_range(pl.date(2020, 1, 1), pl.date(2020, 1, 10), eager=True),
...     "value": list(range(10)),
... })
>>> y_train, y_test = train_test_split(y, test_size=2)
>>> len(y_train), len(y_test)
(8, 2)

Split with a fractional test_size:

>>> y_train, y_test = train_test_split(y, test_size=0.3)
>>> len(y_train), len(y_test)
(7, 3)

Source Code

Show/Hide source
def train_test_split(
    *arrays: pl.DataFrame,
    test_size: int | float,
    X_forecast: pl.DataFrame | None = None,
) -> list[pl.DataFrame]:
    """Split time series data into temporal train and test sets.

    A time series counterpart to :func:`sklearn.model_selection.train_test_split`.
    Data is always split in temporal order (no shuffling): the earliest rows
    form the training set and the most recent rows form the test set.

    Row-indexed arrays (``y``, ``X_actual``) are split by position.
    ``X_forecast``, when provided, is split by ``vintage_time`` range using
    the cutoff time inferred from the first positional array.

    Parameters
    ----------
    *arrays : pl.DataFrame
        One or more Polars DataFrames to split by row index. All must
        have the same number of rows. The first array must contain a
        ``"time"`` column (used to derive the vintage cutoff when
        ``X_forecast`` is provided).
    test_size : int or float
        If ``int``, the number of rows to allocate to the test set.
        If ``float``, the fraction of total rows for testing (must be
        in ``(0.0, 1.0)``).
    X_forecast : pl.DataFrame or None, default=None
        External forecasts with ``"vintage_time"`` and ``"time"`` columns.
        Split by ``vintage_time`` range: training receives vintages where
        ``vintage_time <= cutoff_time``, testing receives vintages where
        ``cutoff_time < vintage_time <= test_end_time``. The cutoff and
        test end times are derived from the ``"time"`` column of the first
        positional array.

    Returns
    -------
    list of pl.DataFrame
        Alternating train/test pairs for each positional array, followed
        by the X_forecast train/test pair if ``X_forecast`` is provided.

        With one array: ``[arr_train, arr_test]``.

        With two arrays: ``[arr1_train, arr1_test, arr2_train, arr2_test]``.

        With ``X_forecast``:
        ``[..., X_forecast_train, X_forecast_test]`` appended.

    Raises
    ------
    ValueError
        If no arrays are provided, arrays have different lengths,
        ``test_size`` is invalid, or the first array is missing a
        ``"time"`` column when ``X_forecast`` is provided.

    Examples
    --------
    >>> import polars as pl
    >>> from yohou.model_selection import train_test_split

    Split y and X_actual (80/20):

    >>> y = pl.DataFrame({
    ...     "time": pl.date_range(pl.date(2020, 1, 1), pl.date(2020, 1, 10), eager=True),
    ...     "value": list(range(10)),
    ... })
    >>> y_train, y_test = train_test_split(y, test_size=2)
    >>> len(y_train), len(y_test)
    (8, 2)

    Split with a fractional test_size:

    >>> y_train, y_test = train_test_split(y, test_size=0.3)
    >>> len(y_train), len(y_test)
    (7, 3)
    """
    if len(arrays) == 0:
        msg = "At least one array is required."
        raise ValueError(msg)

    n_samples = len(arrays[0])
    for i, arr in enumerate(arrays[1:], start=1):
        if len(arr) != n_samples:
            msg = (
                f"All arrays must have the same number of rows. "
                f"Array 0 has {n_samples} rows but array {i} has {len(arr)} rows."
            )
            raise ValueError(msg)

    if isinstance(test_size, float):
        if not 0.0 < test_size < 1.0:
            msg = f"test_size as a float must be in (0.0, 1.0), got {test_size}."
            raise ValueError(msg)
        n_test = max(1, round(n_samples * test_size))
    elif isinstance(test_size, int):
        if test_size < 1 or test_size >= n_samples:
            msg = f"test_size as an int must be in [1, {n_samples - 1}], got {test_size}."
            raise ValueError(msg)
        n_test = test_size
    else:
        msg = f"test_size must be int or float, got {type(test_size).__name__}."
        raise TypeError(msg)

    split_idx = n_samples - n_test
    result: list[pl.DataFrame] = []
    for arr in arrays:
        result.append(arr[:split_idx])
        result.append(arr[split_idx:])

    if X_forecast is not None:
        first = arrays[0]
        if "time" not in first.columns:
            msg = (
                "The first positional array must contain a 'time' column "
                "when X_forecast is provided (needed to derive the vintage "
                "cutoff time)."
            )
            raise ValueError(msg)
        if "vintage_time" not in X_forecast.columns or "time" not in X_forecast.columns:
            msg = (
                "X_forecast must contain both 'vintage_time' and 'time' columns. "
                f"Found columns: {list(X_forecast.columns)}"
            )
            raise ValueError(msg)

        cutoff_time = first["time"][split_idx - 1]
        test_end_time = first["time"][-1]
        result.append(X_forecast.filter(pl.col("vintage_time") <= cutoff_time))
        result.append(
            X_forecast.filter((pl.col("vintage_time") > cutoff_time) & (pl.col("vintage_time") <= test_end_time))
        )

    return result

Tutorials

The following example notebooks use this component:

  • How to Tune Fourier Seasonality Terms


    Data-Features

    Explore how Fourier harmonic count affects seasonal fit quality, compare Fourier vs Pattern seasonality, and tune harmonics jointly with GridSearchCV.

    View · Open in marimo

  • How to Aggregate Scorer Results


    Evaluation-Search

    Demonstrate all scorer aggregation strategies (stepwise, vintagewise, componentwise, groupwise, coveragewise, all) on panel data with weighted group aggregation.

    View · Open in marimo

  • How to Use Lagged Forecasts as Features


    Forecasting-Models

    Compare ForecastedFeatureForecaster strategies (actual, predicted, rewind) and split ratio tuning for chaining feature and target forecasters.

    View · Open in marimo

  • How to Configure LocalPanelForecaster


    Panel-Data

    Wrap any forecaster with LocalPanelForecaster for fully independent per-group clones, parallel fitting via n_jobs, and selective group operations.

    View · Open in marimo

  • How to Forecast Panel Prediction Intervals


    Panel-Data

    Combine conformal and quantile regression intervals on panel data with per-group coverage analysis, calibration plots, and groupwise interval scoring.

    View · Open in marimo

  • How to Apply Stationarity to Panel Data


    Panel-Data

    Apply per-group stationarity transforms on panel data with SeasonalDifferencing, DecompositionPipeline (polynomial trend + pattern seasonality), and residuals.

    View · Open in marimo