Skip to content

check_predict_time_columns

yohou.testing.forecaster.check_predict_time_columns(forecaster, y_test, X_actual_test=None)

Check predictions have vintage_time and time columns.

Parameters

Name Type Description Default
forecaster BaseForecaster

Fitted forecaster instance

required
y_test DataFrame

Test target data

required
X_actual_test DataFrame

Test features

None

Raises

Type Description
AssertionError

If predictions lack required time columns

Source Code

Show/Hide source
def check_predict_time_columns(forecaster, y_test: pl.DataFrame, X_actual_test: pl.DataFrame | None = None) -> None:
    """Check predictions have vintage_time and time columns.

    Parameters
    ----------
    forecaster : BaseForecaster
        Fitted forecaster instance
    y_test : pl.DataFrame
        Test target data
    X_actual_test : pl.DataFrame, optional
        Test features

    Raises
    ------
    AssertionError
        If predictions lack required time columns

    """
    forecasting_horizon = min(3, len(y_test))

    # Check if forecaster is an interval forecaster
    if hasattr(forecaster, "predict_interval"):
        y_pred = forecaster.predict_interval(forecasting_horizon=forecasting_horizon)
    else:
        y_pred = forecaster.predict(forecasting_horizon=forecasting_horizon)

    assert "vintage_time" in y_pred.columns, "Predictions must have 'vintage_time' column"
    assert "time" in y_pred.columns, "Predictions must have 'time' column"

    # Validate shapes
    assert len(y_pred) == forecasting_horizon, f"Predictions should have {forecasting_horizon} rows, got {len(y_pred)}"

    # Validate time column types
    assert y_pred["vintage_time"].dtype == pl.Datetime, "vintage_time must be Datetime dtype"
    assert y_pred["time"].dtype == pl.Datetime, "time must be Datetime dtype"