Skip to content

check_observe_transform_sequential_consistency

yohou.testing.transformer.check_observe_transform_sequential_consistency(transformer, X, y=None)

Check observe_transform(A) then observe_transform(B) == observe_transform(A+B).

Sequential observe_transform calls should produce the same output as a single observe_transform call on concatenated data. Also verifies that internal state (_X_observed) is consistent after both operations.

Parameters

Name Type Description Default
transformer BaseTransformer

Unfitted transformer

required
X DataFrame

Training data (will be split into fit, A, B portions)

required
y DataFrame

Target data

None

Raises

Type Description
AssertionError

If sequential updates produce different results than combined update

Notes

This check splits X into three parts: - X_fit: Used for initial fit - A: First observe_transform batch - B: Second observe_transform batch

Then verifies: 1. concat(observe_transform(A), observe_transform(B)) == observe_transform(concat(A, B)) 2. _X_observed is identical after both operations

Source Code

Show/Hide source
def check_observe_transform_sequential_consistency(transformer, X: pl.DataFrame, y: pl.DataFrame | None = None) -> None:
    """Check observe_transform(A) then observe_transform(B) == observe_transform(A+B).

    Sequential observe_transform calls should produce the same output as a
    single observe_transform call on concatenated data. Also verifies that
    internal state (_X_observed) is consistent after both operations.

    Parameters
    ----------
    transformer : BaseTransformer
        Unfitted transformer
    X : pl.DataFrame
        Training data (will be split into fit, A, B portions)
    y : pl.DataFrame, optional
        Target data

    Raises
    ------
    AssertionError
        If sequential updates produce different results than combined update

    Notes
    -----
    This check splits X into three parts:
    - X_fit: Used for initial fit
    - A: First observe_transform batch
    - B: Second observe_transform batch

    Then verifies:
    1. concat(observe_transform(A), observe_transform(B)) == observe_transform(concat(A, B))
    2. _X_observed is identical after both operations

    """
    # Need enough data for 3-way split
    if len(X) < 12:
        return

    # Split into fit portion and two update portions
    fit_size = len(X) // 2
    update_size = (len(X) - fit_size) // 2
    X_fit = X[:fit_size]
    A = X[fit_size : fit_size + update_size]
    B = X[fit_size + update_size :]

    # Path 1: Sequential observe_transform calls
    transformer1 = clone(transformer)
    transformer1.fit(X_fit, y)
    A_trans = transformer1.observe_transform(A)
    B_trans = transformer1.observe_transform(B)
    output_sequential = pl.concat([A_trans, B_trans])

    # Path 2: Single observe_transform on concatenated data
    transformer2 = clone(transformer)
    transformer2.fit(X_fit, y)
    AB = pl.concat([A, B])
    output_combined = transformer2.observe_transform(AB)

    # Outputs should be equivalent
    assert_frame_equal(
        output_sequential,
        output_combined,
        rel_tol=1e-6,
        abs_tol=1e-8,
    )

    # Internal state should also match
    if hasattr(transformer1, "_X_observed") and hasattr(transformer2, "_X_observed"):
        assert_frame_equal(
            transformer1._X_observed,
            transformer2._X_observed,
            rel_tol=1e-6,
            abs_tol=1e-8,
        )