Check sklearn clone() preserves search CV parameters.
Parameters
| Name |
Type |
Description |
Default |
search_cv
|
BaseSearchCV
|
Search CV instance (fitted or unfitted)
|
required
|
Raises
Source Code
View on GitHub
Show/Hide source
| def check_search_clone_preserves_params(search_cv) -> None:
"""Check sklearn clone() preserves search CV parameters.
Parameters
----------
search_cv : BaseSearchCV
Search CV instance (fitted or unfitted)
Raises
------
AssertionError
If clone doesn't preserve parameters
"""
search_cv_cloned = clone(search_cv)
# Get parameters from original and clone
params_original = search_cv.get_params()
params_cloned = search_cv_cloned.get_params()
# Check that parameter names match
assert params_original.keys() == params_cloned.keys(), (
f"Cloned search CV should have same parameter names, "
f"got {set(params_original.keys())} vs {set(params_cloned.keys())}"
)
# Check that forecaster type matches
assert type(params_original["forecaster"]) is type(params_cloned["forecaster"]), (
f"Cloned forecaster type should match, "
f"got {type(params_original['forecaster'])} vs {type(params_cloned['forecaster'])}"
)
# Check that scorer type matches
if params_original["scoring"] is not None:
if isinstance(params_original["scoring"], dict):
assert isinstance(params_cloned["scoring"], dict), "Cloned scoring should be dict when original is dict"
else:
assert type(params_original["scoring"]) is type(params_cloned["scoring"]), (
f"Cloned scoring type should match, "
f"got {type(params_original['scoring'])} vs {type(params_cloned['scoring'])}"
)
# Check that fitted state is NOT cloned
if hasattr(search_cv, "cv_results_"):
assert not hasattr(search_cv_cloned, "cv_results_"), "Cloned search CV should not have fitted attributes"
|