Skip to content

check_clone_preserves_forecaster_params

yohou.testing.forecaster.check_clone_preserves_forecaster_params(forecaster)

Check sklearn's clone() preserves init parameters.

Enhanced version that handles nested estimators and meta-forecasters like DecompositionPipeline and ColumnForecaster with list of (name, estimator) tuples.

Parameters

Name Type Description Default
forecaster BaseForecaster

Forecaster instance

required

Raises

Type Description
AssertionError

If cloned forecaster has different parameters

Source Code

Show/Hide source
def check_clone_preserves_forecaster_params(forecaster) -> None:
    """Check sklearn's clone() preserves init parameters.

    Enhanced version that handles nested estimators and meta-forecasters like
    DecompositionPipeline and ColumnForecaster with list of (name, estimator) tuples.

    Parameters
    ----------
    forecaster : BaseForecaster
        Forecaster instance

    Raises
    ------
    AssertionError
        If cloned forecaster has different parameters

    """
    forecaster_clone = clone(forecaster)

    # Get parameters
    original_params = forecaster.get_params(deep=False)
    cloned_params = forecaster_clone.get_params(deep=False)

    # Check same parameter keys
    assert set(original_params.keys()) == set(cloned_params.keys()), (
        f"clone() should have same parameter keys, got {set(cloned_params.keys())} vs {set(original_params.keys())}"
    )

    # Check parameter values (for nested estimators, check type)
    for key in original_params:
        orig_val = original_params[key]
        cloned_val = cloned_params[key]

        # For None values
        if orig_val is None:
            assert cloned_val is None, f"Parameter {key}: expected None, got {cloned_val}"
        # For list of (name, estimator) tuples (meta-estimators like DecompositionPipeline, FeaturePipeline)
        elif isinstance(orig_val, list) and len(orig_val) > 0 and isinstance(orig_val[0], tuple):
            assert isinstance(cloned_val, list), f"Parameter {key}: expected list, got {type(cloned_val)}"
            assert len(orig_val) == len(cloned_val), f"Parameter {key}: different lengths"

            for i, (orig_item, cloned_item) in enumerate(zip(orig_val, cloned_val, strict=False)):
                assert isinstance(orig_item, tuple), f"Parameter {key}[{i}]: expected tuple"
                assert isinstance(cloned_item, tuple), f"Parameter {key}[{i}]: expected tuple"
                assert len(orig_item) in (2, 3), (
                    f"Parameter {key}[{i}]: expected (name, estimator) or (name, estimator, columns) tuple, got length {len(orig_item)}"
                )
                assert len(cloned_item) == len(orig_item), (
                    f"Parameter {key}[{i}]: clone tuple length {len(cloned_item)} != original {len(orig_item)}"
                )

                orig_name, orig_est = orig_item[0], orig_item[1]
                cloned_name, cloned_est = cloned_item[0], cloned_item[1]

                # Names should match exactly
                assert orig_name == cloned_name, f"Parameter {key}[{i}]: different names {cloned_name} != {orig_name}"

                # Estimators should be different instances but same type
                assert type(orig_est) is type(cloned_est), (
                    f"Parameter {key}[{i}] estimator: different types {type(cloned_est)} vs {type(orig_est)}"
                )
                assert orig_est is not cloned_est, (
                    f"Parameter {key}[{i}] estimator: should be cloned, not same instance"
                )

                # Check estimator params match
                if hasattr(orig_est, "get_params"):
                    orig_est_params = orig_est.get_params(deep=True)  # ty: ignore[call-non-callable]
                    cloned_est_params = cloned_est.get_params(deep=True)  # ty: ignore[unresolved-attribute]
                    for param_key in orig_est_params:
                        orig_param = orig_est_params.get(param_key)
                        cloned_param = cloned_est_params.get(param_key)
                        if hasattr(orig_param, "get_params"):
                            assert type(orig_param) is type(cloned_param), (
                                f"Parameter {key}[{i}]__{param_key}: different types"
                            )
                        elif orig_param != cloned_param:
                            assert orig_param == cloned_param, (
                                f"Parameter {key}[{i}]__{param_key}: {cloned_param} != {orig_param}"
                            )

                # For 3-tuples (name, estimator, columns), compare the columns element
                if len(orig_item) == 3:
                    orig_cols = orig_item[2]
                    cloned_cols = cloned_item[2]
                    assert orig_cols == cloned_cols, f"Parameter {key}[{i}] columns: {cloned_cols} != {orig_cols}"
        elif isinstance(orig_val, type):
            assert orig_val is cloned_val, (
                f"Parameter {key}: class type should be preserved by clone, got {cloned_val} vs {orig_val}"
            )
        # For estimator instances, check type and params (recursively)
        elif hasattr(orig_val, "get_params"):
            assert type(orig_val) is type(cloned_val), (
                f"Parameter {key}: different types {type(cloned_val)} vs {type(orig_val)}"
            )
            # Use deep=True to get all nested params, compare them
            orig_deep_params = orig_val.get_params(deep=True)
            cloned_deep_params = cloned_val.get_params(deep=True)

            # Compare only primitive values and types (not object instances)
            for param_key in orig_deep_params:
                orig_param = orig_deep_params.get(param_key)
                cloned_param = cloned_deep_params.get(param_key)

                # Skip comparing estimator instances themselves, just check types
                if hasattr(orig_param, "get_params"):
                    assert type(orig_param) is type(cloned_param), f"Parameter {key}__{param_key}: different types"
                elif orig_param != cloned_param:
                    assert orig_param == cloned_param, f"Parameter {key}__{param_key}: {cloned_param} != {orig_param}"
        # For other values, direct comparison
        else:
            try:
                are_equal = bool(orig_val == cloned_val)
            except Exception:
                are_equal = False
            if not are_equal:
                # Fall back to type comparison for objects that don't define __eq__
                # (e.g., PyTorch modules, neuralforecast loss functions).
                assert type(orig_val) is type(cloned_val), f"Parameter {key}: {cloned_val} != {orig_val}"

    # Check they are different objects
    assert forecaster_clone is not forecaster, "clone() should create new instance"