Skip to content

check_rewind_transform_behavior

yohou.testing.transformer.check_rewind_transform_behavior(transformer, X, y=None)

Check rewind_transform() behavior and contract.

Verifies that rewind_transform(): 1. Does not use pre-existing _X_observed from transformer's memory 2. Calls transform() and discards the first observation_horizon values 3. Resets the internal state with the input data

Parameters

Name Type Description Default
transformer BaseTransformer

Unfitted transformer

required
X DataFrame

Training data (needs to be long enough)

required
y DataFrame

Target data

None

Raises

Type Description
AssertionError

If rewind_transform doesn't follow the expected contract

Source Code

Show/Hide source
def check_rewind_transform_behavior(transformer, X: pl.DataFrame, y: pl.DataFrame | None = None) -> None:
    """Check rewind_transform() behavior and contract.

    Verifies that rewind_transform():
    1. Does not use pre-existing _X_observed from transformer's memory
    2. Calls transform() and discards the first observation_horizon values
    3. Resets the internal state with the input data

    Parameters
    ----------
    transformer : BaseTransformer
        Unfitted transformer
    X : pl.DataFrame
        Training data (needs to be long enough)
    y : pl.DataFrame, optional
        Target data

    Raises
    ------
    AssertionError
        If rewind_transform doesn't follow the expected contract

    """
    # Need enough data for meaningful test
    transformer_clone = clone(transformer)
    transformer_clone.fit(X, y)

    horizon = transformer_clone.observation_horizon

    # Need enough data: at least 2 * observation_horizon + 1
    min_len = 2 * horizon + 1 if horizon > 0 else 5
    if len(X) < min_len:
        return

    # Split data into fit and test portions
    split_point = len(X) // 2
    X_fit = X[:split_point]
    X_new = X[split_point:]

    # Test that rewind_transform doesn't use pre-existing memory
    transformer1 = clone(transformer)
    transformer1.fit(X_fit, y)

    # Apply rewind_transform
    X_rewind_trans = transformer1.rewind_transform(X_new)

    # Expected behavior: transform(X_new) (transform already drops warmup rows)
    transformer2 = clone(transformer)
    transformer2.fit(X_fit, y)  # Fit with same data to have same fitted params
    X_expected = transformer2.transform(X_new)

    # Check outputs match
    assert_frame_equal(
        X_rewind_trans,
        X_expected,
        rel_tol=1e-6,
        abs_tol=1e-8,
    )

    # Check that internal state was reset to X_new
    if hasattr(transformer1, "_X_observed") and horizon > 0:
        expected_observed = X_new[-horizon:] if horizon <= len(X_new) else X_new
        assert_frame_equal(
            transformer1._X_observed,
            expected_observed,
            rel_tol=1e-6,
            abs_tol=1e-8,
        )