Skip to content

check_observe_concatenates_memory

yohou.testing.transformer.check_observe_concatenates_memory(transformer, X, y=None)

Check observe() appends new data and maintains horizon size.

The observe() method should append new observations to _X_observed and trim to observation_horizon length.

Parameters

Name Type Description Default
transformer BaseTransformer

Unfitted transformer

required
X DataFrame

Initial training data

required
y DataFrame

Target data

None

Raises

Type Description
AssertionError

If observe() doesn't properly maintain memory

Source Code

Show/Hide source
def check_observe_concatenates_memory(transformer, X: pl.DataFrame, y: pl.DataFrame | None = None) -> None:
    """Check observe() appends new data and maintains horizon size.

    The observe() method should append new observations to _X_observed
    and trim to observation_horizon length.

    Parameters
    ----------
    transformer : BaseTransformer
        Unfitted transformer
    X : pl.DataFrame
        Initial training data
    y : pl.DataFrame, optional
        Target data

    Raises
    ------
    AssertionError
        If observe() doesn't properly maintain memory

    """
    transformer_clone = clone(transformer)

    # Split data: fit on first 80%, update with next 10 rows
    X_train, X_temp = train_test_split(X, test_size=0.2, shuffle=False)
    X_update = X_temp.head(10)  # Take first 10 rows from remaining 20%

    transformer_clone.fit(X_train, y)

    horizon = transformer_clone.observation_horizon
    initial_memory_len = len(transformer_clone._X_observed)

    transformer_clone.observe(X_update)

    # Memory should not exceed horizon
    assert len(transformer_clone._X_observed) <= horizon, (
        f"_X_observed length {len(transformer_clone._X_observed)} exceeds horizon {horizon}"
    )

    # Memory should have grown or stayed at horizon
    expected_len = min(initial_memory_len + len(X_update), horizon)
    assert len(transformer_clone._X_observed) == expected_len, (
        f"Expected _X_observed length {expected_len}, got {len(transformer_clone._X_observed)}"
    )