Source code for ggpubpy.shiftplot

"""
Shift plot functionality for ggpubpy.

This module contains the shift plot function for comparing distributions.
"""

from typing import List, Optional, Tuple, Union, cast

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.collections import PolyCollection
from scipy.stats import mannwhitneyu, ttest_ind

from .helper import _bias_corrected_ci, format_p_value, harrelldavis


[docs] def plot_shift( x: np.ndarray, y: np.ndarray, *, paired: bool = False, n_boot: int = 1000, percentiles: np.ndarray = np.arange(10, 100, 10), confidence: float = 0.95, seed: Optional[int] = None, show_median: bool = True, violin: bool = True, show_quantiles: bool = False, show_quantile_diff: bool = False, parametric: bool = False, x_name: str = "X", y_name: str = "Y", # Backward/forward compatibility keyword args (optional) x_label: Optional[str] = None, y_label: Optional[str] = None, title: Optional[str] = None, subtitle: Optional[str] = None, color: Optional[str] = None, line_color: Optional[str] = None, alpha: Optional[float] = None, figsize: Optional[Tuple[float, float]] = None, ) -> plt.Figure: """Shift plot. Parameters ---------- x, y : array_like First and second set of observations. paired : bool If True, x and y are paired samples. n_boot : int Number of bootstrap iterations. percentiles : array_like Sequence of percentiles (0-100) to compute. confidence : float Confidence level for intervals. seed : int or None Random seed. show_median : bool If True, show median lines. violin : bool If True, plot half-violin densities. show_quantiles : bool If True, show quantile connection lines between distributions. show_quantile_diff : bool If True, show bottom subplot with quantile differences. parametric : bool If True, use t-test; else Mann-Whitney U test. Returns ------- fig : matplotlib Figure instance """ # Safety checks x = np.asarray(x) y = np.asarray(y) pct = np.asarray(percentiles) / 100 assert x.ndim == 1 and y.ndim == 1, "x and y must be 1D." assert ( not np.isnan(x).any() and not np.isnan(y).any() ), "Missing values not allowed." nx, ny = x.size, y.size assert nx >= 10 and ny >= 10, "Each sample must have at least 10 observations." assert 0 < confidence < 1, "confidence must be between 0 and 1." # Optional figsize validation if figsize is not None: assert isinstance(figsize, tuple) and len(figsize) == 2, "figsize must be a tuple" # Map optional labels (for compatibility with docs/examples) if x_label is not None: x_name = x_label if y_label is not None: y_name = y_label if paired: assert ( nx == ny ), "x and y must have same size when paired=True." # Harrell-Davis quantiles x_per: np.ndarray = harrelldavis(x, pct) y_per: np.ndarray = harrelldavis(y, pct) delta: np.ndarray = y_per - x_per # Statistical test for comparison if parametric: if paired: from scipy.stats import ttest_rel stat, p_val = ttest_rel(x, y) test_name = "Paired t-test" else: stat, p_val = ttest_ind(x, y) test_name = "Independent t-test" else: if paired: from scipy.stats import wilcoxon stat, p_val = wilcoxon(x, y) test_name = "Wilcoxon signed-rank" else: stat, p_val = mannwhitneyu(x, y, alternative="two-sided") test_name = "Mann-Whitney U" # Bootstrap differences rng = np.random.default_rng(seed) if paired: bootsam = rng.choice(nx, size=(nx, n_boot), replace=True) x_boot_q = np.array( [harrelldavis(x[bootsam[:, i]], pct) for i in range(n_boot)] ).T y_boot_q = np.array( [harrelldavis(y[bootsam[:, i]], pct) for i in range(n_boot)] ).T bootstat = y_boot_q - x_boot_q else: x_boot_q = np.array( [harrelldavis(rng.choice(x, nx), pct) for _ in range(n_boot)] ).T y_boot_q = np.array( [harrelldavis(rng.choice(y, ny), pct) for _ in range(n_boot)] ).T bootstat = y_boot_q - x_boot_q # Confidence intervals lowers_list: List[float] = [] medians_list: List[float] = [] uppers_list: List[float] = [] for i, d in enumerate(delta): ci = _bias_corrected_ci(bootstat[i], d, alpha=1 - confidence) ci_low, ci_high = ci[0], ci[1] med_val_ci = _bias_corrected_ci(bootstat[i], d, alpha=1) med_val = med_val_ci[0] lowers_list.append(float(ci_low)) uppers_list.append(float(ci_high)) medians_list.append(float(med_val)) lowers = np.array(lowers_list) medians = np.array(medians_list) uppers = np.array(uppers_list) # Prepare data for stripplot data = pd.DataFrame( {"value": np.concatenate([x, y]), "variable": ["X"] * nx + ["Y"] * ny} ) # Plot distributions if show_quantile_diff: fig = plt.figure(figsize=(figsize if figsize is not None else (10, 6))) ax1 = plt.subplot2grid((3, 3), (0, 0), rowspan=2, colspan=3) else: fig, ax1 = plt.subplots(figsize=(figsize if figsize is not None else (10, 4))) # Custom boxplots def adj_vals(vals: np.ndarray) -> Tuple[float, float, float, float, float]: percentiles_arr = cast(np.ndarray, np.percentile(vals, [25, 50, 75])) q1, med, q3 = percentiles_arr[0], percentiles_arr[1], percentiles_arr[2] iqr = q3 - q1 lower = np.clip(q1 - 1.5 * iqr, vals.min(), q1) upper = np.clip(q3 + 1.5 * iqr, q3, vals.max()) return float(q1), float(med), float(q3), float(lower), float(upper) for arr, y0 in zip([x, y], [1.2, -0.2]): q1, med, q3, lo, hi = adj_vals(np.sort(arr)) ax1.plot(med, y0, "o", color="white", zorder=10) ax1.hlines(y0, q1, q3, color="k", lw=7, zorder=9) ax1.hlines(y0, lo, hi, color="k", lw=2, zorder=9) # Colors and alpha (with sensible defaults) x_color = "#cfcfcf" y_color = color or "#88bedc" alpha_points = 0.6 if alpha is None else float(alpha) # Scatter raw data points without jitter ax1.scatter(x, np.full_like(x, 1.2), color=x_color, s=10, alpha=alpha_points, zorder=3) ax1.scatter(y, np.full_like(y, -0.2), color=y_color, s=10, alpha=alpha_points, zorder=3) if violin: import warnings with warnings.catch_warnings(): warnings.simplefilter("ignore", PendingDeprecationWarning) vl = ax1.violinplot([y, x], showextrema=False, vert=False, widths=1) bodies = cast(List[PolyCollection], vl["bodies"]) for idx, color, offset in zip([0, 1], ["#88bedc", "#cfcfcf"], [-1.2, -0.8]): path = bodies[idx].get_paths()[0] verts: np.ndarray = cast(np.ndarray, path.vertices) if idx == 0: verts[:, 1][verts[:, 1] >= 1] = 1 else: verts[:, 1][verts[:, 1] <= 2] = 2 verts[:, 1] += offset bodies[idx].set_edgecolor(line_color or "k") # Use provided main color for the 'y' group (idx==0 corresponds to y above) face_col = (y_color if idx == 0 else x_color) bodies[idx].set_facecolor(face_col) bodies[idx].set_alpha(0.8 if alpha is None else float(alpha)) if show_quantile_diff: ax1.set_ylim(2.2, -1.2) else: ax1.set_ylim(1.8, -0.8) # Connect quantiles (optional) if show_quantiles: for i in range(len(pct)): col = ( (line_color or "#4c72b0") if uppers[i] < 0 else ((line_color or "#c34e52") if lowers[i] > 0 else (line_color or "darkgray")) ) plt.plot([y_per[i], x_per[i]], [0.2, 0.8], "o-", color=col, zorder=10) plt.plot([x_per[i]] * 2, [0.8, 1.2], line_color or "k", linestyle="--", zorder=9) plt.plot([y_per[i]] * 2, [-0.2, 0.2], line_color or "k", linestyle="--", zorder=9) if show_median: m_x, m_y = np.median(x), np.median(y) plt.plot([m_x] * 2, [0.8, 1.2], "k-") plt.plot([m_y] * 2, [-0.2, 0.2], "k-") plt.xlabel("Scores (a.u.)", size=15) ax1.set_yticks([1.2, -0.2]) ax1.set_yticklabels([x_name, y_name], size=15) ax1.set_ylabel("") # Add statistical test result to title p_formatted = format_p_value(p_val) plt.title(f"{test_name}: p = {p_formatted}", fontsize=12, pad=10) # Optional overall title/subtitle if title or subtitle: full_title = f"{title}\n{subtitle}" if subtitle else cast(str, title) fig.suptitle(full_title, fontsize=14, fontweight="bold", y=0.98) # Quantile shift plot (optional) if show_quantile_diff: ax2 = plt.subplot2grid((3, 3), (2, 0), colspan=3) for i, per in enumerate(x_per): col = ( "#4c72b0" if uppers[i] < 0 else ("#c34e52" if lowers[i] > 0 else "darkgray") ) ax2.plot([per] * 2, [uppers[i], lowers[i]], lw=3, color=col, zorder=10) ax2.plot(per, medians[i], "o", ms=10, color=col, zorder=10) ax2.axhline(0, ls="--", lw=2, color="gray") ax2.set_xlabel(f"{x_name} quantiles", size=15) ax2.set_ylabel(f"{y_name} - {x_name} quantiles\ndifferences (a.u.)", size=10) plt.tight_layout() return fig