Skip to content

check_fit_sets_attributes

yohou.testing.transformer.check_fit_sets_attributes(transformer, X, y=None)

Check fit() sets required attributes.

Validates that fit() creates feature_names_in_, n_features_in_, and _observation_horizon attributes as expected by sklearn conventions.

Parameters

Name Type Description Default
transformer BaseTransformer

Unfitted transformer instance

required
X DataFrame

Training data with "time" column

required
y DataFrame

Target data for supervised transformers

None

Raises

Type Description
AssertionError

If required attributes are not set after fit()

Source Code

Show/Hide source
def check_fit_sets_attributes(transformer, X: pl.DataFrame, y: pl.DataFrame | None = None) -> None:
    """Check fit() sets required attributes.

    Validates that fit() creates feature_names_in_, n_features_in_,
    and _observation_horizon attributes as expected by sklearn conventions.

    Parameters
    ----------
    transformer : BaseTransformer
        Unfitted transformer instance
    X : pl.DataFrame
        Training data with "time" column
    y : pl.DataFrame, optional
        Target data for supervised transformers

    Raises
    ------
    AssertionError
        If required attributes are not set after fit()

    """
    transformer_clone = clone(transformer)
    transformer_clone.fit(X, y)

    # Check sklearn-required attributes
    assert hasattr(transformer_clone, "feature_names_in_"), "fit() must set feature_names_in_ attribute"
    assert hasattr(transformer_clone, "n_features_in_"), "fit() must set n_features_in_ attribute"

    # Check yohou-required attributes
    assert hasattr(transformer_clone, "_observation_horizon"), "fit() must set _observation_horizon attribute"

    # Validate values
    expected_features = [col for col in X.columns if col != "time"]
    assert list(transformer_clone.feature_names_in_) == expected_features, (
        f"feature_names_in_ mismatch: {transformer_clone.feature_names_in_} vs {expected_features}"
    )

    assert transformer_clone.n_features_in_ == len(expected_features), (
        f"n_features_in_ should be {len(expected_features)}, got {transformer_clone.n_features_in_}"
    )