def check_clone_preserves_forecaster_params(forecaster) -> None:
"""Check sklearn's clone() preserves init parameters.
Enhanced version that handles nested estimators and meta-forecasters like
DecompositionPipeline and ColumnForecaster with list of (name, estimator) tuples.
Parameters
----------
forecaster : BaseForecaster
Forecaster instance
Raises
------
AssertionError
If cloned forecaster has different parameters
"""
forecaster_clone = clone(forecaster)
# Get parameters
original_params = forecaster.get_params(deep=False)
cloned_params = forecaster_clone.get_params(deep=False)
# Check same parameter keys
assert set(original_params.keys()) == set(cloned_params.keys()), (
f"clone() should have same parameter keys, got {set(cloned_params.keys())} vs {set(original_params.keys())}"
)
# Check parameter values (for nested estimators, check type)
for key in original_params:
orig_val = original_params[key]
cloned_val = cloned_params[key]
# For None values
if orig_val is None:
assert cloned_val is None, f"Parameter {key}: expected None, got {cloned_val}"
# For list of (name, estimator) tuples (meta-estimators like DecompositionPipeline, FeaturePipeline)
elif isinstance(orig_val, list) and len(orig_val) > 0 and isinstance(orig_val[0], tuple):
assert isinstance(cloned_val, list), f"Parameter {key}: expected list, got {type(cloned_val)}"
assert len(orig_val) == len(cloned_val), f"Parameter {key}: different lengths"
for i, (orig_item, cloned_item) in enumerate(zip(orig_val, cloned_val, strict=False)):
assert isinstance(orig_item, tuple), f"Parameter {key}[{i}]: expected tuple"
assert isinstance(cloned_item, tuple), f"Parameter {key}[{i}]: expected tuple"
assert len(orig_item) in (2, 3), (
f"Parameter {key}[{i}]: expected (name, estimator) or (name, estimator, columns) tuple, got length {len(orig_item)}"
)
assert len(cloned_item) == len(orig_item), (
f"Parameter {key}[{i}]: clone tuple length {len(cloned_item)} != original {len(orig_item)}"
)
orig_name, orig_est = orig_item[0], orig_item[1]
cloned_name, cloned_est = cloned_item[0], cloned_item[1]
# Names should match exactly
assert orig_name == cloned_name, f"Parameter {key}[{i}]: different names {cloned_name} != {orig_name}"
# Estimators should be different instances but same type
assert type(orig_est) is type(cloned_est), (
f"Parameter {key}[{i}] estimator: different types {type(cloned_est)} vs {type(orig_est)}"
)
assert orig_est is not cloned_est, (
f"Parameter {key}[{i}] estimator: should be cloned, not same instance"
)
# Check estimator params match
if hasattr(orig_est, "get_params"):
orig_est_params = orig_est.get_params(deep=True) # ty: ignore[call-non-callable]
cloned_est_params = cloned_est.get_params(deep=True) # ty: ignore[unresolved-attribute]
for param_key in orig_est_params:
orig_param = orig_est_params.get(param_key)
cloned_param = cloned_est_params.get(param_key)
if hasattr(orig_param, "get_params"):
assert type(orig_param) is type(cloned_param), (
f"Parameter {key}[{i}]__{param_key}: different types"
)
elif orig_param != cloned_param:
assert orig_param == cloned_param, (
f"Parameter {key}[{i}]__{param_key}: {cloned_param} != {orig_param}"
)
# For 3-tuples (name, estimator, columns), compare the columns element
if len(orig_item) == 3:
orig_cols = orig_item[2]
cloned_cols = cloned_item[2]
assert orig_cols == cloned_cols, f"Parameter {key}[{i}] columns: {cloned_cols} != {orig_cols}"
elif isinstance(orig_val, type):
assert orig_val is cloned_val, (
f"Parameter {key}: class type should be preserved by clone, got {cloned_val} vs {orig_val}"
)
# For estimator instances, check type and params (recursively)
elif hasattr(orig_val, "get_params"):
assert type(orig_val) is type(cloned_val), (
f"Parameter {key}: different types {type(cloned_val)} vs {type(orig_val)}"
)
# Use deep=True to get all nested params, compare them
orig_deep_params = orig_val.get_params(deep=True)
cloned_deep_params = cloned_val.get_params(deep=True)
# Compare only primitive values and types (not object instances)
for param_key in orig_deep_params:
orig_param = orig_deep_params.get(param_key)
cloned_param = cloned_deep_params.get(param_key)
# Skip comparing estimator instances themselves, just check types
if hasattr(orig_param, "get_params"):
assert type(orig_param) is type(cloned_param), f"Parameter {key}__{param_key}: different types"
elif orig_param != cloned_param:
assert orig_param == cloned_param, f"Parameter {key}__{param_key}: {cloned_param} != {orig_param}"
# For other values, direct comparison
else:
try:
are_equal = bool(orig_val == cloned_val)
except Exception:
are_equal = False
if not are_equal:
# Fall back to type comparison for objects that don't define __eq__
# (e.g., PyTorch modules, neuralforecast loss functions).
assert type(orig_val) is type(cloned_val), f"Parameter {key}: {cloned_val} != {orig_val}"
# Check they are different objects
assert forecaster_clone is not forecaster, "clone() should create new instance"