Skip to content

check_search_rewind_delegates

yohou.testing.search.check_search_rewind_delegates(search_cv, y_train, y_reset, X_actual_train=None, X_actual_reset=None, X_future=None, X_forecast=None)

Check rewind() delegates to best_forecaster_.rewind() correctly.

Parameters

Name Type Description Default
search_cv BaseSearchCV

Fitted search CV instance

required
y_train DataFrame

Training target data

required
y_reset DataFrame

Reset target data

required
X_actual_train DataFrame

Training features

None
X_actual_reset DataFrame

Reset features

None

Raises

Type Description
AssertionError

If rewind() doesn't delegate correctly

Source Code

Show/Hide source
def check_search_rewind_delegates(
    search_cv,
    y_train: pl.DataFrame,
    y_reset: pl.DataFrame,
    X_actual_train: pl.DataFrame | None = None,
    X_actual_reset: pl.DataFrame | None = None,
    X_future: pl.DataFrame | None = None,
    X_forecast: pl.DataFrame | None = None,
) -> None:
    """Check rewind() delegates to best_forecaster_.rewind() correctly.

    Parameters
    ----------
    search_cv : BaseSearchCV
        Fitted search CV instance
    y_train : pl.DataFrame
        Training target data
    y_reset : pl.DataFrame
        Reset target data
    X_actual_train : pl.DataFrame, optional
        Training features
    X_actual_reset : pl.DataFrame, optional
        Reset features

    Raises
    ------
    AssertionError
        If rewind() doesn't delegate correctly

    """
    # Get initial observed_time
    initial_observed_time = search_cv.best_forecaster_.observed_time_

    # Rewind via search CV
    search_cv.rewind(y_reset, X_actual_reset, X_future=X_future, X_forecast=X_forecast)

    # Check that best_forecaster_ was rewound
    reset_observed_time = search_cv.best_forecaster_.observed_time_

    # observed_time should have changed
    if isinstance(initial_observed_time, dict):
        # Panel data case
        for group_name in initial_observed_time:
            assert reset_observed_time[group_name] != initial_observed_time[group_name], (
                f"observed_time for group {group_name} should change after rewind"
            )
    else:
        # Non-panel case
        assert reset_observed_time != initial_observed_time, "observed_time should change after rewind"