Skip to content

BaseHardLabelScorer

yohou.metrics.base.BaseHardLabelScorer

Bases: BaseClassProbaScorer

Base class for confusion-matrix classification metrics.

Extends :class:BaseClassProbaScorer for metrics that argmax predicted probabilities into hard labels and compute scores from the resulting confusion matrix (TP, FP, FN counts).

The score() method is inherited from :class:BaseClassProbaScorer. Subclasses override :meth:_compute_raw_errors (produces per-row confusion indicators) and :meth:_aggregate_scores (sums indicators, computes metrics, applies class averaging).

Parameters

Name Type Description Default
average str

Class averaging strategy: "macro" (unweighted mean across classes), "micro" (aggregate counts across classes first), or "weighted" (support-weighted mean).

"macro"
zero_division float

Value to return when a metric denominator is zero.

0.0
aggregation_method list of str or str

Which dimensions to aggregate.

"all"
groups list of str, dict of str to float, or None

Panel group filter or filter with weights.

None
components list of str, dict of str to float, or None

Component filter or filter with weights.

None

Source Code

Show/Hide source
class BaseHardLabelScorer(BaseClassProbaScorer, metaclass=abc.ABCMeta):
    """Base class for confusion-matrix classification metrics.

    Extends :class:`BaseClassProbaScorer` for metrics that argmax predicted
    probabilities into hard labels and compute scores from the resulting
    confusion matrix (TP, FP, FN counts).

    The ``score()`` method is inherited from :class:`BaseClassProbaScorer`.
    Subclasses override :meth:`_compute_raw_errors` (produces per-row
    confusion indicators) and :meth:`_aggregate_scores` (sums indicators,
    computes metrics, applies class averaging).

    Parameters
    ----------
    average : str, default="macro"
        Class averaging strategy: ``"macro"`` (unweighted mean across
        classes), ``"micro"`` (aggregate counts across classes first),
        or ``"weighted"`` (support-weighted mean).
    zero_division : float, default=0.0
        Value to return when a metric denominator is zero.
    aggregation_method : list of str or str, default="all"
        Which dimensions to aggregate.
    groups : list of str, dict of str to float, or None, default=None
        Panel group filter or filter with weights.
    components : list of str, dict of str to float, or None, default=None
        Component filter or filter with weights.

    """

    _parameter_constraints: dict = {
        **BaseClassProbaScorer._parameter_constraints,
        "average": [StrOptions({"macro", "micro", "weighted"})],
        "zero_division": "no_validation",
    }

    def __init__(
        self,
        average: str = "macro",
        zero_division: float = 0.0,
        aggregation_method: list[str] | str = "all",
        groups: list[str] | dict[str, float] | None = None,
        components: list[str] | dict[str, float] | None = None,
    ):
        super().__init__(
            aggregation_method=aggregation_method,
            groups=groups,
            components=components,
        )
        self.average = average
        self.zero_division = zero_division

    def _compute_raw_errors(self, y_truth: pl.DataFrame, y_pred: pl.DataFrame) -> pl.DataFrame:
        """Compute per-row confusion indicators for each class and target.

        Argmaxes predicted probabilities into hard labels and computes
        binary TP/FP/FN indicators per row per class.

        Returns a wide DataFrame with columns like
        ``{target}_tp_{class}``, ``{target}_fp_{class}``,
        ``{target}_fn_{class}`` (target includes panel group prefix
        when present, e.g. ``grpA__weather_tp_sunny``).

        """
        target_cols = self._extract_target_columns(y_truth)
        all_indicator_cols: list[pl.Series] = []

        for target_col in target_cols:
            proba_cols, class_labels = self._extract_class_proba_columns(y_pred, target_col)
            if not proba_cols:
                continue
            self._build_confusion_indicators(
                y_truth[target_col],
                y_pred.select(proba_cols),
                class_labels,
                target_col,
                all_indicator_cols,
            )

        return pl.DataFrame(all_indicator_cols)

    @staticmethod
    def _build_confusion_indicators(
        truth_series: pl.Series,
        proba_df: pl.DataFrame,
        class_labels: list[str],
        prefix: str,
        out: list[pl.Series],
    ) -> None:
        """Build TP/FP/FN indicator columns for one target.

        Parameters
        ----------
        truth_series : pl.Series
            True class labels.
        proba_df : pl.DataFrame
            Probability columns for this target.
        class_labels : list of str
            Class label names (one per proba column).
        prefix : str
            Column name prefix (target or group__target).
        out : list of pl.Series
            Output list; indicator Series are appended in place.

        """
        truth_arr = truth_series.to_numpy().astype(str)
        proba_arr = proba_df.to_numpy()
        pred_indices = np.argmax(proba_arr, axis=1)
        pred_labels = np.array(class_labels)[pred_indices]

        for cls in class_labels:
            is_true = (truth_arr == cls).astype(np.float64)
            is_pred = (pred_labels == cls).astype(np.float64)
            tp = is_true * is_pred
            fp = (1.0 - is_true) * is_pred
            fn = is_true * (1.0 - is_pred)
            out.append(pl.Series(f"{prefix}_tp_{cls}", tp))
            out.append(pl.Series(f"{prefix}_fp_{cls}", fp))
            out.append(pl.Series(f"{prefix}_fn_{cls}", fn))

    def _aggregate_scores(
        self, raw_scores: pl.DataFrame, context: ScoringContext | None = None
    ) -> float | pl.DataFrame:
        """Aggregate confusion indicators into final metric scores.

        Pipeline: sum rows (per-vintage) → compute metric from counts →
        delegate tail to _aggregate_per_vintage_scores.

        """
        dims = self._normalize_agg_methods(self.aggregation_method)

        # 1. Collapse rows via SUM (not mean) — produces per-vintage sums
        result = self._collapse_rows_sum(raw_scores, context, dims)

        # 2 & 3. Compute metric from counts and apply class averaging
        result = self._compute_and_average(result)

        # 4. Delegate tail (components → groups → transform → vintage collapse → finalize)
        return self._aggregate_per_vintage_scores(result, context)

    def _collapse_rows_sum(
        self,
        df: pl.DataFrame,
        context: ScoringContext | None,
        dims: set[str],
    ) -> pl.DataFrame:
        """Collapse row dimensions using sum (for confusion counts)."""
        return self._collapse_rows_with(df, context, dims, agg_fn="sum")

    def _compute_and_average(self, df: pl.DataFrame) -> pl.DataFrame:
        """Compute metric from counts and apply class averaging.

        For micro averaging, sums TP/FP/FN across all classes before
        computing the metric. For macro/weighted, computes per-class
        metrics first, then averages.

        """
        meta_names = {"coverage_rate", "forecasting_step", "vintage_time", "time"}
        meta_cols = [c for c in df.columns if c in meta_names]
        val_cols = [c for c in df.columns if c not in meta_names]

        # Parse column structure: group targets and their classes
        targets = self._parse_indicator_columns(val_cols)

        result_cols: list[pl.Series] = []

        for col_prefix, class_labels in targets.items():
            if self.average == "micro":
                score_series = self._micro_average(df, col_prefix, class_labels)
            elif self.average == "weighted":
                score_series = self._weighted_average(df, col_prefix, class_labels)
            else:
                score_series = self._macro_average(df, col_prefix, class_labels)
            result_cols.append(score_series)

        meta_df = df.select(meta_cols) if meta_cols else pl.DataFrame()
        score_df = pl.DataFrame(result_cols)
        if len(meta_df) > 0:
            return pl.concat([meta_df, score_df], how="horizontal")
        return score_df

    @staticmethod
    def _parse_indicator_columns(val_cols: list[str]) -> dict[str, list[str]]:
        """Parse indicator columns into {target_prefix: [class_labels]}.

        Columns follow the pattern ``{prefix}_tp_{class}``,
        ``{prefix}_fp_{class}``, ``{prefix}_fn_{class}``.
        """
        targets: dict[str, list[str]] = {}
        for col in val_cols:
            for indicator in ("_tp_", "_fp_", "_fn_"):
                if indicator in col:
                    # Use rindex to match the last occurrence, so class
                    # labels containing "_tp_" etc. are handled correctly.
                    idx = col.rindex(indicator)
                    prefix = col[:idx]
                    cls = col[idx + len(indicator) :]
                    if prefix not in targets:
                        targets[prefix] = []
                    if cls not in targets[prefix]:
                        targets[prefix].append(cls)
                    break
        return targets

    def _micro_average(self, df: pl.DataFrame, col_prefix: str, class_labels: list[str]) -> pl.Series:
        """Micro averaging: sum counts across classes, then compute metric."""
        tp_cols = [f"{col_prefix}_tp_{cls}" for cls in class_labels]
        fp_cols = [f"{col_prefix}_fp_{cls}" for cls in class_labels]
        fn_cols = [f"{col_prefix}_fn_{cls}" for cls in class_labels]

        tp_total = df.select(pl.sum_horizontal(tp_cols)).to_series()
        fp_total = df.select(pl.sum_horizontal(fp_cols)).to_series()
        fn_total = df.select(pl.sum_horizontal(fn_cols)).to_series()

        metric_df = self._compute_metric_from_counts(pl.DataFrame({"tp": tp_total, "fp": fp_total, "fn": fn_total}))
        return metric_df.to_series().alias(col_prefix)

    def _macro_average(self, df: pl.DataFrame, col_prefix: str, class_labels: list[str]) -> pl.Series:
        """Macro averaging: compute per-class metric, then unweighted mean."""
        per_class: list[pl.Series] = []
        for i, cls in enumerate(class_labels):
            counts = pl.DataFrame({
                "tp": df[f"{col_prefix}_tp_{cls}"],
                "fp": df[f"{col_prefix}_fp_{cls}"],
                "fn": df[f"{col_prefix}_fn_{cls}"],
            })
            metric_df = self._compute_metric_from_counts(counts)
            per_class.append(metric_df.to_series().alias(f"_cls_{i}"))

        stacked = pl.DataFrame(per_class)
        return stacked.select(pl.mean_horizontal(pl.all())).to_series().alias(col_prefix)

    def _weighted_average(self, df: pl.DataFrame, col_prefix: str, class_labels: list[str]) -> pl.Series:
        """Weighted averaging: per-class metric weighted by class support."""
        per_class_metrics: list[pl.Series] = []
        supports: list[pl.Series] = []
        for i, cls in enumerate(class_labels):
            counts = pl.DataFrame({
                "tp": df[f"{col_prefix}_tp_{cls}"],
                "fp": df[f"{col_prefix}_fp_{cls}"],
                "fn": df[f"{col_prefix}_fn_{cls}"],
            })
            metric_df = self._compute_metric_from_counts(counts)
            per_class_metrics.append(metric_df.to_series().alias(f"_cls_{i}"))
            # Support = TP + FN (number of true instances of this class)
            supports.append((df[f"{col_prefix}_tp_{cls}"] + df[f"{col_prefix}_fn_{cls}"]).alias(f"_sup_{i}"))

        # Build a temporary DataFrame for clean computation
        tmp_data: dict[str, pl.Series] = {}
        for i, (m, s) in enumerate(zip(per_class_metrics, supports, strict=True)):
            tmp_data[f"_m_{i}"] = m
            tmp_data[f"_s_{i}"] = s
        tmp = pl.DataFrame(tmp_data)

        n_classes = len(class_labels)
        metric_cols = [f"_m_{i}" for i in range(n_classes)]
        support_cols = [f"_s_{i}" for i in range(n_classes)]

        result = tmp.select(
            (
                pl.sum_horizontal([pl.col(mc) * pl.col(sc) for mc, sc in zip(metric_cols, support_cols, strict=True)])
                / pl.max_horizontal(pl.sum_horizontal([pl.col(sc) for sc in support_cols]), pl.lit(1.0))
            ).alias(col_prefix)
        )
        return result.to_series()

    @abc.abstractmethod
    def _compute_metric_from_counts(self, counts: pl.DataFrame) -> pl.DataFrame:
        """Compute metric value from confusion counts.

        Parameters
        ----------
        counts : pl.DataFrame
            DataFrame with ``"tp"``, ``"fp"``, ``"fn"`` columns.
            Each row corresponds to a temporal group (or a single row
            when all rows are collapsed).

        Returns
        -------
        pl.DataFrame
            Single-column DataFrame with computed metric values.

        """