Skip to content

check_splitter_parameter_constraints

yohou.testing.splitter.check_splitter_parameter_constraints(splitter_class, param_name, invalid_values)

Check parameter constraints are enforced via sklearn validation.

Parameters

Name Type Description Default
splitter_class type

Splitter class

required
param_name str

Parameter name to test

required
invalid_values list

List of invalid values that should trigger ValueError

required

Raises

Type Description
AssertionError

If invalid values are accepted

Source Code

Show/Hide source
def check_splitter_parameter_constraints(
    splitter_class,
    param_name: str,
    invalid_values: list,
) -> None:
    """Check parameter constraints are enforced via sklearn validation.

    Parameters
    ----------
    splitter_class : type
        Splitter class
    param_name : str
        Parameter name to test
    invalid_values : list
        List of invalid values that should trigger ValueError

    Raises
    ------
    AssertionError
        If invalid values are accepted

    """
    # Default valid values for required parameters (per splitter)

    defaults = {}
    if splitter_class == SlidingWindowSplitter:
        defaults = {"n_splits": 3, "test_size": 5}

    for invalid_value in invalid_values:
        try:
            # Create instance with invalid parameter + defaults for required params
            params = defaults.copy()
            params[param_name] = invalid_value
            splitter = splitter_class(**params)

            # sklearn validates on first method call, so try get_n_splits or split
            try:
                splitter.get_n_splits(y=None, X_actual=None)
            except ValueError as e:
                # Check if it's a parameter validation error
                error_msg = str(e)
                if param_name in error_msg or "parameter" in error_msg.lower():
                    continue  # Parameter validation worked
                # Otherwise try with actual data
                y_test = pl.DataFrame({
                    "time": [datetime.datetime(2020, 1, i) for i in range(1, 101)],
                    "value": range(100),
                })
                list(splitter.split(y_test))

            # If we reach here, invalid value was accepted
            raise AssertionError(f"{splitter_class.__name__}: invalid {param_name}={invalid_value} was accepted")

        except (ValueError, TypeError) as e:
            # Expected: parameter validation should catch this
            error_msg = str(e)
            assert param_name in error_msg or "parameter" in error_msg.lower(), (
                f"Expected validation error for {param_name}, got: {e}"
            )