Skip to content

Commit fb3072d

Browse files
committed
perf(hybrid): optimize multimodal single-turn and fix KV clear bug
- Added a 100% match "FAST PATH" in Llama.generate to bypass N-1 truncation for hybrid models when caching is disabled. - Fixed a bug where failed rollbacks on disabled caches would wipe the KV cache, causing multimodal pseudo-token crashes. - Updated MTMDChatHandler to suppress cache-related logs and anchoring logic when max_checkpoints <= 0.
1 parent 850ed2e commit fb3072d

2 files changed

Lines changed: 95 additions & 54 deletions

File tree

llama_cpp/llama.py

Lines changed: 77 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,59 +1170,91 @@ def generate(
11701170
original_tokens = list(tokens)
11711171
# Check for kv cache prefix match
11721172
if reset and self.n_tokens > 0:
1173-
longest_prefix = self.longest_token_prefix(self._input_ids, tokens[:-1])
1174-
if longest_prefix > 0:
1173+
# 1. First, check for a 100% exact match of the entire sequence
1174+
full_match_prefix = self.longest_token_prefix(self._input_ids, tokens)
1175+
1176+
# --- FAST PATH: Zero-latency bypass for Hybrid Single-Turn & Multimodal ---
1177+
# If the cache is disabled (max_checkpoints <= 0) and we have a 100% match,
1178+
# we completely skip the N-1 truncation. This ensures that multimodal handlers
1179+
# (which just finished evaluating and already hold fresh logits) don't trigger
1180+
# unnecessary N-1 rollbacks or catastrophic KV cache clears.
1181+
if (
1182+
full_match_prefix == len(tokens)
1183+
and full_match_prefix == self.n_tokens
1184+
and self.is_hybrid
1185+
and (self._hybrid_cache_mgr is None or self._hybrid_cache_mgr.max_checkpoints <= 0)
1186+
):
11751187
reset = False
1188+
longest_prefix = len(tokens)
1189+
tokens = tokens[longest_prefix:] # Empties the tokens array to bypass evaluation
1190+
if self.verbose:
1191+
print(f"Llama.generate: Hybrid single-turn full match ({longest_prefix} tokens). Bypassing rollback/truncation.", file=sys.stderr)
11761192

1177-
if longest_prefix == len(tokens):
1178-
if self.verbose:
1179-
print(f"Llama.generate: Full match. Forcing prefix-- to evaluate 1 token.", file=sys.stderr)
1180-
longest_prefix -= 1
1193+
# --- STANDARD PATH: Force N-1 re-evaluation ---
1194+
else:
1195+
# By matching against `tokens[:-1]`, we intentionally drop the last token.
1196+
# This forces the engine to re-evaluate the final token to refresh sampling logits.
1197+
longest_prefix = self.longest_token_prefix(self._input_ids, tokens[:-1])
11811198

1182-
# Physically erase trailing "ghost" tokens from the C++ KV cache
1183-
# to prevent attention misalignment in multi-round chats.
1184-
if longest_prefix < self.n_tokens:
1185-
if self.is_hybrid and self._hybrid_cache_mgr is not None:
1186-
if self.verbose:
1187-
print(f"Llama.generate: Hybrid model rollback triggered.", file=sys.stderr)
1199+
if longest_prefix > 0:
1200+
reset = False
11881201

1189-
best_ckpt = self._hybrid_cache_mgr.find_best_checkpoint(original_tokens, 0)
1190-
if best_ckpt is not None and self._hybrid_cache_mgr.restore_checkpoint(best_ckpt, seq_id=0):
1191-
actual_prefix = best_ckpt.pos
1202+
# Note: Kept for legacy compatibility. Triggers if the prefix matching
1203+
# somehow equals the full token length (e.g., edge cases in tokenization).
1204+
if longest_prefix == len(tokens):
1205+
if self.is_hybrid and (self._hybrid_cache_mgr is None or self._hybrid_cache_mgr.max_checkpoints <= 0):
1206+
if self.verbose:
1207+
print(f"Llama.generate: Full match on disabled hybrid cache. Skipping prefix-- to use existing fresh logits.", file=sys.stderr)
11921208
else:
1193-
actual_prefix = 0
1194-
self._hybrid_cache_mgr.clear()
1195-
self._ctx.memory_clear(True)
1209+
if self.verbose:
1210+
print(f"Llama.generate: Full match. Forcing prefix-- to evaluate 1 token.", file=sys.stderr)
1211+
longest_prefix -= 1
11961212

1197-
self.n_tokens = actual_prefix
1198-
tokens = original_tokens[actual_prefix:]
1199-
if self.verbose:
1200-
print(
1201-
f"Llama.generate: {actual_prefix} prefix-match hit, "
1202-
f"remaining {len(tokens)} prompt tokens to eval",
1203-
file=sys.stderr,
1204-
)
1205-
else:
1206-
if self.verbose:
1207-
print(f"Llama.generate: Truncating KV cache size from {self.n_tokens} to {longest_prefix}", file=sys.stderr)
1208-
self._ctx.memory_seq_rm(0, longest_prefix, -1)
1213+
# Physically erase trailing "ghost" tokens from the C++ KV cache
1214+
# to prevent attention misalignment in multi-round chats.
1215+
if longest_prefix < self.n_tokens:
1216+
if self.is_hybrid and self._hybrid_cache_mgr is not None:
1217+
if self.verbose:
1218+
print(f"Llama.generate: Hybrid model rollback triggered.", file=sys.stderr)
12091219

1210-
# Adjust the tokens array and cursor to reuse the matched cache
1211-
self.n_tokens = longest_prefix
1212-
tokens = tokens[longest_prefix:]
1220+
best_ckpt = self._hybrid_cache_mgr.find_best_checkpoint(original_tokens, 0)
1221+
if best_ckpt is not None and self._hybrid_cache_mgr.restore_checkpoint(best_ckpt, seq_id=0):
1222+
actual_prefix = best_ckpt.pos
1223+
else:
1224+
# Fallback: No checkpoint found, must fully clear the context to prevent poisoning
1225+
actual_prefix = 0
1226+
self._hybrid_cache_mgr.clear()
1227+
self._ctx.memory_clear(True)
12131228

1214-
if self.verbose:
1215-
print(
1216-
f"Llama.generate: {longest_prefix} prefix-match hit, "
1217-
f"remaining {len(tokens)} prompt tokens to eval",
1218-
file=sys.stderr,
1219-
)
1220-
else:
1221-
# No prefix matched. Completely clear the KV cache to prevent context poisoning.
1222-
self.n_tokens = 0
1223-
self._ctx.memory_clear(True)
1224-
if self.is_hybrid and self._hybrid_cache_mgr is not None:
1225-
self._hybrid_cache_mgr.clear()
1229+
self.n_tokens = actual_prefix
1230+
tokens = original_tokens[actual_prefix:]
1231+
if self.verbose:
1232+
print(
1233+
f"Llama.generate: {actual_prefix} prefix-match hit, "
1234+
f"remaining {len(tokens)} prompt tokens to eval",
1235+
file=sys.stderr,
1236+
)
1237+
else:
1238+
if self.verbose:
1239+
print(f"Llama.generate: Truncating KV cache size from {self.n_tokens} to {longest_prefix}", file=sys.stderr)
1240+
self._ctx.memory_seq_rm(0, longest_prefix, -1)
1241+
1242+
# Adjust the tokens array and cursor to reuse the matched cache
1243+
self.n_tokens = longest_prefix
1244+
tokens = tokens[longest_prefix:]
1245+
1246+
if self.verbose:
1247+
print(
1248+
f"Llama.generate: {longest_prefix} prefix-match hit, "
1249+
f"remaining {len(tokens)} prompt tokens to eval",
1250+
file=sys.stderr,
1251+
)
1252+
else:
1253+
# No prefix matched at all. Completely clear the KV cache to prevent context poisoning.
1254+
self.n_tokens = 0
1255+
self._ctx.memory_clear(True)
1256+
if self.is_hybrid and self._hybrid_cache_mgr is not None:
1257+
self._hybrid_cache_mgr.clear()
12261258

12271259
# Reset mirostat sampling
12281260
params = LlamaSamplingParams(

llama_cpp/llama_chat_format.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3271,15 +3271,20 @@ def __call__(
32713271

32723272
if longest_prefix < llama.n_tokens:
32733273
if llama.is_hybrid and llama._hybrid_cache_mgr is not None:
3274-
if self.verbose:
3275-
print(f"{self.log_prefix}(__call__): Hybrid prefix mismatch (matched {longest_prefix}/{llama.n_tokens}). "
3276-
f"Searching for nearest checkpoint...", file=sys.stderr)
3277-
3278-
best_ckpt = llama._hybrid_cache_mgr.find_best_checkpoint(full_prompt_ids, seq_id=0)
3279-
if best_ckpt and llama._hybrid_cache_mgr.restore_checkpoint(best_ckpt, seq_id=0):
3280-
llama.n_tokens = best_ckpt.pos
3274+
if llama._hybrid_cache_mgr.max_checkpoints > 0:
32813275
if self.verbose:
3282-
print(f"{self.log_prefix}(__call__): Successfully rolled back to checkpoint at pos {llama.n_tokens}.", file=sys.stderr)
3276+
print(f"{self.log_prefix}(__call__): Hybrid prefix mismatch (matched {longest_prefix}/{llama.n_tokens}). "
3277+
f"Searching for nearest checkpoint...", file=sys.stderr)
3278+
3279+
best_ckpt = llama._hybrid_cache_mgr.find_best_checkpoint(full_prompt_ids, seq_id=0)
3280+
if best_ckpt and llama._hybrid_cache_mgr.restore_checkpoint(best_ckpt, seq_id=0):
3281+
llama.n_tokens = best_ckpt.pos
3282+
if self.verbose:
3283+
print(f"{self.log_prefix}(__call__): Successfully rolled back to checkpoint at pos {llama.n_tokens}.", file=sys.stderr)
3284+
else:
3285+
llama._hybrid_cache_mgr.clear()
3286+
llama._ctx.memory_clear(True)
3287+
llama.n_tokens = 0
32833288
else:
32843289
llama._hybrid_cache_mgr.clear()
32853290
llama._ctx.memory_clear(True)
@@ -3382,7 +3387,11 @@ def __call__(
33823387

33833388
# End-of-Turn Checkpoint
33843389
# Anchors the state ONLY after the entire multi-modal turn is processed
3385-
if llama.is_hybrid and llama._hybrid_cache_mgr is not None:
3390+
if (
3391+
llama.is_hybrid
3392+
and llama._hybrid_cache_mgr is not None
3393+
and llama._hybrid_cache_mgr.max_checkpoints > 0
3394+
):
33863395
if self.verbose:
33873396
print(f"{self.log_prefix}(__call__): [End-of-Turn Checkpoint] Anchoring full prompt state at pos {llama.n_tokens}.", file=sys.stderr)
33883397

0 commit comments

Comments
 (0)