Skip to content

check_rewind_replaces_observations

yohou.testing.forecaster.check_rewind_replaces_observations(forecaster, y_train, y_reset, X_actual_train=None, X_actual_reset=None, X_future=None, X_forecast=None)

Check rewind() replaces observation buffers correctly.

Parameters

Name Type Description Default
forecaster BaseForecaster

Fitted forecaster instance

required
y_train DataFrame

Original training data

required
y_reset DataFrame

New data for reset

required
X_actual_train DataFrame

Features for training

None
X_actual_reset DataFrame

Features for reset

None

Raises

Type Description
AssertionError

If observation buffers are not replaced correctly

Source Code

Show/Hide source
def check_rewind_replaces_observations(
    forecaster,
    y_train: pl.DataFrame,
    y_reset: pl.DataFrame,
    X_actual_train: pl.DataFrame | None = None,
    X_actual_reset: pl.DataFrame | None = None,
    X_future: pl.DataFrame | None = None,
    X_forecast: pl.DataFrame | None = None,
) -> None:
    """Check rewind() replaces observation buffers correctly.

    Parameters
    ----------
    forecaster : BaseForecaster
        Fitted forecaster instance
    y_train : pl.DataFrame
        Original training data
    y_reset : pl.DataFrame
        New data for reset
    X_actual_train : pl.DataFrame, optional
        Features for training
    X_actual_reset : pl.DataFrame, optional
        Features for reset

    Raises
    ------
    AssertionError
        If observation buffers are not replaced correctly

    """
    # Store original buffer length
    original_observed_time = forecaster.observed_time_

    # Handle both panel (dict) and non-panel (DataFrame or scalar) data
    if forecaster._y_observed is not None:
        if isinstance(forecaster._y_observed, dict):
            # Panel data
            first_group = next(iter(forecaster._y_observed.keys()))
            first_group_y = forecaster._y_observed[first_group]
            # _y_observed[group] can be None when observation_horizon == 0
            if first_group_y is not None:
                original_y_observed_last_time = first_group_y["time"][-1]
                assert original_observed_time[first_group] == original_y_observed_last_time, (
                    "observed_time_ should match last time in _y_observed before observe()"
                )
        else:
            # Non-panel data
            original_y_observed_last_time = forecaster._y_observed["time"][-1]
            assert original_observed_time == original_y_observed_last_time, (
                "observed_time_ should match last time in _y_observed before observe()"
            )

    if forecaster._X_t_observed is not None:
        if isinstance(forecaster._X_t_observed, dict):
            # Panel data
            first_group = next(iter(forecaster._X_t_observed.keys()))
            if forecaster._X_t_observed[first_group] is not None:
                original_X_t_observed_last_time = forecaster._X_t_observed[first_group]["time"][-1]
                assert original_observed_time[first_group] == original_X_t_observed_last_time, (
                    "observed_time_ should match last time in _X_t_observed before observe()"
                )
        else:
            # Non-panel data
            original_X_t_observed_last_time = forecaster._X_t_observed["time"][-1]
            assert original_observed_time == original_X_t_observed_last_time, (
                "observed_time_ should match last time in _X_t_observed before observe()"
            )

    # Reset to new data
    forecaster.rewind(y_reset, X_actual_reset, X_future=X_future, X_forecast=X_forecast)

    # Check buffers were replaced
    reset_observed_time = forecaster.observed_time_

    # Handle both panel and non-panel data
    if isinstance(reset_observed_time, dict):
        # Panel data: check each group's observed_time matches
        for group_name in reset_observed_time:
            # Get expected time from y_reset (last row for this group's column)
            assert reset_observed_time[group_name] == y_reset["time"][-1], (
                f"observed_time_['{group_name}'] should be reset to last time in reset data"
            )
    else:
        # Non-panel data
        assert reset_observed_time == y_reset["time"][-1], "observed_time_ should be reset to last time in reset data"

    if forecaster._y_observed is not None:
        if isinstance(forecaster._y_observed, dict):
            # Panel data
            for group_name, y_obs in forecaster._y_observed.items():
                # _y_observed[group] can be None when observation_horizon == 0
                if y_obs is not None:
                    reset_y_observed_last_time = y_obs["time"][-1]
                    assert reset_y_observed_last_time == reset_observed_time[group_name], (
                        f"Last time in _y_observed['{group_name}'] should match reset observed_time_"
                    )
        else:
            # Non-panel data
            reset_y_observed_last_time = forecaster._y_observed["time"][-1]
            assert reset_y_observed_last_time == reset_observed_time, (
                "Last time in _y_observed should match reset observed_time_ after rewind()"
            )

    if forecaster._X_t_observed is not None:
        if isinstance(forecaster._X_t_observed, dict):
            # Panel data
            for group_name, X_t_obs in forecaster._X_t_observed.items():
                if X_t_obs is not None:
                    reset_X_t_observed_last_time = X_t_obs["time"][-1]
                    assert reset_X_t_observed_last_time == reset_observed_time[group_name], (
                        f"Last time in _X_t_observed['{group_name}'] should match reset observed_time_"
                    )
        else:
            # Non-panel data
            reset_X_t_observed_last_time = forecaster._X_t_observed["time"][-1]
            assert reset_X_t_observed_last_time == reset_observed_time, (
                "Last time in _X_t_observed should match reset observed_time_ after rewind()"
            )