Skip to content

check_search_observe_delegates

yohou.testing.search.check_search_observe_delegates(search_cv, y_train, y_update, X_actual_train=None, X_actual_update=None, X_future=None, X_forecast=None)

Check observe() delegates to best_forecaster_.observe() correctly.

Parameters

Name Type Description Default
search_cv BaseSearchCV

Fitted search CV instance

required
y_train DataFrame

Training target data

required
y_update DataFrame

Update target data

required
X_actual_train DataFrame

Training features

None
X_actual_update DataFrame

Update features

None

Raises

Type Description
AssertionError

If observe() doesn't delegate correctly

Source Code

Show/Hide source
def check_search_observe_delegates(
    search_cv,
    y_train: pl.DataFrame,
    y_update: pl.DataFrame,
    X_actual_train: pl.DataFrame | None = None,
    X_actual_update: pl.DataFrame | None = None,
    X_future: pl.DataFrame | None = None,
    X_forecast: pl.DataFrame | None = None,
) -> None:
    """Check observe() delegates to best_forecaster_.observe() correctly.

    Parameters
    ----------
    search_cv : BaseSearchCV
        Fitted search CV instance
    y_train : pl.DataFrame
        Training target data
    y_update : pl.DataFrame
        Update target data
    X_actual_train : pl.DataFrame, optional
        Training features
    X_actual_update : pl.DataFrame, optional
        Update features

    Raises
    ------
    AssertionError
        If observe() doesn't delegate correctly

    """
    # Get initial observed_time
    initial_observed_time = search_cv.best_forecaster_.observed_time_

    # Observe via search CV
    search_cv.observe(y_update, X_actual_update, X_future=X_future, X_forecast=X_forecast)

    # Check that best_forecaster_ was observed
    updated_observed_time = search_cv.best_forecaster_.observed_time_

    # observed_time should have changed
    if isinstance(initial_observed_time, dict):
        # Panel data case
        for group_name in initial_observed_time:
            assert updated_observed_time[group_name] > initial_observed_time[group_name], (
                f"observed_time for group {group_name} should increase after observe, "
                f"got {initial_observed_time[group_name]} -> {updated_observed_time[group_name]}"
            )
    else:
        # Non-panel case
        assert updated_observed_time > initial_observed_time, (
            f"observed_time should increase after observe, got {initial_observed_time} -> {updated_observed_time}"
        )