Skip to content

check_transform_drops_warmup_rows

yohou.testing.transformer.check_transform_drops_warmup_rows(transformer, X, y=None)

Check stateful transformers drop exactly observation_horizon rows.

Stateful transformers (observation_horizon > 0) must drop the first observation_horizon rows in their transform() output. Stateless transformers are skipped (they may legitimately change row count, e.g. Downsampler).

Parameters

Name Type Description Default
transformer BaseTransformer

Unfitted transformer

required
X DataFrame

Training data (must have enough rows)

required
y DataFrame

Target data

None

Raises

Type Description
AssertionError

If the number of dropped rows doesn't match observation_horizon

Source Code

Show/Hide source
def check_transform_drops_warmup_rows(transformer, X: pl.DataFrame, y: pl.DataFrame | None = None) -> None:
    """Check stateful transformers drop exactly observation_horizon rows.

    Stateful transformers (observation_horizon > 0) must drop the first
    observation_horizon rows in their transform() output. Stateless
    transformers are skipped (they may legitimately change row count,
    e.g. Downsampler).

    Parameters
    ----------
    transformer : BaseTransformer
        Unfitted transformer
    X : pl.DataFrame
        Training data (must have enough rows)
    y : pl.DataFrame, optional
        Target data

    Raises
    ------
    AssertionError
        If the number of dropped rows doesn't match observation_horizon

    """
    transformer_clone = clone(transformer)
    transformer_clone.fit(X, y)

    horizon = transformer_clone.observation_horizon
    X_t = transformer_clone.transform(X)

    # Stateless transformers may legitimately change row count (e.g. Downsampler)
    if horizon == 0:
        return

    assert len(X_t) == len(X) - horizon, (
        f"Stateful transformer should drop exactly {horizon} rows "
        f"(observation_horizon), but input has {len(X)} rows and "
        f"output has {len(X_t)} rows (dropped {len(X) - len(X_t)})"
    )
    assert X_t["time"][0] == X["time"][horizon], (
        f"First output timestamp should be X['time'][{horizon}] = {X['time'][horizon]}, but got {X_t['time'][0]}"
    )