Skip to content

cross_val_predict

yohou.model_selection.validation.cross_val_predict(forecaster, y, X_actual=None, forecasting_horizon=1, *, X_future=None, X_forecast=None, cv=5, predict_forecasting_horizon=None, predict_stride=None, coverage_rates=None, n_jobs=None, verbose=0, pre_dispatch='2*n_jobs', method='predict')

Generate cross-validated predictions for each fold.

For each CV fold, the forecaster is fitted on the training data and predictions are produced on the test data. Predictions from all folds are concatenated into a single DataFrame with a "split" column identifying the originating fold.

Parameters

Name Type Description Default
forecaster BaseForecaster

The forecaster to evaluate.

required
y DataFrame

Target time series with a "time" column.

required
X_actual DataFrame or None

Actual feature observations with a "time" column.

None
forecasting_horizon int

Number of time steps to forecast.

1
X_future DataFrame or None

Known future features with a "time" column.

None
X_forecast DataFrame or None

External forecasts with "vintage_time" and "time" columns.

None
cv int, BaseSplitter, or None

Cross-validation splitting strategy.

5
predict_forecasting_horizon int or None

Override forecasting horizon for observe_predict.

None
predict_stride int or None

Override stride for observe_predict.

None
coverage_rates list of float or None

Coverage rates for interval predictions. Only used when method="predict_interval".

None
n_jobs int or None

Number of parallel jobs.

None
verbose int

Verbosity level.

0
pre_dispatch str or int

Controls pre-dispatched jobs for parallel execution.

"2*n_jobs"
method str

Prediction method: "predict", "predict_interval", or "predict_class_proba".

"predict"

Returns

Type Description
DataFrame

Concatenated predictions from all folds with an integer "split" column identifying the originating fold.

Source Code

Show/Hide source
def cross_val_predict(
    forecaster: BaseForecaster,
    y: pl.DataFrame,
    X_actual: pl.DataFrame | None = None,
    forecasting_horizon: int = 1,
    *,
    X_future: pl.DataFrame | None = None,
    X_forecast: pl.DataFrame | None = None,
    cv: int | BaseSplitter | None = 5,
    predict_forecasting_horizon: int | None = None,
    predict_stride: int | None = None,
    coverage_rates: list[float] | None = None,
    n_jobs: int | None = None,
    verbose: int = 0,
    pre_dispatch: str | int = "2*n_jobs",
    method: str = "predict",
) -> pl.DataFrame:
    """Generate cross-validated predictions for each fold.

    For each CV fold, the forecaster is fitted on the training data
    and predictions are produced on the test data.  Predictions from
    all folds are concatenated into a single DataFrame with a
    ``"split"`` column identifying the originating fold.

    Parameters
    ----------
    forecaster : BaseForecaster
        The forecaster to evaluate.
    y : pl.DataFrame
        Target time series with a ``"time"`` column.
    X_actual : pl.DataFrame or None, default=None
        Actual feature observations with a ``"time"`` column.
    forecasting_horizon : int, default=1
        Number of time steps to forecast.
    X_future : pl.DataFrame or None, default=None
        Known future features with a ``"time"`` column.
    X_forecast : pl.DataFrame or None, default=None
        External forecasts with ``"vintage_time"`` and ``"time"``
        columns.
    cv : int, BaseSplitter, or None, default=5
        Cross-validation splitting strategy.
    predict_forecasting_horizon : int or None, default=None
        Override forecasting horizon for ``observe_predict``.
    predict_stride : int or None, default=None
        Override stride for ``observe_predict``.
    coverage_rates : list of float or None, default=None
        Coverage rates for interval predictions.  Only used when
        ``method="predict_interval"``.
    n_jobs : int or None, default=None
        Number of parallel jobs.
    verbose : int, default=0
        Verbosity level.
    pre_dispatch : str or int, default="2*n_jobs"
        Controls pre-dispatched jobs for parallel execution.
    method : str, default="predict"
        Prediction method: ``"predict"``, ``"predict_interval"``,
        or ``"predict_class_proba"``.

    Returns
    -------
    pl.DataFrame
        Concatenated predictions from all folds with an integer
        ``"split"`` column identifying the originating fold.
    """
    valid_methods = {"predict", "predict_interval", "predict_class_proba"}
    if method not in valid_methods:
        raise ValueError(f"method must be one of {valid_methods}, got {method!r}.")

    y, X_actual = indexable(y, X_actual)

    cv_obj = check_cv(cv, forecasting_horizon)
    splits = list(cv_obj.split(y, X_actual))
    n_splits = len(splits)

    base_forecaster = clone(forecaster)

    out = Parallel(n_jobs=n_jobs, pre_dispatch=pre_dispatch)(
        delayed(_fit_and_score)(
            clone(base_forecaster),
            y,
            X_actual,
            forecasting_horizon,
            X_future=X_future,
            X_forecast=X_forecast,
            scorer=None,
            train=train,
            test=test,
            verbose=verbose,
            parameters=None,
            fit_params=None,
            predict_func_params=None,
            score_params=None,
            return_predictions=True,
            predict_method=method,
            predict_forecasting_horizon=predict_forecasting_horizon,
            predict_stride=predict_stride,
            coverage_rates=coverage_rates,
            error_score="raise",
            split_progress=(split_idx, n_splits),
        )
        for split_idx, (train, test) in enumerate(splits)
    )

    dfs = []
    for split_idx, result in enumerate(out):
        pred = result["predictions"]
        dfs.append(pred.with_columns(pl.lit(split_idx).alias("split")))

    return pl.concat(dfs)

Tutorials

The following example notebooks use this component:

  • Cross-Validation for Time Series


    Evaluation-Search

    Evaluate forecasters with cross_val_score, cross_validate, and cross_val_predict using temporal splitters.

    View · Open in marimo