Skip to content

check_rewind_updates_memory

yohou.testing.transformer.check_rewind_updates_memory(transformer, X, y=None)

Check rewind() updates _X_observed to last observation_horizon rows.

The rewind() method should update the transformer's memory to contain only the last observation_horizon rows of the provided data.

Parameters

Name Type Description Default
transformer BaseTransformer

Unfitted transformer

required
X DataFrame

Training data (must be longer than observation_horizon)

required
y DataFrame

Target data

None

Raises

Type Description
AssertionError

If _X_observed is not properly updated

ValueError

If X is too short for the transformer's observation_horizon

Source Code

Show/Hide source
def check_rewind_updates_memory(transformer, X: pl.DataFrame, y: pl.DataFrame | None = None) -> None:
    """Check rewind() updates _X_observed to last observation_horizon rows.

    The rewind() method should update the transformer's memory to contain
    only the last observation_horizon rows of the provided data.

    Parameters
    ----------
    transformer : BaseTransformer
        Unfitted transformer
    X : pl.DataFrame
        Training data (must be longer than observation_horizon)
    y : pl.DataFrame, optional
        Target data

    Raises
    ------
    AssertionError
        If _X_observed is not properly updated
    ValueError
        If X is too short for the transformer's observation_horizon

    """
    transformer_clone = clone(transformer)
    transformer_clone.fit(X, y)

    horizon = transformer_clone.observation_horizon

    if len(X) < horizon:
        raise ValueError(f"X length {len(X)} < observation_horizon {horizon}")

    # Create new data to reset with
    X_new = X.head(horizon + 5) if len(X) >= horizon + 5 else X
    transformer_clone.rewind(X_new)

    # Check _X_observed has correct length
    assert len(transformer_clone._X_observed) == min(horizon, len(X_new)), (
        f"_X_observed length should be {min(horizon, len(X_new))}, got {len(transformer_clone._X_observed)}"
    )

    # Check _X_observed contains last horizon rows
    expected = X_new.tail(horizon)
    assert_frame_equal(transformer_clone._X_observed, expected)