Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/ctranslate2/generation.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ namespace ctranslate2 {

// Include scores in the result.
bool return_scores = false;
// Store attention vectors in the GenerationResult class.
bool return_attention = false;
// Include log probs of each token in the result
bool return_logits_vocab = false;

Expand Down Expand Up @@ -81,6 +83,7 @@ namespace ctranslate2 {
std::vector<std::vector<std::string>> sequences;
std::vector<std::vector<size_t>> sequences_ids;
std::vector<float> scores;
std::vector<std::vector<std::vector<float>>> attention;
std::vector<std::vector<StorageView>> logits;

size_t num_sequences() const {
Expand Down
5 changes: 5 additions & 0 deletions include/ctranslate2/layers/decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ namespace ctranslate2 {
public:
Decoder(Device device);

// Configure which attention heads to collect when return_attention is enabled.
virtual void set_alignment_heads(const std::vector<std::pair<dim_t, dim_t>>& alignment_heads) {
(void)alignment_heads;
}

virtual DecoderState initial_state(bool iterative_decoding = true) const = 0;

// Forwards one step.
Expand Down
2 changes: 1 addition & 1 deletion include/ctranslate2/layers/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ namespace ctranslate2 {
StorageView* attention = nullptr) override;

void set_alignment_heads(const dim_t layer, const dim_t num_heads_to_average);
void set_alignment_heads(const std::vector<std::pair<dim_t, dim_t>>& alignment_heads);
void set_alignment_heads(const std::vector<std::pair<dim_t, dim_t>>& alignment_heads) override;

std::unique_ptr<StorageView>
get_layer_alignment_heads(const dim_t layer, const dim_t batch_size) const;
Expand Down
8 changes: 8 additions & 0 deletions include/ctranslate2/models/language_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ namespace ctranslate2 {
const StorageView& lengths,
const bool return_log_probs);

// Configure which attention heads to collect when return_attention is enabled.
// Each pair is (layer_index, head_index).
virtual void set_alignment_heads(const std::vector<std::pair<dim_t, dim_t>>& alignment_heads) {
(void)alignment_heads;
}

protected:
virtual bool skip_scoring(const std::vector<std::string>& tokens,
const ScoringOptions& options,
Expand Down Expand Up @@ -89,6 +95,8 @@ namespace ctranslate2 {
DecoderReplica(const std::shared_ptr<const LanguageModel>& model,
std::unique_ptr<layers::Decoder> decoder);

void set_alignment_heads(const std::vector<std::pair<dim_t, dim_t>>& alignment_heads) override;

protected:
bool skip_scoring(const std::vector<std::string>& tokens,
const ScoringOptions& options,
Expand Down
9 changes: 9 additions & 0 deletions include/ctranslate2/replica_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,15 @@ namespace ctranslate2 {
return worker.replica();
}

// Apply a function to each replica. Not thread-safe.
template <typename Func>
void for_each_replica(Func func) {
for (size_t i = 0; i < num_replicas(); ++i) {
auto& worker = static_cast<ReplicaWorker<Replica>&>(_thread_pool->get_worker(i));
func(worker.replica());
}
}

protected:
template <typename Result, typename Func>
std::vector<std::future<Result>>
Expand Down
3 changes: 3 additions & 0 deletions python/cpp/generation_result.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,16 @@ namespace ctranslate2 {
"Generated sequences of token IDs.")
.def_readonly("scores", &GenerationResult::scores,
"Score of each sequence (empty if :obj:`return_scores` was disabled).")
.def_readonly("attention", &GenerationResult::attention,
"Attention matrix of each sequence (empty if :obj:`return_attention` was disabled).")
.def_readonly("logits", &GenerationResult::logits,
"Logits of each sequence (empty if :obj:`return_logits_vocab` was disabled).")

.def("__repr__", [](const GenerationResult& result) {
return "GenerationResult(sequences=" + std::string(py::repr(py::cast(result.sequences)))
+ ", sequences_ids=" + std::string(py::repr(py::cast(result.sequences_ids)))
+ ", scores=" + std::string(py::repr(py::cast(result.scores)))
+ ", attention=" + std::string(py::repr(py::cast(result.attention)))
+ ", logits=" + std::string(py::repr(py::cast(result.logits)))
+ ")";
})
Expand Down
27 changes: 27 additions & 0 deletions python/cpp/generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ namespace ctranslate2 {
public:
using ReplicaPoolHelper::ReplicaPoolHelper;

void set_alignment_heads(const std::vector<std::pair<dim_t, dim_t>>& alignment_heads) {
_pool->for_each_replica([&](models::SequenceGeneratorReplica& replica) {
replica.set_alignment_heads(alignment_heads);
});
}

std::variant<std::vector<GenerationResult>,
std::vector<AsyncResult<GenerationResult>>>
generate_batch(const BatchTokens& tokens,
Expand All @@ -33,6 +39,7 @@ namespace ctranslate2 {
bool cache_static_prompt,
bool include_prompt_in_result,
bool return_scores,
bool return_attention,
bool return_logits_vocab,
bool return_alternatives,
float min_alternative_expansion_prob,
Expand All @@ -59,6 +66,7 @@ namespace ctranslate2 {
options.num_hypotheses = num_hypotheses;
options.return_end_token = return_end_token;
options.return_scores = return_scores;
options.return_attention = return_attention;
options.return_logits_vocab = return_logits_vocab;
options.return_alternatives = return_alternatives;
options.cache_static_prompt = cache_static_prompt;
Expand Down Expand Up @@ -183,6 +191,23 @@ namespace ctranslate2 {
.def_property_readonly("num_active_batches", &GeneratorWrapper::num_active_batches,
"Number of batches waiting to be processed or currently processed.")

.def("set_alignment_heads", &GeneratorWrapper::set_alignment_heads,
py::arg("alignment_heads"),
R"pbdoc(
Configure which attention heads to collect when ``return_attention=True``.

By default, only head 0 of the last layer is returned (averaged).
Use this method to select specific (layer, head) pairs. The attention
from the selected heads will be concatenated in the output.

Arguments:
alignment_heads: List of (layer_index, head_index) pairs to collect.

Example:

>>> generator.set_alignment_heads([(31, 0), (31, 3), (33, 7)])
)pbdoc")

.def("generate_batch", &GeneratorWrapper::generate_batch,
py::arg("start_tokens"),
py::kw_only(),
Expand All @@ -205,6 +230,7 @@ namespace ctranslate2 {
py::arg("cache_static_prompt")=true,
py::arg("include_prompt_in_result")=true,
py::arg("return_scores")=false,
py::arg("return_attention")=false,
py::arg("return_logits_vocab")=false,
py::arg("return_alternatives")=false,
py::arg("min_alternative_expansion_prob")=0,
Expand Down Expand Up @@ -263,6 +289,7 @@ namespace ctranslate2 {
reuse it for future generations using the same static prompt.
include_prompt_in_result: Include the :obj:`start_tokens` in the result.
return_scores: Include the scores in the output.
return_attention: Include the attention matrices in the output.
return_logits_vocab: Include log probs for each token in the output
return_alternatives: Return alternatives at the first unconstrained decoding position.
min_alternative_expansion_prob: Minimum initial probability to expand an alternative.
Expand Down
29 changes: 22 additions & 7 deletions src/decoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,23 @@ namespace ctranslate2 {
if (!history)
return {};

const auto source_length = history.dim(-1);
// For averaged attention: history is (batch, beam, steps, ctx)
// For per-head attention: history is (batch, beam, steps, heads, ctx)
// Compute total floats per time step (ctx or heads*ctx).
dim_t step_size = 1;
for (dim_t d = 3; d < history.rank(); ++d)
step_size *= history.dim(d);

std::vector<std::vector<float>> attention;
attention.reserve(end - start);
// Compute stride for the steps dimension: step_size floats per step.
// Base offset for (batch, beam) = batch * (beam_stride) + beam * (steps * step_size).
const dim_t steps = history.dim(2);
const dim_t beam_stride = steps * step_size;
const float* base = history.data<float>() + batch * history.dim(1) * beam_stride + beam * beam_stride;
for (dim_t t = start; t < end; ++t) {
const auto* vector = history.index<float>({batch, beam, t, 0});
attention.emplace_back(vector, vector + source_length);
const float* vector = base + t * step_size;
attention.emplace_back(vector, vector + step_size);
}
return attention;
}
Expand Down Expand Up @@ -911,8 +921,11 @@ namespace ctranslate2 {
&& (return_prefix || step >= prefix_length)) {
results[batch_id].hypotheses[0].push_back(word_id);
if (attention_step) {
const auto* attn = attention_step.index<float>({i, 0});
results[batch_id].attention[0].emplace_back(attn, attn + attention_step.dim(-1));
// For averaged attention: shape (batch, ctx) -> take ctx floats
// For per-head attention: shape (batch, heads, ctx) -> take heads*ctx floats
const dim_t attn_size = attention_step.size() / attention_step.dim(0);
const auto* attn = attention_step.data<float>() + i * attn_size;
results[batch_id].attention[0].emplace_back(attn, attn + attn_size);
}
}

Expand Down Expand Up @@ -1166,9 +1179,11 @@ namespace ctranslate2 {
if (options.return_attention) {
if (attention.device() != Device::CPU)
attention = attention.to_float32().to(Device::CPU);
// Compute floats per time step (ctx or heads*ctx for multi-head).
const dim_t step_size = attention.size() / (attention.dim(0) * attention.dim(1));
for (dim_t t = 0; t < prefix_length; ++t) {
const float* vector = attention.index<float>({0, t, 0});
result.attention[i].emplace_back(vector, vector + attention.dim(-1));
const float* vector = attention.data<float>() + t * step_size;
result.attention[i].emplace_back(vector, vector + step_size);
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/layers/transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ namespace ctranslate2 {
context,
cached_self_attn_keys,
cached_self_attn_values,
nullptr,
_encoder_attention ? nullptr : attention,
input_padder,
input_padder,
true,
Expand Down Expand Up @@ -291,7 +291,7 @@ namespace ctranslate2 {
attn,
cached_self_attn_keys,
cached_self_attn_values,
nullptr,
_encoder_attention ? nullptr : attention,
input_padder,
input_padder,
true,
Expand All @@ -315,7 +315,7 @@ namespace ctranslate2 {
output,
cached_self_attn_keys,
cached_self_attn_values,
nullptr,
_encoder_attention ? nullptr : attention,
input_padder,
input_padder,
true,
Expand Down
17 changes: 14 additions & 3 deletions src/models/language_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ namespace ctranslate2 {
{
}

void DecoderReplica::set_alignment_heads(
const std::vector<std::pair<dim_t, dim_t>>& alignment_heads) {
_decoder->set_alignment_heads(alignment_heads);
}

std::vector<ScoringResult>
DecoderReplica::run_scoring(const std::vector<std::vector<std::string>>& tokens,
const ScoringOptions& options) {
Expand Down Expand Up @@ -165,6 +170,7 @@ namespace ctranslate2 {
decoding_options.sampling_temperature = options.sampling_temperature;
decoding_options.num_hypotheses = options.num_hypotheses;
decoding_options.return_scores = options.return_scores;
decoding_options.return_attention = options.return_attention;
decoding_options.return_logits_vocab = options.return_logits_vocab;
decoding_options.return_alternatives = options.return_alternatives;
decoding_options.min_alternative_expansion_prob = options.min_alternative_expansion_prob;
Expand Down Expand Up @@ -251,9 +257,13 @@ namespace ctranslate2 {

// Remove EOS token.
if (!options.return_end_token) {
for (auto& sequence : result.hypotheses) {
while (!sequence.empty() && is_eos(sequence.back(), end_ids))
sequence.pop_back();
for (size_t h = 0; h < result.hypotheses.size(); ++h) {
while (!result.hypotheses[h].empty()
&& is_eos(result.hypotheses[h].back(), end_ids)) {
result.hypotheses[h].pop_back();
if (!result.attention.empty())
result.attention[h].pop_back();
}
}
}

Expand All @@ -269,6 +279,7 @@ namespace ctranslate2 {
final_result.sequences = vocabulary.to_tokens(result.hypotheses);
final_result.sequences_ids = std::move(result.hypotheses);
final_result.scores = std::move(result.scores);
final_result.attention = std::move(result.attention);
final_result.logits = std::move(result.logits_vocab);
final_results.emplace_back(std::move(final_result));
}
Expand Down
Loading