diff --git a/include/ctranslate2/generation.h b/include/ctranslate2/generation.h index bd76146ff..b4bfa732e 100644 --- a/include/ctranslate2/generation.h +++ b/include/ctranslate2/generation.h @@ -53,6 +53,8 @@ namespace ctranslate2 { // Include scores in the result. bool return_scores = false; + // Store attention vectors in the GenerationResult class. + bool return_attention = false; // Include log probs of each token in the result bool return_logits_vocab = false; @@ -81,6 +83,7 @@ namespace ctranslate2 { std::vector> sequences; std::vector> sequences_ids; std::vector scores; + std::vector>> attention; std::vector> logits; size_t num_sequences() const { diff --git a/include/ctranslate2/layers/decoder.h b/include/ctranslate2/layers/decoder.h index 7d7a1f51a..266c1af20 100644 --- a/include/ctranslate2/layers/decoder.h +++ b/include/ctranslate2/layers/decoder.h @@ -20,6 +20,11 @@ namespace ctranslate2 { public: Decoder(Device device); + // Configure which attention heads to collect when return_attention is enabled. + virtual void set_alignment_heads(const std::vector>& alignment_heads) { + (void)alignment_heads; + } + virtual DecoderState initial_state(bool iterative_decoding = true) const = 0; // Forwards one step. diff --git a/include/ctranslate2/layers/transformer.h b/include/ctranslate2/layers/transformer.h index 01a7694c6..26eac7da8 100644 --- a/include/ctranslate2/layers/transformer.h +++ b/include/ctranslate2/layers/transformer.h @@ -185,7 +185,7 @@ namespace ctranslate2 { StorageView* attention = nullptr) override; void set_alignment_heads(const dim_t layer, const dim_t num_heads_to_average); - void set_alignment_heads(const std::vector>& alignment_heads); + void set_alignment_heads(const std::vector>& alignment_heads) override; std::unique_ptr get_layer_alignment_heads(const dim_t layer, const dim_t batch_size) const; diff --git a/include/ctranslate2/models/language_model.h b/include/ctranslate2/models/language_model.h index 7532b9a3a..6adb1f8f0 100644 --- a/include/ctranslate2/models/language_model.h +++ b/include/ctranslate2/models/language_model.h @@ -58,6 +58,12 @@ namespace ctranslate2 { const StorageView& lengths, const bool return_log_probs); + // Configure which attention heads to collect when return_attention is enabled. + // Each pair is (layer_index, head_index). + virtual void set_alignment_heads(const std::vector>& alignment_heads) { + (void)alignment_heads; + } + protected: virtual bool skip_scoring(const std::vector& tokens, const ScoringOptions& options, @@ -89,6 +95,8 @@ namespace ctranslate2 { DecoderReplica(const std::shared_ptr& model, std::unique_ptr decoder); + void set_alignment_heads(const std::vector>& alignment_heads) override; + protected: bool skip_scoring(const std::vector& tokens, const ScoringOptions& options, diff --git a/include/ctranslate2/replica_pool.h b/include/ctranslate2/replica_pool.h index 8c8e15d8e..fd5ee8ef1 100644 --- a/include/ctranslate2/replica_pool.h +++ b/include/ctranslate2/replica_pool.h @@ -152,6 +152,15 @@ namespace ctranslate2 { return worker.replica(); } + // Apply a function to each replica. Not thread-safe. + template + void for_each_replica(Func func) { + for (size_t i = 0; i < num_replicas(); ++i) { + auto& worker = static_cast&>(_thread_pool->get_worker(i)); + func(worker.replica()); + } + } + protected: template std::vector> diff --git a/python/cpp/generation_result.cc b/python/cpp/generation_result.cc index 5964a5d5b..113c8543e 100644 --- a/python/cpp/generation_result.cc +++ b/python/cpp/generation_result.cc @@ -49,6 +49,8 @@ namespace ctranslate2 { "Generated sequences of token IDs.") .def_readonly("scores", &GenerationResult::scores, "Score of each sequence (empty if :obj:`return_scores` was disabled).") + .def_readonly("attention", &GenerationResult::attention, + "Attention matrix of each sequence (empty if :obj:`return_attention` was disabled).") .def_readonly("logits", &GenerationResult::logits, "Logits of each sequence (empty if :obj:`return_logits_vocab` was disabled).") @@ -56,6 +58,7 @@ namespace ctranslate2 { return "GenerationResult(sequences=" + std::string(py::repr(py::cast(result.sequences))) + ", sequences_ids=" + std::string(py::repr(py::cast(result.sequences_ids))) + ", scores=" + std::string(py::repr(py::cast(result.scores))) + + ", attention=" + std::string(py::repr(py::cast(result.attention))) + ", logits=" + std::string(py::repr(py::cast(result.logits))) + ")"; }) diff --git a/python/cpp/generator.cc b/python/cpp/generator.cc index 0647552d0..664f87743 100644 --- a/python/cpp/generator.cc +++ b/python/cpp/generator.cc @@ -11,6 +11,12 @@ namespace ctranslate2 { public: using ReplicaPoolHelper::ReplicaPoolHelper; + void set_alignment_heads(const std::vector>& alignment_heads) { + _pool->for_each_replica([&](models::SequenceGeneratorReplica& replica) { + replica.set_alignment_heads(alignment_heads); + }); + } + std::variant, std::vector>> generate_batch(const BatchTokens& tokens, @@ -33,6 +39,7 @@ namespace ctranslate2 { bool cache_static_prompt, bool include_prompt_in_result, bool return_scores, + bool return_attention, bool return_logits_vocab, bool return_alternatives, float min_alternative_expansion_prob, @@ -59,6 +66,7 @@ namespace ctranslate2 { options.num_hypotheses = num_hypotheses; options.return_end_token = return_end_token; options.return_scores = return_scores; + options.return_attention = return_attention; options.return_logits_vocab = return_logits_vocab; options.return_alternatives = return_alternatives; options.cache_static_prompt = cache_static_prompt; @@ -183,6 +191,23 @@ namespace ctranslate2 { .def_property_readonly("num_active_batches", &GeneratorWrapper::num_active_batches, "Number of batches waiting to be processed or currently processed.") + .def("set_alignment_heads", &GeneratorWrapper::set_alignment_heads, + py::arg("alignment_heads"), + R"pbdoc( + Configure which attention heads to collect when ``return_attention=True``. + + By default, only head 0 of the last layer is returned (averaged). + Use this method to select specific (layer, head) pairs. The attention + from the selected heads will be concatenated in the output. + + Arguments: + alignment_heads: List of (layer_index, head_index) pairs to collect. + + Example: + + >>> generator.set_alignment_heads([(31, 0), (31, 3), (33, 7)]) + )pbdoc") + .def("generate_batch", &GeneratorWrapper::generate_batch, py::arg("start_tokens"), py::kw_only(), @@ -205,6 +230,7 @@ namespace ctranslate2 { py::arg("cache_static_prompt")=true, py::arg("include_prompt_in_result")=true, py::arg("return_scores")=false, + py::arg("return_attention")=false, py::arg("return_logits_vocab")=false, py::arg("return_alternatives")=false, py::arg("min_alternative_expansion_prob")=0, @@ -263,6 +289,7 @@ namespace ctranslate2 { reuse it for future generations using the same static prompt. include_prompt_in_result: Include the :obj:`start_tokens` in the result. return_scores: Include the scores in the output. + return_attention: Include the attention matrices in the output. return_logits_vocab: Include log probs for each token in the output return_alternatives: Return alternatives at the first unconstrained decoding position. min_alternative_expansion_prob: Minimum initial probability to expand an alternative. diff --git a/src/decoding.cc b/src/decoding.cc index 84f39ac37..97ef5daee 100644 --- a/src/decoding.cc +++ b/src/decoding.cc @@ -146,13 +146,23 @@ namespace ctranslate2 { if (!history) return {}; - const auto source_length = history.dim(-1); + // For averaged attention: history is (batch, beam, steps, ctx) + // For per-head attention: history is (batch, beam, steps, heads, ctx) + // Compute total floats per time step (ctx or heads*ctx). + dim_t step_size = 1; + for (dim_t d = 3; d < history.rank(); ++d) + step_size *= history.dim(d); std::vector> attention; attention.reserve(end - start); + // Compute stride for the steps dimension: step_size floats per step. + // Base offset for (batch, beam) = batch * (beam_stride) + beam * (steps * step_size). + const dim_t steps = history.dim(2); + const dim_t beam_stride = steps * step_size; + const float* base = history.data() + batch * history.dim(1) * beam_stride + beam * beam_stride; for (dim_t t = start; t < end; ++t) { - const auto* vector = history.index({batch, beam, t, 0}); - attention.emplace_back(vector, vector + source_length); + const float* vector = base + t * step_size; + attention.emplace_back(vector, vector + step_size); } return attention; } @@ -911,8 +921,11 @@ namespace ctranslate2 { && (return_prefix || step >= prefix_length)) { results[batch_id].hypotheses[0].push_back(word_id); if (attention_step) { - const auto* attn = attention_step.index({i, 0}); - results[batch_id].attention[0].emplace_back(attn, attn + attention_step.dim(-1)); + // For averaged attention: shape (batch, ctx) -> take ctx floats + // For per-head attention: shape (batch, heads, ctx) -> take heads*ctx floats + const dim_t attn_size = attention_step.size() / attention_step.dim(0); + const auto* attn = attention_step.data() + i * attn_size; + results[batch_id].attention[0].emplace_back(attn, attn + attn_size); } } @@ -1166,9 +1179,11 @@ namespace ctranslate2 { if (options.return_attention) { if (attention.device() != Device::CPU) attention = attention.to_float32().to(Device::CPU); + // Compute floats per time step (ctx or heads*ctx for multi-head). + const dim_t step_size = attention.size() / (attention.dim(0) * attention.dim(1)); for (dim_t t = 0; t < prefix_length; ++t) { - const float* vector = attention.index({0, t, 0}); - result.attention[i].emplace_back(vector, vector + attention.dim(-1)); + const float* vector = attention.data() + t * step_size; + result.attention[i].emplace_back(vector, vector + step_size); } } } diff --git a/src/layers/transformer.cc b/src/layers/transformer.cc index 4fea80f7f..8e895e0cb 100644 --- a/src/layers/transformer.cc +++ b/src/layers/transformer.cc @@ -222,7 +222,7 @@ namespace ctranslate2 { context, cached_self_attn_keys, cached_self_attn_values, - nullptr, + _encoder_attention ? nullptr : attention, input_padder, input_padder, true, @@ -291,7 +291,7 @@ namespace ctranslate2 { attn, cached_self_attn_keys, cached_self_attn_values, - nullptr, + _encoder_attention ? nullptr : attention, input_padder, input_padder, true, @@ -315,7 +315,7 @@ namespace ctranslate2 { output, cached_self_attn_keys, cached_self_attn_values, - nullptr, + _encoder_attention ? nullptr : attention, input_padder, input_padder, true, diff --git a/src/models/language_model.cc b/src/models/language_model.cc index 5a23fa35a..8914a3ab4 100644 --- a/src/models/language_model.cc +++ b/src/models/language_model.cc @@ -110,6 +110,11 @@ namespace ctranslate2 { { } + void DecoderReplica::set_alignment_heads( + const std::vector>& alignment_heads) { + _decoder->set_alignment_heads(alignment_heads); + } + std::vector DecoderReplica::run_scoring(const std::vector>& tokens, const ScoringOptions& options) { @@ -165,6 +170,7 @@ namespace ctranslate2 { decoding_options.sampling_temperature = options.sampling_temperature; decoding_options.num_hypotheses = options.num_hypotheses; decoding_options.return_scores = options.return_scores; + decoding_options.return_attention = options.return_attention; decoding_options.return_logits_vocab = options.return_logits_vocab; decoding_options.return_alternatives = options.return_alternatives; decoding_options.min_alternative_expansion_prob = options.min_alternative_expansion_prob; @@ -251,9 +257,13 @@ namespace ctranslate2 { // Remove EOS token. if (!options.return_end_token) { - for (auto& sequence : result.hypotheses) { - while (!sequence.empty() && is_eos(sequence.back(), end_ids)) - sequence.pop_back(); + for (size_t h = 0; h < result.hypotheses.size(); ++h) { + while (!result.hypotheses[h].empty() + && is_eos(result.hypotheses[h].back(), end_ids)) { + result.hypotheses[h].pop_back(); + if (!result.attention.empty()) + result.attention[h].pop_back(); + } } } @@ -269,6 +279,7 @@ namespace ctranslate2 { final_result.sequences = vocabulary.to_tokens(result.hypotheses); final_result.sequences_ids = std::move(result.hypotheses); final_result.scores = std::move(result.scores); + final_result.attention = std::move(result.attention); final_result.logits = std::move(result.logits_vocab); final_results.emplace_back(std::move(final_result)); }