Skip to content

Commit 4704078

Browse files
committed
update
1 parent 0f03fb1 commit 4704078

File tree

2 files changed

+79
-40
lines changed

2 files changed

+79
-40
lines changed

examples/dynamo/aot_plugin.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,8 @@ def add_plugin_aot_impl(
9898
"x_ptr": f"*{type_str}",
9999
"n_elements": "i32",
100100
"y_ptr": f"*{type_str}",
101-
"BLOCK_SIZE": "constexpr",
102101
},
103-
constants={
102+
constexprs={
104103
"BLOCK_SIZE": block_size,
105104
},
106105
)
Lines changed: 78 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
"""
2-
Minimal reproducible example demonstrating TensorRT fp16 custom_op() issue.
2+
Using Custom Kernels with NVRTC in TensorRT AOT Plugins
3+
=======================================================
34
4-
This module shows the bug where torch_tensorrt.dynamo.conversion.plugins.custom_op()
5-
fails to compile operations that use fp16 (half-precision) tensors.
5+
This example demonstrates how to use the NVIDIA Runtime Compilation (NVRTC) library
6+
to compile custom CUDA kernels at runtime and integrate them into a TensorRT
7+
Ahead-Of-Time (AOT) plugin.
68
7-
The issue occurs because the JIT plugin generator doesn't properly declare format
8-
support for fp16 data types in the generated TensorRT plugin.
9+
This approach is powerful because it allows you to:
10+
1. Write raw CUDA C++ code for maximum performance.
11+
2. Compile it on-the-fly, adapting to the specific GPU architecture.
12+
3. Wrap it in a TensorRT plugin without writing a separate C++ plugin library.
13+
4. Integrate it seamlessly into Torch-TensorRT's compilation flow.
14+
15+
The example performs a simple pointwise Sigmoid operation: f(x) = 1 / (1 + exp(-x)).
916
"""
1017

1118
from typing import List, Tuple, Union
@@ -14,8 +21,13 @@
1421

1522
import torch_tensorrt
1623

17-
# CUDA kernel source (NVRTC) used by the torch custom op
18-
# Note: TensorRT passes args as: inputs, extra_args, outputs
24+
# ============================================================================
25+
# 1. Define the CUDA Kernel Source
26+
# ============================================================================
27+
# We define the CUDA kernel source code as a Python string.
28+
# This code will be compiled by NVRTC.
29+
# Note that we use extern "C" to avoid name mangling, making it easier to
30+
# retrieve the kernel function by name later.
1931

2032
cu_code = """
2133
// Simple pointwise Sigmoid kernel: f(x) = 1 / (1 + exp(-x))
@@ -32,47 +44,60 @@
3244
}
3345
"""
3446

35-
# Prepare NVRTC program, kernel, and stream once (simple eager path)
47+
# ============================================================================
48+
# 2. Compile the Kernel using NVRTC (for eager mode)
49+
# ============================================================================
50+
# Before defining the Torch custom op, we compile the kernel so we can run it
51+
# in standard PyTorch (eager mode) for verification and testing.
52+
# We use the cuda-python library's NVRTC bindings.
53+
3654
from cuda.core.experimental import Device as _CudaDevice
3755
from cuda.core.experimental import LaunchConfig as _LaunchConfig
3856
from cuda.core.experimental import Program as _CudaProgram
3957
from cuda.core.experimental import ProgramOptions as _CudaProgramOptions
4058
from cuda.core.experimental import launch as _cuda_launch
4159

60+
# Initialize CUDA device and stream
4261
_cuda_device = _CudaDevice()
4362
_cuda_device.set_current()
4463
_cuda_stream = _cuda_device.create_stream()
64+
65+
# Configure compilation options
4566
_program_options = _CudaProgramOptions(
4667
std="c++17",
47-
arch=f"sm_{_cuda_device.arch}",
68+
arch=f"sm_{_cuda_device.arch}", # Target the current GPU architecture
4869
include_path=["/usr/local/cuda/include"],
4970
)
71+
72+
# Create and compile the program
5073
_program = _CudaProgram(cu_code, code_type="c++", options=_program_options)
5174
_module = _program.compile("ptx", name_expressions=("pointwise_sigmoid_kernel_nvrtc",))
5275
_kernel = _module.get_kernel("pointwise_sigmoid_kernel_nvrtc")
5376

54-
# Eager torch custom_op implemented using the CUDA kernel above (no Triton)
55-
5677

5778
# ============================================================================
58-
# Custom Op Registration
79+
# 3. Register Custom Op in PyTorch
5980
# ============================================================================
60-
81+
# We register the custom operation with PyTorch so it can be used in models.
82+
# The 'mutates_args=()' argument tells PyTorch this op is functional (doesn't modify inputs in-place).
6183

6284
@torch.library.custom_op("pointwise_sigmoid_ops::pointwise_sigmoid", mutates_args=()) # type: ignore[misc]
6385
def pointwise_sigmoid(X: torch.Tensor) -> torch.Tensor:
86+
"""
87+
Implementation of the custom op for PyTorch eager execution.
88+
This function launches the pre-compiled NVRTC kernel.
89+
"""
6490
assert X.is_cuda, "Tensor must be on CUDA device."
6591
assert X.dtype == torch.float32, "For this test, expected float32 input"
6692

6793
Y = torch.empty_like(X)
6894
N = int(X.numel())
6995

7096
block = 256
71-
7297
grid_x = max(1, (N + block - 1) // block)
7398
config = _LaunchConfig(grid=(grid_x), block=(block))
7499

75-
# Use PyTorch's current stream by wrapping it for cuda.core
100+
# Helper class to wrap PyTorch's stream for cuda-python
76101
class _PyTorchStreamWrapper:
77102
def __init__(self, pt_stream):
78103
self.pt_stream = pt_stream
@@ -84,9 +109,7 @@ def __cuda_stream__(self):
84109
pt_stream = torch.cuda.current_stream()
85110
s = _cuda_device.create_stream(_PyTorchStreamWrapper(pt_stream))
86111

87-
# Launch kernel with raw pointers as in cuda.core example
88-
# Note: argument order is input, size, (matching TensorRT's convention)
89-
112+
# Launch kernel with raw pointers
90113
_cuda_launch(
91114
s,
92115
config,
@@ -99,34 +122,51 @@ def __cuda_stream__(self):
99122
return Y
100123

101124

125+
# ============================================================================
126+
# 4. Register Fake Implementation (Meta Kernel)
127+
# ============================================================================
128+
# The fake implementation is crucial for TorchDynamo. It tells the compiler
129+
# about the output shape and data type without actually running the kernel.
130+
# This is used during the tracing phase.
131+
102132
@torch.library.register_fake("pointwise_sigmoid_ops::pointwise_sigmoid")
103133
def _(input: torch.Tensor) -> torch.Tensor:
104134
"""Fake implementation for TorchDynamo tracing of base operation."""
105135
return torch.empty_like(input)
106136

107137

108138
# ============================================================================
109-
# TensorRT Wrapper with custom_op() - THIS FAILS WITH FP16
139+
# 5. Define TensorRT AOT Plugin
110140
# ============================================================================
141+
# Now we define how this operation should be handled within TensorRT.
142+
# We use the TensorRT Python Plugin API to register the plugin description,
143+
# autotuning behavior, and the AOT implementation using NVRTC.
111144

112145
import tensorrt.plugin as trtp
113146
from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions
114147

115148

149+
# 5a. Plugin Description
150+
# Tells TensorRT about input/output properties (dtypes, formats)
116151
@trtp.register("pointwise_sigmoid_ops::pointwise_sigmoid")
117152
def sigmoid_plugin_desc(input: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]:
118153
return (input.like(),)
119154

120155

156+
# 5b. Autotuning Support
157+
# Defines valid data type combinations for the plugin.
121158
@trtp.autotune("pointwise_sigmoid_ops::pointwise_sigmoid")
122159
def sigmoid_autotune(
123160
input: trtp.TensorDesc,
124161
outputs: Tuple[trtp.TensorDesc],
125162
) -> List[trtp.AutoTuneCombination]:
126-
# Match float32 path; add FP16 if you want both
163+
# We specify that this plugin supports FP32 input and FP32 output
127164
return [trtp.AutoTuneCombination("FP32, FP32", "LINEAR")]
128165

129166

167+
# 5c. AOT Implementation
168+
# This is where the magic happens. We provide the compiled PTX code and
169+
# launch parameters to TensorRT. This code runs during engine building.
130170
@trtp.aot_impl("pointwise_sigmoid_ops::pointwise_sigmoid")
131171
def sigmoid_aot_nvrtc_impl(
132172
input: trtp.TensorDesc,
@@ -135,23 +175,22 @@ def sigmoid_aot_nvrtc_impl(
135175
) -> Tuple[
136176
Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs
137177
]:
138-
178+
# Get the PTX code from our pre-compiled module
139179
compiled_kernel = _module.code.decode("utf-8")
140-
print(type(compiled_kernel))
141-
print(compiled_kernel)
142-
143-
# import pdb; pdb.set_trace()
144-
180+
181+
# Calculate grid and block dimensions based on input shape
145182
N = input.shape_expr.numel()
146183
launch_params = trtp.KernelLaunchParams()
147184
block = 256
148185
launch_params.grid_x = trtp.cdiv(N, block)
149186
launch_params.block_x = block
150187
launch_params.shared_mem = 0
151188

189+
# Pass the number of elements (N) as an extra argument to the kernel
152190
extra_args = trtp.SymIntExprs(1)
153191
extra_args[0] = trtp.SymInt32(N)
154192

193+
# Return: kernel name, PTX code, launch params, kernel arguments
155194
return (
156195
"pointwise_sigmoid_kernel_nvrtc",
157196
compiled_kernel,
@@ -160,6 +199,13 @@ def sigmoid_aot_nvrtc_impl(
160199
)
161200

162201

202+
# ============================================================================
203+
# 6. Generate Plugin Converter
204+
# ============================================================================
205+
# This registers the mapping between the PyTorch custom op and the TensorRT plugin.
206+
# It tells Torch-TensorRT: "When you see 'pointwise_sigmoid_ops::pointwise_sigmoid',
207+
# replace it with the TensorRT plugin we just defined."
208+
163209
torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
164210
"pointwise_sigmoid_ops::pointwise_sigmoid",
165211
supports_dynamic_shapes=True,
@@ -168,21 +214,15 @@ def sigmoid_aot_nvrtc_impl(
168214

169215

170216
# ============================================================================
171-
# Test Model
217+
# 7. Test the Model
172218
# ============================================================================
173219

174-
175220
class PointwiseSigmoidModel_WithTRTWrapper(torch.nn.Module):
176221
"""
177222
Test model that uses the TRT wrapper with custom_op() registration.
178-
179-
When compiled with torch_tensorrt.compile() using fp16 inputs, this will
180-
fail with: "could not find any supported formats consistent with input/output
181-
data types"
182223
"""
183224

184225
def forward(self, input: torch.Tensor) -> torch.Tensor:
185-
186226
z = torch.ops.pointwise_sigmoid_ops.pointwise_sigmoid(input)
187227
return z
188228

@@ -191,10 +231,13 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
191231
model = PointwiseSigmoidModel_WithTRTWrapper().to("cuda").eval()
192232
input = torch.randn(1, 1024, device="cuda", dtype=torch.float32)
193233

234+
print("PyTorch baseline result:")
194235
print(torch.sigmoid(input))
195236

237+
print("Custom Op eager result:")
196238
print(model(input))
197239

240+
print("\nCompiling with Torch-TensorRT...")
198241
with torch_tensorrt.logging.debug():
199242
trt_inputs = [input]
200243
model_trt = torch_tensorrt.compile(
@@ -204,16 +247,13 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
204247
min_block_size=1,
205248
)
206249
print("Model compiled successfully!")
250+
207251
print("Running inference with compiled model...")
208-
print("Compiled model output:")
209-
print(model_trt(input))
210-
print("Original model output:")
211-
print(model(input))
212252
with torch.no_grad():
213253
for i in range(10):
214254
res = model_trt(input)
215255
assert torch.allclose(
216256
res, model(input), rtol=1e-2, atol=1e-2
217257
), "Results do not match!"
218258

219-
# print("Inference successful!")
259+
print("Inference successful!")

0 commit comments

Comments
 (0)