Skip to content

Add return_attention support to Generator#2010

Open
QuentinFuxa wants to merge 2 commits intoOpenNMT:masterfrom
QuentinFuxa:feature/generator-return-attention
Open

Add return_attention support to Generator#2010
QuentinFuxa wants to merge 2 commits intoOpenNMT:masterfrom
QuentinFuxa:feature/generator-return-attention

Conversation

@QuentinFuxa
Copy link

@QuentinFuxa QuentinFuxa commented Feb 6, 2026

Summary

Expose return_attention in GenerationOptions and GenerationResult, mirroring the existing support in TranslationOptions/TranslationResult. Closes #1994.

The decoding engine (decoding.cc) already computes attention weights when requested. The first commit wires the flag and data through the Generator API layer. The second commit fixes the actual attention capture for decoder-only models and adds a set_alignment_heads() method to control which heads are returned.

Useful for decoder-only models (LLMs) where attention inspection is needed, e.g. for simultaneous translation to detect when the model finishes exploiting source context (AlignAtt policy for instance).

Changes

Commit 1 - Wire return_attention through Generator

  • include/ctranslate2/generation.h: add return_attention to GenerationOptions, add attention field to GenerationResult
  • src/models/language_model.cc: propagate return_attention to DecodingOptions, handle EOS removal for attention rows, transfer attention to final result
  • python/cpp/generator.cc: add return_attention parameter to generate_batch()
  • python/cpp/generation_result.cc: expose attention on Python GenerationResult

Commit 2 - Fix self-attention capture + add set_alignment_heads

  • src/layers/transformer.cc: TransformerDecoderLayer was always passing nullptr as the attention pointer to self-attention. This worked for encoder-decoder models (cross-attention fills it), but for decoder-only models the attention was never captured, causing a Gather: rank >= 1 crash. Fixed by passing the attention pointer to self-attention when there is no encoder-attention.
  • python/cpp/generator.cc, include/ctranslate2/models/language_model.h, src/models/language_model.cc, include/ctranslate2/layers/decoder.h, include/ctranslate2/layers/transformer.h, include/ctranslate2/replica_pool.h: add set_alignment_heads() on the Generator Python API to select specific (layer, head) pairs. By default only head 0 of the last layer is returned. With this method, multiple heads from any layer can be collected and the output is concatenated.
  • src/decoding.cc: fix three code paths (greedy search, build_attention, decode_alternatives) to handle multi-head attention tensors (rank 3) alongside averaged attention (rank 2). Replaced index() calls with pointer arithmetic to support both ranks.

Usage

generator = ctranslate2.Generator("model/", compute_type="auto")

# Default: head 0 of last layer
results = generator.generate_batch(
    [tokens], return_attention=True, beam_size=1, max_length=100
)
# results[0].attention[0][step] has length = context_length

# Multi-head: pick specific (layer, head) pairs
generator.set_alignment_heads([(30, 0), (31, 3), (32, 5), (33, 7)])
results = generator.generate_batch(
    [tokens], return_attention=True, beam_size=1, max_length=100
)
# results[0].attention[0][step] has length = num_heads * context_length
# reshape: np.array(vec).reshape(num_heads, -1)

Tests done

Tested on Gemma3 4B (TranslateGemma, FR->EN) on CPU (Apple Silicon, int8):

  1. Default (1 head): returns attention vectors of length = context_length (92 tokens). Translation works correctly.
  2. 4 heads, same layer [(33,0), (33,1), (33,2), (33,3)]: returns vectors of length 368 (4 * 92), correctly reshapable to (4, 92). Each head shows different attention patterns.
  3. 4 heads, different layers [(30,0), (31,3), (32,5), (33,7)]: works correctly across layers. Different layers attend to different parts of the input.
  4. Without return_attention: still works as before, no regression.

The decoding engine already computes attention weights when requested,
but this was only wired through the Translator API. This exposes the
same capability for decoder-only models (Generator) by propagating
the return_attention flag from GenerationOptions to DecodingOptions
and transferring the attention data back to GenerationResult.
@QuentinFuxa QuentinFuxa force-pushed the feature/generator-return-attention branch from 0c6549e to 948f8f2 Compare February 7, 2026 11:55
@QuentinFuxa QuentinFuxa marked this pull request as ready for review February 7, 2026 11:56
…s API

Self-attention was not returning attention weights for decoder-only models
(Generator) because the attention pointer was always nullptr in
TransformerDecoderLayer. Now passes the attention pointer to self-attention
when there is no encoder-attention (decoder-only case).

Also adds set_alignment_heads() to Generator Python API, allowing users to
select specific (layer, head) pairs instead of the default (last layer,
head 0). The attention from selected heads is concatenated in the output
and can be reshaped to (num_heads, context_length).

Fixed multi-head attention handling in decoding.cc to support variable-rank
attention tensors (rank 3 for multi-head vs rank 2 for averaged).
@QuentinFuxa QuentinFuxa force-pushed the feature/generator-return-attention branch from 444b38b to 9200e51 Compare February 20, 2026 08:48
@QuentinFuxa
Copy link
Author

QuentinFuxa commented Feb 22, 2026

The CI failure is unrelated to this PR, it's a flaky wav2vec2 transcription test:

FAILED test_transformers.py::TestWav2Vec2::test_transformers_wav2vec2[facebook/wav2vec2-large-robust-ft-swbd-300h-expected_transcription0-cpu]

Expected: MISTER QUILTER IS THE APOSSEL OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL
Got:      MISTER QUILTER IS THE APOSSTEL OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL

The model produces a slightly different transcription (APOSSTEL vs APOSSEL) likely due to environment/library version differences. All other 166 tests passed. Could a maintainer re-run the failed job?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

How to return_attention from Generator.generate_tokens ?

1 participant