Skip to content

check_search_cv_results_structure

yohou.testing.search.check_search_cv_results_structure(search_cv, y, X_actual=None, forecasting_horizon=3, X_future=None, X_forecast=None)

Check cv_results_ has required structure.

Validates that cv_results_ contains params, mean_test_score, rank_test_score, and split{n}_test_score keys with correct lengths.

Parameters

Name Type Description Default
search_cv BaseSearchCV

Fitted search CV instance

required
y DataFrame

Training target data

required
X_actual DataFrame

Training features

None
forecasting_horizon int

Number of steps ahead to forecast

3

Raises

Type Description
AssertionError

If cv_results_ structure is invalid

Source Code

Show/Hide source
def check_search_cv_results_structure(
    search_cv,
    y: pl.DataFrame,
    X_actual: pl.DataFrame | None = None,
    forecasting_horizon: int = 3,
    X_future: pl.DataFrame | None = None,
    X_forecast: pl.DataFrame | None = None,
) -> None:
    """Check cv_results_ has required structure.

    Validates that cv_results_ contains params, mean_test_score, rank_test_score,
    and split{n}_test_score keys with correct lengths.

    Parameters
    ----------
    search_cv : BaseSearchCV
        Fitted search CV instance
    y : pl.DataFrame
        Training target data
    X_actual : pl.DataFrame, optional
        Training features
    forecasting_horizon : int, default=3
        Number of steps ahead to forecast

    Raises
    ------
    AssertionError
        If cv_results_ structure is invalid

    """
    search_cv_clone = clone(search_cv)
    search_cv_clone.fit(y, X_actual, forecasting_horizon=forecasting_horizon, X_future=X_future, X_forecast=X_forecast)

    cv_results = search_cv_clone.cv_results_

    # Check required keys
    assert "params" in cv_results, "cv_results_ must have 'params' key"
    assert isinstance(cv_results["params"], list), (
        f"cv_results_['params'] should be list, got {type(cv_results['params'])}"
    )

    n_candidates = len(cv_results["params"])
    assert n_candidates > 0, "cv_results_['params'] should not be empty"

    # Check score keys (single metric)
    if not search_cv_clone.multimetric_:
        assert "mean_test_score" in cv_results, "cv_results_ must have 'mean_test_score' key"
        assert "rank_test_score" in cv_results, "cv_results_ must have 'rank_test_score' key"

        assert len(cv_results["mean_test_score"]) == n_candidates, (
            f"mean_test_score length {len(cv_results['mean_test_score'])} should match params length {n_candidates}"
        )
        assert len(cv_results["rank_test_score"]) == n_candidates, (
            f"rank_test_score length {len(cv_results['rank_test_score'])} should match params length {n_candidates}"
        )

    # Check split score keys
    n_splits = search_cv_clone.n_splits_
    if not search_cv_clone.multimetric_:
        # Single metric: split keys are split{N}_test_score
        for split_idx in range(n_splits):
            split_key = f"split{split_idx}_test_score"
            assert split_key in cv_results, f"cv_results_ must have '{split_key}' key"
            assert len(cv_results[split_key]) == n_candidates, (
                f"{split_key} length {len(cv_results[split_key])} should match params length {n_candidates}"
            )
    else:
        # Multimetric: split keys are split{N}_test_{scorer_name}
        scorer_names = list(search_cv_clone.scorer_.keys()) if hasattr(search_cv_clone.scorer_, "keys") else []
        for split_idx in range(n_splits):
            for scorer_name in scorer_names:
                split_key = f"split{split_idx}_test_{scorer_name}"
                assert split_key in cv_results, f"cv_results_ must have '{split_key}' key"
                assert len(cv_results[split_key]) == n_candidates, (
                    f"{split_key} length {len(cv_results[split_key])} should match params length {n_candidates}"
                )