Skip to content

check_scorer_column_selection

yohou.utils.validation.check_scorer_column_selection(scorer, y_true, y_pred, pred_type, coverage_rates=None, interval_pattern=None)

Subselect columns based on scorer configuration.

Parameters

Name Type Description Default
scorer BaseScorer

Scorer instance with groups and components attributes.

required
y_true DataFrame

True values DataFrame.

required
y_pred DataFrame

Predicted values DataFrame.

required
pred_type str

Prediction type ('point', 'interval', 'conformity', 'class_proba').

required
coverage_rates list[float]

Coverage rates for interval forecasts.

None
interval_pattern Pattern

Regex pattern for matching interval column names.

None

Returns

Type Description
tuple[DataFrame, DataFrame]

Filtered (y_true, y_pred) DataFrames.

Raises

Type Description
ValueError

If groups or components are invalid.

See Also

Source Code

Show/Hide source
def check_scorer_column_selection(
    scorer: "BaseScorer",
    y_true: pl.DataFrame,
    y_pred: pl.DataFrame,
    pred_type: str,
    coverage_rates: list[float] | None = None,
    interval_pattern: re.Pattern | None = None,
) -> tuple[pl.DataFrame, pl.DataFrame]:
    """Subselect columns based on scorer configuration.

    Parameters
    ----------
    scorer : BaseScorer
        Scorer instance with groups and components attributes.
    y_true : pl.DataFrame
        True values DataFrame.
    y_pred : pl.DataFrame
        Predicted values DataFrame.
    pred_type : str
        Prediction type ('point', 'interval', 'conformity', 'class_proba').
    coverage_rates : list[float], optional
        Coverage rates for interval forecasts.
    interval_pattern : re.Pattern, optional
        Regex pattern for matching interval column names.

    Returns
    -------
    tuple[pl.DataFrame, pl.DataFrame]
        Filtered (y_true, y_pred) DataFrames.

    Raises
    ------
    ValueError
        If groups or components are invalid.

    See Also
    --------
    - [`inspect_panel`][yohou.utils.panel.inspect_panel] : Detect panel groups in a DataFrame.
    - [`check_groups`][yohou.utils.validation.check_groups] : Validate panel group names for forecaster operations.

    """
    has_panel_specs = hasattr(scorer, "groups") and scorer.groups is not None
    has_component_specs = hasattr(scorer, "components") and scorer.components is not None

    if not (has_panel_specs or has_component_specs or coverage_rates is not None):
        return y_true, y_pred

    # Validate coverage_rates if present (interval scorers)
    if coverage_rates is not None and pred_type == "interval" and interval_pattern is not None:
        available_rates = set()
        for col in y_pred.columns:
            if col in ("time", "vintage_time"):
                continue
            match = interval_pattern.match(col)
            if match:
                available_rates.add(float(match.group(3)))

        missing_rates = set(coverage_rates) - available_rates
        if missing_rates:
            raise ValueError(
                f"Requested coverage_rates {sorted(missing_rates)} not found in predictions. "
                f"Available rates: {sorted(available_rates)}"
            )

    _, y_groups = inspect_panel(y_true)

    # Validate panel groups if specified (must exist in data)
    if has_panel_specs:
        assert scorer.groups is not None
        missing_groups = set(scorer.groups) - set(y_groups.keys())
        if missing_groups:
            raise ValueError(
                f"Invalid groups: {sorted(missing_groups)} not found in data. "
                f"Available groups: {sorted(y_groups.keys())}"
            )

    is_panel = len(y_groups) > 0

    if is_panel:
        # Panel data: filter by groups and/or components
        selected_cols = []

        # Determine which groups to include
        groups_to_include = scorer.groups if has_panel_specs else list(y_groups.keys())

        # Type narrowing for iteration
        assert groups_to_include is not None
        for group_name in groups_to_include:
            if group_name in y_groups:
                group_cols = y_groups[group_name]

                # Filter by components if specified
                if has_component_specs:
                    assert scorer.components is not None
                    # Extract unprefixed column names and check against components
                    filtered_cols = [col for col in group_cols if col.split("__", 1)[1] in scorer.components]
                    selected_cols.extend(filtered_cols)
                else:
                    selected_cols.extend(group_cols)

        # Filter DataFrames
        if selected_cols:
            # Always preserve time column during subselection
            if "time" not in selected_cols:
                selected_cols = ["time"] + selected_cols

            # Check for interval columns in y_pred if prediction_type is interval
            if pred_type == "interval":
                # For interval, y_pred has _lower_ and _upper_ columns corresponding to y_true columns
                y_pred_selected_cols = ["time"]

                for col in selected_cols:
                    if col == "time":
                        continue
                    # Find matching columns in y_pred
                    matches = [
                        c for c in y_pred.columns if c.startswith(f"{col}_lower_") or c.startswith(f"{col}_upper_")
                    ]

                    # Filter by coverage rates
                    if coverage_rates is not None and interval_pattern is not None:
                        rate_filtered = []
                        for m in matches:
                            match = interval_pattern.match(m)
                            if match and float(match.group(3)) in coverage_rates:
                                rate_filtered.append(m)
                        matches = rate_filtered

                    y_pred_selected_cols.extend(matches)

                y_true = y_true.select(selected_cols)
                y_pred = y_pred.select(y_pred_selected_cols)
            elif pred_type == "class_proba":
                # Class proba: y_pred has {target}_proba_{class} columns
                y_pred_selected_cols = ["time"] if "time" in y_pred.columns else []
                for col in selected_cols:
                    if col == "time":
                        continue
                    y_pred_selected_cols.extend(c for c in y_pred.columns if c.startswith(f"{col}_proba_"))
                y_true = y_true.select(selected_cols)
                y_pred = y_pred.select(y_pred_selected_cols)
            else:
                # Point forecast: columns should match directly
                y_pred_cols = set(y_pred.columns)
                valid_y_pred_cols = [c for c in selected_cols if c in y_pred_cols]

                y_true = y_true.select(selected_cols)
                y_pred = y_pred.select(valid_y_pred_cols)

    # Global data: filter by components only
    elif has_component_specs:
        assert scorer.components is not None
        # Validate component names exist in data
        available_components = [col for col in y_true.columns if col != "time"]
        missing_components = set(scorer.components) - set(available_components)
        if missing_components:
            raise ValueError(
                f"Invalid components: {sorted(missing_components)} not found in data. "
                f"Available components: {sorted(available_components)}"
            )

        selected_cols = [col for col in y_true.columns if col in scorer.components]
        if selected_cols:
            # Always preserve time column during subselection
            if "time" not in selected_cols:
                selected_cols = ["time"] + selected_cols

            # For global data, logic is simpler
            if pred_type == "interval":
                y_pred_selected_cols = ["time"]
                for col in selected_cols:
                    if col == "time":
                        continue
                    matches = [
                        c for c in y_pred.columns if c.startswith(f"{col}_lower_") or c.startswith(f"{col}_upper_")
                    ]

                    # Filter by coverage rates
                    if coverage_rates is not None and interval_pattern is not None:
                        rate_filtered = []
                        for m in matches:
                            match = interval_pattern.match(m)
                            if match and float(match.group(3)) in coverage_rates:
                                rate_filtered.append(m)
                        matches = rate_filtered

                    y_pred_selected_cols.extend(matches)
                y_true = y_true.select(selected_cols)
                y_pred = y_pred.select(y_pred_selected_cols)
            elif pred_type == "class_proba":
                y_pred_selected_cols = ["time"] if "time" in y_pred.columns else []
                for col in selected_cols:
                    if col == "time":
                        continue
                    y_pred_selected_cols.extend(c for c in y_pred.columns if c.startswith(f"{col}_proba_"))
                y_true = y_true.select(selected_cols)
                y_pred = y_pred.select(y_pred_selected_cols)
            else:
                y_pred_cols = set(y_pred.columns)
                valid_y_pred_cols = [c for c in selected_cols if c in y_pred_cols]

                y_true = y_true.select(selected_cols)
                y_pred = y_pred.select(valid_y_pred_cols)
    elif coverage_rates is not None and pred_type == "interval" and interval_pattern is not None:
        # No component filter, but coverage rate filter
        y_pred_selected_cols = ["time"]

        # Filter y_pred columns to only those matching requested rates
        for col in y_pred.columns:
            if col in ("time", "vintage_time"):
                continue
            match = interval_pattern.match(col)
            if match and float(match.group(3)) in coverage_rates:
                y_pred_selected_cols.append(col)

        y_pred = y_pred.select(y_pred_selected_cols)

    return y_true, y_pred