def validate_scorer_data(
scorer: BaseScorer,
y_true: pl.DataFrame | None = None,
y_pred: pl.DataFrame | None = None,
*,
scores: pl.DataFrame | None = None,
reset: bool = False,
inverse: bool = False,
) -> tuple[pl.DataFrame, pl.DataFrame | None, ScoringContext | None]:
"""Validate and prepare scorer input data.
Parameters
----------
scorer : BaseScorer
The scorer instance calling this function.
y_true : pl.DataFrame, default=None
True values with "time" column.
- In fit context (reset=True): This is y_train. Always required.
- In inverse context: Can be None (use scores parameter instead).
- In normal score context: Always required.
y_pred : pl.DataFrame, default=None
Predicted values with "time" column. Required in normal score context.
scores : pl.DataFrame, default=None
Conformity scores with "time" column. Required when inverse=True.
reset : bool, default=False
If True, validate in fit context (skips prediction structure checks).
Implies align_by_time=False and drop_time_columns=True.
inverse : bool, default=False
If True, validate in inverse_score context. Requires scores parameter.
Returns
-------
tuple[pl.DataFrame, pl.DataFrame | None, ScoringContext | None]
Validated and prepared DataFrames and scoring context:
- Normal context: (y_true, y_pred, ScoringContext)
- Inverse context: (y_pred, scores, ScoringContext)
- Fit context (reset=True): (y_train, None, None)
Notes
-----
- When drop_time_columns=False, time column is ALWAYS first in output
- Performs basic validation: None checks, time column existence, panel consistency
- Time alignment preserves common time points only (inner join)
- Time values are extracted before validation for point/interval scorers
See Also
--------
- [`BaseScorer`][yohou.metrics.base.BaseScorer] : Base class for all scorers.
- [`validate_time_weight`][yohou.utils.validate_data.validate_time_weight] : Validate time weighting parameters.
"""
# Runtime validation: enforce parameter requirements for each context
if not inverse and not reset and y_pred is None:
# Normal score context: y_pred required
raise ValueError("`y_pred` cannot be None for scoring. Set reset=True for fit/calibration.")
if inverse:
# Type narrowing: inverse scoring requires y_pred and scores
if y_pred is None:
raise ValueError("`y_pred` is required for inverse scoring. Cannot be None.")
if scores is None:
raise ValueError("`scores` is required for inverse scoring. Cannot be None.")
# Validate time columns (required)
check_time_column(y_pred)
check_time_column(scores)
# Check column schema compatibility (exclude time/vintage_time columns)
exclude_cols_pred = ["time"]
exclude_cols_scores = ["time"]
if "vintage_time" in y_pred.columns:
exclude_cols_pred.append("vintage_time")
y_pred_cols = set(y_pred.select(~cs.by_name(*exclude_cols_pred)).columns)
score_cols = set(scores.select(~cs.by_name(*exclude_cols_scores)).columns)
if y_pred_cols != score_cols:
raise ValueError(
f"Column mismatch between y_pred and conformity_scores. "
f"y_pred has {sorted(y_pred_cols)}, conformity_scores has {sorted(score_cols)}."
)
# Extract vintage_time before dropping (used by multi-vintage aggregation)
vintage_time = None
forecasting_step = None
if "vintage_time" in y_pred.columns:
vintage_time = y_pred["vintage_time"]
# Compute forecasting step if scorer has interval_ attribute
if hasattr(scorer, "interval_") and scorer.interval_ is not None:
forecasting_step = _compute_forecasting_step(y_pred["time"], y_pred["vintage_time"], scorer.interval_)
y_pred = y_pred.drop("vintage_time")
# Extract time values before dropping (always present after validation)
time_values = y_pred["time"].to_list()
# Drop time columns for consistency with normal path
y_pred = y_pred.drop("time")
scores = scores.drop("time")
from yohou.metrics._context import ScoringContext as _ScoringContext # noqa: PLC0415
return (
y_pred,
scores,
_ScoringContext(
time_values=time_values,
vintage_time=vintage_time,
forecasting_step=forecasting_step,
),
)
if reset:
# At fit time, y_true is y_train (always required), y_pred is always None
if y_true is None:
raise ValueError("`y_train` is required for scorer.fit(). Cannot be None.")
check_time_column(y_true)
# Validate seasonality for scorers with seasonality parameter
# y_true still has time column, so subtract 1 for data rows
if hasattr(scorer, "seasonality"):
seasonality_val = scorer.seasonality
if isinstance(seasonality_val, int) and len(y_true) <= seasonality_val:
raise ValueError(
f"Training data length ({len(y_true) - 1}) must be greater than "
f"seasonality ({seasonality_val}). Cannot compute seasonal naive forecast errors."
)
# At fit time: drop time from y_train
y_true = y_true.drop("time")
return y_true, None, None
# At score time, y_true is always required
if y_true is None:
raise ValueError("`y_true` cannot be None for scorer.")
if y_pred is None:
raise ValueError("`y_pred` cannot be None for scorer.")
# Validate time columns
check_time_column(y_true)
# Multi-vintage predictions (from observe_predict with stride) have
# repeating time values across vintages, so the global time column is
# not monotonically sorted. Validate per-vintage sorting instead.
if "vintage_time" in y_pred.columns:
_check_multi_vintage_time(y_pred)
else:
check_time_column(y_pred)
tags = scorer.__sklearn_tags__()
scorer_tags = getattr(tags, "scorer_tags", None)
pred_type = getattr(scorer_tags, "prediction_type", None) if scorer_tags is not None else None
if pred_type is None:
raise ValueError("Scorer tags must have prediction_type attribute")
# Panel consistency check
_, y_groups = inspect_panel(y_true)
_, X_groups = inspect_panel(y_pred)
if set(y_groups.keys()) != set(X_groups.keys()):
raise ValueError(
f"Panel groups mismatch. `y_true` has {sorted(y_groups.keys())}. `y_pred` has {sorted(X_groups.keys())}."
)
# Validate column presence and types
for col in y_true.columns:
if col == "time":
continue
if pred_type == "point":
if col not in y_pred.columns:
raise ValueError(f"'{col}' is present in `y_true` but missing in `y_pred`.")
# Relaxed check: do not enforce exact dtype match (e.g. Int64 vs Float64 is fine)
# But ensure both are numeric to avoid invalid operations
if not (y_true.schema[col].is_numeric() and y_pred.schema[col].is_numeric()):
raise ValueError(
f"Column '{col}' type mismatch. `y_true`: {y_true.schema[col]}, "
f"`y_pred`: {y_pred.schema[col]}. Both must be numeric."
)
elif pred_type == "interval":
related_cols = [c for c in y_pred.columns if c.startswith(f"{col}_lower_") or c.startswith(f"{col}_upper_")]
lower_found = any(c.startswith(f"{col}_lower_") for c in related_cols)
upper_found = any(c.startswith(f"{col}_upper_") for c in related_cols)
if not lower_found or not upper_found:
raise ValueError(f"Interval columns for `y_true` '{col}' missing in `y_pred`.")
for rc in related_cols:
if not (y_true.schema[col].is_numeric() and y_pred.schema[rc].is_numeric()):
raise ValueError(
f"Column '{rc}' type mismatch. `y_true` '{col}': {y_true.schema[col]}, "
f"`y_pred`: {y_pred.schema[rc]}. Both must be numeric."
)
elif pred_type == "class_proba":
proba_cols = [c for c in y_pred.columns if c.startswith(f"{col}_proba_")]
if not proba_cols:
raise ValueError(
f"No probability columns found for target '{col}' in `y_pred`. "
f"Expected columns matching '{col}_proba_<class_label>'."
)
for pc in proba_cols:
if not y_pred.schema[pc].is_numeric():
raise ValueError(f"Probability column '{pc}' must be numeric, got {y_pred.schema[pc]}.")
# Align by time.
# Use semi-joins so that duplicate times in y_pred (multi-vintage data
# from observe_predict) are preserved rather than cross-producted.
unique_truth_times = y_true.select("time").unique()
unique_pred_times = y_pred.select("time").unique()
y_pred = y_pred.join(unique_truth_times, on="time", how="semi")
y_true_filtered = y_true.join(unique_pred_times, on="time", how="semi")
# Replicate y_true rows to match y_pred (1:1 for single-vintage,
# 1:N when y_pred has multiple vintages per time point).
y_true = y_pred.select("time").join(y_true_filtered, on="time", how="left")
# Subselect columns based on scorer configuration
coverage_rates: list[float] | None = getattr(scorer, "coverage_rates", None)
# Extract filter keys from polymorphic param (list or dict)
if isinstance(coverage_rates, dict):
coverage_rates = [float(k) for k in coverage_rates]
interval_pattern = re.compile(r"^(.+)_(lower|upper)_([\d.]+)$")
y_true, y_pred = check_scorer_column_selection(
scorer=scorer,
y_true=y_true,
y_pred=y_pred,
pred_type=pred_type,
coverage_rates=coverage_rates,
interval_pattern=interval_pattern,
)
# Truncate partial vintage: if the last observe_predict window was
# shorter than the regular stride, drop that vintage before scoring.
if "vintage_time" in y_pred.columns:
y_true, y_pred = _truncate_partial_vintage(y_true, y_pred)
# Extract time values before dropping (all scorers get time-less DataFrames)
time_values = y_true["time"].to_list() if "time" in y_true.columns else None
# Extract vintage_time and compute forecasting_step before dropping
vintage_time: pl.Series | None = None
forecasting_step: pl.Series | None = None
if "vintage_time" in y_pred.columns:
vintage_time = y_pred["vintage_time"]
# Compute forecasting_step if scorer has interval_ from fit()
interval_ = getattr(scorer, "interval_", None)
if interval_ is not None and time_values is not None:
forecasting_step = _compute_forecasting_step(y_pred["time"], vintage_time, interval_)
y_pred = y_pred.drop("vintage_time")
# Drop time columns for all scorers (conformity scorers can reconstruct from time_values)
y_true = y_true.drop("time")
if "time" in y_pred.columns:
y_pred = y_pred.drop("time")
# For point scorers, strip extra columns (e.g. interval _lower_/_upper_ columns
# that may be present in mixed multimetric scenarios) to prevent schema mismatches.
if pred_type == "point":
extra_cols = [c for c in y_pred.columns if c not in y_true.columns]
if extra_cols:
y_pred = y_pred.drop(extra_cols)
from yohou.metrics._context import ScoringContext as _ScoringContext # noqa: PLC0415
context = _ScoringContext(
time_values=time_values, # type: ignore
vintage_time=vintage_time,
forecasting_step=forecasting_step,
)
return y_true, y_pred, context