Add return_attention support to Generator#2010
Open
QuentinFuxa wants to merge 2 commits intoOpenNMT:masterfrom
Open
Add return_attention support to Generator#2010QuentinFuxa wants to merge 2 commits intoOpenNMT:masterfrom
QuentinFuxa wants to merge 2 commits intoOpenNMT:masterfrom
Conversation
The decoding engine already computes attention weights when requested, but this was only wired through the Translator API. This exposes the same capability for decoder-only models (Generator) by propagating the return_attention flag from GenerationOptions to DecodingOptions and transferring the attention data back to GenerationResult.
0c6549e to
948f8f2
Compare
…s API Self-attention was not returning attention weights for decoder-only models (Generator) because the attention pointer was always nullptr in TransformerDecoderLayer. Now passes the attention pointer to self-attention when there is no encoder-attention (decoder-only case). Also adds set_alignment_heads() to Generator Python API, allowing users to select specific (layer, head) pairs instead of the default (last layer, head 0). The attention from selected heads is concatenated in the output and can be reshaped to (num_heads, context_length). Fixed multi-head attention handling in decoding.cc to support variable-rank attention tensors (rank 3 for multi-head vs rank 2 for averaged).
444b38b to
9200e51
Compare
Author
|
The CI failure is unrelated to this PR, it's a flaky wav2vec2 transcription test: The model produces a slightly different transcription ( |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Expose
return_attentioninGenerationOptionsandGenerationResult, mirroring the existing support inTranslationOptions/TranslationResult. Closes #1994.The decoding engine (
decoding.cc) already computes attention weights when requested. The first commit wires the flag and data through the Generator API layer. The second commit fixes the actual attention capture for decoder-only models and adds aset_alignment_heads()method to control which heads are returned.Useful for decoder-only models (LLMs) where attention inspection is needed, e.g. for simultaneous translation to detect when the model finishes exploiting source context (AlignAtt policy for instance).
Changes
Commit 1 - Wire return_attention through Generator
include/ctranslate2/generation.h: addreturn_attentiontoGenerationOptions, addattentionfield toGenerationResultsrc/models/language_model.cc: propagatereturn_attentiontoDecodingOptions, handle EOS removal for attention rows, transfer attention to final resultpython/cpp/generator.cc: addreturn_attentionparameter togenerate_batch()python/cpp/generation_result.cc: exposeattentionon PythonGenerationResultCommit 2 - Fix self-attention capture + add set_alignment_heads
src/layers/transformer.cc:TransformerDecoderLayerwas always passingnullptras the attention pointer to self-attention. This worked for encoder-decoder models (cross-attention fills it), but for decoder-only models the attention was never captured, causing aGather: rank >= 1crash. Fixed by passing the attention pointer to self-attention when there is no encoder-attention.python/cpp/generator.cc,include/ctranslate2/models/language_model.h,src/models/language_model.cc,include/ctranslate2/layers/decoder.h,include/ctranslate2/layers/transformer.h,include/ctranslate2/replica_pool.h: addset_alignment_heads()on the Generator Python API to select specific(layer, head)pairs. By default only head 0 of the last layer is returned. With this method, multiple heads from any layer can be collected and the output is concatenated.src/decoding.cc: fix three code paths (greedy search,build_attention,decode_alternatives) to handle multi-head attention tensors (rank 3) alongside averaged attention (rank 2). Replacedindex()calls with pointer arithmetic to support both ranks.Usage
Tests done
Tested on Gemma3 4B (TranslateGemma, FR->EN) on CPU (Apple Silicon, int8):
[(33,0), (33,1), (33,2), (33,3)]: returns vectors of length 368 (4 * 92), correctly reshapable to (4, 92). Each head shows different attention patterns.[(30,0), (31,3), (32,5), (33,7)]: works correctly across layers. Different layers attend to different parts of the input.