Skip to content

chenzc24/academic_plot

Repository files navigation

academic_plot

Publication-quality plot toolkit with one consistent style across all chart types. Inspired by Nature / Science journal figures — thin axes, inward ticks, colour-blind-safe palette, sans-serif fonts, 600 DPI output.


Environment setup (new machine)

1. Prerequisites

Requirement Version
Python >= 3.10
pip latest (python -m pip install --upgrade pip)

2. Clone and install

git clone <your-repo-url> academic_plot
cd academic_plot

# Option A: editable install (recommended — import from anywhere)
pip install -e .

# Option B: just install dependencies
pip install -r requirements.txt

3. Verify

python examples/demo_linear.py
# → outputs/linear_single.png + .svg

4. Dependencies

Package Version Why
matplotlib >= 3.7 All plotting
numpy >= 1.24 Array operations
scipy >= 1.10 Curve fitting (curve_fit, probplot, gaussian_kde)

Quick start

from academic_plot import linear_plot, savefig
import numpy as np

x = np.array([1, 2, 3, 4, 5])
y = np.array([2.1, 4.0, 5.8, 8.2, 9.9])

fig = linear_plot(x, y, xlabel="Voltage (V)", ylabel="Current (mA)")
savefig(fig, "my_plot")  # → outputs/my_plot.png + .svg

Importing academic_plot automatically applies the global style.


Project structure

academic_plot/               # repo root (= git root)
├── __init__.py              # re-exports every public function
├── style.py                 # global style, palette, savefig()
├── utils.py                 # r2_score, sample data generators
├── linear.py                # linear plots
├── nonlinear.py             # curve-fitting plots
├── radar.py                 # radar / spider charts
├── bar.py                   # bar chart variants
├── distribution.py          # box, violin, histogram
├── scatter_error.py         # scatter with error bars
├── heatmap.py               # annotated heatmap
├── contour.py               # contour / surface plots
├── bode.py                  # Bode (frequency response)
├── fill_between.py          # confidence-band plots
├── forest.py                # forest (meta-analysis)
├── qq.py                    # Q-Q plots
├── multipanel.py            # multi-panel grids
├── examples/                # ready-to-run demos
│   ├── demo_linear.py
│   ├── demo_nonlinear.py
│   ├── demo_radar.py
│   ├── demo_bar.py
│   ├── demo_distribution.py
│   ├── demo_scatter_error.py
│   ├── demo_heatmap.py
│   ├── demo_contour.py
│   ├── demo_bode.py
│   ├── demo_fill_between.py
│   ├── demo_forest.py
│   ├── demo_qq.py
│   └── demo_multipanel.py
├── outputs/                 # generated figures (git-ignored)
├── pyproject.toml           # package metadata & deps
├── requirements.txt         # pip-only fallback
├── .gitignore
└── README.md

Global style & utilities

savefig(fig, name, out_dir="outputs", formats=("png","svg"), dpi=600)

Save a figure to disk. Returns a list of Path objects.

savefig(fig, "result", formats=("png", "svg", "pdf"))

apply_style()

Re-apply the global rcParams. Called automatically on import.

COLORS / PALETTE

from academic_plot import COLORS
ax.plot(x, y, color=COLORS["blue"])   # "#2171B5"
Name Hex Name Hex
blue #2171B5 teal #006D75
red #CB181D brown #8C564B
green #238B45 pink #E377C2
orange #D94801 gray #7F7F7F
purple #6A51A3 olive #808000

Plot types — data requirements & usage

1. Linear — linear.py

linear_plot(x, y, ...)

Scatter plot with optional least-squares regression line.

Parameter Type Description
x np.ndarray 1-D Independent variable
y np.ndarray 1-D Dependent variable (same length as x)
fig = linear_plot(x, y, xlabel="Time (s)", ylabel="Voltage (V)",
                  show_fit=True, color="#2171B5")

multi_line_plot(datasets, ...)

Overlay multiple line/scatter series.

Each dict in datasets:

Key Required Type
x Yes np.ndarray 1-D
y Yes np.ndarray 1-D
label No str
color No str (hex)
marker No str (e.g. "o", "s")
linestyle No str (e.g. "-", "--")
datasets = [
    {"x": t, "y": y1, "label": "Sensor A", "color": "#2171B5"},
    {"x": t, "y": y2, "label": "Sensor B", "color": "#CB181D"},
]
fig = multi_line_plot(datasets, xlabel="Time (s)", ylabel="Signal")

2. Nonlinear curve fitting — nonlinear.py

All fit functions auto-annotate the equation and R² on the plot.

Function Model equation Notes
poly_fit_plot(x, y, degree) y = aₙxⁿ + … + a₁x + a₀ Set degree=3 for cubic, etc.
exp_fit_plot(x, y) y = a·exp(b·x) + c Good for growth / decay
log_fit_plot(x, y) y = a·ln(x) + b All x must be > 0
sigmoid_fit_plot(x, y) y = L/(1+exp(-k(x-x₀))) + b S-curves, dose-response
power_fit_plot(x, y) y = a·x^b + c All x must be > 0
multi_fit_plot(x, y, models) Overlay multiple models models=("poly2","exp","sigmoid")

All share the same data input:

Parameter Type Description
x np.ndarray 1-D Independent variable
y np.ndarray 1-D Observed values
fig = exp_fit_plot(x, y, xlabel="Time (s)", ylabel="Amplitude",
                   title="Exponential Growth")

3. Radar / Spider — radar.py

radar_plot(values, labels, ...)

Single-series radar chart.

Parameter Type Description
values np.ndarray 1-D Metric values (one per axis)
labels list[str] Axis names (same length as values)
fig = radar_plot(
    np.array([0.92, 0.88, 0.85, 0.90, 0.78, 0.82]),
    ["Sensitivity", "Precision", "F1", "Recall", "Speed", "Robustness"],
    title="Model Performance", label="Proposed",
)

multi_radar_plot(data, labels, ...)

Each dict in data:

Key Required Type
values Yes np.ndarray 1-D
label No str
color No str (hex)

4. Bar charts — bar.py

bar_plot(values, labels, ...)

Simple bar chart.

Parameter Type Description
values np.ndarray 1-D Bar heights
labels list[str] Category names

grouped_bar_plot(data, labels, ...) / stacked_bar_plot(data, labels, ...)

Each dict in data:

Key Required Type
values Yes np.ndarray 1-D (length = number of groups)
label No str
color No str (hex)
data = [
    {"values": [0.90, 0.85, 0.78], "label": "Method A"},
    {"values": [0.87, 0.88, 0.82], "label": "Method B"},
]
fig = grouped_bar_plot(data, ["Task 1", "Task 2", "Task 3"],
                       ylabel="Accuracy")

5. Distribution — distribution.py

box_plot(data, labels, ...) / violin_plot(data, labels, ...)

Parameter Type Description
data list[np.ndarray] One 1-D array per group
labels list[str] Group names
fig = box_plot(
    [group_control, group_drug_a, group_drug_b],
    ["Control", "Drug A", "Drug B"],
    ylabel="Response (mV)",
)

histogram_plot(data, ...)

Parameter Type Description
data np.ndarray 1-D Raw observations
fig = histogram_plot(data, bins=30, xlabel="Value", show_kde=True)

6. Scatter with error bars — scatter_error.py

errorbar_plot(x, y, ...)

Parameter Type Description
x, y np.ndarray 1-D Data coordinates
xerr, yerr np.ndarray / float / None Error magnitude(s). Can be a scalar (constant), 1-D array (symmetric), or 2×N array (asymmetric [lo, hi]).
fig = errorbar_plot(voltage, current, yerr=uncertainty,
                    xlabel="Voltage (V)", ylabel="Current (mA)")

multi_errorbar_plot(datasets, ...)

Each dict in datasets:

Key Required Type
x, y Yes np.ndarray 1-D
xerr, yerr No np.ndarray / float
label No str
color No str (hex)

7. Heatmap — heatmap.py

heatmap(data, ...)

Parameter Type Description
data np.ndarray 2-D Matrix of shape (rows, cols)
xlabels list[str] or None Column labels
ylabels list[str] or None Row labels
cmap str Colormap name (e.g. "Blues", "RdBu_r", "viridis")
vmin, vmax float or None Colour scale limits
# Correlation matrix
C = np.corrcoef(data_matrix.T)
fig = heatmap(C, xlabels=feature_names, ylabels=feature_names,
              cmap="RdBu_r", vmin=-1, vmax=1)

# Confusion matrix
fig = heatmap(cm, xlabels=classes, ylabels=classes, fmt=".0f",
              cmap="Blues", colorbar_label="Count")

8. Contour — contour.py

contour_plot(fn=..., Z=..., x=..., y=...)

Two input modes:

Mode A — callable:

Parameter Type Description
fn callable(X, Y) -> Z Function over a meshgrid
def rosenbrock(X, Y):
    return (1 - X)**2 + 100 * (Y - X**2)**2

fig = contour_plot(rosenbrock,
                   x=np.linspace(-2, 2, 200),
                   y=np.linspace(-1, 3, 200),
                   levels=25, cmap="viridis")

Mode B — precomputed:

Parameter Type Description
Z np.ndarray 2-D Precomputed surface values
x, y np.ndarray 1-D Coordinate vectors

9. Bode plot — bode.py

bode_plot(tf=..., freq=..., mag=..., phase=...)

Two input modes:

Mode A — transfer function:

Parameter Type Description
tf callable(f) -> complex f is frequency in Hz; returns complex H(f)
def tf(f):
    s = 1j * 2 * np.pi * f
    return 1e6 / (s**2 + 600*s + 1e6)

fig = bode_plot(tf, title="2nd-Order Low-Pass")

Mode B — precomputed arrays:

Parameter Type Description
freq np.ndarray 1-D Frequency points (Hz)
mag np.ndarray 1-D Magnitude in dB
phase np.ndarray 1-D Phase in degrees

10. Confidence band — fill_between.py

confidence_band_plot(x, y, y_low, y_high, ...)

Parameter Type Description
x np.ndarray 1-D Shared x coordinates
y np.ndarray 1-D Central (mean) values
y_low np.ndarray 1-D Lower bound
y_high np.ndarray 1-D Upper bound
fig = confidence_band_plot(time, mean_signal, ci_low, ci_high,
                           xlabel="Time (s)", ylabel="Signal (V)",
                           band_label="95% CI")

multi_confidence_band_plot(datasets, ...)

Each dict in datasets:

Key Required Type
x Yes np.ndarray 1-D
y Yes np.ndarray 1-D
y_low Yes np.ndarray 1-D
y_high Yes np.ndarray 1-D
label No str
color No str (hex)

11. Forest plot — forest.py

forest_plot(labels, estimates, ci_low, ci_high, ...)

Parameter Type Description
labels list[str] Study / group names (y-axis)
estimates np.ndarray 1-D Point estimates
ci_low np.ndarray 1-D Lower CI bounds
ci_high np.ndarray 1-D Upper CI bounds
reference_line float or None Vertical null line (e.g. 0.0 or 1.0)
summary_estimate float or None Pooled estimate (shown as diamond)
summary_ci (float, float) or None Pooled CI bounds
fig = forest_plot(
    ["Smith 2019", "Lee 2020", "Garcia 2021"],
    np.array([0.45, 0.62, 0.38]),
    np.array([0.20, 0.40, 0.15]),
    np.array([0.70, 0.84, 0.61]),
    xlabel="Odds Ratio",
    reference_line=0.5,
    summary_estimate=0.48,
    summary_ci=(0.35, 0.61),
)

12. Q-Q plot — qq.py

qq_plot(data, ...)

Parameter Type Description
data np.ndarray 1-D Sample observations
dist str Theoretical distribution for scipy.stats.probplot (default "norm")
fig = qq_plot(residuals, dist="norm", title="Normality Check")

Points on the reference line = data follows the distribution. Systematic curves = skewness or heavy tails.


13. Multi-panel — multipanel.py

multipanel_plot(panels, nrows=, ncols=, ...)

Each dict in panels:

Key Required Type Description
plot_fn Yes callable(ax) Draws on a single plt.Axes
title No str Subplot title
xlabel No str x-axis label
ylabel No str y-axis label
panels = [
    {"plot_fn": lambda ax: ax.plot(t, signal), "title": "Raw",
     "xlabel": "Time (s)", "ylabel": "V"},
    {"plot_fn": lambda ax: ax.hist(data, bins=20), "title": "Histogram",
     "xlabel": "Value", "ylabel": "Count"},
]
fig = multipanel_plot(panels, nrows=1, ncols=2, title="Overview")

Common parameters (shared by most functions)

Parameter Type Default What it controls
xlabel str "x" x-axis label text
ylabel str "y" y-axis label text
title str or None None Subplot title
figsize (float, float) varies Figure size in inches
show_grid bool True Show/hide background grid
grid_alpha float 0.25 Grid transparency (0–1)
grid_linewidth float 0.3 Grid line thickness
grid_color str "#888888" Grid line colour
color str or None auto Hex colour for main data
legend_loc str "best" Legend placement

All functions return a plt.Figure that can be further customised or passed to savefig().


Running the demos

python examples/demo_linear.py
python examples/demo_nonlinear.py
python examples/demo_radar.py
python examples/demo_bar.py
python examples/demo_distribution.py
python examples/demo_scatter_error.py
python examples/demo_heatmap.py
python examples/demo_contour.py
python examples/demo_bode.py
python examples/demo_fill_between.py
python examples/demo_forest.py
python examples/demo_qq.py
python examples/demo_multipanel.py

All output is saved to outputs/ as PNG + SVG at 600 DPI.


Using in your own project

After pip install -e ., import from anywhere:

from academic_plot import linear_plot, savefig, COLORS

fig = linear_plot(x, y, color=COLORS["blue"])
savefig(fig, "my_figure", formats=("png", "pdf"))

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages