Skip to content

plot_seasonal_heatmap

yohou.plotting.diagnostics.plot_seasonal_heatmap(df, columns=None, *, x_period='hour', y_period='month', agg='mean', groups=None, facet_by='group', facet_n_cols=2, show_legend=True, title=None, x_label=None, y_label=None, width=None, height=None, colorscale='Viridis', show_values=True, value_format='.1f', reverse_y=False)

Plot a 2-D heatmap of aggregated values across two time dimensions.

Reveals seasonal patterns by showing aggregated values in a grid of two temporal periods (e.g. hour-of-day x month-of-year).

Parameters

Name Type Description Default
df DataFrame

Input DataFrame with 'time' column and the target columns.

required
columns str, list of str, or None

Numeric column(s) to aggregate and visualise. When None, all numeric columns (excluding "time") are used. For panel data this selects member postfixes within each group.

None
x_period str

Temporal period for the x-axis. Options: "hour", "day_of_week", "day_of_month", "week", "month", "quarter", "year".

"hour"
y_period str

Temporal period for the y-axis. Same options as x_period.

"month"
agg str

Aggregation function: "mean", "median", "sum", "count", "std", "min", or "max".

"mean"
groups list[str] | None

Panel group prefixes. When panel data is detected the columns are resolved as member postfixes within each group. When None and panel columns are present, auto-detects all groups.

None
facet_by Literal['group', 'member'] | None

Faceting axis for panel data. "group" creates one subplot per group, "member" one per member. None disables faceting. Ignored for non-panel data.

"group"
facet_n_cols int

Columns in the facet grid.

2
show_legend bool

Whether to show the legend.

True
title str | None

Plot title.

None
x_label str | None

X-axis label.

None
y_label str | None

Y-axis label.

None
width int | None

Plot width in pixels.

None
height int | None

Plot height in pixels.

None
colorscale str

Plotly colorscale for the heatmap.

"Viridis"
show_values bool

Whether to display values on the heatmap cells.

True
value_format str

Format string for heatmap cell values.

".1f"
reverse_y bool

Whether to reverse the y-axis.

False

Returns

Type Description
Figure

Plotly figure object.

Raises

Type Description
TypeError

If df is not a Polars DataFrame.

ValueError

If requested columns do not exist, or x_period/y_period are invalid.

Examples

>>> import polars as pl
>>> from yohou.plotting import plot_seasonal_heatmap
>>> df = pl.DataFrame({
...     "time": pl.datetime_range(
...         pl.datetime(2020, 1, 1),
...         pl.datetime(2020, 12, 31, 23),
...         "1h",
...         eager=True,
...     ),
...     "temp": [20.0 + (i % 24) * 0.5 for i in range(8784)],
... })
>>> fig = plot_seasonal_heatmap(df, "temp", x_period="hour", y_period="month")
>>> len(fig.data) > 0
True

See Also

plot_seasonality : Plot seasonal overlay. plot_subseasonality : Plot seasonal subseries.

Source Code

Show/Hide source
def plot_seasonal_heatmap(
    df: pl.DataFrame,
    columns: str | list[str] | None = None,
    *,
    x_period: str = "hour",
    y_period: str = "month",
    agg: str = "mean",
    groups: list[str] | None = None,
    facet_by: Literal["group", "member"] | None = "group",
    facet_n_cols: int = 2,
    show_legend: bool = True,
    title: str | None = None,
    x_label: str | None = None,
    y_label: str | None = None,
    width: int | None = None,
    height: int | None = None,
    colorscale: str = "Viridis",
    show_values: bool = True,
    value_format: str = ".1f",
    reverse_y: bool = False,
) -> go.Figure:
    """Plot a 2-D heatmap of aggregated values across two time dimensions.

    Reveals seasonal patterns by showing aggregated values in a grid
    of two temporal periods (e.g. hour-of-day x month-of-year).

    Parameters
    ----------
    df : pl.DataFrame
        Input DataFrame with 'time' column and the target *columns*.
    columns : str, list of str, or None, default=None
        Numeric column(s) to aggregate and visualise.  When ``None``,
        all numeric columns (excluding ``"time"``) are used.  For panel
        data this selects member postfixes within each group.
    x_period : str, default="hour"
        Temporal period for the x-axis.  Options: "hour", "day_of_week",
        "day_of_month", "week", "month", "quarter", "year".
    y_period : str, default="month"
        Temporal period for the y-axis.  Same options as *x_period*.
    agg : str, default="mean"
        Aggregation function: "mean", "median", "sum", "count",
        "std", "min", or "max".
    groups : list[str] | None, default=None
        Panel group prefixes.  When panel data is detected the *columns*
        are resolved as member postfixes within each group.  When
        ``None`` and panel columns are present, auto-detects all groups.
    facet_by : Literal["group", "member"] | None, default="group"
        Faceting axis for panel data.  ``"group"`` creates one subplot per
        group, ``"member"`` one per member.  ``None`` disables faceting.
        Ignored for non-panel data.
    facet_n_cols : int, default=2
        Columns in the facet grid.
    show_legend : bool, default=True
        Whether to show the legend.
    title : str | None, default=None
        Plot title.
    x_label : str | None, default=None
        X-axis label.
    y_label : str | None, default=None
        Y-axis label.
    width : int | None, default=None
        Plot width in pixels.
    height : int | None, default=None
        Plot height in pixels.
    colorscale : str, default="Viridis"
        Plotly colorscale for the heatmap.
    show_values : bool, default=True
        Whether to display values on the heatmap cells.
    value_format : str, default=".1f"
        Format string for heatmap cell values.
    reverse_y : bool, default=False
        Whether to reverse the y-axis.

    Returns
    -------
    go.Figure
        Plotly figure object.

    Raises
    ------
    TypeError
        If *df* is not a Polars DataFrame.
    ValueError
        If requested *columns* do not exist, or *x_period*/*y_period*
        are invalid.

    Examples
    --------
    >>> import polars as pl
    >>> from yohou.plotting import plot_seasonal_heatmap

    >>> df = pl.DataFrame({
    ...     "time": pl.datetime_range(
    ...         pl.datetime(2020, 1, 1),
    ...         pl.datetime(2020, 12, 31, 23),
    ...         "1h",
    ...         eager=True,
    ...     ),
    ...     "temp": [20.0 + (i % 24) * 0.5 for i in range(8784)],
    ... })
    >>> fig = plot_seasonal_heatmap(df, "temp", x_period="hour", y_period="month")
    >>> len(fig.data) > 0
    True

    See Also
    --------
    [`plot_seasonality`][yohou.plotting.plot_seasonality] : Plot seasonal overlay.
    [`plot_subseasonality`][yohou.plotting.plot_subseasonality] : Plot seasonal subseries.
    """
    # Validate
    validate_plotting_data(df)
    validate_plotting_params(width=width, height=height)

    # Map aggregation name to a Polars expression builder
    _AGG_MAP = {
        "mean": lambda c: pl.col(c).mean(),
        "median": lambda c: pl.col(c).median(),
        "sum": lambda c: pl.col(c).sum(),
        "count": lambda c: pl.col(c).count(),
        "std": lambda c: pl.col(c).std(),
        "min": lambda c: pl.col(c).min(),
        "max": lambda c: pl.col(c).max(),
    }
    if agg not in _AGG_MAP:
        msg = f"Unknown agg: {agg!r}. Valid options: {', '.join(sorted(_AGG_MAP))}"
        raise ValueError(msg)

    def _build_heatmap(
        frame: pl.DataFrame,
        target_col: str,
        display_name: str | None = None,
    ) -> go.Heatmap:
        """Build a single Heatmap trace from *frame*."""
        # Extract periods
        df_aug = frame.with_columns(
            _extract_period(pl.col("time"), x_period).alias("_x"),
            _extract_period(pl.col("time"), y_period).alias("_y"),
        )

        # Aggregate
        agg_expr = _AGG_MAP[agg](target_col).alias("_val")
        df_agg = df_aug.group_by("_y", "_x").agg(agg_expr).sort("_y", "_x")

        # Pivot to matrix
        pivot = df_agg.pivot(on="_x", index="_y", values="_val").sort("_y")
        x_cols = sorted([c for c in pivot.columns if c != "_y"], key=int)
        y_vals = pivot["_y"].to_list()
        z = pivot.select(x_cols).to_numpy()

        x_display = _resolve_axis_labels([int(c) for c in x_cols], x_period)
        y_display = _resolve_axis_labels([int(v) for v in y_vals], y_period)

        text_ann = None
        if show_values:
            text_ann = [[f"{v:{value_format}}" if v is not None and not np.isnan(v) else "" for v in row] for row in z]

        label = display_name or target_col
        return go.Heatmap(
            z=z,
            x=x_display,
            y=y_display,
            colorscale=colorscale,
            text=text_ann,
            texttemplate="%{text}" if show_values else None,
            hovertemplate=(
                f"<b>{label}</b><br>{x_period}: %{{x}}<br>{y_period}: %{{y}}<br>{agg}: %{{z:.2f}}<extra></extra>"
            ),
        )

    # Auto-detect panel data
    if groups is None and _auto_detect_panel(df):
        groups = []

    if groups is not None:
        from yohou.plotting._utils import _group_panel_columns  # noqa: PLC0415

        panel_cols = resolve_panel_columns(df, groups, columns)
        grouped, _all_members = _group_panel_columns(panel_cols)

        # Build flat list of (group_name, member_col, display_name) tuples
        all_cells: list[tuple[str, str, str]] = []
        for gname, gcols in grouped.items():
            for col in gcols:
                mname = _member_name(col)
                all_cells.append((gname, col, f"{gname}: {mname}"))

        n = len(all_cells)
        n_cols_grid = min(n, facet_n_cols)
        n_rows_grid = math.ceil(n / n_cols_grid)

        fig = make_subplots(
            rows=n_rows_grid,
            cols=n_cols_grid,
            subplot_titles=[cell[2] for cell in all_cells],
            vertical_spacing=_subplot_spacing(n_rows_grid),
            horizontal_spacing=_subplot_spacing(n_cols_grid) if n_cols_grid > 1 else 0.08,
        )

        for idx, (_gname, member_col, display_label) in enumerate(all_cells):
            r = idx // n_cols_grid + 1
            c = idx % n_cols_grid + 1
            mname = _member_name(member_col)
            sub_df = df.select("time", pl.col(member_col).alias(mname))
            trace = _build_heatmap(sub_df, mname, display_label)
            trace.showscale = idx == n - 1
            fig.add_trace(trace, row=r, col=c)

        _col_label = ", ".join(columns) if isinstance(columns, list) else (columns or "all")
        fig = apply_default_layout(
            fig,
            title=title or "Seasonal Heatmap",
            x_label=x_label or x_period.replace("_", " ").title(),
            y_label=y_label or y_period.replace("_", " ").title(),
            width=width,
            height=height,
        )
        if reverse_y:
            fig.update_yaxes(autorange="reversed")
        fig.update_layout(showlegend=show_legend)
        return fig

    # Non-panel case: column-mode facet_figure
    plot_columns = validate_plotting_data(df, columns=columns, exclude=["time"])
    _last_col = plot_columns[-1]

    def _render_heatmap(ctx: RenderContext) -> None:
        """Render a single heatmap trace into a subplot."""
        col_name = ctx.display_name
        trace = _build_heatmap(ctx.sub_df, col_name, display_name=col_name)
        trace.showscale = col_name == _last_col
        ctx.fig.add_trace(trace, row=ctx.row, col=ctx.col)

    fig = facet_figure(
        df,
        _render_heatmap,
        columns=plot_columns,
        facet_n_cols=facet_n_cols,
        title=title or "Seasonal Heatmap",
        x_label=x_label or x_period.replace("_", " ").title(),
        y_label=y_label or y_period.replace("_", " ").title(),
        width=width,
        height=height,
        shared_xaxes=False,
    )
    if reverse_y:
        fig.update_yaxes(autorange="reversed")
    fig.update_layout(showlegend=show_legend)
    return fig

Tutorials

The following example notebooks use this component:

  • Seasonal Analysis


    Visualization

    Seasonal overlays, subseasonal structure, ACF/PACF correlation patterns, and STL decomposition for monthly, quarterly, and long-cycle datasets.

    View · Open in marimo