Skip to content

validate_plotting_data

yohou.utils.validate_data.validate_plotting_data(df, *, columns=None, groups=None, min_rows=1, exclude=None, include_categorical=False)

Validate a DataFrame for plotting and resolve columns.

Combines type checking, structural validation, and column resolution into a single entry point for all plotting functions.

Parameters

Name Type Description Default
df DataFrame

Input DataFrame to validate.

required
columns str, list of str, or None

Column specification. When groups is None this selects standard columns (None means all numeric). For panel data this selects member postfixes within the requested groups.

None
groups list of str or None

When not None, resolve panel columns (group__member pattern) instead of plain columns.

None
min_rows int

Minimum number of rows required.

1
exclude list of str or None

Column names to exclude when columns=None and groups is None.

None
include_categorical bool

When True and columns=None, also include pl.String, pl.Categorical, and pl.Enum columns alongside numeric columns.

False

Returns

Type Description
list[str]

Resolved column names to plot.

Raises

Type Description
TypeError

If df is not a pl.DataFrame.

ValueError

If the DataFrame is empty, missing a "time" column, or the requested columns do not exist.

Examples

>>> import polars as pl
>>> from yohou.utils.validate_data import validate_plotting_data
>>> df = pl.DataFrame({"time": [1, 2, 3], "y": [10.0, 20.0, 30.0]})
>>> validate_plotting_data(df, exclude=["time"])
['y']
>>> df_panel = pl.DataFrame({
...     "time": [1, 2],
...     "sales__a": [10, 20],
...     "sales__b": [30, 40],
... })
>>> validate_plotting_data(df_panel, groups=["sales"], columns="a")
['sales__a']

See Also

Source Code

Show/Hide source
def validate_plotting_data(
    df: pl.DataFrame,
    *,
    columns: str | list[str] | None = None,
    groups: list[str] | None = None,
    min_rows: int = 1,
    exclude: list[str] | None = None,
    include_categorical: bool = False,
) -> list[str]:
    """Validate a DataFrame for plotting and resolve columns.

    Combines type checking, structural validation, and column resolution
    into a single entry point for all plotting functions.

    Parameters
    ----------
    df : pl.DataFrame
        Input DataFrame to validate.
    columns : str, list of str, or None, default=None
        Column specification.  When ``groups`` is ``None`` this
        selects standard columns (``None`` means all numeric).  For panel
        data this selects *member postfixes* within the requested groups.
    groups : list of str or None, default=None
        When not ``None``, resolve panel columns
        (``group__member`` pattern) instead of plain columns.
    min_rows : int, default=1
        Minimum number of rows required.
    exclude : list of str or None, default=None
        Column names to exclude when ``columns=None`` and
        ``groups`` is ``None``.
    include_categorical : bool, default=False
        When ``True`` and ``columns=None``, also include
        ``pl.String``, ``pl.Categorical``, and ``pl.Enum`` columns
        alongside numeric columns.

    Returns
    -------
    list[str]
        Resolved column names to plot.

    Raises
    ------
    TypeError
        If *df* is not a ``pl.DataFrame``.
    ValueError
        If the DataFrame is empty, missing a ``"time"`` column, or the
        requested columns do not exist.

    Examples
    --------
    >>> import polars as pl
    >>> from yohou.utils.validate_data import validate_plotting_data
    >>> df = pl.DataFrame({"time": [1, 2, 3], "y": [10.0, 20.0, 30.0]})
    >>> validate_plotting_data(df, exclude=["time"])
    ['y']

    >>> df_panel = pl.DataFrame({
    ...     "time": [1, 2],
    ...     "sales__a": [10, 20],
    ...     "sales__b": [30, 40],
    ... })
    >>> validate_plotting_data(df_panel, groups=["sales"], columns="a")
    ['sales__a']

    See Also
    --------
    - [`validate_plotting_params`][yohou.utils.validate_data.validate_plotting_params] : Validate common plotting parameters.
    - [`get_numeric_columns`][yohou.utils.polars.get_numeric_columns] : Resolve numeric columns from a DataFrame.
    - [`inspect_panel`][yohou.utils.panel.inspect_panel] : Detect panel group structure.

    """
    if not isinstance(df, pl.DataFrame):
        msg = f"Expected pl.DataFrame, got {type(df).__name__}"
        raise TypeError(msg)

    check_sufficient_rows(df, min_rows=min_rows, context="for plotting")

    # Non-datetime time columns (e.g. integer indices) only need existence check
    if "time" in df.columns and not isinstance(df["time"].dtype, pl.Datetime | pl.Date):
        pass  # integer time is acceptable for plotting
    elif "vintage_time" in df.columns:
        _check_multi_vintage_time(df)
    else:
        check_time_column(df)

    # Panel column resolution
    if groups is not None:
        if isinstance(columns, str):
            columns = [columns]

        _, panels = inspect_panel(df)
        cols: list[str] = []
        for prefix, members in panels.items():
            if prefix not in groups:
                continue
            if columns is not None:
                for member in members:
                    _, _, postfix = member.partition("__")
                    if postfix in columns:
                        cols.append(member)
            else:
                cols.extend(members)

        if not cols:
            if columns is not None:
                msg = f"No panel columns found for groups={groups} with members={columns}"
            else:
                msg = f"No panel columns found for groups: {groups}"
            raise ValueError(msg)
        return cols

    # Standard column resolution
    if columns is None:
        result = get_numeric_columns(df, exclude=exclude)
        if include_categorical:
            _excl = set(exclude or [])
            result += [
                c for c in df.columns if is_categorical_dtype(df[c].dtype) and c not in _excl and c not in result
            ]
        return result

    if isinstance(columns, str):
        columns = [columns]

    missing = [col for col in columns if col not in df.columns]
    if missing:
        msg = f"Columns not found in DataFrame: {missing}"
        raise ValueError(msg)

    return columns