Skip to content

check_scorer_aggregation_methods

yohou.testing.scorer.check_scorer_aggregation_methods(scorer, y_truth, y_pred, aggregation_methods)

Check all aggregation_method combinations produce valid output.

Parameters

Name Type Description Default
scorer BaseScorer

Scorer instance with aggregation_method parameter

required
y_truth DataFrame

Ground truth

required
y_pred DataFrame

Predictions

required
aggregation_methods list of str

Valid aggregation methods to test

required

Raises

Type Description
AssertionError

If any aggregation method fails or produces invalid output

Notes

Single aggregation methods (e.g., ['stepwise']) may return DataFrames. Only when using all available methods together does it return a scalar.

Source Code

Show/Hide source
def check_scorer_aggregation_methods(
    scorer,
    y_truth: pl.DataFrame,
    y_pred: pl.DataFrame,
    aggregation_methods: list[str],
) -> None:
    """Check all aggregation_method combinations produce valid output.

    Parameters
    ----------
    scorer : BaseScorer
        Scorer instance with aggregation_method parameter
    y_truth : pl.DataFrame
        Ground truth
    y_pred : pl.DataFrame
        Predictions
    aggregation_methods : list of str
        Valid aggregation methods to test

    Raises
    ------
    AssertionError
        If any aggregation method fails or produces invalid output

    Notes
    -----
    Single aggregation methods (e.g., ['stepwise']) may return DataFrames.
    Only when using all available methods together does it return a scalar.

    """
    for agg_method in aggregation_methods:
        scorer_copy = clone(scorer)
        scorer_copy.set_params(aggregation_method=[agg_method])

        # Always fit scorer before scoring
        scorer_copy.fit(y_truth)

        try:
            score = scorer_copy.score(y_truth, y_pred)

            # Validate return type - can be scalar or DataFrame depending on aggregation
            if isinstance(score, pl.DataFrame):
                # DataFrame is valid for partial aggregations (e.g., stepwise only)
                assert len(score) > 0, f"aggregation_method={agg_method}: returned empty DataFrame"
                assert not score.null_count().sum_horizontal()[0] > 0, (
                    f"aggregation_method={agg_method}: DataFrame contains null values"
                )
            elif isinstance(score, int | float | np.number):
                # Scalar is valid for full aggregations
                assert not np.isnan(score), f"aggregation_method={agg_method}: score is NaN"
            else:
                raise AssertionError(
                    f"aggregation_method={agg_method}: score should be numeric or DataFrame, got {type(score)}"
                )
        except Exception as e:
            raise AssertionError(f"aggregation_method={agg_method} failed: {e}") from e