Skip to content

all_estimators

yohou.utils.discovery.all_estimators(type_filter=None)

Get a list of all estimators from yohou.

This function crawls the module and gets all classes that inherit from BaseEstimator. Classes that are defined in test-modules are not included.

Parameters

Name Type Description Default
type_filter (forecaster, point, interval, class_proba, transformer, splitter, scorer, point_scorer, interval_scorer, class_proba_scorer, conformity_scorer)

Which kind of estimators should be returned. If None, no filter is applied and all estimators are returned. Possible values are:

  • 'forecaster': All forecasters (point, interval, class_proba, or both)
  • 'point': Only point forecasters
  • 'interval': Only interval forecasters
  • 'class_proba': Only class-probability forecasters
  • 'transformer': Transformers
  • 'splitter': Cross-validation splitters
  • 'scorer': All scorers
  • 'point_scorer': Only point scorers
  • 'interval_scorer': Only interval scorers
  • 'class_proba_scorer': Only class-probability scorers
  • 'conformity_scorer': Only conformity scorers
"forecaster"

Returns

Name Type Description
estimators list of tuples

List of (name, class), where name is the class name as string and class is the actual type of the class.

Examples

>>> from yohou.utils.discovery import all_estimators
>>> estimators = all_estimators()
>>> type(estimators)
<class 'list'>
>>> forecasters = all_estimators(type_filter='forecaster')
>>> points = all_estimators(type_filter='point')

See Also

Source Code

Show/Hide source
def all_estimators(type_filter: str | list[str] | None = None) -> list[tuple[str, type]]:
    """Get a list of all estimators from `yohou`.

    This function crawls the module and gets all classes that inherit
    from `BaseEstimator`. Classes that are defined in test-modules are not
    included.

    Parameters
    ----------
    type_filter : {"forecaster", "point", "interval", "class_proba", \
            "transformer", "splitter", "scorer", "point_scorer", \
            "interval_scorer", "class_proba_scorer", "conformity_scorer"} or list of such str, default=None
        Which kind of estimators should be returned. If None, no filter is
        applied and all estimators are returned. Possible values are:

        - 'forecaster': All forecasters (point, interval, class_proba, or both)
        - 'point': Only point forecasters
        - 'interval': Only interval forecasters
        - 'class_proba': Only class-probability forecasters
        - 'transformer': Transformers
        - 'splitter': Cross-validation splitters
        - 'scorer': All scorers
        - 'point_scorer': Only point scorers
        - 'interval_scorer': Only interval scorers
        - 'class_proba_scorer': Only class-probability scorers
        - 'conformity_scorer': Only conformity scorers

    Returns
    -------
    estimators : list of tuples
        List of (name, class), where ``name`` is the class name as string
        and ``class`` is the actual type of the class.

    Examples
    --------
    >>> from yohou.utils.discovery import all_estimators
    >>> estimators = all_estimators()
    >>> type(estimators)
    <class 'list'>
    >>> forecasters = all_estimators(type_filter='forecaster')
    >>> points = all_estimators(type_filter='point')

    See Also
    --------
    - [`all_displays`][yohou.utils.discovery.all_displays] : Get all display classes from yohou.
    - [`all_functions`][yohou.utils.discovery.all_functions] : Get all public functions from yohou.
    """

    def is_abstract(c: type) -> bool:
        """Check if a class is abstract."""
        if not (hasattr(c, "__abstractmethods__")):
            return False
        abstract_methods = c.__abstractmethods__
        return bool(abstract_methods)

    all_classes = []
    root = str(Path(__file__).parent.parent)  # yohou package
    # Ignore deprecation warnings triggered at import time and from walking
    # packages
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=FutureWarning)
        for _, module_name, _ in pkgutil.walk_packages(path=[root], prefix="yohou."):
            module_parts = module_name.split(".")
            if any(part in _MODULE_TO_IGNORE for part in module_parts):
                continue
            try:
                module = import_module(module_name)
            except ImportError:
                continue
            classes = inspect.getmembers(module, inspect.isclass)
            classes = [
                (name, est_cls)
                for name, est_cls in classes
                if not name.startswith("_") and est_cls.__module__ == module_name
            ]

            all_classes.extend(classes)

    all_classes_set = set(all_classes)

    estimators = [c for c in all_classes_set if (issubclass(c[1], BaseEstimator) and c[0] != "BaseEstimator")]
    # get rid of abstract base classes
    estimators = [c for c in estimators if not is_abstract(c[1])]
    # Exclude internal base classes that are not meant to be instantiated directly
    _BASE_CLASSES = {"BaseForecaster", "BaseReductionForecaster", "BaseIntervalForecaster"}
    estimators = [c for c in estimators if c[0] not in _BASE_CLASSES]

    if type_filter is not None:
        type_filter = [type_filter] if not isinstance(type_filter, list) else list(type_filter)

        filtered_estimators = []

        # Define valid filter types
        valid_filters = {
            "forecaster",
            "point",
            "interval",
            "class_proba",
            "transformer",
            "splitter",
            "scorer",
            "point_scorer",
            "interval_scorer",
            "class_proba_scorer",
            "conformity_scorer",
        }

        # Check for invalid filters
        invalid_filters = set(type_filter) - valid_filters
        if invalid_filters:
            raise ValueError(
                f"Invalid type_filter values: {sorted(invalid_filters)}. Valid options are: {sorted(valid_filters)}"
            )

        # Filter estimators by tags
        for name, est_cls in estimators:
            try:
                # Get tags from instance (tags may be instance-dependent)
                # Try default initialization
                try:
                    instance = est_cls()
                except TypeError:
                    # Skip estimators that require constructor arguments
                    continue

                tags = instance.__sklearn_tags__()
                est_type = tags.estimator_type

                # Check base estimator type
                if est_type in type_filter:
                    filtered_estimators.append((name, est_cls))
                    continue

                # Check forecaster sub-types
                if est_type == "forecaster" and hasattr(tags, "forecaster_tags"):
                    forecaster_type = tags.forecaster_tags.forecaster_type
                    if forecaster_type is not None and forecaster_type & frozenset(type_filter):
                        filtered_estimators.append((name, est_cls))
                    elif "forecaster" in type_filter:
                        # Generic forecaster filter matches all forecasters
                        filtered_estimators.append((name, est_cls))

                # Check scorer sub-types
                elif est_type == "scorer" and hasattr(tags, "scorer_tags"):
                    prediction_type = tags.scorer_tags.prediction_type
                    if (
                        prediction_type == "point"
                        and "point_scorer" in type_filter
                        or prediction_type == "interval"
                        and "interval_scorer" in type_filter
                        or prediction_type == "conformity"
                        and "conformity_scorer" in type_filter
                        or prediction_type == "class_proba"
                        and "class_proba_scorer" in type_filter
                    ):
                        filtered_estimators.append((name, est_cls))
                    elif "scorer" in type_filter:
                        # Generic scorer filter matches all scorers
                        filtered_estimators.append((name, est_cls))

            except Exception as e:
                # Skip estimators that can't be instantiated or don't have proper tags
                logging.getLogger(__name__).debug("Skipped %s: %s", name, e)
                continue

        estimators = filtered_estimators

    # drop duplicates, sort for reproducibility
    # itemgetter is used to ensure the sort does not extend to the 2nd item of
    # the tuple
    return sorted(set(estimators), key=itemgetter(0))