Skip to content

check_observe_extends_observations

yohou.testing.forecaster.check_observe_extends_observations(forecaster, y_train, y_observe, X_actual_train=None, X_actual_observe=None, X_future=None, X_forecast=None)

Check observe() extends observation buffers correctly.

Parameters

Name Type Description Default
forecaster BaseForecaster

Fitted forecaster instance

required
y_train DataFrame

Original training data

required
y_observe DataFrame

New data for update

required
X_actual_train DataFrame

Features for training

None
X_actual_observe DataFrame

Features for update

None

Raises

Type Description
AssertionError

If observation buffers are not extended correctly

Source Code

Show/Hide source
def check_observe_extends_observations(
    forecaster,
    y_train: pl.DataFrame,
    y_observe: pl.DataFrame,
    X_actual_train: pl.DataFrame | None = None,
    X_actual_observe: pl.DataFrame | None = None,
    X_future: pl.DataFrame | None = None,
    X_forecast: pl.DataFrame | None = None,
) -> None:
    """Check observe() extends observation buffers correctly.

    Parameters
    ----------
    forecaster : BaseForecaster
        Fitted forecaster instance
    y_train : pl.DataFrame
        Original training data
    y_observe : pl.DataFrame
        New data for update
    X_actual_train : pl.DataFrame, optional
        Features for training
    X_actual_observe : pl.DataFrame, optional
        Features for update

    Raises
    ------
    AssertionError
        If observation buffers are not extended correctly

    """
    # Store original buffer length
    original_observed_time = forecaster.observed_time_

    # Handle both panel (dict) and non-panel (DataFrame or scalar) data
    if forecaster._y_observed is not None:
        if isinstance(forecaster._y_observed, dict):
            # Panel data: observed_time_ is a dict
            # Check the first group as a representative
            first_group = next(iter(forecaster._y_observed.keys()))
            first_group_y = forecaster._y_observed[first_group]
            # _y_observed[group] can be None when observation_horizon == 0
            if first_group_y is not None:
                original_y_observed_last_time = first_group_y["time"][-1]
                assert original_observed_time[first_group] == original_y_observed_last_time, (
                    "observed_time_ should match last time in _y_observed before observe()"
                )
        else:
            # Non-panel data: observed_time_ is a scalar
            original_y_observed_last_time = forecaster._y_observed["time"][-1]
            assert original_observed_time == original_y_observed_last_time, (
                "observed_time_ should match last time in _y_observed before observe()"
            )

    if forecaster._X_t_observed is not None:
        if isinstance(forecaster._X_t_observed, dict):
            # Panel data
            first_group = next(iter(forecaster._X_t_observed.keys()))
            if forecaster._X_t_observed[first_group] is not None:
                original_X_t_observed_last_time = forecaster._X_t_observed[first_group]["time"][-1]
                assert original_observed_time[first_group] == original_X_t_observed_last_time, (
                    "observed_time_ should match last time in _X_t_observed before observe()"
                )
        else:
            # Non-panel data
            original_X_t_observed_last_time = forecaster._X_t_observed["time"][-1]
            assert original_observed_time == original_X_t_observed_last_time, (
                "observed_time_ should match last time in _X_t_observed before observe()"
            )

    # Update with new data
    forecaster.observe(y_observe, X_actual_observe, X_future=X_future, X_forecast=X_forecast)

    # Check buffers were extended
    updated_observed_time = forecaster.observed_time_

    # Handle both panel and non-panel data for comparison
    if isinstance(updated_observed_time, dict):
        # Panel data: check all groups were updated
        for group_name in updated_observed_time:
            assert updated_observed_time[group_name] >= original_observed_time[group_name], (
                f"observed_time_ for group {group_name} should be updated"
            )
    else:
        # Non-panel data
        assert updated_observed_time >= original_observed_time, (
            "observed_time_ should be updated to at least the last time in update data"
        )

    if forecaster._y_observed is not None:
        if isinstance(forecaster._y_observed, dict):
            # Panel data
            for group_name, y_obs in forecaster._y_observed.items():
                # _y_observed[group] can be None when observation_horizon == 0
                if y_obs is not None:
                    updated_y_observed_last_time = y_obs["time"][-1]
                    assert updated_y_observed_last_time == updated_observed_time[group_name], (
                        f"Last time in _y_observed['{group_name}'] should match updated observed_time_"
                    )
        else:
            # Non-panel data
            updated_y_observed_last_time = forecaster._y_observed["time"][-1]
            assert updated_y_observed_last_time == updated_observed_time, (
                "Last time in _y_observed should match updated observed_time_ after observe()"
            )

    if forecaster._X_t_observed is not None:
        if isinstance(forecaster._X_t_observed, dict):
            # Panel data
            for group_name, X_t_obs in forecaster._X_t_observed.items():
                if X_t_obs is not None:
                    updated_X_t_observed_last_time = X_t_obs["time"][-1]
                    assert updated_X_t_observed_last_time == updated_observed_time[group_name], (
                        f"Last time in _X_t_observed['{group_name}'] should match updated observed_time_"
                    )
        else:
            # Non-panel data
            updated_X_t_observed_last_time = forecaster._X_t_observed["time"][-1]
            assert updated_X_t_observed_last_time == updated_observed_time, (
                "Last time in _X_t_observed should match updated observed_time_ after observe()"
            )