Skip to content

check_interval_prediction_columns

yohou.testing.interval.check_interval_prediction_columns(forecaster, y_test, X_actual_test=None)

Check interval predictions have {col}lower} and {colupper format.

Parameters

Name Type Description Default
forecaster BaseIntervalForecaster

Fitted interval forecaster instance

required
y_test DataFrame

Test target data

required
X_actual_test DataFrame

Test features

None

Raises

Type Description
AssertionError

If interval column naming is incorrect

Source Code

Show/Hide source
def check_interval_prediction_columns(
    forecaster, y_test: pl.DataFrame, X_actual_test: pl.DataFrame | None = None
) -> None:
    """Check interval predictions have {col}_lower_{rate} and {col}_upper_{rate} format.

    Parameters
    ----------
    forecaster : BaseIntervalForecaster
        Fitted interval forecaster instance
    y_test : pl.DataFrame
        Test target data
    X_actual_test : pl.DataFrame, optional
        Test features

    Raises
    ------
    AssertionError
        If interval column naming is incorrect

    """
    forecasting_horizon = min(3, len(y_test))

    # Call predict_interval for interval forecasters
    y_pred = forecaster.predict_interval(forecasting_horizon=forecasting_horizon)

    # Get coverage rates - use fit_coverage_rates_ (set during fit)
    coverage_rates = forecaster.fit_coverage_rates_

    # Check if we have panel data (columns with __ separator)
    _, y_panel_groups = inspect_panel(y_test)

    if len(y_panel_groups) > 0:
        # For panel data, interval columns use __ separator
        # e.g., "stores__store_0_lower_0.1"
        for group_prefix in y_panel_groups:
            # Get fields from the original training data (full column names)
            expected_fields = y_panel_groups[group_prefix]

            for rate in coverage_rates:
                for field in expected_fields:
                    lower_col = f"{field}_lower_{rate}"
                    upper_col = f"{field}_upper_{rate}"

                    assert lower_col in y_pred.columns, f"Missing lower bound column: {lower_col}"
                    assert upper_col in y_pred.columns, f"Missing upper bound column: {upper_col}"
    else:
        # For global data, check individual column pattern: {col}_lower_{rate}
        target_cols = list(forecaster.local_y_schema_.keys())

        # Check each coverage rate has lower and upper bounds for each target
        for rate in coverage_rates:
            for col in target_cols:
                lower_col = f"{col}_lower_{rate}"
                upper_col = f"{col}_upper_{rate}"

                assert lower_col in y_pred.columns, f"Missing lower bound column: {lower_col}"
                assert upper_col in y_pred.columns, f"Missing upper bound column: {upper_col}"