Skip to content

check_search_fit_sets_attributes

yohou.testing.search.check_search_fit_sets_attributes(search_cv, y, X_actual=None, forecasting_horizon=3, X_future=None, X_forecast=None)

Check fit() sets required search CV attributes.

Validates that fit() creates all required attributes including cv_results_, best_forecaster_, best_params_, best_score_, best_index_, scorer_, n_splits_, and multimetric_.

Parameters

Name Type Description Default
search_cv BaseSearchCV

Unfitted search CV instance

required
y DataFrame

Training target data with "time" column

required
X_actual DataFrame

Training features with "time" column

None
forecasting_horizon int

Number of steps ahead to forecast

3

Raises

Type Description
AssertionError

If required attributes are not set after fit()

Source Code

Show/Hide source
def check_search_fit_sets_attributes(
    search_cv,
    y: pl.DataFrame,
    X_actual: pl.DataFrame | None = None,
    forecasting_horizon: int = 3,
    X_future: pl.DataFrame | None = None,
    X_forecast: pl.DataFrame | None = None,
) -> None:
    """Check fit() sets required search CV attributes.

    Validates that fit() creates all required attributes including cv_results_,
    best_forecaster_, best_params_, best_score_, best_index_, scorer_, n_splits_,
    and multimetric_.

    Parameters
    ----------
    search_cv : BaseSearchCV
        Unfitted search CV instance
    y : pl.DataFrame
        Training target data with "time" column
    X_actual : pl.DataFrame, optional
        Training features with "time" column
    forecasting_horizon : int, default=3
        Number of steps ahead to forecast

    Raises
    ------
    AssertionError
        If required attributes are not set after fit()

    """
    search_cv_clone = clone(search_cv)
    search_cv_clone.fit(y, X_actual, forecasting_horizon=forecasting_horizon, X_future=X_future, X_forecast=X_forecast)

    # Check core fitted attributes
    assert hasattr(search_cv_clone, "cv_results_"), "fit() must set cv_results_ attribute"
    assert isinstance(search_cv_clone.cv_results_, dict), (
        f"cv_results_ should be dict, got {type(search_cv_clone.cv_results_)}"
    )

    assert hasattr(search_cv_clone, "best_params_"), "fit() must set best_params_ attribute"
    assert isinstance(search_cv_clone.best_params_, dict), (
        f"best_params_ should be dict, got {type(search_cv_clone.best_params_)}"
    )

    assert hasattr(search_cv_clone, "best_score_"), "fit() must set best_score_ attribute"
    assert isinstance(search_cv_clone.best_score_, int | float | np.number), (
        f"best_score_ should be numeric, got {type(search_cv_clone.best_score_)}"
    )

    assert hasattr(search_cv_clone, "best_index_"), "fit() must set best_index_ attribute"
    assert isinstance(search_cv_clone.best_index_, int | np.integer), (
        f"best_index_ should be int, got {type(search_cv_clone.best_index_)}"
    )

    assert hasattr(search_cv_clone, "scorer_"), "fit() must set scorer_ attribute"

    assert hasattr(search_cv_clone, "n_splits_"), "fit() must set n_splits_ attribute"
    assert isinstance(search_cv_clone.n_splits_, int | np.integer), (
        f"n_splits_ should be int, got {type(search_cv_clone.n_splits_)}"
    )

    assert hasattr(search_cv_clone, "multimetric_"), "fit() must set multimetric_ attribute"
    assert isinstance(search_cv_clone.multimetric_, bool), (
        f"multimetric_ should be bool, got {type(search_cv_clone.multimetric_)}"
    )

    # Check best_forecaster_ when refit=True
    if search_cv_clone.refit:
        assert hasattr(search_cv_clone, "best_forecaster_"), "fit() must set best_forecaster_ when refit=True"
        assert hasattr(search_cv_clone, "refit_time_"), "fit() must set refit_time_ when refit=True"

        # Depth assertion: step columns propagate through search to best_forecaster_
        if X_future is not None or X_forecast is not None:
            assert hasattr(search_cv_clone.best_forecaster_, "_step_column_names_"), (
                "best_forecaster_ must have _step_column_names_ when X_future/X_forecast provided"
            )
            assert len(search_cv_clone.best_forecaster_._step_column_names_) > 0, (
                "best_forecaster_._step_column_names_ should be non-empty when X_future/X_forecast provided"
            )