Skip to content

check_scorer_parameter_validation

yohou.testing.scorer.check_scorer_parameter_validation(scorer_class, param_name, invalid_value, error_match=None)

Check parameter validation raises ValueError for invalid inputs.

Tests that scorer._validate_parameters() properly validates inputs during score() calls.

Parameters

Name Type Description Default
scorer_class type

Scorer class

required
param_name str

Parameter name to test

required
invalid_value any

Invalid value that should trigger ValueError

required
error_match str

Expected substring in error message

None

Raises

Type Description
AssertionError

If invalid value is accepted

Source Code

Show/Hide source
def check_scorer_parameter_validation(
    scorer_class,
    param_name: str,
    invalid_value: Any,
    error_match: str | None = None,
) -> None:
    """Check parameter validation raises ValueError for invalid inputs.

    Tests that scorer._validate_parameters() properly validates inputs
    during score() calls.

    Parameters
    ----------
    scorer_class : type
        Scorer class
    param_name : str
        Parameter name to test
    invalid_value : any
        Invalid value that should trigger ValueError
    error_match : str, optional
        Expected substring in error message

    Raises
    ------
    AssertionError
        If invalid value is accepted

    """
    # Create scorer instance to check its type
    scorer = scorer_class()
    tags = scorer.__sklearn_tags__()

    # Create minimal test data based on prediction type
    y_truth = pl.DataFrame({
        "time": [datetime.datetime(2020, 1, i) for i in range(1, 11)],
        "value": [float(i) for i in range(10)],
    })

    if tags.scorer_tags.prediction_type == "interval":
        # Interval scorer needs _lower and _upper columns
        pl.DataFrame({
            "vintage_time": [datetime.datetime(2020, 1, 10) for _ in range(3)],
            "time": [datetime.datetime(2020, 1, i) for i in range(11, 14)],
            "value_lower_0.9": [10.0, 11.0, 12.0],
            "value_upper_0.9": [10.5, 11.5, 12.5],
        })
    else:
        # Point scorer needs regular value columns
        pl.DataFrame({
            "vintage_time": [datetime.datetime(2020, 1, 10) for _ in range(3)],
            "time": [datetime.datetime(2020, 1, i) for i in range(11, 14)],
            "value": range(10, 13),
        })

    # Create scorer with invalid parameter
    params = {param_name: invalid_value}
    scorer = scorer_class(**params)

    # Always call fit() to trigger parameter validation (sklearn pattern)
    # Parameter validation happens in fit(), not score()
    try:
        scorer.fit(y_truth)
        raise AssertionError(f"{scorer_class.__name__}: invalid {param_name}={invalid_value} was accepted in fit()")
    except (ValueError, TypeError) as e:
        # Expected - validation should raise ValueError or TypeError
        if error_match is None or error_match in str(e):
            return  # Test passed
        raise AssertionError(f"Expected error containing '{error_match}', got: {e}") from e
        if error_match is not None:
            assert error_match in str(e), f"Expected error containing '{error_match}', got: {e}"