Skip to content

check_search_interval_predict_delegates

yohou.testing.search.check_search_interval_predict_delegates(search_cv, y_train, y_test, X_actual_train=None, X_actual_test=None, X_future=None, X_forecast=None)

Check predict_interval() works after interval search with refit.

Validates that the best forecaster supports predict_interval and returns a valid interval prediction DataFrame.

Parameters

Name Type Description Default
search_cv BaseSearchCV

Fitted search CV instance (interval scorer, refit=True).

required
y_train DataFrame

Training target data.

required
y_test DataFrame

Test target data.

required
X_actual_train DataFrame

Training features.

None
X_actual_test DataFrame

Test features.

None

Raises

Type Description
AssertionError

If predict_interval() fails or returns invalid predictions.

Source Code

Show/Hide source
def check_search_interval_predict_delegates(
    search_cv,
    y_train: pl.DataFrame,
    y_test: pl.DataFrame,
    X_actual_train: pl.DataFrame | None = None,
    X_actual_test: pl.DataFrame | None = None,
    X_future: pl.DataFrame | None = None,
    X_forecast: pl.DataFrame | None = None,
) -> None:
    """Check predict_interval() works after interval search with refit.

    Validates that the best forecaster supports ``predict_interval`` and
    returns a valid interval prediction DataFrame.

    Parameters
    ----------
    search_cv : BaseSearchCV
        Fitted search CV instance (interval scorer, refit=True).
    y_train : pl.DataFrame
        Training target data.
    y_test : pl.DataFrame
        Test target data.
    X_actual_train : pl.DataFrame, optional
        Training features.
    X_actual_test : pl.DataFrame, optional
        Test features.

    Raises
    ------
    AssertionError
        If predict_interval() fails or returns invalid predictions.

    """
    check_is_fitted(search_cv)

    coverage_rates = [0.9]

    y_pred = search_cv.predict_interval(coverage_rates=coverage_rates, X_future=X_future, X_forecast=X_forecast)

    assert isinstance(y_pred, pl.DataFrame), f"predict_interval should return pl.DataFrame, got {type(y_pred)}"
    assert "time" in y_pred.columns, "Interval predictions should have 'time' column"

    interval_cols = [c for c in y_pred.columns if "_lower_" in c or "_upper_" in c]
    assert len(interval_cols) > 0, f"Interval predictions should have _lower_/_upper_ columns, got {y_pred.columns}"