@@ -1170,59 +1170,91 @@ def generate(
11701170 original_tokens = list (tokens )
11711171 # Check for kv cache prefix match
11721172 if reset and self .n_tokens > 0 :
1173- longest_prefix = self .longest_token_prefix (self ._input_ids , tokens [:- 1 ])
1174- if longest_prefix > 0 :
1173+ # 1. First, check for a 100% exact match of the entire sequence
1174+ full_match_prefix = self .longest_token_prefix (self ._input_ids , tokens )
1175+
1176+ # --- FAST PATH: Zero-latency bypass for Hybrid Single-Turn & Multimodal ---
1177+ # If the cache is disabled (max_checkpoints <= 0) and we have a 100% match,
1178+ # we completely skip the N-1 truncation. This ensures that multimodal handlers
1179+ # (which just finished evaluating and already hold fresh logits) don't trigger
1180+ # unnecessary N-1 rollbacks or catastrophic KV cache clears.
1181+ if (
1182+ full_match_prefix == len (tokens )
1183+ and full_match_prefix == self .n_tokens
1184+ and self .is_hybrid
1185+ and (self ._hybrid_cache_mgr is None or self ._hybrid_cache_mgr .max_checkpoints <= 0 )
1186+ ):
11751187 reset = False
1188+ longest_prefix = len (tokens )
1189+ tokens = tokens [longest_prefix :] # Empties the tokens array to bypass evaluation
1190+ if self .verbose :
1191+ print (f"Llama.generate: Hybrid single-turn full match ({ longest_prefix } tokens). Bypassing rollback/truncation." , file = sys .stderr )
11761192
1177- if longest_prefix == len (tokens ):
1178- if self .verbose :
1179- print (f"Llama.generate: Full match. Forcing prefix-- to evaluate 1 token." , file = sys .stderr )
1180- longest_prefix -= 1
1193+ # --- STANDARD PATH: Force N-1 re-evaluation ---
1194+ else :
1195+ # By matching against `tokens[:-1]`, we intentionally drop the last token.
1196+ # This forces the engine to re-evaluate the final token to refresh sampling logits.
1197+ longest_prefix = self .longest_token_prefix (self ._input_ids , tokens [:- 1 ])
11811198
1182- # Physically erase trailing "ghost" tokens from the C++ KV cache
1183- # to prevent attention misalignment in multi-round chats.
1184- if longest_prefix < self .n_tokens :
1185- if self .is_hybrid and self ._hybrid_cache_mgr is not None :
1186- if self .verbose :
1187- print (f"Llama.generate: Hybrid model rollback triggered." , file = sys .stderr )
1199+ if longest_prefix > 0 :
1200+ reset = False
11881201
1189- best_ckpt = self ._hybrid_cache_mgr .find_best_checkpoint (original_tokens , 0 )
1190- if best_ckpt is not None and self ._hybrid_cache_mgr .restore_checkpoint (best_ckpt , seq_id = 0 ):
1191- actual_prefix = best_ckpt .pos
1202+ # Note: Kept for legacy compatibility. Triggers if the prefix matching
1203+ # somehow equals the full token length (e.g., edge cases in tokenization).
1204+ if longest_prefix == len (tokens ):
1205+ if self .is_hybrid and (self ._hybrid_cache_mgr is None or self ._hybrid_cache_mgr .max_checkpoints <= 0 ):
1206+ if self .verbose :
1207+ print (f"Llama.generate: Full match on disabled hybrid cache. Skipping prefix-- to use existing fresh logits." , file = sys .stderr )
11921208 else :
1193- actual_prefix = 0
1194- self . _hybrid_cache_mgr . clear ( )
1195- self . _ctx . memory_clear ( True )
1209+ if self . verbose :
1210+ print ( f"Llama.generate: Full match. Forcing prefix-- to evaluate 1 token." , file = sys . stderr )
1211+ longest_prefix -= 1
11961212
1197- self .n_tokens = actual_prefix
1198- tokens = original_tokens [actual_prefix :]
1199- if self .verbose :
1200- print (
1201- f"Llama.generate: { actual_prefix } prefix-match hit, "
1202- f"remaining { len (tokens )} prompt tokens to eval" ,
1203- file = sys .stderr ,
1204- )
1205- else :
1206- if self .verbose :
1207- print (f"Llama.generate: Truncating KV cache size from { self .n_tokens } to { longest_prefix } " , file = sys .stderr )
1208- self ._ctx .memory_seq_rm (0 , longest_prefix , - 1 )
1213+ # Physically erase trailing "ghost" tokens from the C++ KV cache
1214+ # to prevent attention misalignment in multi-round chats.
1215+ if longest_prefix < self .n_tokens :
1216+ if self .is_hybrid and self ._hybrid_cache_mgr is not None :
1217+ if self .verbose :
1218+ print (f"Llama.generate: Hybrid model rollback triggered." , file = sys .stderr )
12091219
1210- # Adjust the tokens array and cursor to reuse the matched cache
1211- self .n_tokens = longest_prefix
1212- tokens = tokens [longest_prefix :]
1220+ best_ckpt = self ._hybrid_cache_mgr .find_best_checkpoint (original_tokens , 0 )
1221+ if best_ckpt is not None and self ._hybrid_cache_mgr .restore_checkpoint (best_ckpt , seq_id = 0 ):
1222+ actual_prefix = best_ckpt .pos
1223+ else :
1224+ # Fallback: No checkpoint found, must fully clear the context to prevent poisoning
1225+ actual_prefix = 0
1226+ self ._hybrid_cache_mgr .clear ()
1227+ self ._ctx .memory_clear (True )
12131228
1214- if self .verbose :
1215- print (
1216- f"Llama.generate: { longest_prefix } prefix-match hit, "
1217- f"remaining { len (tokens )} prompt tokens to eval" ,
1218- file = sys .stderr ,
1219- )
1220- else :
1221- # No prefix matched. Completely clear the KV cache to prevent context poisoning.
1222- self .n_tokens = 0
1223- self ._ctx .memory_clear (True )
1224- if self .is_hybrid and self ._hybrid_cache_mgr is not None :
1225- self ._hybrid_cache_mgr .clear ()
1229+ self .n_tokens = actual_prefix
1230+ tokens = original_tokens [actual_prefix :]
1231+ if self .verbose :
1232+ print (
1233+ f"Llama.generate: { actual_prefix } prefix-match hit, "
1234+ f"remaining { len (tokens )} prompt tokens to eval" ,
1235+ file = sys .stderr ,
1236+ )
1237+ else :
1238+ if self .verbose :
1239+ print (f"Llama.generate: Truncating KV cache size from { self .n_tokens } to { longest_prefix } " , file = sys .stderr )
1240+ self ._ctx .memory_seq_rm (0 , longest_prefix , - 1 )
1241+
1242+ # Adjust the tokens array and cursor to reuse the matched cache
1243+ self .n_tokens = longest_prefix
1244+ tokens = tokens [longest_prefix :]
1245+
1246+ if self .verbose :
1247+ print (
1248+ f"Llama.generate: { longest_prefix } prefix-match hit, "
1249+ f"remaining { len (tokens )} prompt tokens to eval" ,
1250+ file = sys .stderr ,
1251+ )
1252+ else :
1253+ # No prefix matched at all. Completely clear the KV cache to prevent context poisoning.
1254+ self .n_tokens = 0
1255+ self ._ctx .memory_clear (True )
1256+ if self .is_hybrid and self ._hybrid_cache_mgr is not None :
1257+ self ._hybrid_cache_mgr .clear ()
12261258
12271259 # Reset mirostat sampling
12281260 params = LlamaSamplingParams (
0 commit comments