Skip to content

check_observe_predict_interval_with_step_columns

yohou.testing.forecaster.check_observe_predict_interval_with_step_columns(forecaster, y_train, X_actual_train, y_test, X_actual_test=None, X_future=None, X_forecast=None, forecasting_horizon=3, coverage_rates=None)

Check observe_predict_interval works with step columns (lightweight).

Runs observe_predict_interval with stride=len(y_test)//2 (2 iterations) and validates output structure and per-vintage time sorting.

Parameters

Name Type Description Default
forecaster BaseForecaster

Unfitted interval forecaster instance.

required
y_train DataFrame

Training target data.

required
X_actual_train DataFrame or None

Training features.

required
y_test DataFrame

Test target data (at least 10 rows).

required
X_actual_test DataFrame or None

Test features.

None
X_future DataFrame or None None
X_forecast DataFrame or None

External forecasts.

None
forecasting_horizon int

Number of steps ahead.

3
coverage_rates list of float or None

Coverage rates for prediction intervals. Defaults to [0.9].

None

Source Code

Show/Hide source
def check_observe_predict_interval_with_step_columns(
    forecaster,
    y_train: pl.DataFrame,
    X_actual_train: pl.DataFrame | None,
    y_test: pl.DataFrame,
    X_actual_test: pl.DataFrame | None = None,
    X_future: pl.DataFrame | None = None,
    X_forecast: pl.DataFrame | None = None,
    forecasting_horizon: int = 3,
    coverage_rates: list[float] | None = None,
) -> None:
    """Check observe_predict_interval works with step columns (lightweight).

    Runs observe_predict_interval with stride=len(y_test)//2 (2 iterations)
    and validates output structure and per-vintage time sorting.

    Parameters
    ----------
    forecaster : BaseForecaster
        Unfitted interval forecaster instance.
    y_train : pl.DataFrame
        Training target data.
    X_actual_train : pl.DataFrame or None
        Training features.
    y_test : pl.DataFrame
        Test target data (at least 10 rows).
    X_actual_test : pl.DataFrame or None
        Test features.
    X_future : pl.DataFrame or None
        Known-future features.
    X_forecast : pl.DataFrame or None
        External forecasts.
    forecasting_horizon : int, default=3
        Number of steps ahead.
    coverage_rates : list of float or None, default=None
        Coverage rates for prediction intervals. Defaults to [0.9].

    """
    if coverage_rates is None:
        coverage_rates = [0.9]

    forecaster_clone = clone(forecaster)
    forecaster_clone.fit(
        y_train,
        X_actual_train,
        forecasting_horizon=forecasting_horizon,
        coverage_rates=coverage_rates,
        X_future=X_future,
        X_forecast=X_forecast,
    )

    stride = max(1, len(y_test) // 2)
    y_pred = forecaster_clone.observe_predict_interval(
        y_test,
        X_actual=X_actual_test,
        forecasting_horizon=forecasting_horizon,
        coverage_rates=coverage_rates,
        stride=stride,
        X_future=X_future,
        X_forecast=X_forecast,
    )

    assert isinstance(y_pred, pl.DataFrame), (
        f"observe_predict_interval() must return pl.DataFrame, got {type(y_pred).__name__}"
    )
    assert "time" in y_pred.columns, "observe_predict_interval() output must contain 'time' column"
    assert "vintage_time" in y_pred.columns, "observe_predict_interval() output must contain 'vintage_time' column"
    assert len(y_pred) > 0, "observe_predict_interval() must return non-empty DataFrame"

    # Validate per-vintage time sorting (catches stale observation state bugs)
    for vt in y_pred["vintage_time"].unique():
        vintage = y_pred.filter(pl.col("vintage_time") == vt)
        assert vintage["time"].is_sorted(), f"'time' column within vintage_time={vt} is not sorted in ascending order"