Skip to content

check_search_panel_data

yohou.testing.search.check_search_panel_data(search_cv, y_train, y_test, X_actual_train=None, X_actual_test=None, groups=None, X_future=None, X_forecast=None)

Check groups parameter propagates correctly.

Parameters

Name Type Description Default
search_cv BaseSearchCV

Fitted search CV instance with panel data

required
y_train DataFrame

Training target data with panel groups

required
y_test DataFrame

Test target data with panel groups

required
X_actual_train DataFrame

Training features with panel groups

None
X_actual_test DataFrame

Test features with panel groups

None
groups list of str

Panel group names to test

None

Raises

Type Description
AssertionError

If groups doesn't propagate correctly

Source Code

Show/Hide source
def check_search_panel_data(
    search_cv,
    y_train: pl.DataFrame,
    y_test: pl.DataFrame,
    X_actual_train: pl.DataFrame | None = None,
    X_actual_test: pl.DataFrame | None = None,
    groups: list[str] | None = None,
    X_future: pl.DataFrame | None = None,
    X_forecast: pl.DataFrame | None = None,
) -> None:
    """Check groups parameter propagates correctly.

    Parameters
    ----------
    search_cv : BaseSearchCV
        Fitted search CV instance with panel data
    y_train : pl.DataFrame
        Training target data with panel groups
    y_test : pl.DataFrame
        Test target data with panel groups
    X_actual_train : pl.DataFrame, optional
        Training features with panel groups
    X_actual_test : pl.DataFrame, optional
        Test features with panel groups
    groups : list of str, optional
        Panel group names to test

    Raises
    ------
    AssertionError
        If groups doesn't propagate correctly

    """
    forecasting_horizon = min(3, len(y_test))

    # Test predict with groups
    y_pred = search_cv.predict(
        forecasting_horizon=forecasting_horizon, groups=groups, X_future=X_future, X_forecast=X_forecast
    )

    # Check that predictions have panel structure if expected
    if groups is not None:
        # Predictions should only include specified groups

        _, panel_groups = inspect_panel(y_pred)

        # Check that all requested groups are present
        pred_group_prefixes = set(panel_groups.keys())
        for group_name in groups:
            # Group name might be a prefix
            assert any(group_name in prefix for prefix in pred_group_prefixes), (
                f"Requested panel group '{group_name}' not found in predictions"
            )