Skip to content

check_rewind_propagates_to_transformers

yohou.testing.forecaster.check_rewind_propagates_to_transformers(forecaster, y_train, y_reset, X_actual_train=None, X_actual_reset=None, X_future=None, X_forecast=None)

Check rewind() propagates to transformers in forecaster.

When a forecaster with transformers calls rewind(), the transformers should also have their observation buffers reset accordingly.

Parameters

Name Type Description Default
forecaster BaseForecaster

Fitted forecaster instance with transformers

required
y_train DataFrame

Original training data

required
y_reset DataFrame

New data for reset

required
X_actual_train DataFrame

Features for training

None
X_actual_reset DataFrame

Features for reset

None

Raises

Type Description
AssertionError

If transformers are not properly reset

Source Code

Show/Hide source
def check_rewind_propagates_to_transformers(
    forecaster,
    y_train: pl.DataFrame,
    y_reset: pl.DataFrame,
    X_actual_train: pl.DataFrame | None = None,
    X_actual_reset: pl.DataFrame | None = None,
    X_future: pl.DataFrame | None = None,
    X_forecast: pl.DataFrame | None = None,
) -> None:
    """Check rewind() propagates to transformers in forecaster.

    When a forecaster with transformers calls rewind(), the transformers
    should also have their observation buffers reset accordingly.

    Parameters
    ----------
    forecaster : BaseForecaster
        Fitted forecaster instance with transformers
    y_train : pl.DataFrame
        Original training data
    y_reset : pl.DataFrame
        New data for reset
    X_actual_train : pl.DataFrame, optional
        Features for training
    X_actual_reset : pl.DataFrame, optional
        Features for reset

    Raises
    ------
    AssertionError
        If transformers are not properly reset

    """
    # Check if forecaster has transformers (target_transformer or feature_transformer)
    if not hasattr(forecaster, "target_transformer_") and not hasattr(forecaster, "feature_transformer_"):
        return  # Nothing to check

    # Rewind the forecaster
    forecaster.rewind(y_reset, X_actual=X_actual_reset, X_future=X_future, X_forecast=X_forecast)

    # Check target transformer is reset
    if hasattr(forecaster, "target_transformer_") and forecaster.target_transformer_ is not None:
        if isinstance(forecaster.target_transformer_, dict):
            # Panel data - check each transformer
            for group_name, transformer in forecaster.target_transformer_.items():
                if (
                    hasattr(transformer, "_X_observed")
                    and transformer._X_observed is not None
                    and getattr(transformer, "observation_horizon", 0) > 0
                ):
                    # Transformer should have observation data matching reset data
                    assert len(transformer._X_observed) > 0, (
                        f"Target transformer for group '{group_name}' should have observations after rewind"
                    )
        # Non-panel data
        elif (
            hasattr(forecaster.target_transformer_, "_X_observed")
            and forecaster.target_transformer_._X_observed is not None
            and getattr(forecaster.target_transformer_, "observation_horizon", 0) > 0
        ):
            assert len(forecaster.target_transformer_._X_observed) > 0, (
                "Target transformer should have observations after rewind"
            )

    # Check feature transformer is reset (if exists)
    if hasattr(forecaster, "feature_transformer_") and forecaster.feature_transformer_ is not None:
        if isinstance(forecaster.feature_transformer_, dict):
            # Panel data - check each transformer
            for group_name, transformer in forecaster.feature_transformer_.items():
                if (
                    hasattr(transformer, "_X_observed")
                    and transformer._X_observed is not None
                    and getattr(transformer, "observation_horizon", 0) > 0
                ):
                    assert len(transformer._X_observed) > 0, (
                        f"Feature transformer for group '{group_name}' should have observations after rewind"
                    )
        # Non-panel data
        elif (
            hasattr(forecaster.feature_transformer_, "_X_observed")
            and forecaster.feature_transformer_._X_observed is not None
            and getattr(forecaster.feature_transformer_, "observation_horizon", 0) > 0
        ):
            assert len(forecaster.feature_transformer_._X_observed) > 0, (
                "Feature transformer should have observations after rewind"
            )