Skip to content

check_panel_single_group

yohou.testing.panel.check_panel_single_group(forecaster, y_panel, X_panel=None)

Check cross-learning filters to specified panel group.

Validates that when panel_group is specified, predictions are generated only for that panel group (all columns with that prefix).

Parameters

Name Type Description Default
forecaster BaseForecaster

Fitted forecaster with panel data

required
y_panel DataFrame

Panel data with panel columns for testing

required
X_panel DataFrame

Panel features

None

Raises

Type Description
AssertionError

If filtered prediction doesn't match specified group

Source Code

Show/Hide source
def check_panel_single_group(forecaster, y_panel: pl.DataFrame, X_panel: pl.DataFrame | None = None) -> None:
    """Check cross-learning filters to specified panel group.

    Validates that when panel_group is specified, predictions are
    generated only for that panel group (all columns with that prefix).

    Parameters
    ----------
    forecaster : BaseForecaster
        Fitted forecaster with panel data
    y_panel : pl.DataFrame
        Panel data with panel columns for testing
    X_panel : pl.DataFrame, optional
        Panel features

    Raises
    ------
    AssertionError
        If filtered prediction doesn't match specified group

    """
    _, y_panel_groups = inspect_panel(y_panel)

    if len(y_panel_groups) > 0:
        # Get first group prefix
        first_group = list(y_panel_groups.keys())[0]

        # Predict with specific group
        y_pred = _call_predict(forecaster, forecasting_horizon=3, panel_group=first_group)

        # Should have columns from the specified group (flat columns with __ separator)
        group_cols = y_panel_groups[first_group]
        assert len(group_cols) > 0, f"Group '{first_group}' should have columns"
        for col in group_cols:
            assert _column_present(col, y_pred.columns), (
                f"Column '{col}' (or interval bounds) from group '{first_group}' "
                f"should be in predictions. Got: {y_pred.columns}"
            )