Skip to content

plot_correlation_heatmap

yohou.plotting.diagnostics.plot_correlation_heatmap(df, *, columns=None, groups=None, facet_by='group', show_legend=True, title=None, x_label=None, y_label=None, width=None, height=None, colorscale='RdBu_r', show_values=True)

Plot correlation matrix heatmap for multiple time series.

Shows pairwise correlations between different time series columns, useful for understanding relationships and multicollinearity.

Parameters

Name Type Description Default
df DataFrame

Input DataFrame with 'time' column and numeric columns.

required
columns str | list[str] | None

Column(s) to include. If None, uses all numeric columns except 'time'.

None
groups list[str] | None

Panel group prefixes to plot. When panel data is detected and this is None, all groups are included.

None
facet_by Literal['group', 'member'] | None

Faceting axis for panel data. "group" returns one figure per group (members as axis labels), "member" returns one figure per member (groups as axis labels). None returns a single figure correlating all panel columns. Ignored for non-panel data.

"group"
show_legend bool

Whether to show the legend.

True
title str | None

Plot title.

None
x_label str | None

X-axis label.

None
y_label str | None

Y-axis label.

None
width int | None

Plot width in pixels.

None
height int | None

Plot height in pixels.

None
colorscale str

Plotly colorscale for the heatmap.

"RdBu_r"
show_values bool

Whether to display correlation values on the heatmap cells.

True

Returns

Type Description
Figure | dict[str, Figure]

Single figure when non-panel or facet_by=None. Dict of figures keyed by group or member name otherwise.

Examples

>>> import polars as pl
>>> from yohou.plotting import plot_correlation_heatmap
>>> # Create sample data with multiple series
>>> df = pl.DataFrame({
...     "time": pl.date_range(pl.date(2020, 1, 1), pl.date(2020, 1, 10), "1d", eager=True),
...     "y1": [10, 20, 30, 40, 50, 60, 70, 80, 90, 100],
...     "y2": [15, 25, 35, 45, 55, 65, 75, 85, 95, 105],
... })
>>> # Plot correlation matrix
>>> fig = plot_correlation_heatmap(df)
>>> len(fig.data) > 0
True

See Also

plot_autocorrelation : Plot autocorrelation function.

Source Code

Show/Hide source
def plot_correlation_heatmap(
    df: pl.DataFrame,
    *,
    columns: str | list[str] | None = None,
    groups: list[str] | None = None,
    facet_by: Literal["group", "member"] | None = "group",
    show_legend: bool = True,
    title: str | None = None,
    x_label: str | None = None,
    y_label: str | None = None,
    width: int | None = None,
    height: int | None = None,
    colorscale: str = "RdBu_r",
    show_values: bool = True,
) -> go.Figure | dict[str, go.Figure]:
    """Plot correlation matrix heatmap for multiple time series.

    Shows pairwise correlations between different time series columns,
    useful for understanding relationships and multicollinearity.

    Parameters
    ----------
    df : pl.DataFrame
        Input DataFrame with 'time' column and numeric columns.
    columns : str | list[str] | None, default=None
        Column(s) to include. If None, uses all numeric columns except 'time'.
    groups : list[str] | None, default=None
        Panel group prefixes to plot.  When panel data is detected
        and this is ``None``, all groups are included.
    facet_by : Literal["group", "member"] | None, default="group"
        Faceting axis for panel data.  ``"group"`` returns one figure
        per group (members as axis labels), ``"member"`` returns one
        figure per member (groups as axis labels).  ``None`` returns a
        single figure correlating all panel columns.
        Ignored for non-panel data.
    show_legend : bool, default=True
        Whether to show the legend.
    title : str | None, default=None
        Plot title.
    x_label : str | None, default=None
        X-axis label.
    y_label : str | None, default=None
        Y-axis label.
    width : int | None, default=None
        Plot width in pixels.
    height : int | None, default=None
        Plot height in pixels.
    colorscale : str, default="RdBu_r"
        Plotly colorscale for the heatmap.
    show_values : bool, default=True
        Whether to display correlation values on the heatmap cells.

    Returns
    -------
    go.Figure | dict[str, go.Figure]
        Single figure when non-panel or ``facet_by=None``.
        Dict of figures keyed by group or member name otherwise.

    Examples
    --------
    >>> import polars as pl
    >>> from yohou.plotting import plot_correlation_heatmap

    >>> # Create sample data with multiple series
    >>> df = pl.DataFrame({
    ...     "time": pl.date_range(pl.date(2020, 1, 1), pl.date(2020, 1, 10), "1d", eager=True),
    ...     "y1": [10, 20, 30, 40, 50, 60, 70, 80, 90, 100],
    ...     "y2": [15, 25, 35, 45, 55, 65, 75, 85, 95, 105],
    ... })

    >>> # Plot correlation matrix
    >>> fig = plot_correlation_heatmap(df)
    >>> len(fig.data) > 0
    True

    See Also
    --------
    [`plot_autocorrelation`][yohou.plotting.plot_autocorrelation] : Plot autocorrelation function.
    """
    # Validate inputs
    validate_plotting_data(df)
    validate_plotting_params(width=width, height=height)

    # Auto-detect panel data
    _, panel_groups = inspect_panel(df)
    if groups is None and columns is None and panel_groups:
        groups = []

    if groups is not None:
        from yohou.plotting._utils import _group_panel_columns  # noqa: PLC0415

        # Normalize columns to list for member filtering
        col_filter = [columns] if isinstance(columns, str) else columns

        # Filter to requested groups and optionally by member postfix
        group_map: dict[str, list[str]] = {}
        for g, gcols in panel_groups.items():
            if not groups or g in groups:
                filtered = [c for c in gcols if c.split("__", 1)[1] in col_filter] if col_filter is not None else gcols
                if filtered:
                    group_map[g] = filtered
        if not group_map:
            msg = f"No panel groups found for {groups}. Available groups: {list(panel_groups.keys())}"
            raise ValueError(msg)

        all_panel_cols = [c for gcols in group_map.values() for c in gcols]
        _, all_members = _group_panel_columns(all_panel_cols)

        _layout_kwargs = {
            "colorscale": colorscale,
            "show_values": show_values,
            "show_legend": show_legend,
            "title": title,
            "x_label": x_label,
            "y_label": y_label,
            "width": width,
            "height": height,
        }

        if facet_by is None:
            return _correlation_heatmap_single(df, all_panel_cols, **_layout_kwargs)

        if facet_by == "group":
            return _correlation_heatmap_separate(df, group_map, **_layout_kwargs)

        # facet_by == "member"
        member_groups: dict[str, list[str]] = {}
        for member in all_members:
            member_cols = []
            for _gname, gcols in group_map.items():
                col = next((c for c in gcols if _member_name(c) == member), None)
                if col:
                    member_cols.append(col)
            if member_cols:
                member_groups[member] = member_cols
        return _correlation_heatmap_by_member(df, member_groups, **_layout_kwargs)

    # Non-panel path
    plot_columns = validate_plotting_data(df, columns=columns, exclude=["time"])
    return _correlation_heatmap_single(
        df,
        plot_columns,
        colorscale=colorscale,
        show_values=show_values,
        show_legend=show_legend,
        title=title,
        x_label=x_label,
        y_label=y_label,
        width=width,
        height=height,
    )

Tutorials

The following example notebooks use this component:

  • How to Visualize Correlations


    Visualization

    Pairwise correlation heatmaps, scatter matrices, cross-correlation at multiple lags, and lag scatter plots for multivariate time series diagnostics.

    View · Open in marimo