Check upper >= lower for all coverage rates and time steps.
Parameters
| Name |
Type |
Description |
Default |
forecaster
|
BaseIntervalForecaster
|
Fitted interval forecaster instance
|
required
|
y_test
|
DataFrame
|
|
required
|
X_actual_test
|
DataFrame
|
|
None
|
Raises
| Type |
Description |
AssertionError
|
If upper bounds are less than lower bounds
|
Source Code
View on GitHub
Show/Hide source
| def check_interval_bounds(forecaster, y_test: pl.DataFrame, X_actual_test: pl.DataFrame | None = None) -> None:
"""Check upper >= lower for all coverage rates and time steps.
Parameters
----------
forecaster : BaseIntervalForecaster
Fitted interval forecaster instance
y_test : pl.DataFrame
Test target data
X_actual_test : pl.DataFrame, optional
Test features
Raises
------
AssertionError
If upper bounds are less than lower bounds
"""
forecasting_horizon = min(3, len(y_test))
y_pred = forecaster.predict_interval(forecasting_horizon=forecasting_horizon)
coverage_rates = forecaster.fit_coverage_rates_
# Check if we have panel data (columns with __ separator)
_, y_panel_groups = inspect_panel(y_test)
if len(y_panel_groups) > 0:
# For panel data, interval columns use __ separator
for group_prefix in y_panel_groups:
# Get fields from the original training data (full column names)
expected_fields = y_panel_groups[group_prefix]
for rate in coverage_rates:
for field in expected_fields:
lower_col = f"{field}_lower_{rate}"
upper_col = f"{field}_upper_{rate}"
lower_vals = y_pred[lower_col].to_numpy()
upper_vals = y_pred[upper_col].to_numpy()
violations = lower_vals > upper_vals
if violations.any():
raise AssertionError(
f"Found {violations.sum()} violations where lower > upper for {field} at coverage {rate}"
)
else:
# For global data, check individual columns
target_cols = list(forecaster.local_y_schema_.keys())
for rate in coverage_rates:
for col in target_cols:
lower_col = f"{col}_lower_{rate}"
upper_col = f"{col}_upper_{rate}"
lower_vals = y_pred[lower_col].to_numpy()
upper_vals = y_pred[upper_col].to_numpy()
violations = lower_vals > upper_vals
if violations.any():
raise AssertionError(
f"Found {violations.sum()} violations where lower > upper for {col} at coverage {rate}"
)
|