Skip to content

check_point_prediction_structure

yohou.testing.point.check_point_prediction_structure(forecaster, y_test, X_actual_test=None)

Check point predictions have correct column structure.

Parameters

Name Type Description Default
forecaster BasePointForecaster

Fitted point forecaster instance

required
y_test DataFrame

Test target data

required
X_actual_test DataFrame

Test features

None

Raises

Type Description
AssertionError

If prediction structure is incorrect

Source Code

Show/Hide source
def check_point_prediction_structure(
    forecaster, y_test: pl.DataFrame, X_actual_test: pl.DataFrame | None = None
) -> None:
    """Check point predictions have correct column structure.

    Parameters
    ----------
    forecaster : BasePointForecaster
        Fitted point forecaster instance
    y_test : pl.DataFrame
        Test target data
    X_actual_test : pl.DataFrame, optional
        Test features

    Raises
    ------
    AssertionError
        If prediction structure is incorrect

    """
    forecasting_horizon = min(3, len(y_test))
    y_pred = forecaster.predict(forecasting_horizon=forecasting_horizon)

    # Should have vintage_time, time, and target columns
    assert "vintage_time" in y_pred.columns, "Point predictions must have 'vintage_time'"
    assert "time" in y_pred.columns, "Point predictions must have 'time'"

    # Should NOT have interval columns
    interval_cols = [col for col in y_pred.columns if "_lower_" in col or "_upper_" in col]
    assert len(interval_cols) == 0, f"Point predictions should not have interval columns, found: {interval_cols}"

    # Should have target columns
    target_cols = [col for col in y_pred.columns if col not in ["vintage_time", "time"]]
    assert len(target_cols) > 0, "Point predictions must have at least one target column"