Skip to content

BaseSplitter

yohou.model_selection.split.BaseSplitter

Bases: BaseEstimator, ABC

Base class for yohou time series cross-validation splitters.

Extends sklearn's BaseCrossValidator with time series-specific functionality including polars DataFrame support and panel data awareness.

All concrete splitters should inherit from this class and implement the _iter_test_indices() method.

Attributes

Name Type Description
interval_ str

Detected time interval of the data, set during split().

Notes

This is an abstract base class. Concrete splitters should inherit from this class and implement _iter_test_indices() and get_n_splits().

See Also

Source Code

Show/Hide source
class BaseSplitter(BaseEstimator, ABC):
    """Base class for yohou time series cross-validation splitters.

    Extends sklearn's BaseCrossValidator with time series-specific
    functionality including polars DataFrame support and panel data awareness.

    All concrete splitters should inherit from this class and implement
    the ``_iter_test_indices()`` method.

    Attributes
    ----------
    interval_ : str
        Detected time interval of the data, set during ``split()``.

    Notes
    -----
    This is an abstract base class. Concrete splitters should inherit from
    this class and implement ``_iter_test_indices()`` and ``get_n_splits()``.

    See Also
    --------
    - [`ExpandingWindowSplitter`][yohou.model_selection.split.ExpandingWindowSplitter] : Expanding-window cross-validation.
    - [`SlidingWindowSplitter`][yohou.model_selection.split.SlidingWindowSplitter] : Sliding-window cross-validation.

    """

    _parameter_constraints: dict = {}
    _tags: ClassVar[dict[str, Any]] = {}

    # Fitted attributes (set during split())
    interval_: str

    def __init_subclass__(cls, **kwargs: Any) -> None:
        """Merge parameter constraints from all classes in the MRO."""
        super().__init_subclass__(**kwargs)
        # Auto-merge _parameter_constraints from all classes in the MRO.
        merged: dict = {}
        for klass in reversed(cls.__mro__):
            own = klass.__dict__.get("_parameter_constraints")
            if own and isinstance(own, dict):
                merged.update(own)
        cls._parameter_constraints = merged

    @abstractmethod
    def split(
        self,
        y: pl.DataFrame,
        X_actual: pl.DataFrame | None = None,
    ) -> Iterator[tuple[np.ndarray[Any, np.dtype[np.intp]], np.ndarray[Any, np.dtype[np.intp]]]]:
        """Generate indices to split time series data.

        Parameters
        ----------
        y : pl.DataFrame
            Target time series used to generate train/test split indices.
            Must have a ``"time"`` column.
        X_actual : pl.DataFrame or None, default=None
            Actual features.  Not used for splitting but accepted for
            API consistency.

        Yields
        ------
        train : ndarray
            Training set row indices for that split.
        test : ndarray
            Test set row indices for that split.

        """

    @abstractmethod
    def _iter_test_indices(
        self,
        y: pl.DataFrame,
        X_actual: pl.DataFrame | None = None,
    ) -> Iterator[np.ndarray[Any, np.dtype[np.intp]]]:
        """Generate test indices for each split.

        Must be implemented by concrete splitter classes.

        Parameters
        ----------
        y : pl.DataFrame
            Target time series.
        X_actual : pl.DataFrame or None, default=None
            Actual features. Not used for splitting but accepted for
            API consistency.

        Yields
        ------
        test : ndarray
            Test set indices for this split.

        """

    @abstractmethod
    def get_n_splits(
        self,
        y: pl.DataFrame | None = None,
        X_actual: pl.DataFrame | None = None,
    ) -> int:
        """Return the number of cross-validation folds.

        Parameters
        ----------
        y : pl.DataFrame or None, default=None
            Not used.  Accepted for API consistency.
        X_actual : pl.DataFrame or None, default=None
            Not used.  Accepted for API consistency.

        Returns
        -------
        int
            The number of cross-validation folds.

        """

    def __sklearn_tags__(self):
        """Get metadata tags for this splitter.

        Returns
        -------
        tags : Tags
            Metadata tags describing splitter capabilities.

        """
        tags = Tags(estimator_type="splitter")
        if tags.splitter_tags is not None:
            tags.splitter_tags.supports_panel_data = True

        # Merge class-level _tags dict (flat keys) into tag dataclasses.
        # Walk MRO in reverse so most-derived class wins.
        merged_tags: dict[str, Any] = {}
        for klass in reversed(type(self).__mro__):
            class_tags = klass.__dict__.get("_tags")
            if class_tags and isinstance(class_tags, dict):
                merged_tags.update(class_tags)

        if merged_tags:
            for key, value in merged_tags.items():
                if tags.splitter_tags is not None and hasattr(tags.splitter_tags, key):
                    setattr(tags.splitter_tags, key, value)
                elif tags.input_tags is not None and hasattr(tags.input_tags, key):
                    setattr(tags.input_tags, key, value)
                elif hasattr(tags, key):
                    setattr(tags, key, value)

        return tags

Methods

__init_subclass__(**kwargs)

Merge parameter constraints from all classes in the MRO.

Source Code
Show/Hide source
def __init_subclass__(cls, **kwargs: Any) -> None:
    """Merge parameter constraints from all classes in the MRO."""
    super().__init_subclass__(**kwargs)
    # Auto-merge _parameter_constraints from all classes in the MRO.
    merged: dict = {}
    for klass in reversed(cls.__mro__):
        own = klass.__dict__.get("_parameter_constraints")
        if own and isinstance(own, dict):
            merged.update(own)
    cls._parameter_constraints = merged

split(y, X_actual=None) abstractmethod

Generate indices to split time series data.

Parameters
Name Type Description Default
y DataFrame

Target time series used to generate train/test split indices. Must have a "time" column.

required
X_actual DataFrame or None

Actual features. Not used for splitting but accepted for API consistency.

None

Yields:

Name Type Description
train ndarray

Training set row indices for that split.

test ndarray

Test set row indices for that split.

Source Code
Show/Hide source
@abstractmethod
def split(
    self,
    y: pl.DataFrame,
    X_actual: pl.DataFrame | None = None,
) -> Iterator[tuple[np.ndarray[Any, np.dtype[np.intp]], np.ndarray[Any, np.dtype[np.intp]]]]:
    """Generate indices to split time series data.

    Parameters
    ----------
    y : pl.DataFrame
        Target time series used to generate train/test split indices.
        Must have a ``"time"`` column.
    X_actual : pl.DataFrame or None, default=None
        Actual features.  Not used for splitting but accepted for
        API consistency.

    Yields
    ------
    train : ndarray
        Training set row indices for that split.
    test : ndarray
        Test set row indices for that split.

    """

get_n_splits(y=None, X_actual=None) abstractmethod

Return the number of cross-validation folds.

Parameters
Name Type Description Default
y DataFrame or None

Not used. Accepted for API consistency.

None
X_actual DataFrame or None

Not used. Accepted for API consistency.

None
Returns
Type Description
int

The number of cross-validation folds.

Source Code
Show/Hide source
@abstractmethod
def get_n_splits(
    self,
    y: pl.DataFrame | None = None,
    X_actual: pl.DataFrame | None = None,
) -> int:
    """Return the number of cross-validation folds.

    Parameters
    ----------
    y : pl.DataFrame or None, default=None
        Not used.  Accepted for API consistency.
    X_actual : pl.DataFrame or None, default=None
        Not used.  Accepted for API consistency.

    Returns
    -------
    int
        The number of cross-validation folds.

    """

__sklearn_tags__()

Get metadata tags for this splitter.

Returns
Name Type Description
tags Tags

Metadata tags describing splitter capabilities.

Source Code
Show/Hide source
def __sklearn_tags__(self):
    """Get metadata tags for this splitter.

    Returns
    -------
    tags : Tags
        Metadata tags describing splitter capabilities.

    """
    tags = Tags(estimator_type="splitter")
    if tags.splitter_tags is not None:
        tags.splitter_tags.supports_panel_data = True

    # Merge class-level _tags dict (flat keys) into tag dataclasses.
    # Walk MRO in reverse so most-derived class wins.
    merged_tags: dict[str, Any] = {}
    for klass in reversed(type(self).__mro__):
        class_tags = klass.__dict__.get("_tags")
        if class_tags and isinstance(class_tags, dict):
            merged_tags.update(class_tags)

    if merged_tags:
        for key, value in merged_tags.items():
            if tags.splitter_tags is not None and hasattr(tags.splitter_tags, key):
                setattr(tags.splitter_tags, key, value)
            elif tags.input_tags is not None and hasattr(tags.input_tags, key):
                setattr(tags.input_tags, key, value)
            elif hasattr(tags, key):
                setattr(tags, key, value)

    return tags