"""
Violin plot functionality for ggpubpy.
This module contains the violin plot function with statistical annotations.
"""
from typing import Dict, List, Optional, Tuple, cast
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.collections import PolyCollection
from .helper import (
_get_palette_for_data,
_perform_statistical_tests,
_validate_inputs,
format_p_value,
significance_stars,
)
[docs]
def plot_violin_with_stats(
df: pd.DataFrame,
x: str,
y: str,
*,
x_label: Optional[str] = None,
y_label: Optional[str] = None,
title: Optional[str] = None,
subtitle: Optional[str] = None,
order: Optional[List] = None,
palette: Optional[Dict] = None,
figsize: Tuple[int, int] = (6, 6),
figsize_scale: float = 1.0,
add_jitter: bool = True,
jitter_std: float = 0.04,
alpha: Optional[float] = None,
violin_width: float = 0.6,
box_width: float = 0.15,
global_test: bool = True,
pairwise_test: bool = True,
parametric: bool = False,
) -> Tuple[plt.Figure, plt.Axes]:
"""
Draw a violin + boxplot + jitter + stats.
Parameters
----------
df : pd.DataFrame
Your data.
x : str
Categorical column name.
y : str
Numeric column name.
x_label : str, optional
Custom label for the x-axis.
y_label : str, optional
Custom label for the y-axis.
title, subtitle : str, optional
Overall plot title and optional subtitle.
order : list, optional
Order of x categories. Defaults to sorted unique values.
palette : dict, optional
Mapping from category -> color.
figsize : tuple
Figure size.
figsize_scale : float
Scale factor for figure size.
add_jitter : bool
Whether to add jittered points.
jitter_std : float
Standard deviation for horizontal jitter.
alpha : float, optional
Transparency for jittered points (0-1). Defaults to 0.6.
violin_width : float
Width of violin plots.
box_width : float
Width of boxplots inside violins.
global_test : bool
Whether to perform global statistical test.
pairwise_test : bool
Whether to perform pairwise statistical tests.
parametric : bool
If True, use parametric tests (ANOVA + t-test).
If False, use non-parametric tests (Kruskal-Wallis + Mann-Whitney U).
Returns
-------
tuple
(figure, axes) matplotlib objects.
"""
# Validate inputs
_validate_inputs(df, x, y, order)
assert (
isinstance(figsize, (tuple, list)) and len(figsize) == 2
), "figsize must be a tuple/list of length 2"
assert figsize_scale > 0, "figsize_scale must be positive"
assert jitter_std >= 0, "jitter_std must be non-negative"
assert violin_width > 0, "violin_width must be positive"
assert box_width > 0, "box_width must be positive"
assert isinstance(
parametric, bool
), "parametric must be a boolean" # Prepare category levels and corresponding data
levels = order if order is not None else sorted(df[x].unique())
groups = [df[df[x] == lvl][y].dropna().values for lvl in levels]
positions = np.arange(len(levels)) + 1
# Generate color palette
color_palette = _get_palette_for_data(levels, palette)
# Statistical tests
global_stat, global_p, pairwise = _perform_statistical_tests(groups, parametric)
# Filter pairwise results if pairwise_test is False
if not pairwise_test:
pairwise = [] # Create figure
scaled_figsize = (figsize[0] * figsize_scale, figsize[1] * figsize_scale)
fig, ax = plt.subplots(figsize=scaled_figsize)
# Violin plots
violin_parts = ax.violinplot(
groups,
positions=positions,
widths=violin_width,
showextrema=True,
showmedians=False,
showmeans=False,
) # Color the violins with palette
bodies = cast(List[PolyCollection], violin_parts["bodies"])
for idx, body in enumerate(bodies):
level = levels[idx]
color = color_palette[level]
body.set_facecolor(color)
body.set_edgecolor("black")
body.set_alpha(1.0) # Fully filled violin
# Color the violin extrema lines if they exist (these are LineCollection objects)
if "cmins" in violin_parts and violin_parts["cmins"] is not None:
violin_parts["cmins"].set_color("black")
violin_parts["cmins"].set_linewidth(1)
if "cmaxes" in violin_parts and violin_parts["cmaxes"] is not None:
violin_parts["cmaxes"].set_color("black")
violin_parts["cmaxes"].set_linewidth(1)
if "cbars" in violin_parts and violin_parts["cbars"] is not None:
violin_parts["cbars"].set_color("black")
violin_parts["cbars"].set_linewidth(1)
# Boxplots - white background
ax.boxplot(
groups,
positions=positions,
widths=box_width,
patch_artist=True,
showfliers=False,
boxprops=dict(facecolor="white", color="black"),
whiskerprops=dict(color="black"),
capprops=dict(color="black"),
medianprops=dict(color="black"),
)
# Add jittered points
if add_jitter:
rng = np.random.default_rng(0)
alpha_points = 0.6 if alpha is None else float(alpha)
for pos, values in zip(positions, groups):
xs = rng.normal(pos, jitter_std, size=len(values))
ax.scatter(
xs, values, s=15, color="k", alpha=alpha_points, zorder=3
) # Statistical annotations
data_min: float = np.min([np.min(g) for g in groups if len(g) > 0])
data_max: float = np.max([np.max(g) for g in groups if len(g) > 0])
span = data_max - data_min
base = data_max + 0.1 * span
step = 0.1 * span # Pairwise annotations
for idx, (i, j, pval) in enumerate(pairwise):
i_pos, j_pos = positions[i], positions[j]
y0 = base + step * idx
p_text = significance_stars(pval)
ax.plot(
[i_pos, i_pos, j_pos, j_pos],
[y0, y0 + 0.02 * span, y0 + 0.02 * span, y0],
color="black",
)
ax.text((i_pos + j_pos) / 2, y0 + 0.03 * span, p_text, ha="center", va="bottom")
# Global test annotation
if global_test and not np.isnan(global_p):
test_name = "One-way ANOVA" if parametric else "Kruskal-Wallis"
p_formatted = format_p_value(global_p)
ax.text(
positions[0],
base + step * (len(pairwise) + 0.4),
f"{test_name} p = {p_formatted}",
fontsize=10,
va="bottom",
) # Axis labels
ax.set_xticks(positions)
ax.set_xticklabels(levels)
ax.set_xlabel(x_label or x)
ax.set_ylabel(y_label or y)
# Legend
import matplotlib.patches as mpatches
handles = [mpatches.Patch(color=color_palette[l], label=l) for l in levels]
ax.legend(handles=handles, title=(x_label or x))
# Optional overall title/subtitle
if title or subtitle:
full_title = f"{title}\n{subtitle}" if subtitle else title
if full_title:
fig.suptitle(full_title, fontsize=14, fontweight="bold", y=0.98)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_ylim(data_min - 0.05 * span, base + step * (len(pairwise) + 0.6))
plt.tight_layout()
return fig, ax