Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 36 additions & 13 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ bool FindCycleHelper(size_t i, gsl::span<const InlinedVector<size_t>> adjacency_
st[i] = false;
return false;
}

#if NV_TENSORRT_MAJOR < 11
bool SetDynamicRange(nvinfer1::INetworkDefinition& network, std::unordered_map<std::string, float>& dynamic_range_map) {
// Set dynamic range for input tensors
for (int i = 0; i < network.getNbInputs(); ++i) {
Expand Down Expand Up @@ -153,6 +153,7 @@ bool SetDynamicRange(nvinfer1::INetworkDefinition& network, std::unordered_map<s
}
return true;
}
#endif // NV_TENSORRT_MAJOR < 11

std::vector<std::string> SplitToStringVec(std::string const& s, char separator) {
std::vector<std::string> splitted;
Expand Down Expand Up @@ -1231,7 +1232,7 @@ void TensorrtExecutionProvider::PerThreadContext::ResetTensorRTContext(std::stri

bool TensorrtExecutionProvider::PerThreadContext::UpdateTensorRTContext(std::string fused_node, std::unique_ptr<nvinfer1::IExecutionContext> context) {
if (!context) {
context = std::make_unique<nvinfer1::IExecutionContext>();
ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "UpdateTensorRTContext: Provided context was nullptr!"));
}
trt_context_map_[fused_node] = std::move(context);

Expand All @@ -1251,12 +1252,10 @@ bool TensorrtExecutionProvider::PerThreadContext::IsTensorRTContextInMap(std::st

nvinfer1::IExecutionContext& TensorrtExecutionProvider::PerThreadContext::GetTensorRTContext(std::string fused_node) {
auto it = trt_context_map_.find(fused_node);
if (it != trt_context_map_.end()) {
return *(it->second); // dereference shared pointer
if (it == trt_context_map_.end()) {
ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "GetTensorRTContext: Requested context was not found!"));
}
auto context = std::make_unique<nvinfer1::IExecutionContext>();
trt_context_map_[fused_node] = std::move(context);
return *(trt_context_map_[fused_node]); // dereference shared pointer
return *(it->second); // dereference shared pointer
}

void TensorrtExecutionProvider::ReleasePerThreadContext() const {
Expand Down Expand Up @@ -1649,6 +1648,16 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
}
}

// In TRT 11.0 precision flags are removed.
#if NV_TENSORRT_MAJOR >= 11
if (fp16_enable_ || bf16_enable_ || int8_enable_) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] TensorRT EP was compiled for TensorRT version >= 11.0 - precision flags (BF16 / FP16 / INT8) have been removed and no longer have an effect. Strongly-typed will be used for all networks.";
fp16_enable_ = false;
bf16_enable_ = false;
int8_enable_ = false;
}
#endif

// Validate setting
if (max_partition_iterations_ <= 0) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] TensorRT option trt_max_partition_iterations must be a positive integer value. Set it to 1000";
Expand Down Expand Up @@ -2382,7 +2391,9 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_);
auto trt_builder = GetBuilder(trt_logger);
auto network_flags = 0;
#if NV_TENSORRT_MAJOR > 8
#if NV_TENSORRT_VERSION >= 11
network_flags |= 0;
#elif NV_TENSORRT_MAJOR > 8
network_flags |= (fp16_enable_ || int8_enable_ || bf16_enable_) ? 0 : 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED);
#else
network_flags |= 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
Expand Down Expand Up @@ -3144,7 +3155,9 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_);
auto trt_builder = GetBuilder(trt_logger);
auto network_flags = 0;
#if NV_TENSORRT_MAJOR > 8
#if NV_TENSORRT_VERSION >= 11
network_flags |= 0;
#elif NV_TENSORRT_MAJOR > 8
network_flags |= (fp16_enable_ || int8_enable_ || bf16_enable_) ? 0 : 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED);
#else
network_flags |= 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
Expand All @@ -3171,6 +3184,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
}

// Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow
#if NV_TENSORRT_MAJOR < 11
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4996)
Expand All @@ -3191,6 +3205,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
#if defined(_MSC_VER)
#pragma warning(pop)
#endif
#endif // NV_TENSORRT_MAJOR < 11

int num_inputs = trt_network->getNbInputs();
int num_outputs = trt_network->getNbOutputs();
Expand Down Expand Up @@ -3326,6 +3341,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
trt_profiles.push_back(trt_builder->createOptimizationProfile());
}

#if NV_TENSORRT_MAJOR < 11
// Check platform availability for low precision
if (fp16_enable_ || bf16_enable_) {
#if defined(_MSC_VER)
Expand Down Expand Up @@ -3355,6 +3371,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_INT8_ENABLE is set, but platform doesn't support fast native int8";
}
}
#endif // NV_TENSORRT_MAJOR < 11

// Load INT8 calibration table
std::unordered_map<std::string, float> dynamic_range_map;
Expand All @@ -3364,13 +3381,14 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
throw std::runtime_error("Failed to read INT8 calibration table " + calibration_cache_path);
}
}
std::string trt_node_name_with_precision = fused_node.Name();

#if NV_TENSORRT_MAJOR < 11
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4996)
#endif
// Set precision flags
std::string trt_node_name_with_precision = fused_node.Name();
if (fp16_enable_) {
trt_config->setFlag(nvinfer1::BuilderFlag::kFP16);
trt_node_name_with_precision += "_fp16";
Expand All @@ -3389,6 +3407,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
#if defined(_MSC_VER)
#pragma warning(pop)
#endif
#endif // NV_TENSORRT_MAJOR >= 11
// Set DLA
if (fp16_enable_ || int8_enable_) {
if (dla_enable_ && dla_core_ >= 0) { // DLA can only run with FP16 and INT8
Expand Down Expand Up @@ -3581,6 +3600,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
"TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path);
}
} else {
#if NV_TENSORRT_MAJOR < 11
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4996)
Expand All @@ -3596,6 +3616,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
"TensorRT EP could not set INT8 dynamic range for fused node: " + fused_node.Name());
}
}
#endif // NV_TENSORRT_MAJOR < 11

// Load timing cache from file. Create a fresh cache if the file doesn't exist
std::unique_ptr<nvinfer1::ITimingCache> timing_cache = nullptr;
Expand Down Expand Up @@ -3995,6 +4016,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
for (auto trt_profile : trt_profiles) {
trt_config->addOptimizationProfile(trt_profile);
}
#if NV_TENSORRT_MAJOR < 11
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4996)
Expand Down Expand Up @@ -4029,6 +4051,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
#if defined(_MSC_VER)
#pragma warning(pop)
#endif
#endif // NV_TENSORRT_MAJOR < 11
// Set DLA (DLA can only run with FP16 or INT8)
if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) {
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core;
Expand Down Expand Up @@ -4302,7 +4325,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
#pragma warning(push)
#pragma warning(disable : 4996)
#endif
size_t mem_size = trt_engine->getDeviceMemorySize();
size_t mem_size = trt_engine->getDeviceMemorySizeV2();
Copy link
Copy Markdown
Contributor

@chilo-ms chilo-ms May 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this API call needs to be guarded as well? as current TRT EP still supports building with older TRT, e.g. TRT 8. We might consider remove the support for TRT 8 though.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK with removing support for TRT 8

Copy link
Copy Markdown
Contributor

@chilo-ms chilo-ms May 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this PR let's still keep the support for TRT 8 for now as there might be some customers using TRT 8 and we should notice them before we really remove the support.
So i think we should put guard on getDeviceMemorySizeV2 here for TRT >= 10?

#if defined(_MSC_VER)
#pragma warning(pop)
#endif
Expand Down Expand Up @@ -4445,7 +4468,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con
#pragma warning(push)
#pragma warning(disable : 4996)
#endif
size_t mem_size = trt_engine->getDeviceMemorySize();
size_t mem_size = trt_engine->getDeviceMemorySizeV2();
#if defined(_MSC_VER)
#pragma warning(pop)
#endif
Expand Down Expand Up @@ -4633,7 +4656,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con
#pragma warning(push)
#pragma warning(disable : 4996)
#endif
size_t mem_size = trt_engine->getDeviceMemorySize();
size_t mem_size = trt_engine->getDeviceMemorySizeV2();
#if defined(_MSC_VER)
#pragma warning(pop)
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ typedef void* cudnnStatus_t;
#include "core/providers/cuda/cuda_graph.h"
#include "tensorrt_execution_provider_info.h"

// These types used to come from NvOnnxParser.h, but they've been removed.
#if NV_TENSORRT_MAJOR >= 11
#include <utility>
#include <vector>
using SubGraph_t = std::pair<std::vector<size_t>, bool>;
using SubGraphCollection_t = std::vector<SubGraph_t>;
#endif

namespace onnxruntime {

namespace tensorrt_env_vars {
Expand Down
Loading