Skip to content

MedianAbsoluteError

yohou.metrics.point.MedianAbsoluteError

Bases: BasePointScorer

Median Absolute Error metric for point forecasts.

Computes the median of absolute differences between predictions and actual values. This metric is highly robust to outliers and provides a more stable measure of typical error magnitude compared to mean-based metrics.

The MedianAE is defined as:

\[\\text{MedianAE} = \\text{median}(|y_i - \\hat{y}_i|)\]

where \(y_i\) is the actual value and \(\\hat{y}_i\) is the predicted value.

Parameters

Name Type Description Default
aggregation_method list of str or str

Dimensions to aggregate over. Options: - "stepwise": Aggregate across forecasting steps. - "vintagewise": Aggregate across vintages (observed times). - "componentwise": Aggregate across components, return per-timestep DataFrame - "groupwise": Aggregate across panel groups (panel data only) - "all": Aggregate across all dimensions (returns scalar). Same as ["stepwise", "vintagewise", "componentwise", "groupwise"]. Example outputs: - ["stepwise", "vintagewise"]: Per-component (and per-group) DataFrame. - "componentwise" or ["componentwise"]: Per-timestep (and per-group) DataFrame. - "groupwise" or ["groupwise"]: Per-component per-timestep DataFrame (panel aggregated). - ["stepwise", "vintagewise", "componentwise"]: Scalar (global) or per-group DataFrame (panel). - "all": Scalar float (hierarchically aggregated for panel data).

"all"
groups list of str, dict of str to float, or None

Panel group filter (list) or filter with weights (dict).

None
components list of str, dict of str to float, or None

Component filter (list) or filter with weights (dict).

None

Attributes

Name Type Description
lower_is_better bool

Always True for MedianAE.

Examples

>>> import polars as pl
>>> from datetime import datetime
>>> from yohou.metrics import MedianAbsoluteError
>>> y_true = pl.DataFrame({
...     "time": [datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 3)],
...     "value": [10.0, 20.0, 30.0],
... })
>>> y_pred = pl.DataFrame({
...     "vintage_time": [datetime(2019, 12, 31)] * 3,
...     "time": [datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 3)],
...     "value": [12.0, 19.0, 28.0],
... })
>>> medae = MedianAbsoluteError()
>>> _ = medae.fit(y_true)
>>> medae.score(y_true, y_pred)
2.0

Notes

  • MedianAE is highly robust to outliers and extreme errors
  • Provides a better measure of typical error when error distribution is skewed
  • Less sensitive to a few very large prediction errors compared to MAE
  • Interpretable in the same units as the target variable
  • Suitable when outliers should not dominate the evaluation

See Also

Source Code

Show/Hide source
class MedianAbsoluteError(BasePointScorer):
    r"""Median Absolute Error metric for point forecasts.

    Computes the median of absolute differences between predictions and actual values.
    This metric is highly robust to outliers and provides a more stable measure of
    typical error magnitude compared to mean-based metrics.

    The MedianAE is defined as:

    $$\\text{MedianAE} = \\text{median}(|y_i - \\hat{y}_i|)$$

    where $y_i$ is the actual value and $\\hat{y}_i$ is the predicted value.

    Parameters
    ----------
    aggregation_method : list of str or str, default="all"
        Dimensions to aggregate over. Options:
        - "stepwise": Aggregate across forecasting steps.
        - "vintagewise": Aggregate across vintages (observed times).
        - "componentwise": Aggregate across components, return per-timestep DataFrame
        - "groupwise": Aggregate across panel groups (panel data only)
        - "all": Aggregate across all dimensions (returns scalar). Same as
          ["stepwise", "vintagewise", "componentwise", "groupwise"].
        Example outputs:
        - ["stepwise", "vintagewise"]: Per-component (and per-group) DataFrame.
        - "componentwise" or ["componentwise"]: Per-timestep (and per-group) DataFrame.
        - "groupwise" or ["groupwise"]: Per-component per-timestep DataFrame (panel aggregated).
        - ["stepwise", "vintagewise", "componentwise"]: Scalar (global) or per-group DataFrame (panel).
        - "all": Scalar float (hierarchically aggregated for panel data).
    groups : list of str, dict of str to float, or None, default=None
        Panel group filter (list) or filter with weights (dict).
    components : list of str, dict of str to float, or None, default=None
        Component filter (list) or filter with weights (dict).

    Attributes
    ----------
    lower_is_better : bool
        Always True for MedianAE.

    Examples
    --------
    >>> import polars as pl
    >>> from datetime import datetime
    >>> from yohou.metrics import MedianAbsoluteError
    >>> y_true = pl.DataFrame({
    ...     "time": [datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 3)],
    ...     "value": [10.0, 20.0, 30.0],
    ... })
    >>> y_pred = pl.DataFrame({
    ...     "vintage_time": [datetime(2019, 12, 31)] * 3,
    ...     "time": [datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 3)],
    ...     "value": [12.0, 19.0, 28.0],
    ... })
    >>> medae = MedianAbsoluteError()
    >>> _ = medae.fit(y_true)
    >>> medae.score(y_true, y_pred)
    2.0

    Notes
    -----
    - MedianAE is highly robust to outliers and extreme errors
    - Provides a better measure of typical error when error distribution is skewed
    - Less sensitive to a few very large prediction errors compared to MAE
    - Interpretable in the same units as the target variable
    - Suitable when outliers should not dominate the evaluation

    See Also
    --------
    - [`MeanAbsoluteError`][yohou.metrics.point.MeanAbsoluteError] : Mean-based absolute error, more sensitive to outliers
    - [`MaxAbsoluteError`][yohou.metrics.point.MaxAbsoluteError] : Maximum absolute error, worst-case measure

    """

    _parameter_constraints: dict = {
        **BasePointScorer._parameter_constraints,
    }

    _metric_name = "median_ae"

    def __init__(
        self,
        aggregation_method: list[str] | str = "all",
        groups: list[str] | dict[str, float] | None = None,
        components: list[str] | dict[str, float] | None = None,
    ) -> None:
        super().__init__(
            aggregation_method=aggregation_method,
            groups=groups,
            components=components,
        )

    def _compute_raw_errors(self, y_truth: pl.DataFrame, y_pred: pl.DataFrame) -> pl.DataFrame:
        """Compute per-row absolute errors for median aggregation."""
        return (y_truth - y_pred).select(pl.all().abs())

    def score(  # type: ignore
        self,
        y_truth: pl.DataFrame,
        y_pred: pl.DataFrame,
        /,
        vintage_weight: Callable | pl.DataFrame | dict | None = None,
        **params,
    ) -> float | pl.DataFrame:
        """Compute median absolute error.

        Parameters
        ----------
        y_truth : pl.DataFrame
            True values with "time" column.
        y_pred : pl.DataFrame
            Predicted values with "time" column.
        vintage_weight : callable, pl.DataFrame, dict, or None, default=None
            Per-vintage weights for cross-vintage aggregation.
        **params : dict
            Metadata to route to nested estimators.

        Returns
        -------
        float or pl.DataFrame
            Aggregated median absolute error.

        Raises
        ------
        TypeError
            If time_weight or step_weight are passed (median is not weight-compatible).

        """
        self._reject_weights(**params)
        check_is_fitted(self, ["_is_fitted"])

        y_truth, y_pred, context = validate_scorer_data(
            self,
            y_truth,
            y_pred,
        )

        # Resolve vintage_weight into context
        context = self._resolve_vintage_weight_to_context(context, vintage_weight)

        dims = self._normalize_agg_methods(self.aggregation_method)
        collapse_steps = "stepwise" in dims
        collapse_vintages = "vintagewise" in dims

        if not collapse_steps and not collapse_vintages:
            # Componentwise/groupwise only: keep per-row scores, let
            # _aggregate_per_vintage_scores handle component/group collapse.
            abs_errors = self._compute_raw_errors(y_truth, y_pred)
            # Per-row median across columns (components)
            result = abs_errors.select(pl.concat_list(pl.all()).alias("_err")).select(
                pl.col("_err").list.eval(pl.element().median()).list.first().alias("score")
            )
            time_values = context.time_values if context is not None else None
            if time_values is not None:
                result = result.with_columns(pl.Series("time", time_values).cast(pl.Datetime))
                result = result.select(["time"] + [c for c in result.columns if c != "time"])
            result = self._rename_metric_columns(result)
            return result

        def _compute_median(yt_slice: pl.DataFrame, yp_slice: pl.DataFrame) -> pl.DataFrame:
            """Compute per-column median absolute error."""
            errors = (yt_slice - yp_slice).select(pl.all().abs())
            return errors.select(pl.all().median())

        result = self._map_per_vintage(y_truth, y_pred, context, _compute_median)
        return self._aggregate_per_vintage_scores(result, context)

Methods

score(y_truth, y_pred, /, vintage_weight=None, **params)

Compute median absolute error.

Parameters
Name Type Description Default
y_truth DataFrame

True values with "time" column.

required
y_pred DataFrame

Predicted values with "time" column.

required
vintage_weight callable, pl.DataFrame, dict, or None

Per-vintage weights for cross-vintage aggregation.

None
**params dict

Metadata to route to nested estimators.

{}
Returns
Type Description
float or DataFrame

Aggregated median absolute error.

Raises
Type Description
TypeError

If time_weight or step_weight are passed (median is not weight-compatible).

Source Code
Show/Hide source
def score(  # type: ignore
    self,
    y_truth: pl.DataFrame,
    y_pred: pl.DataFrame,
    /,
    vintage_weight: Callable | pl.DataFrame | dict | None = None,
    **params,
) -> float | pl.DataFrame:
    """Compute median absolute error.

    Parameters
    ----------
    y_truth : pl.DataFrame
        True values with "time" column.
    y_pred : pl.DataFrame
        Predicted values with "time" column.
    vintage_weight : callable, pl.DataFrame, dict, or None, default=None
        Per-vintage weights for cross-vintage aggregation.
    **params : dict
        Metadata to route to nested estimators.

    Returns
    -------
    float or pl.DataFrame
        Aggregated median absolute error.

    Raises
    ------
    TypeError
        If time_weight or step_weight are passed (median is not weight-compatible).

    """
    self._reject_weights(**params)
    check_is_fitted(self, ["_is_fitted"])

    y_truth, y_pred, context = validate_scorer_data(
        self,
        y_truth,
        y_pred,
    )

    # Resolve vintage_weight into context
    context = self._resolve_vintage_weight_to_context(context, vintage_weight)

    dims = self._normalize_agg_methods(self.aggregation_method)
    collapse_steps = "stepwise" in dims
    collapse_vintages = "vintagewise" in dims

    if not collapse_steps and not collapse_vintages:
        # Componentwise/groupwise only: keep per-row scores, let
        # _aggregate_per_vintage_scores handle component/group collapse.
        abs_errors = self._compute_raw_errors(y_truth, y_pred)
        # Per-row median across columns (components)
        result = abs_errors.select(pl.concat_list(pl.all()).alias("_err")).select(
            pl.col("_err").list.eval(pl.element().median()).list.first().alias("score")
        )
        time_values = context.time_values if context is not None else None
        if time_values is not None:
            result = result.with_columns(pl.Series("time", time_values).cast(pl.Datetime))
            result = result.select(["time"] + [c for c in result.columns if c != "time"])
        result = self._rename_metric_columns(result)
        return result

    def _compute_median(yt_slice: pl.DataFrame, yp_slice: pl.DataFrame) -> pl.DataFrame:
        """Compute per-column median absolute error."""
        errors = (yt_slice - yp_slice).select(pl.all().abs())
        return errors.select(pl.all().median())

    result = self._map_per_vintage(y_truth, y_pred, context, _compute_median)
    return self._aggregate_per_vintage_scores(result, context)

Tutorials

The following example notebooks use this component:

  • How to Use Point Forecast Metrics


    Evaluation-Search

    Compare MAE, MAPE, MASE, RMSE, and other point metrics across multiple forecasters with componentwise and groupwise aggregation.

    View · Open in marimo