Skip to content

get_group_df

yohou.utils.panel.get_group_df(df, group_name, schema, key_cols=('time',))

Extract and rename columns for a specific panel group.

Selects columns matching the group prefix pattern (__*), renames them to remove the prefix, and returns a DataFrame with key columns and the unprefixed columns. Also handles global columns (no prefix) that are shared across all groups.

Parameters

Name Type Description Default
df DataFrame

Input DataFrame with panel data columns. Must contain columns listed in key_cols.

required
group_name str

Group prefix to extract (e.g., "sales", "inventory"). Columns matching __* will be selected.

required
schema dict of str to pl.DataType

Schema mapping unprefixed column names to their data types. Used to determine which columns to extract. Can contain both local columns (will have group prefix in df) and global columns (no prefix in df). Example: {"store_1": pl.Int64, "store_2": pl.Int64, "holiday": pl.Boolean}

required
key_cols tuple of str

Index columns to preserve in the output (e.g. ("time",) for y/X_actual/X_future, ("vintage_time", "time") for X_forecast).

("time",)

Returns

to . Global columns keep their original names.

Type Description
DataFrame

DataFrame with "time" column and unprefixed group columns. Local columns are renamed from __

Examples

>>> import polars as pl
>>> df = pl.DataFrame({
...     "time": [1, 2, 3],
...     "sales__store_1": [100, 110, 120],
...     "sales__store_2": [150, 160, 170],
...     "holiday": [True, False, True],  # Global column
...     "inventory__store_1": [50, 55, 60],
... })
>>> # Schema includes both local and global columns
>>> schema = {"store_1": pl.Int64, "store_2": pl.Int64, "holiday": pl.Boolean}
>>> df_sales = get_group_df(df, "sales", schema)
>>> df_sales.columns
['time', 'store_1', 'store_2', 'holiday']
>>> df_sales.shape
(3, 4)

See Also

Notes

This function is used internally by forecasters to extract individual panel groups for processing, particularly in the context of the new architecture where schemas store unprefixed column names.

For X (feature) data, the schema typically combines local_X_actual_schema_ and shared_X_actual_schema_, allowing each group to access both its own features and shared features.

Source Code

Show/Hide source
def get_group_df(
    df: pl.DataFrame,
    group_name: str,
    schema: dict[str, pl.DataType],
    key_cols: tuple[str, ...] = ("time",),
) -> pl.DataFrame:
    """Extract and rename columns for a specific panel group.

    Selects columns matching the group prefix pattern (<group_name>__*),
    renames them to remove the prefix, and returns a DataFrame with key
    columns and the unprefixed columns. Also handles global columns (no
    prefix) that are shared across all groups.

    Parameters
    ----------
    df : pl.DataFrame
        Input DataFrame with panel data columns.
        Must contain columns listed in ``key_cols``.
    group_name : str
        Group prefix to extract (e.g., "sales", "inventory").
        Columns matching <group_name>__* will be selected.
    schema : dict of str to pl.DataType
        Schema mapping unprefixed column names to their data types.
        Used to determine which columns to extract.
        Can contain both local columns (will have group prefix in df) and
        global columns (no prefix in df).
        Example: {"store_1": pl.Int64, "store_2": pl.Int64, "holiday": pl.Boolean}
    key_cols : tuple of str, default=("time",)
        Index columns to preserve in the output (e.g. ``("time",)`` for
        y/X_actual/X_future, ``("vintage_time", "time")`` for X_forecast).

    Returns
    -------
    pl.DataFrame
        DataFrame with "time" column and unprefixed group columns.
        Local columns are renamed from <group_name>__<col> to <col>.
        Global columns keep their original names.

    Examples
    --------
    >>> import polars as pl
    >>> df = pl.DataFrame({
    ...     "time": [1, 2, 3],
    ...     "sales__store_1": [100, 110, 120],
    ...     "sales__store_2": [150, 160, 170],
    ...     "holiday": [True, False, True],  # Global column
    ...     "inventory__store_1": [50, 55, 60],
    ... })
    >>> # Schema includes both local and global columns
    >>> schema = {"store_1": pl.Int64, "store_2": pl.Int64, "holiday": pl.Boolean}
    >>> df_sales = get_group_df(df, "sales", schema)
    >>> df_sales.columns
    ['time', 'store_1', 'store_2', 'holiday']
    >>> df_sales.shape
    (3, 4)

    See Also
    --------
    - [`inspect_panel`][yohou.utils.panel.inspect_panel] : Inspect DataFrame to identify global and local columns
    - [`select_panel_columns`][yohou.utils.panel.select_panel_columns] : Filter DataFrame to panel group columns and global columns

    Notes
    -----
    This function is used internally by forecasters to extract individual
    panel groups for processing, particularly in the context of the new
    architecture where schemas store unprefixed column names.

    For X (feature) data, the schema typically combines local_X_actual_schema_ and
    shared_X_actual_schema_, allowing each group to access both its own features
    and shared features.
    """
    # Separate local (prefixed) and global (unprefixed) columns
    local_cols = []
    global_cols = []
    rename_map = {}

    for col_name in schema:
        prefixed_col = f"{group_name}__{col_name}"
        if prefixed_col in df.columns:
            # Local column (has group prefix)
            local_cols.append(prefixed_col)
            rename_map[prefixed_col] = col_name
        elif col_name in df.columns:
            # Global column (no prefix)
            global_cols.append(col_name)
        else:
            # Column not found
            raise ValueError(
                f"Column '{col_name}' not found as either '{prefixed_col}' (local) "
                f"or '{col_name}' (global) in DataFrame. "
                f"Available columns: {df.columns}"
            )

    # Select key + local + global columns
    df_group = df.select(list(key_cols) + local_cols + global_cols)

    # Rename only local columns to remove prefix (global columns keep their names)
    if rename_map:
        df_group = df_group.rename(rename_map)

    return df_group