Skip to content

check_inverse_transform_round_trip

yohou.testing.transformer.check_inverse_transform_round_trip(transformer, X, y=None, atol=1e-06, rtol=1e-05)

Check inverse_transform(transform(X)) ≈ X with shape validation.

More comprehensive than check_inverse_transform_identity: - Validates shape preservation - Checks dtype consistency - Handles panel data columns - Configurable tolerance

Parameters

Name Type Description Default
transformer BaseTransformer

Unfitted invertible transformer

required
X DataFrame

Test data

required
y DataFrame

Target data

None
atol float

Absolute tolerance for numerical comparison

1e-06
rtol float

Relative tolerance for numerical comparison

1e-05

Raises

Type Description
AssertionError

If round-trip fails

Source Code

Show/Hide source
def check_inverse_transform_round_trip(
    transformer,
    X: pl.DataFrame,
    y: pl.DataFrame | None = None,
    atol: float = 1e-6,
    rtol: float = 1e-5,
) -> None:
    """Check inverse_transform(transform(X)) ≈ X with shape validation.

    More comprehensive than check_inverse_transform_identity:
    - Validates shape preservation
    - Checks dtype consistency
    - Handles panel data columns
    - Configurable tolerance

    Parameters
    ----------
    transformer : BaseTransformer
        Unfitted invertible transformer
    X : pl.DataFrame
        Test data
    y : pl.DataFrame, optional
        Target data
    atol : float
        Absolute tolerance for numerical comparison
    rtol : float
        Relative tolerance for numerical comparison

    Raises
    ------
    AssertionError
        If round-trip fails

    """
    tags = transformer.__sklearn_tags__()
    if not (tags.transformer_tags and tags.transformer_tags.invertible):
        return

    transformer_clone = clone(transformer)
    transformer_clone.fit(X, y)

    # Forward transform
    X_trans = transformer_clone.transform(X)

    # For yohou transformers, inverse_transform requires X_p (past observations)
    horizon = transformer_clone.observation_horizon
    if horizon > 0:
        n_dropped = len(X) - len(X_trans)
        X_p = X[n_dropped - horizon : n_dropped]
    else:
        X_p = None

    # Backward transform
    X_reconstructed = transformer_clone.inverse_transform(X_trans, X_p)

    # Expected is the portion of X that was transformed
    X_expected = X.tail(len(X_trans))

    # Shape validation
    assert X_reconstructed.shape == X_expected.shape, (
        f"Shape mismatch: {X_expected.shape} -> {X_trans.shape} -> {X_reconstructed.shape}"
    )

    # Column validation
    assert set(X_reconstructed.columns) == set(X_expected.columns), (
        f"Columns mismatch: {set(X_expected.columns)} vs {set(X_reconstructed.columns)}"
    )

    # Dtype validation (excluding 'time')
    for col in X_expected.select(~cs.by_name("time")).columns:
        assert X_expected[col].dtype == X_reconstructed[col].dtype, (
            f"Dtype changed for '{col}': {X_expected[col].dtype} -> {X_reconstructed[col].dtype}"
        )

    # Numerical comparison
    assert_frame_equal(X_expected, X_reconstructed, rel_tol=rtol, abs_tol=atol)