Skip to content

check_search_clone_preserves_params

yohou.testing.search.check_search_clone_preserves_params(search_cv)

Check sklearn clone() preserves search CV parameters.

Parameters

Name Type Description Default
search_cv BaseSearchCV

Search CV instance (fitted or unfitted)

required

Raises

Type Description
AssertionError

If clone doesn't preserve parameters

Source Code

Show/Hide source
def check_search_clone_preserves_params(search_cv) -> None:
    """Check sklearn clone() preserves search CV parameters.

    Parameters
    ----------
    search_cv : BaseSearchCV
        Search CV instance (fitted or unfitted)

    Raises
    ------
    AssertionError
        If clone doesn't preserve parameters

    """
    search_cv_cloned = clone(search_cv)

    # Get parameters from original and clone
    params_original = search_cv.get_params()
    params_cloned = search_cv_cloned.get_params()

    # Check that parameter names match
    assert params_original.keys() == params_cloned.keys(), (
        f"Cloned search CV should have same parameter names, "
        f"got {set(params_original.keys())} vs {set(params_cloned.keys())}"
    )

    # Check that forecaster type matches
    assert type(params_original["forecaster"]) is type(params_cloned["forecaster"]), (
        f"Cloned forecaster type should match, "
        f"got {type(params_original['forecaster'])} vs {type(params_cloned['forecaster'])}"
    )

    # Check that scorer type matches
    if params_original["scoring"] is not None:
        if isinstance(params_original["scoring"], dict):
            assert isinstance(params_cloned["scoring"], dict), "Cloned scoring should be dict when original is dict"
        else:
            assert type(params_original["scoring"]) is type(params_cloned["scoring"]), (
                f"Cloned scoring type should match, "
                f"got {type(params_original['scoring'])} vs {type(params_cloned['scoring'])}"
            )

    # Check that fitted state is NOT cloned
    if hasattr(search_cv, "cv_results_"):
        assert not hasattr(search_cv_cloned, "cv_results_"), "Cloned search CV should not have fitted attributes"