def check_observe_extends_observations(
forecaster,
y_train: pl.DataFrame,
y_observe: pl.DataFrame,
X_actual_train: pl.DataFrame | None = None,
X_actual_observe: pl.DataFrame | None = None,
X_future: pl.DataFrame | None = None,
X_forecast: pl.DataFrame | None = None,
) -> None:
"""Check observe() extends observation buffers correctly.
Parameters
----------
forecaster : BaseForecaster
Fitted forecaster instance
y_train : pl.DataFrame
Original training data
y_observe : pl.DataFrame
New data for update
X_actual_train : pl.DataFrame, optional
Features for training
X_actual_observe : pl.DataFrame, optional
Features for update
Raises
------
AssertionError
If observation buffers are not extended correctly
"""
# Store original buffer length
original_observed_time = forecaster.observed_time_
# Handle both panel (dict) and non-panel (DataFrame or scalar) data
if forecaster._y_observed is not None:
if isinstance(forecaster._y_observed, dict):
# Panel data: observed_time_ is a dict
# Check the first group as a representative
first_group = next(iter(forecaster._y_observed.keys()))
first_group_y = forecaster._y_observed[first_group]
# _y_observed[group] can be None when observation_horizon == 0
if first_group_y is not None:
original_y_observed_last_time = first_group_y["time"][-1]
assert original_observed_time[first_group] == original_y_observed_last_time, (
"observed_time_ should match last time in _y_observed before observe()"
)
else:
# Non-panel data: observed_time_ is a scalar
original_y_observed_last_time = forecaster._y_observed["time"][-1]
assert original_observed_time == original_y_observed_last_time, (
"observed_time_ should match last time in _y_observed before observe()"
)
if forecaster._X_t_observed is not None:
if isinstance(forecaster._X_t_observed, dict):
# Panel data
first_group = next(iter(forecaster._X_t_observed.keys()))
if forecaster._X_t_observed[first_group] is not None:
original_X_t_observed_last_time = forecaster._X_t_observed[first_group]["time"][-1]
assert original_observed_time[first_group] == original_X_t_observed_last_time, (
"observed_time_ should match last time in _X_t_observed before observe()"
)
else:
# Non-panel data
original_X_t_observed_last_time = forecaster._X_t_observed["time"][-1]
assert original_observed_time == original_X_t_observed_last_time, (
"observed_time_ should match last time in _X_t_observed before observe()"
)
# Update with new data
forecaster.observe(y_observe, X_actual_observe, X_future=X_future, X_forecast=X_forecast)
# Check buffers were extended
updated_observed_time = forecaster.observed_time_
# Handle both panel and non-panel data for comparison
if isinstance(updated_observed_time, dict):
# Panel data: check all groups were updated
for group_name in updated_observed_time:
assert updated_observed_time[group_name] >= original_observed_time[group_name], (
f"observed_time_ for group {group_name} should be updated"
)
else:
# Non-panel data
assert updated_observed_time >= original_observed_time, (
"observed_time_ should be updated to at least the last time in update data"
)
if forecaster._y_observed is not None:
if isinstance(forecaster._y_observed, dict):
# Panel data
for group_name, y_obs in forecaster._y_observed.items():
# _y_observed[group] can be None when observation_horizon == 0
if y_obs is not None:
updated_y_observed_last_time = y_obs["time"][-1]
assert updated_y_observed_last_time == updated_observed_time[group_name], (
f"Last time in _y_observed['{group_name}'] should match updated observed_time_"
)
else:
# Non-panel data
updated_y_observed_last_time = forecaster._y_observed["time"][-1]
assert updated_y_observed_last_time == updated_observed_time, (
"Last time in _y_observed should match updated observed_time_ after observe()"
)
if forecaster._X_t_observed is not None:
if isinstance(forecaster._X_t_observed, dict):
# Panel data
for group_name, X_t_obs in forecaster._X_t_observed.items():
if X_t_obs is not None:
updated_X_t_observed_last_time = X_t_obs["time"][-1]
assert updated_X_t_observed_last_time == updated_observed_time[group_name], (
f"Last time in _X_t_observed['{group_name}'] should match updated observed_time_"
)
else:
# Non-panel data
updated_X_t_observed_last_time = forecaster._X_t_observed["time"][-1]
assert updated_X_t_observed_last_time == updated_observed_time, (
"Last time in _X_t_observed should match updated observed_time_ after observe()"
)