Skip to content

check_class_proba_prediction_sums

yohou.testing.class_proba.check_class_proba_prediction_sums(forecaster, y_test, X_actual_test=None)

Check probabilities sum to approximately 1.0 per row per target.

Parameters

Name Type Description Default
forecaster BaseClassProbaForecaster

Fitted class-probability forecaster instance.

required
y_test DataFrame

Test target data.

required
X_actual_test DataFrame or None

Test features.

None

Raises

Type Description
AssertionError

If probabilities do not sum to approximately 1.0.

Source Code

Show/Hide source
def check_class_proba_prediction_sums(
    forecaster, y_test: pl.DataFrame, X_actual_test: pl.DataFrame | None = None
) -> None:
    """Check probabilities sum to approximately 1.0 per row per target.

    Parameters
    ----------
    forecaster : BaseClassProbaForecaster
        Fitted class-probability forecaster instance.
    y_test : pl.DataFrame
        Test target data.
    X_actual_test : pl.DataFrame or None, default=None
        Test features.

    Raises
    ------
    AssertionError
        If probabilities do not sum to approximately 1.0.

    """
    forecasting_horizon = min(3, len(y_test))
    y_pred = forecaster.predict_class_proba(forecasting_horizon=forecasting_horizon)

    _, y_panel_groups = inspect_panel(y_test)

    if len(y_panel_groups) > 0:
        for group_prefix in y_panel_groups:
            for target_col, class_labels in forecaster.classes_.items():
                proba_cols = [f"{group_prefix}__{target_col}_proba_{label}" for label in class_labels]
                row_sums = y_pred.select(proba_cols).sum_horizontal()
                for i, s in enumerate(row_sums):
                    assert abs(s - 1.0) < 1e-6, (
                        f"Probabilities for {group_prefix}__{target_col} at row {i} sum to {s}, expected ~1.0"
                    )
    else:
        for target_col, class_labels in forecaster.classes_.items():
            proba_cols = [f"{target_col}_proba_{label}" for label in class_labels]
            row_sums = y_pred.select(proba_cols).sum_horizontal()
            for i, s in enumerate(row_sums):
                assert abs(s - 1.0) < 1e-6, f"Probabilities for {target_col} at row {i} sum to {s}, expected ~1.0"