Skip to content

check_class_proba_prediction_structure

yohou.testing.class_proba.check_class_proba_prediction_structure(forecaster, y_test, X_actual_test=None)

Check class-probability predictions have correct column structure.

Validates that predict_class_proba output contains "vintage_time" and "time" columns, plus {target}_proba_{class} columns for every target and class.

Parameters

Name Type Description Default
forecaster BaseClassProbaForecaster

Fitted class-probability forecaster instance.

required
y_test DataFrame

Test target data.

required
X_actual_test DataFrame or None

Test features.

None

Raises

Type Description
AssertionError

If prediction structure is incorrect.

Source Code

Show/Hide source
def check_class_proba_prediction_structure(
    forecaster, y_test: pl.DataFrame, X_actual_test: pl.DataFrame | None = None
) -> None:
    """Check class-probability predictions have correct column structure.

    Validates that ``predict_class_proba`` output contains ``"vintage_time"``
    and ``"time"`` columns, plus ``{target}_proba_{class}`` columns for every
    target and class.

    Parameters
    ----------
    forecaster : BaseClassProbaForecaster
        Fitted class-probability forecaster instance.
    y_test : pl.DataFrame
        Test target data.
    X_actual_test : pl.DataFrame or None, default=None
        Test features.

    Raises
    ------
    AssertionError
        If prediction structure is incorrect.

    """
    forecasting_horizon = min(3, len(y_test))
    y_pred = forecaster.predict_class_proba(forecasting_horizon=forecasting_horizon)

    assert "vintage_time" in y_pred.columns, "Class-proba predictions must have 'vintage_time'"
    assert "time" in y_pred.columns, "Class-proba predictions must have 'time'"

    # Should NOT have interval columns
    interval_cols = [col for col in y_pred.columns if "_lower_" in col or "_upper_" in col]
    assert len(interval_cols) == 0, f"Class-proba predictions should not have interval columns, found: {interval_cols}"

    # Should have _proba_ columns for each target and class
    _, y_panel_groups = inspect_panel(y_test)
    if len(y_panel_groups) > 0:
        # Panel data: classes_ keys are unprefixed target names
        for target_col, class_labels in forecaster.classes_.items():
            for group_prefix in y_panel_groups:
                for label in class_labels:
                    col_name = f"{group_prefix}__{target_col}_proba_{label}"
                    assert col_name in y_pred.columns, f"Missing probability column: {col_name}"
    else:
        for target_col, class_labels in forecaster.classes_.items():
            for label in class_labels:
                col_name = f"{target_col}_proba_{label}"
                assert col_name in y_pred.columns, f"Missing probability column: {col_name}"