def check_rewind_replaces_observations(
forecaster,
y_train: pl.DataFrame,
y_reset: pl.DataFrame,
X_actual_train: pl.DataFrame | None = None,
X_actual_reset: pl.DataFrame | None = None,
X_future: pl.DataFrame | None = None,
X_forecast: pl.DataFrame | None = None,
) -> None:
"""Check rewind() replaces observation buffers correctly.
Parameters
----------
forecaster : BaseForecaster
Fitted forecaster instance
y_train : pl.DataFrame
Original training data
y_reset : pl.DataFrame
New data for reset
X_actual_train : pl.DataFrame, optional
Features for training
X_actual_reset : pl.DataFrame, optional
Features for reset
Raises
------
AssertionError
If observation buffers are not replaced correctly
"""
# Store original buffer length
original_observed_time = forecaster.observed_time_
# Handle both panel (dict) and non-panel (DataFrame or scalar) data
if forecaster._y_observed is not None:
if isinstance(forecaster._y_observed, dict):
# Panel data
first_group = next(iter(forecaster._y_observed.keys()))
first_group_y = forecaster._y_observed[first_group]
# _y_observed[group] can be None when observation_horizon == 0
if first_group_y is not None:
original_y_observed_last_time = first_group_y["time"][-1]
assert original_observed_time[first_group] == original_y_observed_last_time, (
"observed_time_ should match last time in _y_observed before observe()"
)
else:
# Non-panel data
original_y_observed_last_time = forecaster._y_observed["time"][-1]
assert original_observed_time == original_y_observed_last_time, (
"observed_time_ should match last time in _y_observed before observe()"
)
if forecaster._X_t_observed is not None:
if isinstance(forecaster._X_t_observed, dict):
# Panel data
first_group = next(iter(forecaster._X_t_observed.keys()))
if forecaster._X_t_observed[first_group] is not None:
original_X_t_observed_last_time = forecaster._X_t_observed[first_group]["time"][-1]
assert original_observed_time[first_group] == original_X_t_observed_last_time, (
"observed_time_ should match last time in _X_t_observed before observe()"
)
else:
# Non-panel data
original_X_t_observed_last_time = forecaster._X_t_observed["time"][-1]
assert original_observed_time == original_X_t_observed_last_time, (
"observed_time_ should match last time in _X_t_observed before observe()"
)
# Reset to new data
forecaster.rewind(y_reset, X_actual_reset, X_future=X_future, X_forecast=X_forecast)
# Check buffers were replaced
reset_observed_time = forecaster.observed_time_
# Handle both panel and non-panel data
if isinstance(reset_observed_time, dict):
# Panel data: check each group's observed_time matches
for group_name in reset_observed_time:
# Get expected time from y_reset (last row for this group's column)
assert reset_observed_time[group_name] == y_reset["time"][-1], (
f"observed_time_['{group_name}'] should be reset to last time in reset data"
)
else:
# Non-panel data
assert reset_observed_time == y_reset["time"][-1], "observed_time_ should be reset to last time in reset data"
if forecaster._y_observed is not None:
if isinstance(forecaster._y_observed, dict):
# Panel data
for group_name, y_obs in forecaster._y_observed.items():
# _y_observed[group] can be None when observation_horizon == 0
if y_obs is not None:
reset_y_observed_last_time = y_obs["time"][-1]
assert reset_y_observed_last_time == reset_observed_time[group_name], (
f"Last time in _y_observed['{group_name}'] should match reset observed_time_"
)
else:
# Non-panel data
reset_y_observed_last_time = forecaster._y_observed["time"][-1]
assert reset_y_observed_last_time == reset_observed_time, (
"Last time in _y_observed should match reset observed_time_ after rewind()"
)
if forecaster._X_t_observed is not None:
if isinstance(forecaster._X_t_observed, dict):
# Panel data
for group_name, X_t_obs in forecaster._X_t_observed.items():
if X_t_obs is not None:
reset_X_t_observed_last_time = X_t_obs["time"][-1]
assert reset_X_t_observed_last_time == reset_observed_time[group_name], (
f"Last time in _X_t_observed['{group_name}'] should match reset observed_time_"
)
else:
# Non-panel data
reset_X_t_observed_last_time = forecaster._X_t_observed["time"][-1]
assert reset_X_t_observed_last_time == reset_observed_time, (
"Last time in _X_t_observed should match reset observed_time_ after rewind()"
)