Skip to content

check_grid_search_exhaustive

yohou.testing.search.check_grid_search_exhaustive(search_cv, y, X_actual=None, forecasting_horizon=3, X_future=None, X_forecast=None)

Check GridSearchCV evaluates all parameter combinations.

Parameters

Name Type Description Default
search_cv GridSearchCV

Unfitted GridSearchCV instance

required
y DataFrame

Training target data

required
X_actual DataFrame

Training features

None
forecasting_horizon int

Number of steps ahead to forecast

3

Raises

Type Description
AssertionError

If not all parameter combinations are evaluated

Source Code

Show/Hide source
def check_grid_search_exhaustive(
    search_cv,
    y: pl.DataFrame,
    X_actual: pl.DataFrame | None = None,
    forecasting_horizon: int = 3,
    X_future: pl.DataFrame | None = None,
    X_forecast: pl.DataFrame | None = None,
) -> None:
    """Check GridSearchCV evaluates all parameter combinations.

    Parameters
    ----------
    search_cv : GridSearchCV
        Unfitted GridSearchCV instance
    y : pl.DataFrame
        Training target data
    X_actual : pl.DataFrame, optional
        Training features
    forecasting_horizon : int, default=3
        Number of steps ahead to forecast

    Raises
    ------
    AssertionError
        If not all parameter combinations are evaluated

    """
    if not isinstance(search_cv, GridSearchCV):
        raise ValueError("This check requires GridSearchCV instance")

    search_cv_clone = clone(search_cv)
    search_cv_clone.fit(y, X_actual, forecasting_horizon=forecasting_horizon, X_future=X_future, X_forecast=X_forecast)

    # Count expected combinations
    param_grid = search_cv_clone.param_grid
    if isinstance(param_grid, dict):
        # Single grid
        expected_combinations = 1
        for param_values in param_grid.values():
            expected_combinations *= len(param_values)
    elif isinstance(param_grid, list):
        # Multiple grids
        expected_combinations = 0
        for grid in param_grid:
            grid_combinations = 1
            for param_values in grid.values():
                grid_combinations *= len(param_values)
            expected_combinations += grid_combinations
    else:
        raise ValueError(f"Invalid param_grid type: {type(param_grid)}")

    # Check actual combinations
    actual_combinations = len(search_cv_clone.cv_results_["params"])
    assert actual_combinations == expected_combinations, (
        f"GridSearchCV should evaluate {expected_combinations} combinations, got {actual_combinations}"
    )