Skip to content

check_cv_alignment

yohou.model_selection.split.check_cv_alignment(cv, forecasting_horizon)

Inspect how a CV splitter's test windows align with a forecasting horizon.

Call this before scoring to understand how many vintages will be produced in each fold and whether each forecasting step is represented equally.

Parameters

Name Type Description Default
cv BaseSplitter

A fitted or unfitted splitter instance.

required
forecasting_horizon int

The forecasting horizon (number of steps predicted per vintage).

required

Returns

Type Description
dict

n_vintages Number of predict calls per fold (1 initial + test_size // stride). steps_per_vintage List of step counts per vintage. All entries equal forecasting_horizon except possibly the last. step_counts Dict mapping step number (1-based) to the total number of vintages that include that step. is_balanced True when every step appears in exactly the same number of vintages.

Examples

>>> from yohou.model_selection import SlidingWindowSplitter, check_cv_alignment
>>> cv = SlidingWindowSplitter(n_splits=3, test_size=10, stride=4)
>>> info = check_cv_alignment(cv, forecasting_horizon=4)
>>> info["is_balanced"]
False

Source Code

Show/Hide source
def check_cv_alignment(
    cv: BaseSplitter,
    forecasting_horizon: int,
) -> dict[str, Any]:
    """Inspect how a CV splitter's test windows align with a forecasting horizon.

    Call this before scoring to understand how many vintages will be
    produced in each fold and whether each forecasting step is
    represented equally.

    Parameters
    ----------
    cv : BaseSplitter
        A fitted or unfitted splitter instance.
    forecasting_horizon : int
        The forecasting horizon (number of steps predicted per vintage).

    Returns
    -------
    dict
        ``n_vintages``
            Number of predict calls per fold (1 initial + test_size // stride).
        ``steps_per_vintage``
            List of step counts per vintage.  All entries equal
            ``forecasting_horizon`` except possibly the last.
        ``step_counts``
            Dict mapping step number (1-based) to the total number of
            vintages that include that step.
        ``is_balanced``
            ``True`` when every step appears in exactly the same number
            of vintages.

    Examples
    --------
    >>> from yohou.model_selection import SlidingWindowSplitter, check_cv_alignment
    >>> cv = SlidingWindowSplitter(n_splits=3, test_size=10, stride=4)
    >>> info = check_cv_alignment(cv, forecasting_horizon=4)
    >>> info["is_balanced"]
    False

    """
    # Resolve test_size and stride from the splitter
    test_size: int
    stride: int

    if isinstance(cv, SlidingWindowSplitter):
        test_size = cv.test_size
        stride = cv.stride if cv.stride is not None else cv.test_size
    elif isinstance(cv, ExpandingWindowSplitter):
        test_size = cv.test_size if cv.test_size is not None else forecasting_horizon
        stride = test_size  # ExpandingWindowSplitter always uses test_size as stride
    else:
        return {
            "n_vintages": None,
            "steps_per_vintage": None,
            "step_counts": None,
            "is_balanced": None,
        }

    # Number of vintages: initial predict + one per observe step
    n_observe_steps = (test_size + stride - 1) // stride  # ceil division
    n_vintages = 1 + n_observe_steps

    # Each vintage predicts forecasting_horizon steps, but
    # the last observe may consume fewer than stride rows.
    # This doesn't change the prediction size - it only affects
    # how many observed rows back-fill before that vintage's prediction.
    steps_per_vintage = [forecasting_horizon] * n_vintages

    # Build step counts: each vintage contributes steps 1..fh.
    # All vintages contribute equally.
    step_counts = dict.fromkeys(range(1, forecasting_horizon + 1), n_vintages)

    # Check if test_size aligns with stride
    is_balanced = test_size % stride == 0

    return {
        "n_vintages": n_vintages,
        "steps_per_vintage": steps_per_vintage,
        "step_counts": step_counts,
        "is_balanced": is_balanced,
    }