Skip to content

resolve_weight_to_array

yohou.utils.weighting.resolve_weight_to_array(weight, key_series, join_column, group_name=None)

Resolve a weight specification to a raw numpy array.

Accepts three formats (callable, DataFrame, or dict) and returns an unnormalized weight array aligned to key_series. Validation is performed via validate_weight_array after resolution.

Parameters

Name Type Description Default
weight callable, pl.DataFrame, or dict

Weight specification in one of the supported formats.

required
key_series Series

Data series whose values are used for alignment (e.g., time values, forecasting steps, vintage times).

required
join_column str

Column name used for DataFrame joins ("time", "forecasting_step", or "vintage_time").

required
group_name str or None

Panel group name for 2-parameter callables and group-specific DataFrame columns. None for global (non-panel) data.

None

Returns

Type Description
ndarray

Raw (unnormalized) weight array of same length as key_series.

Raises

Type Description
ValueError

If weight is not callable, DataFrame, or dict, or if the resolved weights fail validation.

Source Code

Show/Hide source
def resolve_weight_to_array(
    weight: Callable | pl.DataFrame | dict,
    key_series: pl.Series,
    join_column: str,
    group_name: str | None = None,
) -> np.ndarray:
    """Resolve a weight specification to a raw numpy array.

    Accepts three formats (callable, DataFrame, or dict) and returns an
    unnormalized weight array aligned to ``key_series``.  Validation is
    performed via `validate_weight_array` after resolution.

    Parameters
    ----------
    weight : callable, pl.DataFrame, or dict
        Weight specification in one of the supported formats.
    key_series : polars.Series
        Data series whose values are used for alignment (e.g., time
        values, forecasting steps, vintage times).
    join_column : str
        Column name used for DataFrame joins (``"time"``,
        ``"forecasting_step"``, or ``"vintage_time"``).
    group_name : str or None, default=None
        Panel group name for 2-parameter callables and group-specific
        DataFrame columns.  ``None`` for global (non-panel) data.

    Returns
    -------
    numpy.ndarray
        Raw (unnormalized) weight array of same length as ``key_series``.

    Raises
    ------
    ValueError
        If ``weight`` is not callable, DataFrame, or dict, or if
        the resolved weights fail validation.

    """
    if isinstance(weight, dict):
        weights_np = resolve_dict_weights(weight, key_series.to_list())
    elif callable(weight):
        n_params = validate_callable_signature(weight)
        weights_series = weight(key_series) if n_params == 1 else weight(key_series, group_name)  # type: ignore[call-arg]  # ty: ignore[call-top-callable]

        if not isinstance(weights_series, pl.Series):
            raise ValueError(f"Weight callable must return pl.Series, got {type(weights_series).__name__}")
        if len(weights_series) != len(key_series):
            raise ValueError(f"Weight callable returned {len(weights_series)} weights, expected {len(key_series)} rows")
        weights_np = weights_series.to_numpy().astype(np.float64)
    elif isinstance(weight, pl.DataFrame):
        key_df = pl.DataFrame({join_column: key_series})
        joined = key_df.join(weight, on=join_column, how="left")

        weight_col = None
        if group_name is not None:
            group_col = f"{group_name}_weight"
            if group_col in joined.columns:
                weight_col = group_col
        if weight_col is None and "weight" in joined.columns:
            weight_col = "weight"
        if weight_col is None:
            if group_name is not None:
                raise ValueError(
                    f"Weight DataFrame missing both '{group_name}_weight' and "
                    f"'weight' columns for panel group '{group_name}'"
                )
            raise ValueError("Weight DataFrame must have 'weight' column")

        weights_np = joined[weight_col].to_numpy().astype(np.float64)

        # Check for NaN from unmatched keys (left join produces null for missing keys)
        nan_mask = np.isnan(weights_np)
        if nan_mask.any():
            missing_keys = key_series.filter(pl.Series(nan_mask)).unique().to_list()
            raise ValueError(f"Weight DataFrame has no values for {join_column}s: {missing_keys}")
    else:
        raise ValueError(f"Weight must be callable, pl.DataFrame, dict, or None, got {type(weight).__name__}")

    # Validate
    try:
        validate_weight_array(weights_np, name=f"{join_column} weight")
    except ValueError as exc:
        if "All weights are zero" in str(exc):
            # Re-raise with contextual info about which keys were requested
            if isinstance(weight, dict):
                requested = {k for k in weight if k != "*"}
                available = set(key_series.unique().to_list())
                raise ValueError(
                    f"All weights are zero for {join_column} weight. "
                    f"Requested keys {requested} vs available keys "
                    f"{available}"
                ) from exc
            raise
        raise

    return weights_np