Skip to content

validate_scorer_data

yohou.utils.validate_data.validate_scorer_data(scorer, y_true=None, y_pred=None, *, scores=None, reset=False, inverse=False)

validate_scorer_data(
    scorer: BaseScorer,
    y_true: pl.DataFrame,
    y_pred: None = None,
    *,
    scores: None = None,
    reset: bool = True,
    inverse: bool = False,
) -> tuple[pl.DataFrame, None, None]
validate_scorer_data(
    scorer: BaseScorer,
    y_true: None = None,
    y_pred: pl.DataFrame = ...,
    *,
    scores: pl.DataFrame = ...,
    reset: bool = False,
    inverse: bool = True,
) -> tuple[pl.DataFrame, pl.DataFrame, ScoringContext]
validate_scorer_data(
    scorer: BaseScorer,
    y_true: pl.DataFrame = ...,
    y_pred: pl.DataFrame = ...,
    *,
    scores: None = None,
    reset: bool = False,
    inverse: bool = False,
) -> tuple[pl.DataFrame, pl.DataFrame, ScoringContext]

Validate and prepare scorer input data.

Parameters

Name Type Description Default
scorer BaseScorer

The scorer instance calling this function.

required
y_true DataFrame

True values with "time" column. - In fit context (reset=True): This is y_train. Always required. - In inverse context: Can be None (use scores parameter instead). - In normal score context: Always required.

None
y_pred DataFrame

Predicted values with "time" column. Required in normal score context.

None
scores DataFrame

Conformity scores with "time" column. Required when inverse=True.

None
reset bool

If True, validate in fit context (skips prediction structure checks). Implies align_by_time=False and drop_time_columns=True.

False
inverse bool

If True, validate in inverse_score context. Requires scores parameter.

False

Returns

Type Description
tuple[DataFrame, DataFrame | None, ScoringContext | None]

Validated and prepared DataFrames and scoring context: - Normal context: (y_true, y_pred, ScoringContext) - Inverse context: (y_pred, scores, ScoringContext) - Fit context (reset=True): (y_train, None, None)

Notes

  • When drop_time_columns=False, time column is ALWAYS first in output
  • Performs basic validation: None checks, time column existence, panel consistency
  • Time alignment preserves common time points only (inner join)
  • Time values are extracted before validation for point/interval scorers

See Also

Source Code

Show/Hide source
def validate_scorer_data(
    scorer: BaseScorer,
    y_true: pl.DataFrame | None = None,
    y_pred: pl.DataFrame | None = None,
    *,
    scores: pl.DataFrame | None = None,
    reset: bool = False,
    inverse: bool = False,
) -> tuple[pl.DataFrame, pl.DataFrame | None, ScoringContext | None]:
    """Validate and prepare scorer input data.

    Parameters
    ----------
    scorer : BaseScorer
        The scorer instance calling this function.
    y_true : pl.DataFrame, default=None
        True values with "time" column.
        - In fit context (reset=True): This is y_train. Always required.
        - In inverse context: Can be None (use scores parameter instead).
        - In normal score context: Always required.
    y_pred : pl.DataFrame, default=None
        Predicted values with "time" column. Required in normal score context.
    scores : pl.DataFrame, default=None
        Conformity scores with "time" column. Required when inverse=True.
    reset : bool, default=False
        If True, validate in fit context (skips prediction structure checks).
        Implies align_by_time=False and drop_time_columns=True.
    inverse : bool, default=False
        If True, validate in inverse_score context. Requires scores parameter.

    Returns
    -------
    tuple[pl.DataFrame, pl.DataFrame | None, ScoringContext | None]
        Validated and prepared DataFrames and scoring context:
        - Normal context: (y_true, y_pred, ScoringContext)
        - Inverse context: (y_pred, scores, ScoringContext)
        - Fit context (reset=True): (y_train, None, None)

    Notes
    -----
    - When drop_time_columns=False, time column is ALWAYS first in output
    - Performs basic validation: None checks, time column existence, panel consistency
    - Time alignment preserves common time points only (inner join)
    - Time values are extracted before validation for point/interval scorers

    See Also
    --------
    - [`BaseScorer`][yohou.metrics.base.BaseScorer] : Base class for all scorers.
    - [`validate_time_weight`][yohou.utils.validate_data.validate_time_weight] : Validate time weighting parameters.

    """
    # Runtime validation: enforce parameter requirements for each context
    if not inverse and not reset and y_pred is None:
        # Normal score context: y_pred required
        raise ValueError("`y_pred` cannot be None for scoring. Set reset=True for fit/calibration.")

    if inverse:
        # Type narrowing: inverse scoring requires y_pred and scores
        if y_pred is None:
            raise ValueError("`y_pred` is required for inverse scoring. Cannot be None.")
        if scores is None:
            raise ValueError("`scores` is required for inverse scoring. Cannot be None.")

        # Validate time columns (required)
        check_time_column(y_pred)
        check_time_column(scores)

        # Check column schema compatibility (exclude time/vintage_time columns)
        exclude_cols_pred = ["time"]
        exclude_cols_scores = ["time"]

        if "vintage_time" in y_pred.columns:
            exclude_cols_pred.append("vintage_time")

        y_pred_cols = set(y_pred.select(~cs.by_name(*exclude_cols_pred)).columns)
        score_cols = set(scores.select(~cs.by_name(*exclude_cols_scores)).columns)

        if y_pred_cols != score_cols:
            raise ValueError(
                f"Column mismatch between y_pred and conformity_scores. "
                f"y_pred has {sorted(y_pred_cols)}, conformity_scores has {sorted(score_cols)}."
            )

        # Extract vintage_time before dropping (used by multi-vintage aggregation)
        vintage_time = None
        forecasting_step = None
        if "vintage_time" in y_pred.columns:
            vintage_time = y_pred["vintage_time"]
            # Compute forecasting step if scorer has interval_ attribute
            if hasattr(scorer, "interval_") and scorer.interval_ is not None:
                forecasting_step = _compute_forecasting_step(y_pred["time"], y_pred["vintage_time"], scorer.interval_)
            y_pred = y_pred.drop("vintage_time")

        # Extract time values before dropping (always present after validation)
        time_values = y_pred["time"].to_list()

        # Drop time columns for consistency with normal path
        y_pred = y_pred.drop("time")
        scores = scores.drop("time")

        from yohou.metrics._context import ScoringContext as _ScoringContext  # noqa: PLC0415

        return (
            y_pred,
            scores,
            _ScoringContext(
                time_values=time_values,
                vintage_time=vintage_time,
                forecasting_step=forecasting_step,
            ),
        )

    if reset:
        # At fit time, y_true is y_train (always required), y_pred is always None
        if y_true is None:
            raise ValueError("`y_train` is required for scorer.fit(). Cannot be None.")

        check_time_column(y_true)

        # Validate seasonality for scorers with seasonality parameter
        # y_true still has time column, so subtract 1 for data rows
        if hasattr(scorer, "seasonality"):
            seasonality_val = scorer.seasonality
            if isinstance(seasonality_val, int) and len(y_true) <= seasonality_val:
                raise ValueError(
                    f"Training data length ({len(y_true) - 1}) must be greater than "
                    f"seasonality ({seasonality_val}). Cannot compute seasonal naive forecast errors."
                )

        # At fit time: drop time from y_train
        y_true = y_true.drop("time")

        return y_true, None, None

    # At score time, y_true is always required
    if y_true is None:
        raise ValueError("`y_true` cannot be None for scorer.")

    if y_pred is None:
        raise ValueError("`y_pred` cannot be None for scorer.")

    # Validate time columns
    check_time_column(y_true)
    # Multi-vintage predictions (from observe_predict with stride) have
    # repeating time values across vintages, so the global time column is
    # not monotonically sorted.  Validate per-vintage sorting instead.
    if "vintage_time" in y_pred.columns:
        _check_multi_vintage_time(y_pred)
    else:
        check_time_column(y_pred)

    tags = scorer.__sklearn_tags__()
    scorer_tags = getattr(tags, "scorer_tags", None)
    pred_type = getattr(scorer_tags, "prediction_type", None) if scorer_tags is not None else None

    if pred_type is None:
        raise ValueError("Scorer tags must have prediction_type attribute")

    # Panel consistency check
    _, y_groups = inspect_panel(y_true)
    _, X_groups = inspect_panel(y_pred)
    if set(y_groups.keys()) != set(X_groups.keys()):
        raise ValueError(
            f"Panel groups mismatch. `y_true` has {sorted(y_groups.keys())}. `y_pred` has {sorted(X_groups.keys())}."
        )

    # Validate column presence and types
    for col in y_true.columns:
        if col == "time":
            continue

        if pred_type == "point":
            if col not in y_pred.columns:
                raise ValueError(f"'{col}' is present in `y_true` but missing in `y_pred`.")
            # Relaxed check: do not enforce exact dtype match (e.g. Int64 vs Float64 is fine)
            # But ensure both are numeric to avoid invalid operations
            if not (y_true.schema[col].is_numeric() and y_pred.schema[col].is_numeric()):
                raise ValueError(
                    f"Column '{col}' type mismatch. `y_true`: {y_true.schema[col]}, "
                    f"`y_pred`: {y_pred.schema[col]}. Both must be numeric."
                )
        elif pred_type == "interval":
            related_cols = [c for c in y_pred.columns if c.startswith(f"{col}_lower_") or c.startswith(f"{col}_upper_")]
            lower_found = any(c.startswith(f"{col}_lower_") for c in related_cols)
            upper_found = any(c.startswith(f"{col}_upper_") for c in related_cols)
            if not lower_found or not upper_found:
                raise ValueError(f"Interval columns for `y_true` '{col}' missing in `y_pred`.")

            for rc in related_cols:
                if not (y_true.schema[col].is_numeric() and y_pred.schema[rc].is_numeric()):
                    raise ValueError(
                        f"Column '{rc}' type mismatch. `y_true` '{col}': {y_true.schema[col]}, "
                        f"`y_pred`: {y_pred.schema[rc]}. Both must be numeric."
                    )
        elif pred_type == "class_proba":
            proba_cols = [c for c in y_pred.columns if c.startswith(f"{col}_proba_")]
            if not proba_cols:
                raise ValueError(
                    f"No probability columns found for target '{col}' in `y_pred`. "
                    f"Expected columns matching '{col}_proba_<class_label>'."
                )
            for pc in proba_cols:
                if not y_pred.schema[pc].is_numeric():
                    raise ValueError(f"Probability column '{pc}' must be numeric, got {y_pred.schema[pc]}.")

    # Align by time.
    # Use semi-joins so that duplicate times in y_pred (multi-vintage data
    # from observe_predict) are preserved rather than cross-producted.
    unique_truth_times = y_true.select("time").unique()
    unique_pred_times = y_pred.select("time").unique()

    y_pred = y_pred.join(unique_truth_times, on="time", how="semi")
    y_true_filtered = y_true.join(unique_pred_times, on="time", how="semi")

    # Replicate y_true rows to match y_pred (1:1 for single-vintage,
    # 1:N when y_pred has multiple vintages per time point).
    y_true = y_pred.select("time").join(y_true_filtered, on="time", how="left")

    # Subselect columns based on scorer configuration
    coverage_rates: list[float] | None = getattr(scorer, "coverage_rates", None)
    # Extract filter keys from polymorphic param (list or dict)
    if isinstance(coverage_rates, dict):
        coverage_rates = [float(k) for k in coverage_rates]
    interval_pattern = re.compile(r"^(.+)_(lower|upper)_([\d.]+)$")

    y_true, y_pred = check_scorer_column_selection(
        scorer=scorer,
        y_true=y_true,
        y_pred=y_pred,
        pred_type=pred_type,
        coverage_rates=coverage_rates,
        interval_pattern=interval_pattern,
    )

    # Truncate partial vintage: if the last observe_predict window was
    # shorter than the regular stride, drop that vintage before scoring.
    if "vintage_time" in y_pred.columns:
        y_true, y_pred = _truncate_partial_vintage(y_true, y_pred)

    # Extract time values before dropping (all scorers get time-less DataFrames)
    time_values = y_true["time"].to_list() if "time" in y_true.columns else None

    # Extract vintage_time and compute forecasting_step before dropping
    vintage_time: pl.Series | None = None
    forecasting_step: pl.Series | None = None

    if "vintage_time" in y_pred.columns:
        vintage_time = y_pred["vintage_time"]

        # Compute forecasting_step if scorer has interval_ from fit()
        interval_ = getattr(scorer, "interval_", None)
        if interval_ is not None and time_values is not None:
            forecasting_step = _compute_forecasting_step(y_pred["time"], vintage_time, interval_)

        y_pred = y_pred.drop("vintage_time")

    # Drop time columns for all scorers (conformity scorers can reconstruct from time_values)
    y_true = y_true.drop("time")

    if "time" in y_pred.columns:
        y_pred = y_pred.drop("time")

    # For point scorers, strip extra columns (e.g. interval _lower_/_upper_ columns
    # that may be present in mixed multimetric scenarios) to prevent schema mismatches.
    if pred_type == "point":
        extra_cols = [c for c in y_pred.columns if c not in y_true.columns]
        if extra_cols:
            y_pred = y_pred.drop(extra_cols)

    from yohou.metrics._context import ScoringContext as _ScoringContext  # noqa: PLC0415

    context = _ScoringContext(
        time_values=time_values,  # type: ignore
        vintage_time=vintage_time,
        forecasting_step=forecasting_step,
    )

    return y_true, y_pred, context