Skip to content

plot_cross_correlation

yohou.plotting.diagnostics.plot_cross_correlation(df, *, columns=None, max_lags=None, confidence_level=0.95, groups=None, facet_n_cols=2, color_palette=None, show_legend=True, title=None, x_label=None, y_label=None, width=None, height=None)

Plot cross-correlation function (CCF) between time series pairs.

Computes and visualizes the cross-correlation between pairs of series at various lags, useful for identifying lead-lag relationships and temporal dependencies.

When columns is None (or has more than 2 entries), all unique upper-triangle pairs are computed and arranged in a subplot grid. When exactly two columns are given, a single CCF plot is produced (matching legacy behaviour).

Parameters

Name Type Description Default
df DataFrame

Input DataFrame with 'time' column and numeric columns.

required
columns list[str] | None

Column names to cross-correlate. When None, all numeric columns are used and every unique upper-triangle pair is plotted. When exactly two columns are given, a single CCF plot is produced.

None
max_lags int | None

Number of lags to compute (both positive and negative). If None, uses min(len(df) // 2, 40).

None
confidence_level float

Confidence level for confidence bands (e.g. 0.95 for 95%).

0.95
groups list[str] | None

Panel group prefixes. When set (or auto-detected), CCF is computed between members within each group using an upper-triangle matrix layout.

None
facet_n_cols int

Number of columns in the subplot grid.

2
color_palette list[str] | None

Custom color palette for pair traces.

None
show_legend bool

Whether to show the legend.

True
title str | None

Plot title.

None
x_label str | None

X-axis label. Defaults to "Lag".

None
y_label str | None

Y-axis label. Defaults to "Cross-Correlation".

None
width int | None

Plot width in pixels.

None
height int | None

Plot height in pixels.

None

Returns

Type Description
Figure

Plotly figure object.

Examples

>>> import polars as pl
>>> from yohou.plotting import plot_cross_correlation
>>> # Create two time series with lag relationship
>>> df = pl.DataFrame({
...     "time": pl.date_range(pl.date(2020, 1, 1), pl.date(2020, 3, 31), "1d", eager=True),
...     "x": [100 + i for i in range(91)],
...     "y": [105 + i for i in range(91)],  # y leads x by 5 units
... })
>>> # Plot cross-correlation
>>> fig = plot_cross_correlation(df, columns=["x", "y"], max_lags=20)
>>> len(fig.data) > 0
True

See Also

plot_autocorrelation : Plot autocorrelation function.

Source Code

Show/Hide source
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
def plot_cross_correlation(
    df: pl.DataFrame,
    *,
    columns: list[str] | None = None,
    max_lags: int | None = None,
    confidence_level: float = 0.95,
    groups: list[str] | None = None,
    facet_n_cols: int = 2,
    color_palette: list[str] | None = None,
    show_legend: bool = True,
    title: str | None = None,
    x_label: str | None = None,
    y_label: str | None = None,
    width: int | None = None,
    height: int | None = None,
) -> go.Figure:
    """Plot cross-correlation function (CCF) between time series pairs.

    Computes and visualizes the cross-correlation between pairs of series at
    various lags, useful for identifying lead-lag relationships and temporal
    dependencies.

    When ``columns`` is ``None`` (or has more than 2 entries), all unique
    upper-triangle pairs are computed and arranged in a subplot grid.  When
    exactly two columns are given, a single CCF plot is produced (matching
    legacy behaviour).

    Parameters
    ----------
    df : pl.DataFrame
        Input DataFrame with 'time' column and numeric columns.
    columns : list[str] | None, default=None
        Column names to cross-correlate.  When ``None``, all numeric
        columns are used and every unique upper-triangle pair is plotted.
        When exactly two columns are given, a single CCF plot is produced.
    max_lags : int | None, default=None
        Number of lags to compute (both positive and negative).
        If None, uses ``min(len(df) // 2, 40)``.
    confidence_level : float, default=0.95
        Confidence level for confidence bands (e.g. ``0.95`` for 95%).
    groups : list[str] | None, default=None
        Panel group prefixes.  When set (or auto-detected), CCF is
        computed between members within each group using an
        upper-triangle matrix layout.
    facet_n_cols : int, default=2
        Number of columns in the subplot grid.
    color_palette : list[str] | None, default=None
        Custom color palette for pair traces.
    show_legend : bool, default=True
        Whether to show the legend.
    title : str | None, default=None
        Plot title.
    x_label : str | None, default=None
        X-axis label. Defaults to ``"Lag"``.
    y_label : str | None, default=None
        Y-axis label. Defaults to ``"Cross-Correlation"``.
    width : int | None, default=None
        Plot width in pixels.
    height : int | None, default=None
        Plot height in pixels.

    Returns
    -------
    go.Figure
        Plotly figure object.

    Examples
    --------
    >>> import polars as pl
    >>> from yohou.plotting import plot_cross_correlation

    >>> # Create two time series with lag relationship
    >>> df = pl.DataFrame({
    ...     "time": pl.date_range(pl.date(2020, 1, 1), pl.date(2020, 3, 31), "1d", eager=True),
    ...     "x": [100 + i for i in range(91)],
    ...     "y": [105 + i for i in range(91)],  # y leads x by 5 units
    ... })

    >>> # Plot cross-correlation
    >>> fig = plot_cross_correlation(df, columns=["x", "y"], max_lags=20)
    >>> len(fig.data) > 0
    True

    See Also
    --------
    [`plot_autocorrelation`][yohou.plotting.plot_autocorrelation] : Plot autocorrelation function.
    """
    # Validate inputs
    validate_plotting_data(df)
    validate_plotting_params(width=width, height=height)

    if max_lags is None:
        max_lags = min(len(df) // 2, 40)

    _lag_vals = list(range(-max_lags, max_lags + 1))

    def _add_ccf_bar(
        fig: go.Figure,
        lag_vals: list[int],
        ccf: list[float],
        pair_color: str,
        pair_label: str,
        show_in_legend: bool,
        *,
        row: int | None = None,
        col: int | None = None,
    ) -> None:
        """Add a single ACF-style bar trace."""
        trace_kwargs: dict = {}
        if row is not None and col is not None:
            trace_kwargs["row"] = row
            trace_kwargs["col"] = col
        fig.add_trace(
            go.Bar(
                x=lag_vals,
                y=ccf,
                marker={"color": pair_color},
                name=pair_label,
                showlegend=show_in_legend,
                legendgroup=pair_label,
                hovertemplate=_make_hovertemplate(pair_label, "Lag", "CCF", decimals=3),
            ),
            **trace_kwargs,
        )

    def _add_ccf_bands(
        fig: go.Figure,
        n_pts: int,
        *,
        row: int | None = None,
        col: int | None = None,
    ) -> None:
        """Add confidence bands and zero line."""
        hline_kw: dict = {}
        if row is not None and col is not None:
            hline_kw["row"] = row
            hline_kw["col"] = col
        if confidence_level > 0:
            cb = norm.ppf(1 - (1 - confidence_level) / 2) / math.sqrt(n_pts)
            fig.add_hline(y=cb, line={"dash": "dash", "color": "#DC2626", "width": 1}, **hline_kw)
            fig.add_hline(y=-cb, line={"dash": "dash", "color": "#DC2626", "width": 1}, **hline_kw)
        fig.add_hline(y=0, line={"color": "#64748B", "width": 1}, **hline_kw)

    # Auto-detect panel data
    if groups is None and _auto_detect_panel(df) and columns is None:
        groups = []

    if groups is not None:
        _panel_cols = resolve_panel_columns(df, groups, columns)
        grouped, _all_members = _group_panel_columns(_panel_cols)
        color_mgr = PanelColorManager(color_palette)

        # Build per-group pair lists
        group_pairs: list[tuple[str, list[tuple[str, str]]]] = []
        for gname, gcols in grouped.items():
            member_names = [_member_name(c) for c in gcols]
            pairs = _upper_triangle_pairs(member_names)
            if pairs:
                group_pairs.append((gname, pairs))

        if not group_pairs:
            msg = f"Need at least 2 members per group for cross-correlation. Groups: {list(grouped.keys())}"
            raise ValueError(msg)

        # All groups + pairs in one figure
        all_cells: list[tuple[str, str, str]] = []
        for gname, pairs in group_pairs:
            for a, b in pairs:
                all_cells.append((gname, a, b))

        n = len(all_cells)
        n_cols_grid = min(n, facet_n_cols)
        n_rows_grid = math.ceil(n / n_cols_grid)
        legend_tracker = LegendTracker(show_legend=show_legend)

        fig = make_subplots(
            rows=n_rows_grid,
            cols=n_cols_grid,
            subplot_titles=[f"{gn}: {a} vs {b}" for gn, a, b in all_cells],
            vertical_spacing=_subplot_spacing(n_rows_grid),
            horizontal_spacing=_subplot_spacing(n_cols_grid) if n_cols_grid > 1 else 0.08,
        )

        for idx, (gname, a_name, b_name) in enumerate(all_cells):
            r = idx // n_cols_grid + 1
            c = idx % n_cols_grid + 1
            col_a = f"{gname}__{a_name}"
            col_b = f"{gname}__{b_name}"
            clean = df.select(col_a, col_b).drop_nulls()
            x_arr = clean[col_a].to_numpy()
            y_arr = clean[col_b].to_numpy()
            ccf = _compute_ccf(x_arr, y_arr, max_lags)
            pair_label = f"{a_name} vs {b_name}"
            pair_color = color_mgr.get_color(pair_label)
            _add_ccf_bar(
                fig,
                _lag_vals,
                ccf,
                pair_color,
                pair_label,
                legend_tracker.should_show(pair_label),
                row=r,
                col=c,
            )
            _add_ccf_bands(fig, len(x_arr), row=r, col=c)

        fig = apply_default_layout(
            fig,
            title=title or "Cross-Correlation",
            x_label=x_label or "Lag",
            y_label=y_label or "Cross-Correlation",
            width=width,
            height=height,
        )
        fig.update_layout(showlegend=show_legend)
        return fig

    plot_columns = validate_plotting_data(df, columns=columns, exclude=["time"])

    if len(plot_columns) < 2:  # noqa: PLR2004
        msg = f"Cross-correlation requires at least 2 columns, got {len(plot_columns)}: {plot_columns}"
        raise ValueError(msg)

    pairs = _upper_triangle_pairs(plot_columns)
    color_mgr = PanelColorManager(color_palette)
    legend_tracker = LegendTracker()

    if len(pairs) == 1:
        # Single pair - flat figure
        x_column, y_column = pairs[0]
        clean = df.select(x_column, y_column).drop_nulls()
        x_arr = clean[x_column].to_numpy()
        y_arr = clean[y_column].to_numpy()
        n_pts = len(x_arr)
        ccf = _compute_ccf(x_arr, y_arr, max_lags)
        pair_label = f"{x_column} vs {y_column}"
        pair_color = color_mgr.get_color(pair_label)

        fig = go.Figure()
        _add_ccf_bar(fig, _lag_vals, ccf, pair_color, pair_label, True)

        if confidence_level > 0:
            cb = norm.ppf(1 - (1 - confidence_level) / 2) / math.sqrt(n_pts)
            ci_pct = f"{confidence_level:.0%}"
            fig.add_hline(
                y=cb,
                line={"dash": "dash", "color": "#DC2626", "width": 1},
                annotation_text=f"{ci_pct} CI",
                annotation_position="right",
            )
            fig.add_hline(y=-cb, line={"dash": "dash", "color": "#DC2626", "width": 1})
        fig.add_hline(y=0, line={"color": "#64748B", "width": 1})

        fig = apply_default_layout(
            fig,
            title=title or "Cross-Correlation",
            x_label=x_label or "Lag",
            y_label=y_label or "Cross-Correlation",
            width=width,
            height=height,
        )
        fig.update_layout(showlegend=show_legend)
        return fig

    # Multiple pairs - upper-triangle subplot grid
    n = len(pairs)
    n_cols_grid = min(n, facet_n_cols)
    n_rows_grid = math.ceil(n / n_cols_grid)

    fig = make_subplots(
        rows=n_rows_grid,
        cols=n_cols_grid,
        subplot_titles=[f"{a} vs {b}" for a, b in pairs],
        vertical_spacing=_subplot_spacing(n_rows_grid),
        horizontal_spacing=_subplot_spacing(n_cols_grid) if n_cols_grid > 1 else 0.08,
    )

    for idx, (col_a, col_b) in enumerate(pairs):
        r = idx // n_cols_grid + 1
        c = idx % n_cols_grid + 1
        clean = df.select(col_a, col_b).drop_nulls()
        x_arr = clean[col_a].to_numpy()
        y_arr = clean[col_b].to_numpy()
        ccf = _compute_ccf(x_arr, y_arr, max_lags)
        pair_label = f"{col_a} vs {col_b}"
        pair_color = color_mgr.get_color(pair_label)
        _add_ccf_bar(
            fig,
            _lag_vals,
            ccf,
            pair_color,
            pair_label,
            legend_tracker.should_show(pair_label),
            row=r,
            col=c,
        )
        _add_ccf_bands(fig, len(x_arr), row=r, col=c)

    fig = apply_default_layout(
        fig,
        title=title or "Cross-Correlation",
        x_label=x_label or "Lag",
        y_label=y_label or "Cross-Correlation",
        width=width,
        height=height,
    )
    fig.update_layout(showlegend=show_legend)
    return fig

Tutorials

The following example notebooks use this component:

  • How to Visualize Correlations


    Visualization

    Pairwise correlation heatmaps, scatter matrices, cross-correlation at multiple lags, and lag scatter plots for multivariate time series diagnostics.

    View · Open in marimo