Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions opf/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import functools
import json
import os
from collections.abc import Sequence
from pathlib import Path
from typing import Literal, TypeVar

Expand All @@ -16,6 +17,7 @@
build_detection_summary,
load_inference_runtime,
predict_text,
predict_texts,
)


Expand Down Expand Up @@ -273,6 +275,68 @@ def redact(
warning=_warning_for_prediction(prediction),
)

def redact_many(
self,
texts: Sequence[str],
*,
decode: DecodeOptions | None = None,
max_tokens_per_forward: int = 16384,
) -> list[str] | list[RedactionResult]:
"""Run redaction on a batch of input strings in shared forward passes.

Equivalent to calling :meth:`redact` on each text individually, but
windows from multiple inputs share one or more batched
``model.forward()`` calls — useful when serving concurrent requests
on the same instance.

Args:
texts: Input strings to redact.
decode: Optional per-call decode overrides shared across the batch.
max_tokens_per_forward: Bin-pack target (``rows * max_seq``) for
batched forward passes. A single window longer than this value
still runs by itself.

Returns:
One result per input text in input order. Type matches the
return shape of :meth:`redact` based on ``output_text_only``.
"""
if isinstance(texts, str):
raise TypeError("texts must be a sequence of strings, not a single string")
texts_list = list(texts)
if any(not isinstance(text, str) for text in texts_list):
raise TypeError("texts must contain only strings")
if max_tokens_per_forward <= 0:
raise ValueError("max_tokens_per_forward must be positive")
if not texts_list:
return []
runtime, decoder = self.get_prediction_components(decode=decode)
predictions = predict_texts(
runtime,
texts_list,
decoder=decoder,
max_tokens_per_forward=max_tokens_per_forward,
)
if self._output_text_only:
return [
_redact_text(prediction.text, prediction.spans)
for prediction in predictions
]
return [
RedactionResult(
schema_version=SCHEMA_VERSION,
summary=build_detection_summary(
output_mode=runtime.output_mode,
labels=[span.label for span in prediction.spans],
decoded_mismatch=prediction.decoded_mismatch,
),
text=prediction.text,
detected_spans=tuple(prediction.spans),
redacted_text=_redact_text(prediction.text, prediction.spans),
warning=_warning_for_prediction(prediction),
)
for prediction in predictions
]

def set_model_path(self, model_path: str | os.PathLike[str]) -> OPF:
"""Update the checkpoint directory used by this redactor.

Expand Down
196 changes: 177 additions & 19 deletions opf/_core/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,28 +296,53 @@ def predict_text(
log_probs = F.log_softmax(logits.float(), dim=-1)[0].cpu()
if log_probs.shape[0] != len(window.tokens):
raise ValueError("Logprob output length does not match window length")
_accumulate_window(aggregation, window, log_probs, example_id)

for token_pos, is_valid in enumerate(window.mask):
if not bool(is_valid):
continue
token_idx = int(window.offsets[token_pos])
if token_idx < 0:
continue
aggregation.ensure_capacity(token_idx)
score_vec = log_probs[token_pos]
existing = aggregation.logprob_logsumexp[token_idx]
if existing is None:
aggregation.logprob_logsumexp[token_idx] = score_vec.clone()
else:
aggregation.logprob_logsumexp[token_idx] = torch.logaddexp(
existing, score_vec
)
aggregation.counts[token_idx] += 1
aggregation.record_token_id(
token_idx, int(window.tokens[token_pos]), example_id
return _prediction_from_aggregation(
runtime, text=text, token_ids=token_ids, aggregation=aggregation, decoder=decoder
)


def _accumulate_window(
aggregation: ExampleAggregation,
window,
log_probs: torch.Tensor,
example_id: str,
) -> None:
"""Merge one window's [seq, n_classes] log-probs into the aggregation."""
for token_pos, is_valid in enumerate(window.mask):
if not bool(is_valid):
continue
token_idx = int(window.offsets[token_pos])
if token_idx < 0:
continue
aggregation.ensure_capacity(token_idx)
score_vec = log_probs[token_pos]
existing = aggregation.logprob_logsumexp[token_idx]
if existing is None:
aggregation.logprob_logsumexp[token_idx] = score_vec.clone()
else:
aggregation.logprob_logsumexp[token_idx] = torch.logaddexp(
existing, score_vec
)
aggregation.length = max(aggregation.length, token_idx + 1)
aggregation.counts[token_idx] += 1
aggregation.record_token_id(
token_idx, int(window.tokens[token_pos]), example_id
)
aggregation.length = max(aggregation.length, token_idx + 1)


def _prediction_from_aggregation(
runtime: InferenceRuntime,
*,
text: str,
token_ids: tuple[int, ...],
aggregation: ExampleAggregation,
decoder: ViterbiCRFDecoder | None,
) -> PredictionResult:
"""Decode an aggregated score buffer into a PredictionResult.

Shared post-processing for ``predict_text`` and ``predict_texts``."""
token_positions: list[int] = []
token_score_vectors: list[torch.Tensor] = []
for token_idx in range(aggregation.length):
Expand Down Expand Up @@ -393,3 +418,136 @@ def predict_text(
spans=tuple(display_spans),
decoded_mismatch=decoded_mismatch,
)


@torch.inference_mode()
def predict_texts(
runtime: InferenceRuntime,
texts: Sequence[str],
*,
decoder: ViterbiCRFDecoder | None,
max_tokens_per_forward: int = 16384,
) -> list[PredictionResult]:
"""Run prediction on a batch of texts in shared model.forward() calls.

Tokenizes each input, windows it as ``predict_text`` does, then bin-packs
windows from all texts into one or more batched forward passes. The
``max_tokens_per_forward`` value is a bin-pack target for ``rows * max_seq``;
a single longer window still runs by itself.

Padding positions are masked out via ``attention_mask`` so they cannot
influence real tokens through the banded attention path.

Returns one PredictionResult per input text, in input order. Output is
identical to calling ``predict_text`` on each text individually."""
if max_tokens_per_forward <= 0:
raise ValueError("max_tokens_per_forward must be positive")
if not texts:
return []

background = int(runtime.label_info.background_token_label)
token_id_lists: list[tuple[int, ...]] = []
windows_per_text: list[list] = []
aggregations: list[ExampleAggregation] = []
for text_idx, text in enumerate(texts):
token_ids = tuple(
int(tok)
for tok in runtime.encoding.encode(text, allowed_special="all")
)
token_id_lists.append(token_ids)
aggregations.append(
ExampleAggregation(
logprob_logsumexp=[], counts=[], labels=[], token_ids=[]
)
)
windows: list = []
if token_ids:
example = TokenizedExample(
tokens=token_ids,
labels=tuple(background for _ in token_ids),
example_id=f"batch-{text_idx}",
text=text,
)
for window in example_to_windows(example, runtime.n_ctx):
if window.tokens:
windows.append(window)
windows_per_text.append(windows)

flat: list[tuple[int, object]] = [
(text_idx, window)
for text_idx, windows in enumerate(windows_per_text)
for window in windows
]

if flat:
pad_id = int(runtime.pad_token_id)
# Sort descending by length so similar-size windows pack together
# and padding waste in each batch is minimized.
ordered = sorted(flat, key=lambda item: -len(item[1].tokens))
chunk: list[tuple[int, object]] = []
chunk_max_seq = 0
for item in ordered:
seq_len = len(item[1].tokens)
new_max_seq = max(chunk_max_seq, seq_len)
projected = (len(chunk) + 1) * new_max_seq
if chunk and projected > max_tokens_per_forward:
_run_chunk(runtime, chunk, pad_id, aggregations)
chunk = [item]
chunk_max_seq = seq_len
else:
chunk.append(item)
chunk_max_seq = new_max_seq
if chunk:
_run_chunk(runtime, chunk, pad_id, aggregations)

return [
_prediction_from_aggregation(
runtime,
text=texts[text_idx],
token_ids=token_id_lists[text_idx],
aggregation=aggregations[text_idx],
decoder=decoder,
)
for text_idx in range(len(texts))
]


def _run_chunk(
runtime: InferenceRuntime,
chunk: Sequence[tuple[int, object]],
pad_id: int,
aggregations: list[ExampleAggregation],
) -> None:
"""Run one batched forward over `chunk` and demux back to per-text aggregations."""
max_seq = max(len(window.tokens) for _, window in chunk)
batch_size = len(chunk)
tokens_tensor = torch.full(
(batch_size, max_seq),
pad_id,
device=runtime.device,
dtype=torch.int32,
)
mask_tensor = torch.zeros(
(batch_size, max_seq), device=runtime.device, dtype=torch.bool
)
for row, (_, window) in enumerate(chunk):
seq_len = len(window.tokens)
tokens_tensor[row, :seq_len] = torch.tensor(
list(window.tokens), dtype=torch.int32, device=runtime.device
)
mask_tensor[row, :seq_len] = True
logits = runtime.model(tokens_tensor, attention_mask=mask_tensor)
if logits.shape[:2] != tokens_tensor.shape:
raise ValueError(
"Logit output batch/sequence shape does not match input shape: "
f"logits={tuple(logits.shape[:2])} input={tuple(tokens_tensor.shape)}"
)
log_probs = F.log_softmax(logits.float(), dim=-1).cpu()
for row, (text_idx, window) in enumerate(chunk):
seq_len = len(window.tokens)
_accumulate_window(
aggregations[text_idx],
window,
log_probs[row, :seq_len],
f"batch-{text_idx}",
)