Skip to content

check_interval_bounds

yohou.testing.interval.check_interval_bounds(forecaster, y_test, X_actual_test=None)

Check upper >= lower for all coverage rates and time steps.

Parameters

Name Type Description Default
forecaster BaseIntervalForecaster

Fitted interval forecaster instance

required
y_test DataFrame

Test target data

required
X_actual_test DataFrame

Test features

None

Raises

Type Description
AssertionError

If upper bounds are less than lower bounds

Source Code

Show/Hide source
def check_interval_bounds(forecaster, y_test: pl.DataFrame, X_actual_test: pl.DataFrame | None = None) -> None:
    """Check upper >= lower for all coverage rates and time steps.

    Parameters
    ----------
    forecaster : BaseIntervalForecaster
        Fitted interval forecaster instance
    y_test : pl.DataFrame
        Test target data
    X_actual_test : pl.DataFrame, optional
        Test features

    Raises
    ------
    AssertionError
        If upper bounds are less than lower bounds

    """
    forecasting_horizon = min(3, len(y_test))
    y_pred = forecaster.predict_interval(forecasting_horizon=forecasting_horizon)

    coverage_rates = forecaster.fit_coverage_rates_

    # Check if we have panel data (columns with __ separator)
    _, y_panel_groups = inspect_panel(y_test)

    if len(y_panel_groups) > 0:
        # For panel data, interval columns use __ separator
        for group_prefix in y_panel_groups:
            # Get fields from the original training data (full column names)
            expected_fields = y_panel_groups[group_prefix]

            for rate in coverage_rates:
                for field in expected_fields:
                    lower_col = f"{field}_lower_{rate}"
                    upper_col = f"{field}_upper_{rate}"

                    lower_vals = y_pred[lower_col].to_numpy()
                    upper_vals = y_pred[upper_col].to_numpy()

                    violations = lower_vals > upper_vals
                    if violations.any():
                        raise AssertionError(
                            f"Found {violations.sum()} violations where lower > upper for {field} at coverage {rate}"
                        )
    else:
        # For global data, check individual columns
        target_cols = list(forecaster.local_y_schema_.keys())

        for rate in coverage_rates:
            for col in target_cols:
                lower_col = f"{col}_lower_{rate}"
                upper_col = f"{col}_upper_{rate}"

                lower_vals = y_pred[lower_col].to_numpy()
                upper_vals = y_pred[upper_col].to_numpy()

                violations = lower_vals > upper_vals
                if violations.any():
                    raise AssertionError(
                        f"Found {violations.sum()} violations where lower > upper for {col} at coverage {rate}"
                    )