Skip to content

check_inverse_transform_identity

yohou.testing.transformer.check_inverse_transform_identity(transformer, X, y=None, atol=1e-06, rtol=1e-05)

Check inverse_transform(transform(X)) ≈ X.

Basic round-trip test for invertible transformers.

Parameters

Name Type Description Default
transformer BaseTransformer

Unfitted transformer

required
X DataFrame

Training data

required
y DataFrame

Target data

None
atol float

Absolute tolerance for numerical comparison

1e-06
rtol float

Relative tolerance for numerical comparison

1e-05

Raises

Type Description
AssertionError

If round-trip fails

Source Code

Show/Hide source
def check_inverse_transform_identity(
    transformer,
    X: pl.DataFrame,
    y: pl.DataFrame | None = None,
    atol: float = 1e-6,
    rtol: float = 1e-5,
) -> None:
    """Check inverse_transform(transform(X)) ≈ X.

    Basic round-trip test for invertible transformers.

    Parameters
    ----------
    transformer : BaseTransformer
        Unfitted transformer
    X : pl.DataFrame
        Training data
    y : pl.DataFrame, optional
        Target data
    atol : float
        Absolute tolerance for numerical comparison
    rtol : float
        Relative tolerance for numerical comparison

    Raises
    ------
    AssertionError
        If round-trip fails

    """
    tags = transformer.__sklearn_tags__()
    if not (tags.transformer_tags and tags.transformer_tags.invertible):
        return

    transformer_clone = clone(transformer)
    transformer_clone.fit(X, y)

    # Transform the data - this may drop some rows (e.g., differencing)
    X_trans = transformer_clone.transform(X)

    # For yohou transformers, inverse_transform requires X_p (past observations)
    # X_p should be the observations immediately before X_trans in the original X
    horizon = transformer_clone.observation_horizon
    if horizon > 0:
        # Find how many rows were dropped by transform
        n_dropped = len(X) - len(X_trans)
        # X_p: the `horizon` rows immediately before X_trans started
        # These are rows at position [n_dropped - horizon : n_dropped]
        X_p = X[n_dropped - horizon : n_dropped]
    else:
        # Stateless transformer
        X_p = None

    # Inverse transform
    X_reconstructed = transformer_clone.inverse_transform(X_trans, X_p)

    # The original data we should recover is the portion that was transformed
    # (excluding the dropped rows)
    X_expected = X.tail(len(X_trans))

    # Basic shape check
    assert X_reconstructed.shape == X_expected.shape, (
        f"Shape mismatch: {X_expected.shape} -> {X_trans.shape} -> {X_reconstructed.shape}"
    )

    # Numerical comparison
    assert_frame_equal(X_expected, X_reconstructed, rel_tol=rtol, abs_tol=atol)