Skip to content

check_schema

yohou.utils.validation.check_schema(df, expected_schema, groups=None)

Validate DataFrame schema and return with proper column ordering.

Ensures that data has the same column names and dtypes as expected, and returns the DataFrame with columns in the correct order (time column first, followed by schema columns in order).

Parameters

Name Type Description Default
df DataFrame

DataFrame to validate (should include "time" column).

required
expected_schema dict[str, DataType]

Expected schema for non-time columns. For panel data, this should contain unprefixed column names.

required
groups list[str] or None

Group prefixes for panel data. If provided, constructs expected schema with prefixes (e.g., "panel__series_0"). None for global data.

None

Returns

Type Description
DataFrame

DataFrame with columns in proper order: ["time"] + schema columns.

Raises

Type Description
ValueError

If incoming schema doesn't match expected schema.

Examples

>>> import polars as pl
>>> # Non-panel data validation
>>> df = pl.DataFrame({"value": [10, 20], "time": [1, 2]})
>>> expected_schema = {"value": pl.Int64}
>>> result = check_schema(df, expected_schema)
>>> list(result.columns)
['time', 'value']
>>> # Schema mismatch raises error
>>> df_wrong = pl.DataFrame({"time": [1, 2], "value": [10.0, 20.0]})  # Float64
>>> check_schema(df_wrong, expected_schema)
Traceback (most recent call last):
    ...
ValueError: Schema mismatch. Expected: {'value': Int64}, got: {'value': Float64}
>>> # Panel data validation (constructs prefixed schema automatically)
>>> df_panel = pl.DataFrame({"panel__s1": [15, 25], "time": [1, 2], "panel__s0": [10, 20]})
>>> expected_schema = {"s0": pl.Int64, "s1": pl.Int64}
>>> result = check_schema(df_panel, expected_schema, groups=["panel"])
>>> list(result.columns)
['time', 'panel__s0', 'panel__s1']

See Also

Notes

For panel data, this function automatically constructs the expected schema with prefixes (e.g., "sales__store_1") from the unprefixed expected_schema. The returned DataFrame has columns ordered consistently with the schema.

Source Code

Show/Hide source
def check_schema(
    df: pl.DataFrame,
    expected_schema: dict[str, pl.DataType],
    groups: list[str] | None = None,
) -> pl.DataFrame:
    """Validate DataFrame schema and return with proper column ordering.

    Ensures that data has the same column names and dtypes as expected,
    and returns the DataFrame with columns in the correct order (time column first,
    followed by schema columns in order).

    Parameters
    ----------
    df : pl.DataFrame
        DataFrame to validate (should include "time" column).
    expected_schema : dict[str, pl.DataType]
        Expected schema for non-time columns.
        For panel data, this should contain unprefixed column names.
    groups : list[str] or None, default=None
        Group prefixes for panel data. If provided, constructs expected
        schema with prefixes (e.g., "panel__series_0"). None for global data.

    Returns
    -------
    pl.DataFrame
        DataFrame with columns in proper order: ["time"] + schema columns.

    Raises
    ------
    ValueError
        If incoming schema doesn't match expected schema.

    Examples
    --------
    >>> import polars as pl
    >>> # Non-panel data validation
    >>> df = pl.DataFrame({"value": [10, 20], "time": [1, 2]})
    >>> expected_schema = {"value": pl.Int64}
    >>> result = check_schema(df, expected_schema)
    >>> list(result.columns)
    ['time', 'value']

    >>> # Schema mismatch raises error
    >>> df_wrong = pl.DataFrame({"time": [1, 2], "value": [10.0, 20.0]})  # Float64
    >>> check_schema(df_wrong, expected_schema)  # doctest: +SKIP
    Traceback (most recent call last):
        ...
    ValueError: Schema mismatch. Expected: {'value': Int64}, got: {'value': Float64}

    >>> # Panel data validation (constructs prefixed schema automatically)
    >>> df_panel = pl.DataFrame({"panel__s1": [15, 25], "time": [1, 2], "panel__s0": [10, 20]})
    >>> expected_schema = {"s0": pl.Int64, "s1": pl.Int64}
    >>> result = check_schema(df_panel, expected_schema, groups=["panel"])
    >>> list(result.columns)
    ['time', 'panel__s0', 'panel__s1']

    See Also
    --------
    - [`check_inputs`][yohou.utils.validation.check_inputs] : Validates time intervals
    - [`BaseForecaster`][yohou.base.forecaster.BaseForecaster] : Uses this function to validate incoming data

    Notes
    -----
    For panel data, this function automatically constructs the expected schema
    with prefixes (e.g., "sales__store_1") from the unprefixed expected_schema.
    The returned DataFrame has columns ordered consistently with the schema.

    """
    # Construct expected column list based on groups
    if groups is None:
        # Non-panel data: use schema as-is
        expected_columns = ["time"] + list(expected_schema.keys())
        expected_full_schema = expected_schema
    else:
        # Panel data: construct prefixed schema
        expected_columns = ["time"]
        expected_full_schema = {}
        for group_name in groups:
            for col, dtype in expected_schema.items():
                prefixed_col = f"{group_name}__{col}"
                expected_columns.append(prefixed_col)
                expected_full_schema[prefixed_col] = dtype

    # Select columns in proper order (also validates presence)
    df = df.select(expected_columns)

    # Extract actual schema (excluding time column) for validation
    incoming_schema = dict(df.select(~cs.by_name("time")).schema)

    # Validate dtypes
    if incoming_schema != expected_full_schema:
        raise ValueError(f"Schema mismatch. Expected: {expected_full_schema}, got: {incoming_schema}")

    return df