Skip to content

Conversation

@mergennachin
Copy link
Contributor

@mergennachin mergennachin commented Dec 2, 2025

Summary:

In encoder-decoder models like Whisper, the encoder output tensor is
used as input to every decoder iteration, and doing unnecessary
CPU->GPU->CPU->GPU cpies.

Implemented a "keep on device" caching mechanism in the CUDA backend
that:

  • Caches encoder output in persistent GPU memory after the encoder runs
  • Uses fast GPU-to-GPU copies decoder iterations instead of slow CPU-to-GPU copies

Test Plan:

make whisper-cuda

Reviewers:

Subscribers:

Tasks:

Tags:

Summary:

In encoder-decoder models like Whisper, the encoder output tensor is
used as input to every decoder iteration, and doing unnecessary
CPU->GPU->CPU->GPU cpies.

Implemented a "keep on device" caching mechanism in the CUDA backend
that:

-  Caches encoder output in persistent GPU memory after the encoder runs
-  Uses fast GPU-to-GPU copies decoder iterations instead of slow CPU-to-GPU copies

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Copilot AI review requested due to automatic review settings December 2, 2025 23:15
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16056

Note: Links to docs will display an error until the docs builds have been completed.

❌ 14 New Failures

As of commit 37c47d4 with merge base 33ec615 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 2, 2025
@github-actions
Copy link

github-actions bot commented Dec 2, 2025

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR implements a GPU device caching mechanism in the CUDA backend to optimize encoder-decoder models like Whisper. The encoder output tensor is cached in persistent GPU memory after the encoder runs, allowing subsequent decoder iterations to use fast GPU-to-GPU copies instead of slow CPU-to-GPU transfers for the encoder output.

Key changes:

  • Added global device cache (g_device_cache) to store GPU tensors by name
  • Implemented set_option API to configure cache behavior via "cache_output" and "use_cache_input" options
  • Modified encoder-decoder workflow in runner.cpp to leverage caching for encoder output across decoder iterations

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 10 comments.

File Description
extension/asr/runner/runner.cpp Sets cache options to cache encoder output (slot 0) and use it for decoder input (slot 2)
backends/cuda/runtime/cuda_backend.cpp Implements caching infrastructure with CachedGpuData structure, set_option handler, and GPU-to-GPU copy logic for cached tensors

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +431 to +435
ET_CHECK_OR_RETURN_ERROR(
copy_err == cudaSuccess,
Internal,
"Failed to copy output to GPU cache: %s",
cudaGetErrorString(copy_err));
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Memory leak: If cudaMemcpy fails (line 426-430), the function returns early via ET_CHECK_OR_RETURN_ERROR, but the cache_ptr allocated at line 418 is never freed. Consider adding cudaFree(cache_ptr) before returning on error, or use a RAII wrapper like a unique_ptr with a custom deleter to ensure automatic cleanup.

Suggested change
ET_CHECK_OR_RETURN_ERROR(
copy_err == cudaSuccess,
Internal,
"Failed to copy output to GPU cache: %s",
cudaGetErrorString(copy_err));
if (copy_err != cudaSuccess) {
cudaFree(cache_ptr);
ET_CHECK_OR_RETURN_ERROR(
false,
Internal,
"Failed to copy output to GPU cache: %s",
cudaGetErrorString(copy_err));
}

Copilot uses AI. Check for mistakes.
std::string val(arr->data());
auto colon_pos = val.find(':');
if (colon_pos != std::string::npos) {
cache_output_slot_ = std::stoi(val.substr(0, colon_pos));
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The std::stoi call can throw std::invalid_argument or std::out_of_range exceptions if the input string is malformed or the number is too large. Since this function should return an Error rather than throwing exceptions, consider wrapping this in a try-catch block and returning an appropriate error, or validate the input string before calling std::stoi.

Copilot uses AI. Check for mistakes.
Comment on lines +268 to +276
{
::executorch::runtime::BackendOptions<1> opts;
opts.set_option("use_cache_input", "2:encoder_output");
auto err =
::executorch::runtime::set_option("CudaBackend", opts.view());
if (err != ::executorch::runtime::Error::Ok) {
ET_LOG(Info, "Failed to set use_cache_input option (backend may not support caching)");
}
}
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cache input option is set before the decoder loop but never cleared afterward. This means the cache will persist and be used for subsequent calls to transcribe(), which may not be intended. If the encoder is run again with different input, the cached encoder output will be stale but still used. Consider clearing the cache input option after the decoding loop completes (after line 332) using the "clear_cache_input" option, or document this behavior clearly.

Copilot uses AI. Check for mistakes.
// Free old cache if exists
auto old_it = g_device_cache.find(cache_output_name_);
if (old_it != g_device_cache.end()) {
cudaFree(old_it->second.data_ptr);
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old cached GPU memory is freed without checking for errors. While cudaFree rarely fails, if it does, the error is silently ignored, potentially indicating GPU state corruption. Consider checking the return value and logging a warning if the free operation fails, similar to how other CUDA operations are checked in this file.

Suggested change
cudaFree(old_it->second.data_ptr);
cudaError_t free_err = cudaFree(old_it->second.data_ptr);
if (free_err != cudaSuccess) {
std::fprintf(
stderr,
"Warning: Failed to free old GPU cache memory: %s\n",
cudaGetErrorString(free_err));
}

Copilot uses AI. Check for mistakes.

// Global device cache - maps name to cached GPU data
// Using raw GPU pointers instead of tensor handles for format independence
static std::unordered_map<std::string, CachedGpuData> g_device_cache;
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The global g_device_cache is never cleaned up, causing a memory leak. GPU memory allocated via cudaMalloc (line 418) is stored in this cache but never freed when the backend is destroyed. Consider adding cleanup logic in the destroy() method to iterate through g_device_cache and call cudaFree() on each cached data_ptr, or implement a RAII wrapper for cache management.

Copilot uses AI. Check for mistakes.

// Global device cache - maps name to cached GPU data
// Using raw GPU pointers instead of tensor handles for format independence
static std::unordered_map<std::string, CachedGpuData> g_device_cache;
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The global g_device_cache is accessed without synchronization, creating potential race conditions in multi-threaded environments. If multiple threads call execute() concurrently, simultaneous reads/writes to the cache could cause data corruption or crashes. Consider using a mutex (e.g., std::mutex) to protect all accesses to g_device_cache, or document that the CUDA backend is not thread-safe.

Copilot uses AI. Check for mistakes.

// Reset cache_output settings after caching
cache_output_slot_ = -1;
cache_output_name_.clear();
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unlike cache_output_slot_ which is reset after use (lines 454-455), the use_cache_input_slot_ and use_cache_input_name_ are never cleared after execution. This means once set, they will continue to be used for all subsequent executions, which may not be the intended behavior. Consider resetting these fields after the execute() method completes to match the pattern used for cache_output settings.

Suggested change
cache_output_name_.clear();
cache_output_name_.clear();
// Reset cache_input settings after use (fix for CodeQL warning)
use_cache_input_slot_ = -1;
use_cache_input_name_.clear();

Copilot uses AI. Check for mistakes.
std::string val(arr->data());
auto colon_pos = val.find(':');
if (colon_pos != std::string::npos) {
use_cache_input_slot_ = std::stoi(val.substr(0, colon_pos));
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The std::stoi call can throw std::invalid_argument or std::out_of_range exceptions if the input string is malformed or the number is too large. Since this function should return an Error rather than throwing exceptions, consider wrapping this in a try-catch block and returning an appropriate error, or validate the input string before calling std::stoi.

Copilot uses AI. Check for mistakes.
Comment on lines +67 to +70
mutable int cache_output_slot_ = -1; // Which output slot to cache (-1 = none)
mutable std::string cache_output_name_; // Name to cache output under
mutable int use_cache_input_slot_ = -1; // Which input slot to use cache for (-1 = none)
mutable std::string use_cache_input_name_; // Name of cached tensor to use
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These cache control fields are marked mutable and modified in const methods (set_option and execute), which is unusual and suggests potential design issues. The mutable keyword typically indicates thread-safety concerns or shared state that bypasses const-correctness. Since execute() is const but modifies these fields, concurrent calls to execute() on the same backend instance will have race conditions when accessing these member variables. Consider using proper synchronization or redesigning to avoid mutable state in const methods.

Copilot uses AI. Check for mistakes.
Comment on lines +266 to +270
// Tell CUDA backend to use cached encoder output for decoder input slot 2
// Note: Decoder input order in AOTI is: input_ids[0], cache_position[1], encoder_output[2]
{
::executorch::runtime::BackendOptions<1> opts;
opts.set_option("use_cache_input", "2:encoder_output");
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment states "Decoder input order in AOTI is: input_ids[0], cache_position[1], encoder_output[2]", but the actual order in the code is: decoder_input_ptr[0] (input_ids), encoder_output_ptr[1], cache_position_ptr[2]. The encoder_output is at index 1, not index 2. Either the comment is wrong or the cache input slot should be "1:encoder_output" instead of "2:encoder_output".

Suggested change
// Tell CUDA backend to use cached encoder output for decoder input slot 2
// Note: Decoder input order in AOTI is: input_ids[0], cache_position[1], encoder_output[2]
{
::executorch::runtime::BackendOptions<1> opts;
opts.set_option("use_cache_input", "2:encoder_output");
// Tell CUDA backend to use cached encoder output for decoder input slot 1
// Note: Decoder input order in AOTI is: input_ids[0], encoder_output[1], cache_position[2]
{
::executorch::runtime::BackendOptions<1> opts;
opts.set_option("use_cache_input", "1:encoder_output");

Copilot uses AI. Check for mistakes.
@Gasoonjia Gasoonjia temporarily deployed to upload-benchmark-results December 3, 2025 01:13 — with GitHub Actions Inactive
@mergennachin
Copy link
Contributor Author

Newer version here: #16060

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants