Skip to content

check_class_proba_classes_attribute

yohou.testing.class_proba.check_class_proba_classes_attribute(forecaster)

Check classes_ and n_classes_ attributes are populated correctly after fit.

Validates that classes_ is a dict mapping target column names to sorted lists of class labels, that n_classes_ maps each target to an integer class count, and that label_to_code_ maps each label to a numeric code.

Parameters

Name Type Description Default
forecaster BaseClassProbaForecaster

Fitted class-probability forecaster instance.

required

Raises

Type Description
AssertionError

If classes_, n_classes_, or label_to_code_ are invalid.

Source Code

Show/Hide source
def check_class_proba_classes_attribute(forecaster) -> None:
    """Check classes_ and n_classes_ attributes are populated correctly after fit.

    Validates that ``classes_`` is a dict mapping target column names to
    sorted lists of class labels, that ``n_classes_`` maps each target
    to an integer class count, and that ``label_to_code_`` maps each
    label to a numeric code.

    Parameters
    ----------
    forecaster : BaseClassProbaForecaster
        Fitted class-probability forecaster instance.

    Raises
    ------
    AssertionError
        If classes_, n_classes_, or label_to_code_ are invalid.

    """
    assert hasattr(forecaster, "classes_"), "Fitted forecaster must have classes_ attribute"
    assert isinstance(forecaster.classes_, dict), f"classes_ should be dict, got {type(forecaster.classes_)}"
    assert len(forecaster.classes_) > 0, "classes_ should not be empty"

    for target_col, labels in forecaster.classes_.items():
        assert isinstance(labels, list), f"classes_[{target_col!r}] should be list, got {type(labels)}"
        assert len(labels) >= 2, f"classes_[{target_col!r}] should have at least 2 classes, got {len(labels)}"
        assert labels == sorted(labels), f"classes_[{target_col!r}] should be sorted, got {labels}"

    assert hasattr(forecaster, "n_classes_"), "Fitted forecaster must have n_classes_ attribute"
    assert isinstance(forecaster.n_classes_, dict), f"n_classes_ should be dict, got {type(forecaster.n_classes_)}"
    for target_col, n in forecaster.n_classes_.items():
        assert target_col in forecaster.classes_, f"n_classes_ key {target_col!r} not in classes_"
        assert n == len(forecaster.classes_[target_col]), (
            f"n_classes_[{target_col!r}]={n} doesn't match len(classes_[{target_col!r}])={len(forecaster.classes_[target_col])}"
        )

    assert hasattr(forecaster, "label_to_code_"), "Fitted forecaster must have label_to_code_ attribute"
    assert isinstance(forecaster.label_to_code_, dict), (
        f"label_to_code_ should be dict, got {type(forecaster.label_to_code_)}"
    )

    for target_col, mapping in forecaster.label_to_code_.items():
        assert target_col in forecaster.classes_, f"label_to_code_ key {target_col!r} not in classes_"
        assert set(mapping.keys()) == set(forecaster.classes_[target_col]), (
            f"label_to_code_[{target_col!r}] keys {set(mapping.keys())} "
            f"don't match classes_ {set(forecaster.classes_[target_col])}"
        )