Custom Estimator Reference¶
Complete API reference for implementing custom Yohou components. Each component type has a base class, methods to implement, and systematic checks.
Component Types¶
| Type | Base class | Methods to implement | Use case |
|---|---|---|---|
| Point forecaster | BasePointForecaster |
_predict_one, (_fit, _observation_horizon) |
Single-value continuous or categorical predictions |
| Interval forecaster | BaseIntervalForecaster |
_predict_one, (_fit, _observation_horizon) |
Prediction intervals with coverage rates |
| Class-probability forecaster | BaseClassProbaForecaster |
_predict_one, (_fit, _observation_horizon) |
Categorical outcome probabilities |
| Transformer | BaseTransformer |
_fit, _transform, get_feature_names_out, (_inverse_transform) |
Feature engineering, preprocessing |
| Scorer | BasePointScorer / BaseIntervalScorer |
_compute_raw_errors |
Evaluation metrics |
Point Forecaster¶
Tier 1 (simple) forecasters override _observation_horizon and _predict_one.
Override _fit() when custom fitting logic is needed. The base fit() handles
validation, transformer setup, panel detection, and calls _fit() automatically.
import polars as pl
import polars.selectors as cs
from yohou.point.base import BasePointForecaster
class LastValueForecaster(BasePointForecaster):
"""Repeats the last observed value."""
_tags = {"ignores_exogenous": True, "stateful": True}
@property
def _observation_horizon(self):
return 1
def _predict_one(self, groups, **params):
last_value = self._y_observed.select(~cs.by_name("time")).row(-1)[0]
return pl.DataFrame({
self._y_columns[0]: [last_value] * self.fit_forecasting_horizon_,
})
_observation_horizon(property): declares the lookback window size. Return a value computed from constructor params (e.g.self.seasonality)._fit(y_t, X_t, forecasting_horizon)(optional): receives transformed data after_pre_fit()runs. Override this for model-specific fitting logic._predict_one(groups, **params): produces raw predictions for one forecast step. Read fromself._y_observedso thatobserve()updates carry through. Return a DataFrame without a"time"column.- Tier 2 (complex) forecasters override
fit()directly for full control.
Interval Forecaster¶
from yohou.interval import BaseIntervalForecaster
class MyIntervalForecaster(BaseIntervalForecaster):
_tags = {"ignores_exogenous": True}
def _predict_one(self, groups, **params):
# Fit your interval model
# Column naming: {target}_lower_{rate}, {target}_upper_{rate}
...
The base class handles coverage_rates validation, storage, and calls _fit()
automatically. Override _fit() for model-specific logic. Test with
_yield_yohou_forecaster_checks (interval-specific checks are included
automatically).
Class-Probability Forecaster¶
from yohou.class_proba import BaseClassProbaForecaster
class MyClassProbaForecaster(BaseClassProbaForecaster):
def _predict_one(self, groups, **params):
# Return DataFrame with {target}_proba_{class} columns
# Probabilities should sum to 1 per row
...
Transformer¶
import polars as pl
import polars.selectors as cs
from yohou.base import BaseTransformer
class ScaleTransformer(BaseTransformer):
"""Multiplies all numeric values by a constant."""
_tags = {"stateful": False, "invertible": True}
_parameter_constraints: dict = {
"factor": [float, int],
}
def __init__(self, factor=2.0):
self.factor = factor
def _fit(self, X, y=None):
self.columns_ = [c for c in X.columns if c != "time"]
def _transform(self, X):
return X.with_columns(cs.numeric() * self.factor)
def _inverse_transform(self, X_t, X_p=None):
return X_t.with_columns(cs.numeric() / self.factor)
def get_feature_names_out(self, input_features=None):
return self.columns_ if input_features is None else input_features
_fit(X, y)is called at the end offit(). The base class handles validation and fitted attribute setup (feature_names_in_,interval_, etc.) before calling_fit._transform(X)receives already-validated data with the"time"column. Return a DataFrame with"time"preserved._inverse_transform(X_t, X_p)reverses the transformation.X_pprovides warmup rows for stateful transformers. Set_tags = {"invertible": True}when implementing this.
For stateful transformers that need a lookback window:
class MyWindowTransformer(BaseTransformer):
_tags = {"stateful": True}
_parameter_constraints: dict = {
"window_size": [Interval(numbers.Integral, 1, None, closed="left")],
}
def __init__(self, window_size=5):
self.window_size = window_size
@property
def observation_horizon(self):
return self.window_size
Scorer¶
Scorers implement _compute_raw_errors() which receives aligned DataFrames
(time column already removed) and returns per-timestep per-component raw
scores. The base class score() method handles validation, time weighting,
aggregation, post-aggregate transforms, and column renaming automatically:
import polars as pl
from yohou.metrics.base import BasePointScorer
class MaxAbsoluteError(BasePointScorer):
"""Maximum absolute error across the forecast horizon."""
_parameter_constraints: dict = {
**BasePointScorer._parameter_constraints,
}
_metric_name = "max_ae"
def __init__(
self,
aggregation_method="all",
groups=None,
components=None,
):
super().__init__(
aggregation_method=aggregation_method,
groups=groups,
components=components,
)
def _compute_raw_errors(self, y_truth, y_pred):
return (y_truth - y_pred).select(pl.all().abs())
- Raw scores must be a DataFrame without a
"time"column. _aggregate_scoresapplies theaggregation_methodstrategy (stepwise, vintagewise, componentwise, groupwise, all).- Override
_post_aggregatefor transforms after aggregation (e.g. sqrt for RMSE).
For interval scorers, extend BaseIntervalScorer instead. It adds
coverage_rates and "coveragewise" aggregation.
For the full walkthrough, see How to Create a Custom Scorer.
Anatomy of an Estimator¶
__init__¶
Follow scikit-learn conventions: every parameter must be stored as an attribute
with the same name. Forward shared parameters (target_transformer,
feature_transformer, panel_strategy) to the parent:
_tags¶
Declare tags as a class-level dictionary. The base class merges _tags from
all classes in the MRO (most-derived wins):
Common tag keys for forecasters: ignores_exogenous, stateful,
forecaster_type, uses_reduction, supports_panel_data.
Common tag keys for transformers: stateful, invertible.
For advanced cases (dynamic tags based on constructor params or child
estimators), override __sklearn_tags__() instead:
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.forecaster_tags.ignores_exogenous = self.my_param > 0
return tags
_parameter_constraints¶
Define only your own constraints. Parent constraints are merged automatically
via __init_subclass__:
_parameter_constraints: dict = {
"my_param": [Interval(numbers.Integral, 1, None, closed="left")],
"mode": [StrOptions({"fast", "accurate"})],
}
Constraint types: Interval, StrOptions, HasMethods, type classes
(int, float, str), None, "boolean", callable, "no_validation".
observation_horizon¶
Declare as a @property that computes from constructor params.
Returns 0 for stateless estimators (the default):
The base class uses this to manage internal memory buffers for observe()
and rewind() operations.
Testing¶
Yohou provides systematic check generators that validate API conformance:
from conftest import run_checks
from yohou.testing import _yield_yohou_forecaster_checks
def test_my_forecaster(y_X_factory):
from yohou.model_selection import train_test_split
y, X = y_X_factory(length=100)
y_train, y_test = train_test_split(y, test_size=20)
forecaster = MyForecaster()
forecaster.fit(y_train, forecasting_horizon=len(y_test))
run_checks(
forecaster,
_yield_yohou_forecaster_checks(forecaster, y_train, None, y_test),
)
Check Generators¶
| Generator | Checks | Component |
|---|---|---|
_yield_yohou_forecaster_checks |
27 | Forecasters (point, interval, class-probability) |
_yield_yohou_transformer_checks |
26 | Transformers |
_yield_yohou_scorer_checks |
11 | Scorers |
_yield_yohou_splitter_checks |
8 | Splitters |
Skipping Checks¶
Skip checks that don't apply with expected_failures:
run_checks(
transformer,
_yield_yohou_transformer_checks(transformer, X_train, None, X_test),
expected_failures={"check_inverse_transform_identity"},
)
Panel Data¶
The base class handles panel dispatch automatically when columns use the
group__column naming convention. Your _predict_one receives
groups indicating which groups to predict for:
def _predict_one(self, groups, **params):
if self.groups_ is None:
# Non-panel: self._y_observed is a pl.DataFrame
...
else:
# Panel: self._y_observed is a dict[str, pl.DataFrame]
for group in groups:
group_data = self._y_observed[group]
...
Using _pre_fit() Directly¶
For complex estimators that need full control over the fit process (reduction
forecasters, ensemble methods, pipeline composites), call _pre_fit() instead
of super().fit():
class MyComplexForecaster(BasePointForecaster):
def fit(self, y, X_actual=None, forecasting_horizon=1, **params):
self._pre_fit(y, X_actual, forecasting_horizon)
# Custom fitting logic...
return self
Call self._pre_fit(y, X_actual, forecasting_horizon) to run validation, transformer
setup, and panel detection.
See Also¶
- How to Create a Custom Point Forecaster for a focused walkthrough
- Extending Yohou for design rationale
- Extensions for official and community extensions