Skip to content

check_inverse_observe_transform_identity

yohou.testing.transformer.check_inverse_observe_transform_identity(transformer, X, y=None, atol=1e-06, rtol=1e-05)

Check inverse_transform(observe_transform(X)) ≈ X.

Verifies that inverse_transform correctly inverts observe_transform output. For stateless transformers: inverse_transform(observe_transform(X), X_p=None) == X For stateful transformers: uses X_p from before the observe_transform window.

Parameters

Name Type Description Default
transformer BaseTransformer

Unfitted transformer

required
X DataFrame

Training data (will be split for fit and observe_transform)

required
y DataFrame

Target data

None
atol float

Absolute tolerance for numerical comparison

1e-06
rtol float

Relative tolerance for numerical comparison

1e-05

Raises

Type Description
AssertionError

If inverse round-trip via observe_transform fails

Notes

This check differs from check_inverse_transform_identity: - Uses observe_transform instead of transform - Tests round-trip on data AFTER initial fit (streaming scenario) - Verifies X_p handling for stateful transformers during updates

Source Code

Show/Hide source
def check_inverse_observe_transform_identity(
    transformer,
    X: pl.DataFrame,
    y: pl.DataFrame | None = None,
    atol: float = 1e-6,
    rtol: float = 1e-5,
) -> None:
    """Check inverse_transform(observe_transform(X)) ≈ X.

    Verifies that inverse_transform correctly inverts observe_transform output.
    For stateless transformers: inverse_transform(observe_transform(X), X_p=None) == X
    For stateful transformers: uses X_p from before the observe_transform window.

    Parameters
    ----------
    transformer : BaseTransformer
        Unfitted transformer
    X : pl.DataFrame
        Training data (will be split for fit and observe_transform)
    y : pl.DataFrame, optional
        Target data
    atol : float
        Absolute tolerance for numerical comparison
    rtol : float
        Relative tolerance for numerical comparison

    Raises
    ------
    AssertionError
        If inverse round-trip via observe_transform fails

    Notes
    -----
    This check differs from check_inverse_transform_identity:
    - Uses observe_transform instead of transform
    - Tests round-trip on data AFTER initial fit (streaming scenario)
    - Verifies X_p handling for stateful transformers during updates

    """
    tags = transformer.__sklearn_tags__()
    if not (tags.transformer_tags and tags.transformer_tags.invertible):
        return

    # Need enough data to split into fit and update portions
    if len(X) < 10:
        return

    # Split data: first half for fit, second half for observe_transform
    split_idx = len(X) // 2
    X_fit = X[:split_idx]
    X_update = X[split_idx:]

    transformer_clone = clone(transformer)
    transformer_clone.fit(X_fit, y)

    horizon = transformer_clone.observation_horizon

    # Get X_p BEFORE observe_transform (from the fitted state)
    # For stateful transformers, X_p is the last `horizon` rows of observed data
    X_p = transformer_clone._X_observed.clone() if horizon > 0 else None

    # Apply observe_transform - this transforms X_update using memory context
    X_trans = transformer_clone.observe_transform(X_update)

    # Inverse transform to recover original
    X_reconstructed = transformer_clone.inverse_transform(X_trans, X_p)

    # The portion of X we should recover is X_update (possibly truncated)
    # If transformer drops rows, we compare against the tail of X_update
    X_expected = X_update.tail(len(X_trans))

    # Shape check
    assert X_reconstructed.shape == X_expected.shape, (
        f"Shape mismatch: {X_expected.shape} -> {X_trans.shape} -> {X_reconstructed.shape}"
    )

    # Numerical comparison
    assert_frame_equal(X_expected, X_reconstructed, rel_tol=rtol, abs_tol=atol)