def check_scorer_column_selection(
scorer: "BaseScorer",
y_true: pl.DataFrame,
y_pred: pl.DataFrame,
pred_type: str,
coverage_rates: list[float] | None = None,
interval_pattern: re.Pattern | None = None,
) -> tuple[pl.DataFrame, pl.DataFrame]:
"""Subselect columns based on scorer configuration.
Parameters
----------
scorer : BaseScorer
Scorer instance with groups and components attributes.
y_true : pl.DataFrame
True values DataFrame.
y_pred : pl.DataFrame
Predicted values DataFrame.
pred_type : str
Prediction type ('point', 'interval', 'conformity', 'class_proba').
coverage_rates : list[float], optional
Coverage rates for interval forecasts.
interval_pattern : re.Pattern, optional
Regex pattern for matching interval column names.
Returns
-------
tuple[pl.DataFrame, pl.DataFrame]
Filtered (y_true, y_pred) DataFrames.
Raises
------
ValueError
If groups or components are invalid.
See Also
--------
- [`inspect_panel`][yohou.utils.panel.inspect_panel] : Detect panel groups in a DataFrame.
- [`check_groups`][yohou.utils.validation.check_groups] : Validate panel group names for forecaster operations.
"""
has_panel_specs = hasattr(scorer, "groups") and scorer.groups is not None
has_component_specs = hasattr(scorer, "components") and scorer.components is not None
if not (has_panel_specs or has_component_specs or coverage_rates is not None):
return y_true, y_pred
# Validate coverage_rates if present (interval scorers)
if coverage_rates is not None and pred_type == "interval" and interval_pattern is not None:
available_rates = set()
for col in y_pred.columns:
if col in ("time", "vintage_time"):
continue
match = interval_pattern.match(col)
if match:
available_rates.add(float(match.group(3)))
missing_rates = set(coverage_rates) - available_rates
if missing_rates:
raise ValueError(
f"Requested coverage_rates {sorted(missing_rates)} not found in predictions. "
f"Available rates: {sorted(available_rates)}"
)
_, y_groups = inspect_panel(y_true)
# Validate panel groups if specified (must exist in data)
if has_panel_specs:
assert scorer.groups is not None
missing_groups = set(scorer.groups) - set(y_groups.keys())
if missing_groups:
raise ValueError(
f"Invalid groups: {sorted(missing_groups)} not found in data. "
f"Available groups: {sorted(y_groups.keys())}"
)
is_panel = len(y_groups) > 0
if is_panel:
# Panel data: filter by groups and/or components
selected_cols = []
# Determine which groups to include
groups_to_include = scorer.groups if has_panel_specs else list(y_groups.keys())
# Type narrowing for iteration
assert groups_to_include is not None
for group_name in groups_to_include:
if group_name in y_groups:
group_cols = y_groups[group_name]
# Filter by components if specified
if has_component_specs:
assert scorer.components is not None
# Extract unprefixed column names and check against components
filtered_cols = [col for col in group_cols if col.split("__", 1)[1] in scorer.components]
selected_cols.extend(filtered_cols)
else:
selected_cols.extend(group_cols)
# Filter DataFrames
if selected_cols:
# Always preserve time column during subselection
if "time" not in selected_cols:
selected_cols = ["time"] + selected_cols
# Check for interval columns in y_pred if prediction_type is interval
if pred_type == "interval":
# For interval, y_pred has _lower_ and _upper_ columns corresponding to y_true columns
y_pred_selected_cols = ["time"]
for col in selected_cols:
if col == "time":
continue
# Find matching columns in y_pred
matches = [
c for c in y_pred.columns if c.startswith(f"{col}_lower_") or c.startswith(f"{col}_upper_")
]
# Filter by coverage rates
if coverage_rates is not None and interval_pattern is not None:
rate_filtered = []
for m in matches:
match = interval_pattern.match(m)
if match and float(match.group(3)) in coverage_rates:
rate_filtered.append(m)
matches = rate_filtered
y_pred_selected_cols.extend(matches)
y_true = y_true.select(selected_cols)
y_pred = y_pred.select(y_pred_selected_cols)
elif pred_type == "class_proba":
# Class proba: y_pred has {target}_proba_{class} columns
y_pred_selected_cols = ["time"] if "time" in y_pred.columns else []
for col in selected_cols:
if col == "time":
continue
y_pred_selected_cols.extend(c for c in y_pred.columns if c.startswith(f"{col}_proba_"))
y_true = y_true.select(selected_cols)
y_pred = y_pred.select(y_pred_selected_cols)
else:
# Point forecast: columns should match directly
y_pred_cols = set(y_pred.columns)
valid_y_pred_cols = [c for c in selected_cols if c in y_pred_cols]
y_true = y_true.select(selected_cols)
y_pred = y_pred.select(valid_y_pred_cols)
# Global data: filter by components only
elif has_component_specs:
assert scorer.components is not None
# Validate component names exist in data
available_components = [col for col in y_true.columns if col != "time"]
missing_components = set(scorer.components) - set(available_components)
if missing_components:
raise ValueError(
f"Invalid components: {sorted(missing_components)} not found in data. "
f"Available components: {sorted(available_components)}"
)
selected_cols = [col for col in y_true.columns if col in scorer.components]
if selected_cols:
# Always preserve time column during subselection
if "time" not in selected_cols:
selected_cols = ["time"] + selected_cols
# For global data, logic is simpler
if pred_type == "interval":
y_pred_selected_cols = ["time"]
for col in selected_cols:
if col == "time":
continue
matches = [
c for c in y_pred.columns if c.startswith(f"{col}_lower_") or c.startswith(f"{col}_upper_")
]
# Filter by coverage rates
if coverage_rates is not None and interval_pattern is not None:
rate_filtered = []
for m in matches:
match = interval_pattern.match(m)
if match and float(match.group(3)) in coverage_rates:
rate_filtered.append(m)
matches = rate_filtered
y_pred_selected_cols.extend(matches)
y_true = y_true.select(selected_cols)
y_pred = y_pred.select(y_pred_selected_cols)
elif pred_type == "class_proba":
y_pred_selected_cols = ["time"] if "time" in y_pred.columns else []
for col in selected_cols:
if col == "time":
continue
y_pred_selected_cols.extend(c for c in y_pred.columns if c.startswith(f"{col}_proba_"))
y_true = y_true.select(selected_cols)
y_pred = y_pred.select(y_pred_selected_cols)
else:
y_pred_cols = set(y_pred.columns)
valid_y_pred_cols = [c for c in selected_cols if c in y_pred_cols]
y_true = y_true.select(selected_cols)
y_pred = y_pred.select(valid_y_pred_cols)
elif coverage_rates is not None and pred_type == "interval" and interval_pattern is not None:
# No component filter, but coverage rate filter
y_pred_selected_cols = ["time"]
# Filter y_pred columns to only those matching requested rates
for col in y_pred.columns:
if col in ("time", "vintage_time"):
continue
match = interval_pattern.match(col)
if match and float(match.group(3)) in coverage_rates:
y_pred_selected_cols.append(col)
y_pred = y_pred.select(y_pred_selected_cols)
return y_true, y_pred