Skip to content

check_forecaster_methods_call_check_is_fitted

yohou.testing.forecaster.check_forecaster_methods_call_check_is_fitted(forecaster, y, X_actual=None, forecasting_horizon=3, X_future=None, X_forecast=None)

Check all forecaster methods (except fit) raise NotFittedError when unfitted.

Validates that predict()/predict_interval(), observe(), rewind(), and observe_predict()/observe_predict_interval() methods all check fitted state and raise NotFittedError before operating on an unfitted forecaster.

Parameters

Name Type Description Default
forecaster BaseForecaster

Unfitted forecaster instance

required
y DataFrame

Training/test target data with "time" column

required
X_actual DataFrame

Training/test features with "time" column

None
forecasting_horizon int

Number of steps ahead to forecast

3

Raises

Type Description
AssertionError

If any method fails to raise NotFittedError when called on unfitted forecaster

Source Code

Show/Hide source
def check_forecaster_methods_call_check_is_fitted(
    forecaster,
    y: pl.DataFrame,
    X_actual: pl.DataFrame | None = None,
    forecasting_horizon: int = 3,
    X_future: pl.DataFrame | None = None,
    X_forecast: pl.DataFrame | None = None,
) -> None:
    """Check all forecaster methods (except fit) raise NotFittedError when unfitted.

    Validates that predict()/predict_interval(), observe(), rewind(), and
    observe_predict()/observe_predict_interval() methods all check fitted state
    and raise NotFittedError before operating on an unfitted forecaster.

    Parameters
    ----------
    forecaster : BaseForecaster
        Unfitted forecaster instance
    y : pl.DataFrame
        Training/test target data with "time" column
    X_actual : pl.DataFrame, optional
        Training/test features with "time" column
    forecasting_horizon : int, default=3
        Number of steps ahead to forecast

    Raises
    ------
    AssertionError
        If any method fails to raise NotFittedError when called on unfitted forecaster

    """
    forecaster_clone = clone(forecaster)

    # Determine if this is a point or interval forecaster
    is_interval = hasattr(forecaster_clone, "predict_interval") and not hasattr(forecaster_clone, "predict")

    # Test that predict() or predict_interval() raises NotFittedError when unfitted
    try:
        if is_interval:
            forecaster_clone.predict_interval(
                forecasting_horizon=forecasting_horizon,
                coverage_rates=[0.9],
            )
            method_name = "predict_interval"
        else:
            forecaster_clone.predict(forecasting_horizon=forecasting_horizon)
            method_name = "predict"
        raise AssertionError(
            f"{forecaster_clone.__class__.__name__}.{method_name}() must raise NotFittedError when called on unfitted forecaster"
        )
    except NotFittedError:
        pass  # Expected

    # Test that observe() raises NotFittedError when unfitted
    try:
        forecaster_clone.observe(y[50:53], X_actual[50:53] if X_actual is not None else None)
        raise AssertionError(
            f"{forecaster_clone.__class__.__name__}.observe() must raise NotFittedError when called on unfitted forecaster"
        )
    except NotFittedError:
        pass  # Expected

    # Test that rewind() raises NotFittedError when unfitted
    try:
        forecaster_clone.rewind(y[40:50], X_actual[40:50] if X_actual is not None else None)
        raise AssertionError(
            f"{forecaster_clone.__class__.__name__}.rewind() must raise NotFittedError when called on unfitted forecaster"
        )
    except NotFittedError:
        pass  # Expected

    # Test that observe_predict() or observe_predict_interval() raises NotFittedError when unfitted
    try:
        if is_interval:
            forecaster_clone.observe_predict_interval(
                y[50:53], X_actual[50:53] if X_actual is not None else None, coverage_rates=[0.9]
            )
            method_name = "observe_predict_interval"
        else:
            forecaster_clone.observe_predict(y[50:53], X_actual[50:53] if X_actual is not None else None)
            method_name = "observe_predict"
        raise AssertionError(
            f"{forecaster_clone.__class__.__name__}.{method_name}() must raise NotFittedError when called on unfitted forecaster"
        )
    except NotFittedError:
        pass  # Expected