11"""
2- Minimal reproducible example demonstrating TensorRT fp16 custom_op() issue.
2+ Using Custom Kernels with NVRTC in TensorRT AOT Plugins
3+ =======================================================
34
4- This module shows the bug where torch_tensorrt.dynamo.conversion.plugins.custom_op()
5- fails to compile operations that use fp16 (half-precision) tensors.
5+ This example demonstrates how to use the NVIDIA Runtime Compilation (NVRTC) library
6+ to compile custom CUDA kernels at runtime and integrate them into a TensorRT
7+ Ahead-Of-Time (AOT) plugin.
68
7- The issue occurs because the JIT plugin generator doesn't properly declare format
8- support for fp16 data types in the generated TensorRT plugin.
9+ This approach is powerful because it allows you to:
10+ 1. Write raw CUDA C++ code for maximum performance.
11+ 2. Compile it on-the-fly, adapting to the specific GPU architecture.
12+ 3. Wrap it in a TensorRT plugin without writing a separate C++ plugin library.
13+ 4. Integrate it seamlessly into Torch-TensorRT's compilation flow.
14+
15+ The example performs a simple pointwise Sigmoid operation: f(x) = 1 / (1 + exp(-x)).
916"""
1017
1118from typing import List , Tuple , Union
1421
1522import torch_tensorrt
1623
17- # CUDA kernel source (NVRTC) used by the torch custom op
18- # Note: TensorRT passes args as: inputs, extra_args, outputs
24+ # ============================================================================
25+ # 1. Define the CUDA Kernel Source
26+ # ============================================================================
27+ # We define the CUDA kernel source code as a Python string.
28+ # This code will be compiled by NVRTC.
29+ # Note that we use extern "C" to avoid name mangling, making it easier to
30+ # retrieve the kernel function by name later.
1931
2032cu_code = """
2133// Simple pointwise Sigmoid kernel: f(x) = 1 / (1 + exp(-x))
3244}
3345"""
3446
35- # Prepare NVRTC program, kernel, and stream once (simple eager path)
47+ # ============================================================================
48+ # 2. Compile the Kernel using NVRTC (for eager mode)
49+ # ============================================================================
50+ # Before defining the Torch custom op, we compile the kernel so we can run it
51+ # in standard PyTorch (eager mode) for verification and testing.
52+ # We use the cuda-python library's NVRTC bindings.
53+
3654from cuda .core .experimental import Device as _CudaDevice
3755from cuda .core .experimental import LaunchConfig as _LaunchConfig
3856from cuda .core .experimental import Program as _CudaProgram
3957from cuda .core .experimental import ProgramOptions as _CudaProgramOptions
4058from cuda .core .experimental import launch as _cuda_launch
4159
60+ # Initialize CUDA device and stream
4261_cuda_device = _CudaDevice ()
4362_cuda_device .set_current ()
4463_cuda_stream = _cuda_device .create_stream ()
64+
65+ # Configure compilation options
4566_program_options = _CudaProgramOptions (
4667 std = "c++17" ,
47- arch = f"sm_{ _cuda_device .arch } " ,
68+ arch = f"sm_{ _cuda_device .arch } " , # Target the current GPU architecture
4869 include_path = ["/usr/local/cuda/include" ],
4970)
71+
72+ # Create and compile the program
5073_program = _CudaProgram (cu_code , code_type = "c++" , options = _program_options )
5174_module = _program .compile ("ptx" , name_expressions = ("pointwise_sigmoid_kernel_nvrtc" ,))
5275_kernel = _module .get_kernel ("pointwise_sigmoid_kernel_nvrtc" )
5376
54- # Eager torch custom_op implemented using the CUDA kernel above (no Triton)
55-
5677
5778# ============================================================================
58- # Custom Op Registration
79+ # 3. Register Custom Op in PyTorch
5980# ============================================================================
60-
81+ # We register the custom operation with PyTorch so it can be used in models.
82+ # The 'mutates_args=()' argument tells PyTorch this op is functional (doesn't modify inputs in-place).
6183
6284@torch .library .custom_op ("pointwise_sigmoid_ops::pointwise_sigmoid" , mutates_args = ()) # type: ignore[misc]
6385def pointwise_sigmoid (X : torch .Tensor ) -> torch .Tensor :
86+ """
87+ Implementation of the custom op for PyTorch eager execution.
88+ This function launches the pre-compiled NVRTC kernel.
89+ """
6490 assert X .is_cuda , "Tensor must be on CUDA device."
6591 assert X .dtype == torch .float32 , "For this test, expected float32 input"
6692
6793 Y = torch .empty_like (X )
6894 N = int (X .numel ())
6995
7096 block = 256
71-
7297 grid_x = max (1 , (N + block - 1 ) // block )
7398 config = _LaunchConfig (grid = (grid_x ), block = (block ))
7499
75- # Use PyTorch's current stream by wrapping it for cuda.core
100+ # Helper class to wrap PyTorch's stream for cuda-python
76101 class _PyTorchStreamWrapper :
77102 def __init__ (self , pt_stream ):
78103 self .pt_stream = pt_stream
@@ -84,9 +109,7 @@ def __cuda_stream__(self):
84109 pt_stream = torch .cuda .current_stream ()
85110 s = _cuda_device .create_stream (_PyTorchStreamWrapper (pt_stream ))
86111
87- # Launch kernel with raw pointers as in cuda.core example
88- # Note: argument order is input, size, (matching TensorRT's convention)
89-
112+ # Launch kernel with raw pointers
90113 _cuda_launch (
91114 s ,
92115 config ,
@@ -99,34 +122,51 @@ def __cuda_stream__(self):
99122 return Y
100123
101124
125+ # ============================================================================
126+ # 4. Register Fake Implementation (Meta Kernel)
127+ # ============================================================================
128+ # The fake implementation is crucial for TorchDynamo. It tells the compiler
129+ # about the output shape and data type without actually running the kernel.
130+ # This is used during the tracing phase.
131+
102132@torch .library .register_fake ("pointwise_sigmoid_ops::pointwise_sigmoid" )
103133def _ (input : torch .Tensor ) -> torch .Tensor :
104134 """Fake implementation for TorchDynamo tracing of base operation."""
105135 return torch .empty_like (input )
106136
107137
108138# ============================================================================
109- # TensorRT Wrapper with custom_op() - THIS FAILS WITH FP16
139+ # 5. Define TensorRT AOT Plugin
110140# ============================================================================
141+ # Now we define how this operation should be handled within TensorRT.
142+ # We use the TensorRT Python Plugin API to register the plugin description,
143+ # autotuning behavior, and the AOT implementation using NVRTC.
111144
112145import tensorrt .plugin as trtp
113146from cuda .core .experimental import Device , LaunchConfig , Program , ProgramOptions
114147
115148
149+ # 5a. Plugin Description
150+ # Tells TensorRT about input/output properties (dtypes, formats)
116151@trtp .register ("pointwise_sigmoid_ops::pointwise_sigmoid" )
117152def sigmoid_plugin_desc (input : trtp .TensorDesc ) -> Tuple [trtp .TensorDesc ]:
118153 return (input .like (),)
119154
120155
156+ # 5b. Autotuning Support
157+ # Defines valid data type combinations for the plugin.
121158@trtp .autotune ("pointwise_sigmoid_ops::pointwise_sigmoid" )
122159def sigmoid_autotune (
123160 input : trtp .TensorDesc ,
124161 outputs : Tuple [trtp .TensorDesc ],
125162) -> List [trtp .AutoTuneCombination ]:
126- # Match float32 path; add FP16 if you want both
163+ # We specify that this plugin supports FP32 input and FP32 output
127164 return [trtp .AutoTuneCombination ("FP32, FP32" , "LINEAR" )]
128165
129166
167+ # 5c. AOT Implementation
168+ # This is where the magic happens. We provide the compiled PTX code and
169+ # launch parameters to TensorRT. This code runs during engine building.
130170@trtp .aot_impl ("pointwise_sigmoid_ops::pointwise_sigmoid" )
131171def sigmoid_aot_nvrtc_impl (
132172 input : trtp .TensorDesc ,
@@ -135,23 +175,22 @@ def sigmoid_aot_nvrtc_impl(
135175) -> Tuple [
136176 Union [str , bytes ], Union [str , bytes ], trtp .KernelLaunchParams , trtp .SymExprs
137177]:
138-
178+ # Get the PTX code from our pre-compiled module
139179 compiled_kernel = _module .code .decode ("utf-8" )
140- print (type (compiled_kernel ))
141- print (compiled_kernel )
142-
143- # import pdb; pdb.set_trace()
144-
180+
181+ # Calculate grid and block dimensions based on input shape
145182 N = input .shape_expr .numel ()
146183 launch_params = trtp .KernelLaunchParams ()
147184 block = 256
148185 launch_params .grid_x = trtp .cdiv (N , block )
149186 launch_params .block_x = block
150187 launch_params .shared_mem = 0
151188
189+ # Pass the number of elements (N) as an extra argument to the kernel
152190 extra_args = trtp .SymIntExprs (1 )
153191 extra_args [0 ] = trtp .SymInt32 (N )
154192
193+ # Return: kernel name, PTX code, launch params, kernel arguments
155194 return (
156195 "pointwise_sigmoid_kernel_nvrtc" ,
157196 compiled_kernel ,
@@ -160,6 +199,13 @@ def sigmoid_aot_nvrtc_impl(
160199 )
161200
162201
202+ # ============================================================================
203+ # 6. Generate Plugin Converter
204+ # ============================================================================
205+ # This registers the mapping between the PyTorch custom op and the TensorRT plugin.
206+ # It tells Torch-TensorRT: "When you see 'pointwise_sigmoid_ops::pointwise_sigmoid',
207+ # replace it with the TensorRT plugin we just defined."
208+
163209torch_tensorrt .dynamo .conversion .plugins .generate_plugin_converter (
164210 "pointwise_sigmoid_ops::pointwise_sigmoid" ,
165211 supports_dynamic_shapes = True ,
@@ -168,21 +214,15 @@ def sigmoid_aot_nvrtc_impl(
168214
169215
170216# ============================================================================
171- # Test Model
217+ # 7. Test the Model
172218# ============================================================================
173219
174-
175220class PointwiseSigmoidModel_WithTRTWrapper (torch .nn .Module ):
176221 """
177222 Test model that uses the TRT wrapper with custom_op() registration.
178-
179- When compiled with torch_tensorrt.compile() using fp16 inputs, this will
180- fail with: "could not find any supported formats consistent with input/output
181- data types"
182223 """
183224
184225 def forward (self , input : torch .Tensor ) -> torch .Tensor :
185-
186226 z = torch .ops .pointwise_sigmoid_ops .pointwise_sigmoid (input )
187227 return z
188228
@@ -191,10 +231,13 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
191231 model = PointwiseSigmoidModel_WithTRTWrapper ().to ("cuda" ).eval ()
192232 input = torch .randn (1 , 1024 , device = "cuda" , dtype = torch .float32 )
193233
234+ print ("PyTorch baseline result:" )
194235 print (torch .sigmoid (input ))
195236
237+ print ("Custom Op eager result:" )
196238 print (model (input ))
197239
240+ print ("\n Compiling with Torch-TensorRT..." )
198241 with torch_tensorrt .logging .debug ():
199242 trt_inputs = [input ]
200243 model_trt = torch_tensorrt .compile (
@@ -204,16 +247,13 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
204247 min_block_size = 1 ,
205248 )
206249 print ("Model compiled successfully!" )
250+
207251 print ("Running inference with compiled model..." )
208- print ("Compiled model output:" )
209- print (model_trt (input ))
210- print ("Original model output:" )
211- print (model (input ))
212252 with torch .no_grad ():
213253 for i in range (10 ):
214254 res = model_trt (input )
215255 assert torch .allclose (
216256 res , model (input ), rtol = 1e-2 , atol = 1e-2
217257 ), "Results do not match!"
218258
219- # print("Inference successful!")
259+ print ("Inference successful!" )
0 commit comments