torch-cudagraph-debug is a PyTorch extension for inspecting tensor values from
CUDA Graph replay. A tensor probe returns its input tensor unchanged while
inserting graph-captured debug work:
- device-to-host copy into pinned host staging memory;
- CUDA host callback;
- native print, record, or compare logic on CPU memory.
The first domain is tensor_debug:
TensorPrint: print a compact tensor summary and sample values.TensorRecord: persist replay snapshots into a native ring buffer for later Python reads.TensorCompare: compare replay values against CPU or NumPy ground truth.
This is a v0.1 source-built package for Linux CUDA environments. It targets PyTorch 2.6+ and builds against the CUDA-enabled PyTorch already installed in the runtime where you install it. Prebuilt wheels are intentionally out of scope for the first release.
Install CUDA-enabled PyTorch first, then install the v0.1.0 release tag from GitHub with build isolation disabled so the extension builds against that exact PyTorch:
pip install --no-build-isolation \
"git+https://github.com/buptzyb/torch-cudagraph-debug.git@v0.1.0"To test the latest development branch instead, install main:
pip install --no-build-isolation \
"git+https://github.com/buptzyb/torch-cudagraph-debug.git@main"From a source checkout:
cd torch-cudagraph-debug
pip install --no-build-isolation .Enabled probes require a native extension built against CUDA-enabled PyTorch. Source-tree imports and all-disabled probes can run without the extension, but normal source installation for runtime use should happen in the target CUDA environment.
- API reference: signatures, method semantics, execution modes, multi-action behavior, examples, and troubleshooting.
- Release checklist: validation steps for source releases and GPU smoke tests.
- Changelog: release notes and compatibility notes.
import torch
from torch_cudagraph_debug.tensor_debug import (
CudaGraphTensorProbe,
TensorCompare,
TensorPrint,
TensorRecord,
)
static_x = torch.ones(4, device="cuda")
expected = torch.full((4,), 3.0, device="cpu")
probe = CudaGraphTensorProbe(
"mid",
actions=[
TensorPrint(max_items=8),
TensorRecord(capacity=16),
TensorCompare(expected, rtol=1e-5, atol=1e-8),
],
)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
mid = static_x + 2
mid = probe(mid)
g.replay()
torch.cuda.synchronize()
snapshots = probe.records()
probe.assert_ok()
probe.close()CudaGraphTensorProbe.__call__ always returns the original tensor object. By
default, eager and warmup calls are transparent no-ops; debug side effects are
installed only while the current CUDA stream is being captured.
TensorPrint is useful when you need a quick value check from graph replay:
probe = CudaGraphTensorProbe("activation", [TensorPrint(max_items=16, every=10)])TensorRecord keeps CPU snapshots in a fixed-size native ring buffer:
probe = CudaGraphTensorProbe("activation", [TensorRecord(capacity=8)])
snapshots = probe.records()
last_cpu_tensor = snapshots[-1].tensorEach TensorSnapshot contains probe_name, replay_index, shape, dtype,
device, and the recorded CPU tensor.
TensorCompare checks replay values against a CPU tensor or NumPy array. A
mismatch sets a sticky failure flag; call probe.assert_ok() after replay and
synchronization.
probe = CudaGraphTensorProbe(
"activation",
[TensorCompare(expected_cpu_tensor, rtol=1e-4, atol=1e-5)],
)Every action accepts enabled=False for config-driven toggles. Disabled actions
are not installed into the native probe. If every action is disabled, the probe
is a pure no-op: it returns the input tensor unchanged, exposes no records, and
does not require the native extension to be loaded.
Use probe.attach_grad(tensor) to register an autograd hook that probes a
tensor's backward gradient. It returns the original tensor, so activation probes
can stay inline:
hidden = activation_value_probe(hidden)
hidden = activation_grad_probe.attach_grad(hidden)For parameter gradients, prefer side-effect style so the code does not look like it replaces module state:
weight_grad_probe.attach_grad(module.weight)This parameter hook observes the gradient when autograd produces it. If you want
the final .grad buffer after backward(), probe that buffer explicitly:
loss.backward()
if module.weight.grad is not None:
final_weight_grad_probe(module.weight.grad)If you need to remove a long-lived hook, request the PyTorch hook handle:
_, handle = weight_grad_probe.attach_grad(module.weight, return_handle=True)
handle.remove()For a module-level example that inserts a probe into an internal hidden tensor
and controls actions from a small config object, see
examples/transformer_block_probe.py.
For the common forward/gradient probe patterns, see
examples/grad_probe_patterns.py.
For TensorBoard summaries from recorded snapshots, see
examples/tensorboard_export_records.py.
Custom Python actions are intentionally not executed inside CUDA host callbacks.
Use TensorRecord to bring graph replay values back to CPU, then consume them
after synchronization:
from torch_cudagraph_debug.tensor_debug.postprocess import drain_records
probe = CudaGraphTensorProbe("hidden", [TensorRecord(capacity=8)])
# after graph replay
drain_records(probe, lambda snapshot: my_custom_action(snapshot.tensor))drain_records() synchronizes CUDA by default, passes each TensorSnapshot to
the consumer, and clears the probe records after successful consumption. For a
complete custom-consumer example, see
examples/custom_record_consumer.py.
The default execution mode is mode="capture":
probe = CudaGraphTensorProbe("activation", [TensorRecord(capacity=8)])In this mode, probe(tensor) is completely transparent during eager warmup: it
does not validate the input tensor, allocate staging memory, enqueue D2H copies,
launch callbacks, print, record, or compare. When the same call runs inside
torch.cuda.graph(...), the probe captures the debug D2H copy and one host
callback node.
mode="always" is an explicit escape hatch:
probe = CudaGraphTensorProbe(
"activation",
[TensorRecord(capacity=8)],
mode="always",
)Keep this mode for cases where eager side effects are intentional: testing the probe without writing a CUDA graph, recording eager and graph values with the same probe machinery, debugging non-graph CUDA stream code, or covering native enqueue behavior in this package's tests. It is not the default because it can pollute warmup records, print during warmup, or set compare failures before graph replay.
The default non-contiguous policy is fail-fast when debug work is actually installed:
probe = CudaGraphTensorProbe("view", [TensorRecord()])
with torch.cuda.graph(g):
probe(non_contiguous_tensor) # raisesIn the default mode="capture", eager warmup calls are no-ops, so this policy is
checked during capture or explicit probe.prepare(...). This avoids hidden CUDA
graph memory pressure. If you explicitly allow it, the probe creates an internal
contiguous CUDA copy only for the debug path, captures the copy before the D2H
node, keeps that internal storage alive with the captured probe payload, and
still returns the original tensor:
probe = CudaGraphTensorProbe(
"view",
[TensorRecord(capacity=4)],
non_contiguous="copy",
)The copy policy costs roughly tensor.numel() * tensor.element_size() additional
CUDA graph-pool memory for each captured non-contiguous probe site. In tight
memory captures, prefer probing a smaller slice or making the debug copy explicit
in your model code so the memory cost is visible.
- Keep the probe alive for at least as long as any captured
torch.cuda.CUDAGraphthat contains it can replay. - Call
probe.close()only after those graphs will never replay again. - Call
torch.cuda.synchronize()before reading records or asserting compare status if replay was launched asynchronously. - Host callbacks do not call Python or CUDA APIs. Print, record, and compare are native CPU operations on pinned staging memory.
- Linux/CUDA only.
- One probe is not designed for concurrent replay on multiple streams.
- Inputs must use supported dense dtypes: float64, float32, float16, bfloat16, int64, int32, int16, int8, uint8, or bool.
TensorCompareexpected values must be CPU tensors or NumPy arrays with the same shape and dtype as the probed tensor.
The v0.1.0 tag has been validated from the public GitHub install path in
nvcr.io/nvidia/pytorch:26.03-py3 on Computelab GH200/GH100 nodes. The release
gate built the native extension from source, reported
native_extension_available True, passed the full CUDA pytest suite
(36 passed), and ran the custom record consumer and gradient probe examples.
The package is intended to build against CUDA-enabled PyTorch 2.6+ installations, but each CUDA/PyTorch/container combination should be verified in the target environment before relying on it in a larger training workflow.
Local metadata and Python checks:
python -m py_compile $(find src tests examples -name '*.py')
python -m build --sdistGPU validation requires a CUDA-enabled PyTorch environment:
pip install --no-build-isolation --no-deps .
pytest -q