Skip to content

check_recorded_metadata

yohou.testing.metadata_routing.check_recorded_metadata(obj, method, parent, split_params=(), **params)

Check whether the expected metadata is passed to the object's method.

Parameters

Name Type Description Default
obj estimator object

Sub-estimator to check routed params for

required
method str

Sub-estimator's method where metadata is routed to (callee)

required
parent str

The parent method which called method (caller)

required
split_params tuple

Parameters which should be checked as subsets of the original values (used for CV splits where each fold gets a subset)

empty
**params dict

Expected metadata that should have been passed

{}

Raises

Type Description
AssertionError

If recorded metadata doesn't match expected metadata

Source Code

Show/Hide source
def check_recorded_metadata(obj, method: str, parent: str, split_params: tuple = (), **params) -> None:
    """Check whether the expected metadata is passed to the object's method.

    Parameters
    ----------
    obj : estimator object
        Sub-estimator to check routed params for
    method : str
        Sub-estimator's method where metadata is routed to (callee)
    parent : str
        The parent method which called `method` (caller)
    split_params : tuple, default=empty
        Parameters which should be checked as subsets of the original values
        (used for CV splits where each fold gets a subset)
    **params : dict
        Expected metadata that should have been passed

    Raises
    ------
    AssertionError
        If recorded metadata doesn't match expected metadata

    """
    all_records = getattr(obj, "_records", {}).get(method, {}).get(parent, [])
    for record in all_records:
        # Check that metadata names match
        assert set(params.keys()) == set(record.keys()), f"Expected {params.keys()} vs {record.keys()}"
        for key, value in params.items():
            expected_value = value
            recorded_value = record[key]
            # For split_params, check if recorded is a subset of original
            if key in split_params and recorded_value is not None:
                if isinstance(expected_value, pl.Series):
                    expected_value = expected_value.to_numpy()
                if isinstance(recorded_value, pl.Series):
                    recorded_value = recorded_value.to_numpy()
                assert np.isin(recorded_value, expected_value).all()
            elif isinstance(recorded_value, np.ndarray):
                assert_array_equal(recorded_value, expected_value)
            elif isinstance(recorded_value, pl.Series):
                assert recorded_value.equals(expected_value)
            else:
                assert recorded_value is expected_value, (
                    f"Expected {recorded_value} vs {expected_value}. Method: {method}"
                )