From 948f8f2677e389e4e0ae0d4eff1f6b562df793cb Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Fri, 6 Feb 2026 18:33:41 +0100 Subject: [PATCH 1/2] Add return_attention support to Generator The decoding engine already computes attention weights when requested, but this was only wired through the Translator API. This exposes the same capability for decoder-only models (Generator) by propagating the return_attention flag from GenerationOptions to DecodingOptions and transferring the attention data back to GenerationResult. --- include/ctranslate2/generation.h | 3 +++ python/cpp/generation_result.cc | 3 +++ python/cpp/generator.cc | 4 ++++ src/models/language_model.cc | 12 +++++++++--- 4 files changed, 19 insertions(+), 3 deletions(-) diff --git a/include/ctranslate2/generation.h b/include/ctranslate2/generation.h index bd76146ff..b4bfa732e 100644 --- a/include/ctranslate2/generation.h +++ b/include/ctranslate2/generation.h @@ -53,6 +53,8 @@ namespace ctranslate2 { // Include scores in the result. bool return_scores = false; + // Store attention vectors in the GenerationResult class. + bool return_attention = false; // Include log probs of each token in the result bool return_logits_vocab = false; @@ -81,6 +83,7 @@ namespace ctranslate2 { std::vector> sequences; std::vector> sequences_ids; std::vector scores; + std::vector>> attention; std::vector> logits; size_t num_sequences() const { diff --git a/python/cpp/generation_result.cc b/python/cpp/generation_result.cc index 5964a5d5b..113c8543e 100644 --- a/python/cpp/generation_result.cc +++ b/python/cpp/generation_result.cc @@ -49,6 +49,8 @@ namespace ctranslate2 { "Generated sequences of token IDs.") .def_readonly("scores", &GenerationResult::scores, "Score of each sequence (empty if :obj:`return_scores` was disabled).") + .def_readonly("attention", &GenerationResult::attention, + "Attention matrix of each sequence (empty if :obj:`return_attention` was disabled).") .def_readonly("logits", &GenerationResult::logits, "Logits of each sequence (empty if :obj:`return_logits_vocab` was disabled).") @@ -56,6 +58,7 @@ namespace ctranslate2 { return "GenerationResult(sequences=" + std::string(py::repr(py::cast(result.sequences))) + ", sequences_ids=" + std::string(py::repr(py::cast(result.sequences_ids))) + ", scores=" + std::string(py::repr(py::cast(result.scores))) + + ", attention=" + std::string(py::repr(py::cast(result.attention))) + ", logits=" + std::string(py::repr(py::cast(result.logits))) + ")"; }) diff --git a/python/cpp/generator.cc b/python/cpp/generator.cc index 0647552d0..4d9292cee 100644 --- a/python/cpp/generator.cc +++ b/python/cpp/generator.cc @@ -33,6 +33,7 @@ namespace ctranslate2 { bool cache_static_prompt, bool include_prompt_in_result, bool return_scores, + bool return_attention, bool return_logits_vocab, bool return_alternatives, float min_alternative_expansion_prob, @@ -59,6 +60,7 @@ namespace ctranslate2 { options.num_hypotheses = num_hypotheses; options.return_end_token = return_end_token; options.return_scores = return_scores; + options.return_attention = return_attention; options.return_logits_vocab = return_logits_vocab; options.return_alternatives = return_alternatives; options.cache_static_prompt = cache_static_prompt; @@ -205,6 +207,7 @@ namespace ctranslate2 { py::arg("cache_static_prompt")=true, py::arg("include_prompt_in_result")=true, py::arg("return_scores")=false, + py::arg("return_attention")=false, py::arg("return_logits_vocab")=false, py::arg("return_alternatives")=false, py::arg("min_alternative_expansion_prob")=0, @@ -263,6 +266,7 @@ namespace ctranslate2 { reuse it for future generations using the same static prompt. include_prompt_in_result: Include the :obj:`start_tokens` in the result. return_scores: Include the scores in the output. + return_attention: Include the attention matrices in the output. return_logits_vocab: Include log probs for each token in the output return_alternatives: Return alternatives at the first unconstrained decoding position. min_alternative_expansion_prob: Minimum initial probability to expand an alternative. diff --git a/src/models/language_model.cc b/src/models/language_model.cc index 5a23fa35a..d52928e16 100644 --- a/src/models/language_model.cc +++ b/src/models/language_model.cc @@ -165,6 +165,7 @@ namespace ctranslate2 { decoding_options.sampling_temperature = options.sampling_temperature; decoding_options.num_hypotheses = options.num_hypotheses; decoding_options.return_scores = options.return_scores; + decoding_options.return_attention = options.return_attention; decoding_options.return_logits_vocab = options.return_logits_vocab; decoding_options.return_alternatives = options.return_alternatives; decoding_options.min_alternative_expansion_prob = options.min_alternative_expansion_prob; @@ -251,9 +252,13 @@ namespace ctranslate2 { // Remove EOS token. if (!options.return_end_token) { - for (auto& sequence : result.hypotheses) { - while (!sequence.empty() && is_eos(sequence.back(), end_ids)) - sequence.pop_back(); + for (size_t h = 0; h < result.hypotheses.size(); ++h) { + while (!result.hypotheses[h].empty() + && is_eos(result.hypotheses[h].back(), end_ids)) { + result.hypotheses[h].pop_back(); + if (!result.attention.empty()) + result.attention[h].pop_back(); + } } } @@ -269,6 +274,7 @@ namespace ctranslate2 { final_result.sequences = vocabulary.to_tokens(result.hypotheses); final_result.sequences_ids = std::move(result.hypotheses); final_result.scores = std::move(result.scores); + final_result.attention = std::move(result.attention); final_result.logits = std::move(result.logits_vocab); final_results.emplace_back(std::move(final_result)); } From 9200e51e188662b0d4573c466bcf5c2c6f335552 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Fri, 20 Feb 2026 08:59:53 +0100 Subject: [PATCH 2/2] Fix self-attention for decoder-only models and add set_alignment_heads API Self-attention was not returning attention weights for decoder-only models (Generator) because the attention pointer was always nullptr in TransformerDecoderLayer. Now passes the attention pointer to self-attention when there is no encoder-attention (decoder-only case). Also adds set_alignment_heads() to Generator Python API, allowing users to select specific (layer, head) pairs instead of the default (last layer, head 0). The attention from selected heads is concatenated in the output and can be reshaped to (num_heads, context_length). Fixed multi-head attention handling in decoding.cc to support variable-rank attention tensors (rank 3 for multi-head vs rank 2 for averaged). --- include/ctranslate2/layers/decoder.h | 5 ++++ include/ctranslate2/layers/transformer.h | 2 +- include/ctranslate2/models/language_model.h | 8 ++++++ include/ctranslate2/replica_pool.h | 9 +++++++ python/cpp/generator.cc | 23 ++++++++++++++++ src/decoding.cc | 29 ++++++++++++++++----- src/layers/transformer.cc | 6 ++--- src/models/language_model.cc | 5 ++++ 8 files changed, 76 insertions(+), 11 deletions(-) diff --git a/include/ctranslate2/layers/decoder.h b/include/ctranslate2/layers/decoder.h index 7d7a1f51a..266c1af20 100644 --- a/include/ctranslate2/layers/decoder.h +++ b/include/ctranslate2/layers/decoder.h @@ -20,6 +20,11 @@ namespace ctranslate2 { public: Decoder(Device device); + // Configure which attention heads to collect when return_attention is enabled. + virtual void set_alignment_heads(const std::vector>& alignment_heads) { + (void)alignment_heads; + } + virtual DecoderState initial_state(bool iterative_decoding = true) const = 0; // Forwards one step. diff --git a/include/ctranslate2/layers/transformer.h b/include/ctranslate2/layers/transformer.h index 01a7694c6..26eac7da8 100644 --- a/include/ctranslate2/layers/transformer.h +++ b/include/ctranslate2/layers/transformer.h @@ -185,7 +185,7 @@ namespace ctranslate2 { StorageView* attention = nullptr) override; void set_alignment_heads(const dim_t layer, const dim_t num_heads_to_average); - void set_alignment_heads(const std::vector>& alignment_heads); + void set_alignment_heads(const std::vector>& alignment_heads) override; std::unique_ptr get_layer_alignment_heads(const dim_t layer, const dim_t batch_size) const; diff --git a/include/ctranslate2/models/language_model.h b/include/ctranslate2/models/language_model.h index 7532b9a3a..6adb1f8f0 100644 --- a/include/ctranslate2/models/language_model.h +++ b/include/ctranslate2/models/language_model.h @@ -58,6 +58,12 @@ namespace ctranslate2 { const StorageView& lengths, const bool return_log_probs); + // Configure which attention heads to collect when return_attention is enabled. + // Each pair is (layer_index, head_index). + virtual void set_alignment_heads(const std::vector>& alignment_heads) { + (void)alignment_heads; + } + protected: virtual bool skip_scoring(const std::vector& tokens, const ScoringOptions& options, @@ -89,6 +95,8 @@ namespace ctranslate2 { DecoderReplica(const std::shared_ptr& model, std::unique_ptr decoder); + void set_alignment_heads(const std::vector>& alignment_heads) override; + protected: bool skip_scoring(const std::vector& tokens, const ScoringOptions& options, diff --git a/include/ctranslate2/replica_pool.h b/include/ctranslate2/replica_pool.h index 8c8e15d8e..fd5ee8ef1 100644 --- a/include/ctranslate2/replica_pool.h +++ b/include/ctranslate2/replica_pool.h @@ -152,6 +152,15 @@ namespace ctranslate2 { return worker.replica(); } + // Apply a function to each replica. Not thread-safe. + template + void for_each_replica(Func func) { + for (size_t i = 0; i < num_replicas(); ++i) { + auto& worker = static_cast&>(_thread_pool->get_worker(i)); + func(worker.replica()); + } + } + protected: template std::vector> diff --git a/python/cpp/generator.cc b/python/cpp/generator.cc index 4d9292cee..664f87743 100644 --- a/python/cpp/generator.cc +++ b/python/cpp/generator.cc @@ -11,6 +11,12 @@ namespace ctranslate2 { public: using ReplicaPoolHelper::ReplicaPoolHelper; + void set_alignment_heads(const std::vector>& alignment_heads) { + _pool->for_each_replica([&](models::SequenceGeneratorReplica& replica) { + replica.set_alignment_heads(alignment_heads); + }); + } + std::variant, std::vector>> generate_batch(const BatchTokens& tokens, @@ -185,6 +191,23 @@ namespace ctranslate2 { .def_property_readonly("num_active_batches", &GeneratorWrapper::num_active_batches, "Number of batches waiting to be processed or currently processed.") + .def("set_alignment_heads", &GeneratorWrapper::set_alignment_heads, + py::arg("alignment_heads"), + R"pbdoc( + Configure which attention heads to collect when ``return_attention=True``. + + By default, only head 0 of the last layer is returned (averaged). + Use this method to select specific (layer, head) pairs. The attention + from the selected heads will be concatenated in the output. + + Arguments: + alignment_heads: List of (layer_index, head_index) pairs to collect. + + Example: + + >>> generator.set_alignment_heads([(31, 0), (31, 3), (33, 7)]) + )pbdoc") + .def("generate_batch", &GeneratorWrapper::generate_batch, py::arg("start_tokens"), py::kw_only(), diff --git a/src/decoding.cc b/src/decoding.cc index 84f39ac37..97ef5daee 100644 --- a/src/decoding.cc +++ b/src/decoding.cc @@ -146,13 +146,23 @@ namespace ctranslate2 { if (!history) return {}; - const auto source_length = history.dim(-1); + // For averaged attention: history is (batch, beam, steps, ctx) + // For per-head attention: history is (batch, beam, steps, heads, ctx) + // Compute total floats per time step (ctx or heads*ctx). + dim_t step_size = 1; + for (dim_t d = 3; d < history.rank(); ++d) + step_size *= history.dim(d); std::vector> attention; attention.reserve(end - start); + // Compute stride for the steps dimension: step_size floats per step. + // Base offset for (batch, beam) = batch * (beam_stride) + beam * (steps * step_size). + const dim_t steps = history.dim(2); + const dim_t beam_stride = steps * step_size; + const float* base = history.data() + batch * history.dim(1) * beam_stride + beam * beam_stride; for (dim_t t = start; t < end; ++t) { - const auto* vector = history.index({batch, beam, t, 0}); - attention.emplace_back(vector, vector + source_length); + const float* vector = base + t * step_size; + attention.emplace_back(vector, vector + step_size); } return attention; } @@ -911,8 +921,11 @@ namespace ctranslate2 { && (return_prefix || step >= prefix_length)) { results[batch_id].hypotheses[0].push_back(word_id); if (attention_step) { - const auto* attn = attention_step.index({i, 0}); - results[batch_id].attention[0].emplace_back(attn, attn + attention_step.dim(-1)); + // For averaged attention: shape (batch, ctx) -> take ctx floats + // For per-head attention: shape (batch, heads, ctx) -> take heads*ctx floats + const dim_t attn_size = attention_step.size() / attention_step.dim(0); + const auto* attn = attention_step.data() + i * attn_size; + results[batch_id].attention[0].emplace_back(attn, attn + attn_size); } } @@ -1166,9 +1179,11 @@ namespace ctranslate2 { if (options.return_attention) { if (attention.device() != Device::CPU) attention = attention.to_float32().to(Device::CPU); + // Compute floats per time step (ctx or heads*ctx for multi-head). + const dim_t step_size = attention.size() / (attention.dim(0) * attention.dim(1)); for (dim_t t = 0; t < prefix_length; ++t) { - const float* vector = attention.index({0, t, 0}); - result.attention[i].emplace_back(vector, vector + attention.dim(-1)); + const float* vector = attention.data() + t * step_size; + result.attention[i].emplace_back(vector, vector + step_size); } } } diff --git a/src/layers/transformer.cc b/src/layers/transformer.cc index 4fea80f7f..8e895e0cb 100644 --- a/src/layers/transformer.cc +++ b/src/layers/transformer.cc @@ -222,7 +222,7 @@ namespace ctranslate2 { context, cached_self_attn_keys, cached_self_attn_values, - nullptr, + _encoder_attention ? nullptr : attention, input_padder, input_padder, true, @@ -291,7 +291,7 @@ namespace ctranslate2 { attn, cached_self_attn_keys, cached_self_attn_values, - nullptr, + _encoder_attention ? nullptr : attention, input_padder, input_padder, true, @@ -315,7 +315,7 @@ namespace ctranslate2 { output, cached_self_attn_keys, cached_self_attn_values, - nullptr, + _encoder_attention ? nullptr : attention, input_padder, input_padder, true, diff --git a/src/models/language_model.cc b/src/models/language_model.cc index d52928e16..8914a3ab4 100644 --- a/src/models/language_model.cc +++ b/src/models/language_model.cc @@ -110,6 +110,11 @@ namespace ctranslate2 { { } + void DecoderReplica::set_alignment_heads( + const std::vector>& alignment_heads) { + _decoder->set_alignment_heads(alignment_heads); + } + std::vector DecoderReplica::run_scoring(const std::vector>& tokens, const ScoringOptions& options) {