From fcb8975c46b2ef9b5b2156620dcad047b8a53a6b Mon Sep 17 00:00:00 2001 From: Kevin Chen Date: Thu, 21 May 2026 10:04:01 -0700 Subject: [PATCH] Guard removed APIs in TRT 11 Signed-off-by: Kevin Chen --- .../tensorrt/tensorrt_execution_provider.cc | 49 ++++++++++++++----- .../tensorrt/tensorrt_execution_provider.h | 8 +++ 2 files changed, 44 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 629c4cd3fc29c..0a97280b40436 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -64,7 +64,7 @@ bool FindCycleHelper(size_t i, gsl::span> adjacency_ st[i] = false; return false; } - +#if NV_TENSORRT_MAJOR < 11 bool SetDynamicRange(nvinfer1::INetworkDefinition& network, std::unordered_map& dynamic_range_map) { // Set dynamic range for input tensors for (int i = 0; i < network.getNbInputs(); ++i) { @@ -153,6 +153,7 @@ bool SetDynamicRange(nvinfer1::INetworkDefinition& network, std::unordered_map SplitToStringVec(std::string const& s, char separator) { std::vector splitted; @@ -1231,7 +1232,7 @@ void TensorrtExecutionProvider::PerThreadContext::ResetTensorRTContext(std::stri bool TensorrtExecutionProvider::PerThreadContext::UpdateTensorRTContext(std::string fused_node, std::unique_ptr context) { if (!context) { - context = std::make_unique(); + ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "UpdateTensorRTContext: Provided context was nullptr!")); } trt_context_map_[fused_node] = std::move(context); @@ -1251,12 +1252,10 @@ bool TensorrtExecutionProvider::PerThreadContext::IsTensorRTContextInMap(std::st nvinfer1::IExecutionContext& TensorrtExecutionProvider::PerThreadContext::GetTensorRTContext(std::string fused_node) { auto it = trt_context_map_.find(fused_node); - if (it != trt_context_map_.end()) { - return *(it->second); // dereference shared pointer + if (it == trt_context_map_.end()) { + ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "GetTensorRTContext: Requested context was not found!")); } - auto context = std::make_unique(); - trt_context_map_[fused_node] = std::move(context); - return *(trt_context_map_[fused_node]); // dereference shared pointer + return *(it->second); // dereference shared pointer } void TensorrtExecutionProvider::ReleasePerThreadContext() const { @@ -1649,6 +1648,16 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv } } + // In TRT 11.0 precision flags are removed. +#if NV_TENSORRT_MAJOR >= 11 + if (fp16_enable_ || bf16_enable_ || int8_enable_) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] TensorRT EP was compiled for TensorRT version >= 11.0 - precision flags (BF16 / FP16 / INT8) have been removed and no longer have an effect. Strongly-typed will be used for all networks."; + fp16_enable_ = false; + bf16_enable_ = false; + int8_enable_ = false; + } +#endif + // Validate setting if (max_partition_iterations_ <= 0) { LOGS_DEFAULT(WARNING) << "[TensorRT EP] TensorRT option trt_max_partition_iterations must be a positive integer value. Set it to 1000"; @@ -2382,7 +2391,9 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_); auto trt_builder = GetBuilder(trt_logger); auto network_flags = 0; -#if NV_TENSORRT_MAJOR > 8 +#if NV_TENSORRT_VERSION >= 11 + network_flags |= 0; +#elif NV_TENSORRT_MAJOR > 8 network_flags |= (fp16_enable_ || int8_enable_ || bf16_enable_) ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); #else network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); @@ -3144,7 +3155,9 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_); auto trt_builder = GetBuilder(trt_logger); auto network_flags = 0; -#if NV_TENSORRT_MAJOR > 8 +#if NV_TENSORRT_VERSION >= 11 + network_flags |= 0; +#elif NV_TENSORRT_MAJOR > 8 network_flags |= (fp16_enable_ || int8_enable_ || bf16_enable_) ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); #else network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); @@ -3171,6 +3184,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView } // Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow +#if NV_TENSORRT_MAJOR < 11 #if defined(_MSC_VER) #pragma warning(push) #pragma warning(disable : 4996) @@ -3191,6 +3205,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView #if defined(_MSC_VER) #pragma warning(pop) #endif +#endif // NV_TENSORRT_MAJOR < 11 int num_inputs = trt_network->getNbInputs(); int num_outputs = trt_network->getNbOutputs(); @@ -3326,6 +3341,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView trt_profiles.push_back(trt_builder->createOptimizationProfile()); } +#if NV_TENSORRT_MAJOR < 11 // Check platform availability for low precision if (fp16_enable_ || bf16_enable_) { #if defined(_MSC_VER) @@ -3355,6 +3371,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_INT8_ENABLE is set, but platform doesn't support fast native int8"; } } +#endif // NV_TENSORRT_MAJOR < 11 // Load INT8 calibration table std::unordered_map dynamic_range_map; @@ -3364,13 +3381,14 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView throw std::runtime_error("Failed to read INT8 calibration table " + calibration_cache_path); } } + std::string trt_node_name_with_precision = fused_node.Name(); +#if NV_TENSORRT_MAJOR < 11 #if defined(_MSC_VER) #pragma warning(push) #pragma warning(disable : 4996) #endif // Set precision flags - std::string trt_node_name_with_precision = fused_node.Name(); if (fp16_enable_) { trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); trt_node_name_with_precision += "_fp16"; @@ -3389,6 +3407,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView #if defined(_MSC_VER) #pragma warning(pop) #endif +#endif // NV_TENSORRT_MAJOR >= 11 // Set DLA if (fp16_enable_ || int8_enable_) { if (dla_enable_ && dla_core_ >= 0) { // DLA can only run with FP16 and INT8 @@ -3581,6 +3600,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path); } } else { +#if NV_TENSORRT_MAJOR < 11 #if defined(_MSC_VER) #pragma warning(push) #pragma warning(disable : 4996) @@ -3596,6 +3616,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView "TensorRT EP could not set INT8 dynamic range for fused node: " + fused_node.Name()); } } +#endif // NV_TENSORRT_MAJOR < 11 // Load timing cache from file. Create a fresh cache if the file doesn't exist std::unique_ptr timing_cache = nullptr; @@ -3995,6 +4016,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView for (auto trt_profile : trt_profiles) { trt_config->addOptimizationProfile(trt_profile); } +#if NV_TENSORRT_MAJOR < 11 #if defined(_MSC_VER) #pragma warning(push) #pragma warning(disable : 4996) @@ -4029,6 +4051,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView #if defined(_MSC_VER) #pragma warning(pop) #endif +#endif // NV_TENSORRT_MAJOR < 11 // Set DLA (DLA can only run with FP16 or INT8) if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) { LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core; @@ -4302,7 +4325,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView #pragma warning(push) #pragma warning(disable : 4996) #endif - size_t mem_size = trt_engine->getDeviceMemorySize(); + size_t mem_size = trt_engine->getDeviceMemorySizeV2(); #if defined(_MSC_VER) #pragma warning(pop) #endif @@ -4445,7 +4468,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con #pragma warning(push) #pragma warning(disable : 4996) #endif - size_t mem_size = trt_engine->getDeviceMemorySize(); + size_t mem_size = trt_engine->getDeviceMemorySizeV2(); #if defined(_MSC_VER) #pragma warning(pop) #endif @@ -4633,7 +4656,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con #pragma warning(push) #pragma warning(disable : 4996) #endif - size_t mem_size = trt_engine->getDeviceMemorySize(); + size_t mem_size = trt_engine->getDeviceMemorySizeV2(); #if defined(_MSC_VER) #pragma warning(pop) #endif diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index e817fc51237c0..c2cbb2227fae2 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -16,6 +16,14 @@ typedef void* cudnnStatus_t; #include "core/providers/cuda/cuda_graph.h" #include "tensorrt_execution_provider_info.h" +// These types used to come from NvOnnxParser.h, but they've been removed. +#if NV_TENSORRT_MAJOR >= 11 +#include +#include +using SubGraph_t = std::pair, bool>; +using SubGraphCollection_t = std::vector; +#endif + namespace onnxruntime { namespace tensorrt_env_vars {