Skip to content

check_feature_names_out_match

yohou.testing.transformer.check_feature_names_out_match(transformer, X, y=None)

Check get_feature_names_out() matches transform() output columns.

The feature names returned by get_feature_names_out() should match the actual columns in the transform() output (excluding 'time').

Parameters

Name Type Description Default
transformer BaseTransformer

Unfitted transformer

required
X DataFrame

Training data

required
y DataFrame

Target data

None

Raises

Type Description
AssertionError

If feature names don't match output columns

Source Code

Show/Hide source
def check_feature_names_out_match(transformer, X: pl.DataFrame, y: pl.DataFrame | None = None) -> None:
    """Check get_feature_names_out() matches transform() output columns.

    The feature names returned by get_feature_names_out() should match
    the actual columns in the transform() output (excluding 'time').

    Parameters
    ----------
    transformer : BaseTransformer
        Unfitted transformer
    X : pl.DataFrame
        Training data
    y : pl.DataFrame, optional
        Target data

    Raises
    ------
    AssertionError
        If feature names don't match output columns

    """
    transformer_clone = clone(transformer)
    transformer_clone.fit(X, y)

    X_trans = transformer_clone.transform(X)
    feature_names = transformer_clone.get_feature_names_out()

    # Get actual feature columns (exclude time)
    actual_features = [col for col in X_trans.columns if col != "time"]

    assert list(feature_names) == actual_features, (
        f"get_feature_names_out() mismatch: {list(feature_names)} vs {actual_features}"
    )