Skip to content

check_fit_predict_with_X_future

yohou.testing.forecaster.check_fit_predict_with_X_future(forecaster, y_train, X_actual_train, y_test, X_future, forecasting_horizon=3)

Check fit + predict works with X_future provided.

Validates that fitting with X_future sets _X_future_schema_, populates _step_column_names_, and that predict returns valid output.

Parameters

Name Type Description Default
forecaster BaseForecaster

Unfitted forecaster instance.

required
y_train DataFrame

Training target data.

required
X_actual_train DataFrame or None

Training features.

required
y_test DataFrame

Test target data.

required
X_future DataFrame

Known-future features with "time" column.

required
forecasting_horizon int

Number of steps ahead to forecast.

3

Source Code

Show/Hide source
def check_fit_predict_with_X_future(
    forecaster,
    y_train: pl.DataFrame,
    X_actual_train: pl.DataFrame | None,
    y_test: pl.DataFrame,
    X_future: pl.DataFrame,
    forecasting_horizon: int = 3,
) -> None:
    """Check fit + predict works with X_future provided.

    Validates that fitting with X_future sets ``_X_future_schema_``,
    populates ``_step_column_names_``, and that predict returns valid output.

    Parameters
    ----------
    forecaster : BaseForecaster
        Unfitted forecaster instance.
    y_train : pl.DataFrame
        Training target data.
    X_actual_train : pl.DataFrame or None
        Training features.
    y_test : pl.DataFrame
        Test target data.
    X_future : pl.DataFrame
        Known-future features with ``"time"`` column.
    forecasting_horizon : int, default=3
        Number of steps ahead to forecast.

    """
    forecaster_clone = clone(forecaster)
    forecaster_clone.fit(
        y_train,
        X_actual_train,
        forecasting_horizon=forecasting_horizon,
        X_future=X_future,
    )

    # Schema set
    assert forecaster_clone._X_future_schema_ is not None, "fit() with X_future must set _X_future_schema_"

    # Step columns populated
    assert len(forecaster_clone._step_column_names_) > 0, (
        "_step_column_names_ should be non-empty after fit with X_future"
    )

    # Raw stored
    assert forecaster_clone._X_future_raw_ is not None, "fit() with X_future must store _X_future_raw_"

    # Predict works
    y_pred = forecaster_clone.predict(forecasting_horizon=forecasting_horizon)
    assert isinstance(y_pred, pl.DataFrame), f"predict() must return pl.DataFrame, got {type(y_pred).__name__}"
    assert "time" in y_pred.columns, "predict() output must contain 'time' column"