Skip to content

How to Create a Custom Scorer

This guide shows you how to implement a custom forecasting metric that plugs into yohou's scoring pipeline. Use this when the built-in metrics don't cover your evaluation needs. The guide follows the most common pattern (a per-row decomposable point metric) end to end.

Try it interactively

How to Create a Custom Scorer

Implement a custom point scorer with aggregation, panel support, and systematic testing.

ViewOpen in marimo

Prerequisites

1. Subclass the Base

Create a class that extends BasePointScorer and implement _compute_raw_errors. The base score() method handles weight resolution, weight application, aggregation, vintage collapse, and column renaming automatically.

Every scorer must define two class attributes:

  • _metric_name (str): Controls output column names (e.g., "mae" produces value__mae).
  • _parameter_constraints (dict): Merged with base class constraints for validation. Pass **BasePointScorer._parameter_constraints at minimum.

If your metric is one where higher is better (R², accuracy), set _lower_is_better = False.

import polars as pl

from yohou.metrics.base import BasePointScorer


class MeanAbsoluteError(BasePointScorer):

    _parameter_constraints: dict = {
        **BasePointScorer._parameter_constraints,
    }

    _metric_name = "mae"

    def __init__(
        self,
        aggregation_method: list[str] | str = "all",
        groups: list[str] | dict[str, float] | None = None,
        components: list[str] | dict[str, float] | None = None,
    ) -> None:
        super().__init__(
            aggregation_method=aggregation_method,
            groups=groups,
            components=components,
        )

    def _compute_raw_errors(
        self, y_truth: pl.DataFrame, y_pred: pl.DataFrame
    ) -> pl.DataFrame:
        """Return per-row absolute errors."""
        return (y_truth - y_pred).select(pl.all().abs())

_compute_raw_errors receives DataFrames with only value columns (no time or vintage_time). Return a DataFrame with the same shape: one row per timestep, one column per component.

Scorers that implement _compute_raw_errors inherit full weight support (time_weight, step_weight, vintage_weight) from BasePointScorer.score() with no additional code.

If you are evaluating prediction intervals instead of point predictions, extend BaseIntervalScorer and implement _compute_raw_scores. See the yohou.metrics API Reference for all base class options.

2. Add Custom Parameters

If your metric needs additional configuration (e.g., an epsilon to prevent division by zero), add a parameter to __init__ and register its constraint:

import numbers

from yohou.utils._compat import Interval


class MeanAbsolutePercentageError(BasePointScorer):

    _parameter_constraints: dict = {
        **BasePointScorer._parameter_constraints,
        "epsilon": [Interval(numbers.Real, 0, None, closed="neither")],
    }

    _metric_name = "mape"

    def __init__(
        self,
        epsilon: float = 1e-8,
        aggregation_method: list[str] | str = "all",
        groups: list[str] | dict[str, float] | None = None,
        components: list[str] | dict[str, float] | None = None,
    ) -> None:
        super().__init__(
            aggregation_method=aggregation_method,
            groups=groups,
            components=components,
        )
        self.epsilon = epsilon

    def _compute_raw_errors(self, y_truth, y_pred):
        pct_errors = {}
        for col in y_truth.columns:
            abs_errors = (y_truth[col] - y_pred[col]).abs()
            pct_errors[col] = (abs_errors / (y_truth[col].abs() + self.epsilon)) * 100.0
        return pl.DataFrame(pct_errors)

Available constraint validators include Interval (numeric range) and StrOptions (string enum), both from yohou.utils._compat.

If your scorer has no custom parameters, skip this step.

3. Customize Aggregation

The base class collapses raw errors to a scalar using the mean. Override these hooks when you need different behavior:

Post-aggregation transform (e.g., square root for RMSE):

def _transform_scores(self, df):
    return df.select(pl.all().sqrt())

Custom row aggregation (e.g., max instead of mean):

def _collapse_rows(self, df, context, dims):
    return self._collapse_rows_with(df, context, dims, agg_fn="max")

Training-data statistics (e.g., RMSSE needs seasonal naive errors): override fit(), call super().fit(), and store computed statistics as attributes ending in _.

If mean aggregation with no transform is sufficient, skip this step.

4. Register in __init__.py

Note

This step applies only when contributing a scorer to the Yohou package itself. If you are building a scorer in your own package, skip this section.

Add your scorer to src/yohou/metrics/__init__.py so it is accessible via get_scorer() and make_scorer():

from .point import MeanAbsoluteError

_SCORER_REGISTRY: dict[str, type[BaseScorer]] = {
    # ...existing entries...
    "mae": MeanAbsoluteError,
}

The registry key becomes the short name used by get_scorer("mae").

5. Test Your Scorer

yohou provides check generators that validate API conformance (tag accessibility, aggregation methods, parameter validation, fit/score lifecycle, multi-vintage scoring):

import polars as pl
from datetime import datetime
from yohou.testing import _yield_yohou_scorer_checks

y_truth = pl.DataFrame({
    "time": [datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 3)],
    "value": [10.0, 20.0, 30.0],
})
y_pred = pl.DataFrame({
    "vintage_time": [datetime(2019, 12, 31)] * 3,
    "time": [datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 3)],
    "value": [12.0, 19.0, 28.0],
})

scorer = MeanAbsoluteError()
scorer.fit(y_truth)

for check_name, check_func, check_kwargs in _yield_yohou_scorer_checks(
    scorer, y_truth, y_pred
):
    check_func(scorer, **check_kwargs)

The generator yields approximately 11 checks depending on scorer type. check_scorer_multi_vintage automatically builds a 2-vintage dataset from your test data and verifies the scorer produces a finite result. Add your own tests for numerical correctness alongside the generated checks.

See Also