Skip to content

check_splitter_non_overlapping_tests

yohou.testing.splitter.check_splitter_non_overlapping_tests(splitter, y, X_actual=None)

Check test sets don't overlap if produces_non_overlapping_tests=True.

Parameters

Name Type Description Default
splitter BaseSplitter

Splitter instance

required
y DataFrame

Target time series with "time" column

required
X_actual DataFrame None

Raises

Type Description
AssertionError

If test sets overlap when they shouldn't

Source Code

Show/Hide source
def check_splitter_non_overlapping_tests(splitter, y: pl.DataFrame, X_actual: pl.DataFrame | None = None) -> None:
    """Check test sets don't overlap if produces_non_overlapping_tests=True.

    Parameters
    ----------
    splitter : BaseSplitter
        Splitter instance
    y : pl.DataFrame
        Target time series with "time" column
    X_actual : pl.DataFrame, optional
        Exogenous features

    Raises
    ------
    AssertionError
        If test sets overlap when they shouldn't

    """
    tags = splitter.__sklearn_tags__()
    if not tags.splitter_tags.produces_non_overlapping_tests:
        # Skip check if overlap is allowed
        return

    splits = list(splitter.split(y, X_actual))
    test_sets = [set(test_idx) for _, test_idx in splits]

    # Check all pairs for overlap
    for i in range(len(test_sets)):
        for j in range(i + 1, len(test_sets)):
            overlap = test_sets[i] & test_sets[j]
            assert len(overlap) == 0, (
                f"Splits {i} and {j} have overlapping test sets (indices: {sorted(overlap)}), "
                f"but produces_non_overlapping_tests=True"
            )