From e3e03db0a39890b81f90c80f295b6048990bdb1a Mon Sep 17 00:00:00 2001 From: Abimael Martell Date: Tue, 12 May 2026 14:53:00 -0400 Subject: [PATCH] perf(runtime): fast-path predict_text for non-overlapping windows --- opf/_core/runtime.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/opf/_core/runtime.py b/opf/_core/runtime.py index 2c3034e..8d42be3 100644 --- a/opf/_core/runtime.py +++ b/opf/_core/runtime.py @@ -297,6 +297,27 @@ def predict_text( if log_probs.shape[0] != len(window.tokens): raise ValueError("Logprob output length does not match window length") + # Fast path for the common non-overlapping configuration + # (example_to_windows produces stride == window_size, so every + # token sits in exactly one window). Skip the per-token Python + # loop entirely and bulk-extend the aggregation lists. On + # long inputs this removes O(N_tokens) iterations per window, + # each running ensure_capacity / logaddexp / record_token_id + # on count == 1 data — purely no-op work. + first_offset = int(window.offsets[0]) + if ( + all(window.mask) + and first_offset == aggregation.length + and tuple(int(o) for o in window.offsets) + == tuple(range(first_offset, first_offset + len(window.tokens))) + ): + aggregation.logprob_logsumexp.extend(log_probs.unbind(0)) + aggregation.counts.extend([1] * len(window.tokens)) + aggregation.labels.extend([None] * len(window.tokens)) + aggregation.token_ids.extend(int(t) for t in window.tokens) + aggregation.length = first_offset + len(window.tokens) + continue + for token_pos, is_valid in enumerate(window.mask): if not bool(is_valid): continue