diff --git a/opf/_api.py b/opf/_api.py index 6b58e5e..24bce2b 100644 --- a/opf/_api.py +++ b/opf/_api.py @@ -4,6 +4,7 @@ import functools import json import os +from collections.abc import Sequence from pathlib import Path from typing import Literal, TypeVar @@ -16,6 +17,7 @@ build_detection_summary, load_inference_runtime, predict_text, + predict_texts, ) @@ -273,6 +275,68 @@ def redact( warning=_warning_for_prediction(prediction), ) + def redact_many( + self, + texts: Sequence[str], + *, + decode: DecodeOptions | None = None, + max_tokens_per_forward: int = 16384, + ) -> list[str] | list[RedactionResult]: + """Run redaction on a batch of input strings in shared forward passes. + + Equivalent to calling :meth:`redact` on each text individually, but + windows from multiple inputs share one or more batched + ``model.forward()`` calls — useful when serving concurrent requests + on the same instance. + + Args: + texts: Input strings to redact. + decode: Optional per-call decode overrides shared across the batch. + max_tokens_per_forward: Bin-pack target (``rows * max_seq``) for + batched forward passes. A single window longer than this value + still runs by itself. + + Returns: + One result per input text in input order. Type matches the + return shape of :meth:`redact` based on ``output_text_only``. + """ + if isinstance(texts, str): + raise TypeError("texts must be a sequence of strings, not a single string") + texts_list = list(texts) + if any(not isinstance(text, str) for text in texts_list): + raise TypeError("texts must contain only strings") + if max_tokens_per_forward <= 0: + raise ValueError("max_tokens_per_forward must be positive") + if not texts_list: + return [] + runtime, decoder = self.get_prediction_components(decode=decode) + predictions = predict_texts( + runtime, + texts_list, + decoder=decoder, + max_tokens_per_forward=max_tokens_per_forward, + ) + if self._output_text_only: + return [ + _redact_text(prediction.text, prediction.spans) + for prediction in predictions + ] + return [ + RedactionResult( + schema_version=SCHEMA_VERSION, + summary=build_detection_summary( + output_mode=runtime.output_mode, + labels=[span.label for span in prediction.spans], + decoded_mismatch=prediction.decoded_mismatch, + ), + text=prediction.text, + detected_spans=tuple(prediction.spans), + redacted_text=_redact_text(prediction.text, prediction.spans), + warning=_warning_for_prediction(prediction), + ) + for prediction in predictions + ] + def set_model_path(self, model_path: str | os.PathLike[str]) -> OPF: """Update the checkpoint directory used by this redactor. diff --git a/opf/_core/runtime.py b/opf/_core/runtime.py index 2c3034e..d5bec1b 100644 --- a/opf/_core/runtime.py +++ b/opf/_core/runtime.py @@ -296,28 +296,53 @@ def predict_text( log_probs = F.log_softmax(logits.float(), dim=-1)[0].cpu() if log_probs.shape[0] != len(window.tokens): raise ValueError("Logprob output length does not match window length") + _accumulate_window(aggregation, window, log_probs, example_id) - for token_pos, is_valid in enumerate(window.mask): - if not bool(is_valid): - continue - token_idx = int(window.offsets[token_pos]) - if token_idx < 0: - continue - aggregation.ensure_capacity(token_idx) - score_vec = log_probs[token_pos] - existing = aggregation.logprob_logsumexp[token_idx] - if existing is None: - aggregation.logprob_logsumexp[token_idx] = score_vec.clone() - else: - aggregation.logprob_logsumexp[token_idx] = torch.logaddexp( - existing, score_vec - ) - aggregation.counts[token_idx] += 1 - aggregation.record_token_id( - token_idx, int(window.tokens[token_pos]), example_id + return _prediction_from_aggregation( + runtime, text=text, token_ids=token_ids, aggregation=aggregation, decoder=decoder + ) + + +def _accumulate_window( + aggregation: ExampleAggregation, + window, + log_probs: torch.Tensor, + example_id: str, +) -> None: + """Merge one window's [seq, n_classes] log-probs into the aggregation.""" + for token_pos, is_valid in enumerate(window.mask): + if not bool(is_valid): + continue + token_idx = int(window.offsets[token_pos]) + if token_idx < 0: + continue + aggregation.ensure_capacity(token_idx) + score_vec = log_probs[token_pos] + existing = aggregation.logprob_logsumexp[token_idx] + if existing is None: + aggregation.logprob_logsumexp[token_idx] = score_vec.clone() + else: + aggregation.logprob_logsumexp[token_idx] = torch.logaddexp( + existing, score_vec ) - aggregation.length = max(aggregation.length, token_idx + 1) + aggregation.counts[token_idx] += 1 + aggregation.record_token_id( + token_idx, int(window.tokens[token_pos]), example_id + ) + aggregation.length = max(aggregation.length, token_idx + 1) + +def _prediction_from_aggregation( + runtime: InferenceRuntime, + *, + text: str, + token_ids: tuple[int, ...], + aggregation: ExampleAggregation, + decoder: ViterbiCRFDecoder | None, +) -> PredictionResult: + """Decode an aggregated score buffer into a PredictionResult. + + Shared post-processing for ``predict_text`` and ``predict_texts``.""" token_positions: list[int] = [] token_score_vectors: list[torch.Tensor] = [] for token_idx in range(aggregation.length): @@ -393,3 +418,136 @@ def predict_text( spans=tuple(display_spans), decoded_mismatch=decoded_mismatch, ) + + +@torch.inference_mode() +def predict_texts( + runtime: InferenceRuntime, + texts: Sequence[str], + *, + decoder: ViterbiCRFDecoder | None, + max_tokens_per_forward: int = 16384, +) -> list[PredictionResult]: + """Run prediction on a batch of texts in shared model.forward() calls. + + Tokenizes each input, windows it as ``predict_text`` does, then bin-packs + windows from all texts into one or more batched forward passes. The + ``max_tokens_per_forward`` value is a bin-pack target for ``rows * max_seq``; + a single longer window still runs by itself. + + Padding positions are masked out via ``attention_mask`` so they cannot + influence real tokens through the banded attention path. + + Returns one PredictionResult per input text, in input order. Output is + identical to calling ``predict_text`` on each text individually.""" + if max_tokens_per_forward <= 0: + raise ValueError("max_tokens_per_forward must be positive") + if not texts: + return [] + + background = int(runtime.label_info.background_token_label) + token_id_lists: list[tuple[int, ...]] = [] + windows_per_text: list[list] = [] + aggregations: list[ExampleAggregation] = [] + for text_idx, text in enumerate(texts): + token_ids = tuple( + int(tok) + for tok in runtime.encoding.encode(text, allowed_special="all") + ) + token_id_lists.append(token_ids) + aggregations.append( + ExampleAggregation( + logprob_logsumexp=[], counts=[], labels=[], token_ids=[] + ) + ) + windows: list = [] + if token_ids: + example = TokenizedExample( + tokens=token_ids, + labels=tuple(background for _ in token_ids), + example_id=f"batch-{text_idx}", + text=text, + ) + for window in example_to_windows(example, runtime.n_ctx): + if window.tokens: + windows.append(window) + windows_per_text.append(windows) + + flat: list[tuple[int, object]] = [ + (text_idx, window) + for text_idx, windows in enumerate(windows_per_text) + for window in windows + ] + + if flat: + pad_id = int(runtime.pad_token_id) + # Sort descending by length so similar-size windows pack together + # and padding waste in each batch is minimized. + ordered = sorted(flat, key=lambda item: -len(item[1].tokens)) + chunk: list[tuple[int, object]] = [] + chunk_max_seq = 0 + for item in ordered: + seq_len = len(item[1].tokens) + new_max_seq = max(chunk_max_seq, seq_len) + projected = (len(chunk) + 1) * new_max_seq + if chunk and projected > max_tokens_per_forward: + _run_chunk(runtime, chunk, pad_id, aggregations) + chunk = [item] + chunk_max_seq = seq_len + else: + chunk.append(item) + chunk_max_seq = new_max_seq + if chunk: + _run_chunk(runtime, chunk, pad_id, aggregations) + + return [ + _prediction_from_aggregation( + runtime, + text=texts[text_idx], + token_ids=token_id_lists[text_idx], + aggregation=aggregations[text_idx], + decoder=decoder, + ) + for text_idx in range(len(texts)) + ] + + +def _run_chunk( + runtime: InferenceRuntime, + chunk: Sequence[tuple[int, object]], + pad_id: int, + aggregations: list[ExampleAggregation], +) -> None: + """Run one batched forward over `chunk` and demux back to per-text aggregations.""" + max_seq = max(len(window.tokens) for _, window in chunk) + batch_size = len(chunk) + tokens_tensor = torch.full( + (batch_size, max_seq), + pad_id, + device=runtime.device, + dtype=torch.int32, + ) + mask_tensor = torch.zeros( + (batch_size, max_seq), device=runtime.device, dtype=torch.bool + ) + for row, (_, window) in enumerate(chunk): + seq_len = len(window.tokens) + tokens_tensor[row, :seq_len] = torch.tensor( + list(window.tokens), dtype=torch.int32, device=runtime.device + ) + mask_tensor[row, :seq_len] = True + logits = runtime.model(tokens_tensor, attention_mask=mask_tensor) + if logits.shape[:2] != tokens_tensor.shape: + raise ValueError( + "Logit output batch/sequence shape does not match input shape: " + f"logits={tuple(logits.shape[:2])} input={tuple(tokens_tensor.shape)}" + ) + log_probs = F.log_softmax(logits.float(), dim=-1).cpu() + for row, (text_idx, window) in enumerate(chunk): + seq_len = len(window.tokens) + _accumulate_window( + aggregations[text_idx], + window, + log_probs[row, :seq_len], + f"batch-{text_idx}", + )