How to Create a Custom Transformer¶
This guide shows you how to implement a time series transformer that plugs into Yohou's fit/transform/observe lifecycle. Use this when the built-in preprocessing transformers do not cover your feature engineering or data transformation needs.
Prerequisites¶
- Yohou installed (Getting Started)
- Familiarity with Yohou transformers (Use Preprocessing Transformers)
- Understanding of the observation horizon concept (Core Concepts)
Try it interactively
Implement a ScaleTransformer from scratch, validate it with the check generator, and use it in a forecast pipeline.
ViewOpen in marimo1. Subclass BaseTransformer¶
Create a class that extends BaseTransformer
and implement three methods:
_fit(X, y=None): store any state computed from the training data. Must set at least one fitted attribute (name ending in_)._transform(X): return a transformedpl.DataFrame. Always preserve the"time"column.get_feature_names_out(input_features=None): return the list of output column names (excluding"time").
import polars as pl
import polars.selectors as cs
from yohou.base import BaseTransformer
class ScaleTransformer(BaseTransformer):
"""Multiplies all numeric columns by a constant factor."""
_parameter_constraints: dict = {
"factor": [float, int],
}
def __init__(self, factor=2.0):
self.factor = factor
def _fit(self, X, y=None):
self.columns_ = [c for c in X.columns if c != "time"]
def _transform(self, X):
return X.with_columns(cs.numeric() * self.factor)
def get_feature_names_out(self, input_features=None):
return self.columns_ if input_features is None else input_features
You only implement the underscore hooks (_fit, _transform). The base class
fit() handles input validation, sets feature_names_in_, n_features_in_,
X_schema_, and interval_, then delegates to your _fit(). Similarly,
transform() validates and delegates to _transform().
_parameter_constraints declares validation rules for constructor parameters.
Constraints from parent classes are merged automatically, so you only need to
declare constraints for parameters you introduce.
2. Add Inverse Transform¶
If your transformation is reversible, implement _inverse_transform and
tag the transformer as invertible. This lets forecasters automatically
undo target transformations when producing predictions:
class ScaleTransformer(BaseTransformer):
"""Multiplies all numeric columns by a constant factor."""
_tags = {"invertible": True}
_parameter_constraints: dict = {
"factor": [float, int],
}
def __init__(self, factor=2.0):
self.factor = factor
def _fit(self, X, y=None):
self.columns_ = [c for c in X.columns if c != "time"]
def _transform(self, X):
return X.with_columns(cs.numeric() * self.factor)
def _inverse_transform(self, X_t, X_p=None):
return X_t.with_columns(cs.numeric() / self.factor)
def get_feature_names_out(self, input_features=None):
return self.columns_ if input_features is None else input_features
_inverse_transform receives X_t (the transformed data) and optionally
X_p (past observations, useful for stateful transformers that need warmup
rows to reverse the operation).
If your transformation is not invertible, skip this step.
3. Make It Stateful¶
Stateful transformers maintain a memory buffer of recent observations. This is
needed when the transformation depends on a lookback window (lags, rolling
statistics, filters). Tag the transformer as stateful and override the
observation_horizon property to declare how many past rows are required:
import numbers
import polars as pl
import polars.selectors as cs
from yohou.base import BaseTransformer
from yohou.utils._compat import Interval
class RollingMeanTransformer(BaseTransformer):
"""Replaces each value with the rolling mean over the last ``window`` steps."""
_tags = {"stateful": True}
_parameter_constraints: dict = {
"window": [Interval(numbers.Integral, 2, None, closed="left")],
}
def __init__(self, window=5):
self.window = window
@property
def observation_horizon(self):
return self.window
def _fit(self, X, y=None):
self.columns_ = [c for c in X.columns if c != "time"]
def _transform(self, X):
return X.with_columns(
cs.numeric().rolling_mean(window_size=self.window)
).tail(len(X) - self.window + 1)
def get_feature_names_out(self, input_features=None):
return self.columns_ if input_features is None else input_features
The observation_horizon value tells the framework how many warmup rows the
transform needs. The first observation_horizon rows of _transform() output
are consumed by the rolling window and dropped via .tail().
The base class uses observation_horizon to manage the internal memory buffer
(self._X_observed). After fitting, observe() appends new data and trims the
buffer to observation_horizon rows, while rewind() replaces the buffer
entirely. This lifecycle is handled automatically when the transformer is used
inside a pipeline or forecaster.
For stateless transformers (no lookback needed), observation_horizon defaults
to 0 and you can skip this step.
4. Add Output Columns¶
If your transformer produces additional columns (e.g., lag features, calendar
features), reflect this in get_feature_names_out. The output must match the
actual columns produced by _transform (excluding "time"):
class SimpleLagTransformer(BaseTransformer):
"""Creates a single lag-1 feature for each input column."""
_tags = {"stateful": True}
def __init__(self):
pass
@property
def observation_horizon(self):
return 1
def _fit(self, X, y=None):
self.columns_ = [c for c in X.columns if c != "time"]
def _transform(self, X):
lag_cols = {f"{col}_lag_1": X[col].shift(1) for col in self.columns_}
result = X.with_columns(**lag_cols)
return result.tail(len(result) - 1) # drop first row (null from shift)
def get_feature_names_out(self, input_features=None):
cols = input_features if input_features is not None else self.columns_
return list(cols) + [f"{c}_lag_1" for c in cols]
If your transformer preserves the same columns, get_feature_names_out can
simply return input_features or the stored columns_ as shown in the earlier
examples.
5. Handle Panel Data¶
Transformers automatically support panel data (columns prefixed with
group_name__). The base class detects panel structure and your _transform
receives the full DataFrame with prefixed columns.
If your transformation logic is column-agnostic (operating on all numeric
columns via cs.numeric()), panel data works without additional code. If you
reference specific column names, make sure to handle the prefixed variants.
If your transformer only makes sense for non-panel data, override
__sklearn_tags__ to declare this:
from yohou.utils.tags import Tags
class MyUnivariateTransformer(BaseTransformer):
def __sklearn_tags__(self) -> Tags:
tags = super().__sklearn_tags__()
tags.transformer_tags.supports_panel_data = False
return tags
6. Test Your Transformer¶
Use _yield_yohou_transformer_checks to validate API conformance. It runs
26 checks covering fit/transform contracts, observation/rewind behaviour,
inverse transforms, feature names, and panel data:
from sklearn.base import clone
from conftest import run_checks
from yohou.testing import _yield_yohou_transformer_checks
class TestScaleTransformer:
def test_systematic_checks(self, time_series_train_test_factory):
X_train, X_test = time_series_train_test_factory(
train_length=100, test_length=30
)
transformer = ScaleTransformer(factor=3.0)
transformer_fitted = clone(transformer)
transformer_fitted.fit(X_train)
run_checks(
transformer_fitted,
_yield_yohou_transformer_checks(
transformer_fitted, X_train, None, X_test
),
)
If a check does not apply (e.g., check_inverse_transform_identity for a
non-invertible transformer), pass it as an expected failure:
run_checks(
transformer_fitted,
_yield_yohou_transformer_checks(transformer_fitted, X_train, None, X_test),
expected_failures={"check_inverse_transform_identity"},
)
If any check fails, its name tells you exactly which contract is violated. Add your own tests for numerical correctness alongside the generated checks.
See Also¶
- Use Preprocessing Transformers: using the built-in transformers
- Compose Feature Pipelines: combining transformers into pipelines
- Create a Point Forecaster: using transformers inside forecasters via
target_transformerandfeature_transformer - Create an Interval Forecaster: prediction interval forecasters
- Create a Class-Probability Forecaster: categorical outcome forecasters
- Create a Custom Scorer: custom evaluation metrics
- Extending Yohou: when to extend vs compose