Skip to content

check_search_method_availability

yohou.testing.search.check_search_method_availability(search_cv, y, X_actual=None, forecasting_horizon=3, X_future=None, X_forecast=None)

Check @available_if decorator logic with refit=True/False.

Parameters

Name Type Description Default
search_cv BaseSearchCV

Unfitted search CV instance

required
y DataFrame

Training target data

required
X_actual DataFrame

Training features

None
forecasting_horizon int

Number of steps ahead to forecast

3

Raises

Type Description
AssertionError

If method availability doesn't match refit setting

Source Code

Show/Hide source
def check_search_method_availability(
    search_cv,
    y: pl.DataFrame,
    X_actual: pl.DataFrame | None = None,
    forecasting_horizon: int = 3,
    X_future: pl.DataFrame | None = None,
    X_forecast: pl.DataFrame | None = None,
) -> None:
    """Check @available_if decorator logic with refit=True/False.

    Parameters
    ----------
    search_cv : BaseSearchCV
        Unfitted search CV instance
    y : pl.DataFrame
        Training target data
    X_actual : pl.DataFrame, optional
        Training features
    forecasting_horizon : int, default=3
        Number of steps ahead to forecast

    Raises
    ------
    AssertionError
        If method availability doesn't match refit setting

    """
    # Test with refit=True
    search_cv_refit = clone(search_cv)
    # For multimetric, preserve the original refit value (scorer name or callable)
    # For single metric, set refit=True
    if not (isinstance(search_cv.scoring, dict) and isinstance(search_cv.refit, str)):
        search_cv_refit.refit = True
    search_cv_refit.fit(y, X_actual, forecasting_horizon=forecasting_horizon, X_future=X_future, X_forecast=X_forecast)

    # Methods should be available
    assert hasattr(search_cv_refit, "predict"), "predict() should be available when refit=True"
    assert callable(search_cv_refit.predict), "predict should be callable when refit=True"

    # Test with refit=False
    search_cv_no_refit = clone(search_cv)
    search_cv_no_refit.refit = False
    search_cv_no_refit.fit(
        y, X_actual, forecasting_horizon=forecasting_horizon, X_future=X_future, X_forecast=X_forecast
    )

    # Methods should raise AttributeError
    try:
        search_cv_no_refit.predict(forecasting_horizon=1)
        raise AssertionError("predict() should raise AttributeError when refit=False")
    except AttributeError:
        # Expected behavior
        pass