Skip to content

check_panel_group_preservation

yohou.testing.transformer.check_panel_group_preservation(transformer, X_panel, y=None)

Check that transformers preserve panel group names after transformation.

Panel data uses columns with __ separator (<GROUP>__<SERIES>). After transformation, the panel group names must remain unchanged even if the series names change (e.g., store_1__sales may become store_1__diff_s_7_sales but the group must stay store_1).

Parameters

Name Type Description Default
transformer BaseTransformer

Unfitted transformer.

required
X_panel DataFrame

Panel data with panel columns.

required
y DataFrame

Target data.

None

Raises

Type Description
AssertionError

If panel group names are not preserved after transformation.

Source Code

Show/Hide source
def check_panel_group_preservation(transformer, X_panel: pl.DataFrame, y: pl.DataFrame | None = None) -> None:
    """Check that transformers preserve panel group names after transformation.

    Panel data uses columns with ``__`` separator (``<GROUP>__<SERIES>``).
    After transformation, the panel group names must remain unchanged even
    if the series names change (e.g., ``store_1__sales`` may become
    ``store_1__diff_s_7_sales`` but the group must stay ``store_1``).

    Parameters
    ----------
    transformer : BaseTransformer
        Unfitted transformer.
    X_panel : pl.DataFrame
        Panel data with panel columns.
    y : pl.DataFrame, optional
        Target data.

    Raises
    ------
    AssertionError
        If panel group names are not preserved after transformation.

    """
    # Check if X_panel actually has panel columns
    _, input_panel_groups = inspect_panel(X_panel)

    if not input_panel_groups:
        # Not panel data, skip
        return

    input_group_names = set(input_panel_groups.keys())

    transformer_clone = clone(transformer)

    try:
        transformer_clone.fit(X_panel, y)
        X_trans = transformer_clone.transform(X_panel)
    except NotImplementedError:
        # Transformer explicitly doesn't support panel data
        return

    # Inspect the output for panel groups
    _, output_panel_groups = inspect_panel(X_trans)
    output_group_names = set(output_panel_groups.keys())

    # Output must have the same panel group names
    assert output_group_names == input_group_names, (
        f"Panel group names changed after transformation. "
        f"Input groups: {sorted(input_group_names)}, "
        f"Output groups: {sorted(output_group_names)}. "
        f"Transformers must preserve panel group prefixes. "
        f"Output columns: {X_trans.columns}"
    )