diff --git a/utils/bench_serving/backend_request_func.py b/utils/bench_serving/backend_request_func.py index 32331a398..af030720e 100644 --- a/utils/bench_serving/backend_request_func.py +++ b/utils/bench_serving/backend_request_func.py @@ -439,6 +439,76 @@ def get_model(pretrained_model_name_or_path: str) -> str: return pretrained_model_name_or_path +def _fix_tokenizer_for_sglang(tokenizer, model_path): + """Fix transformers v5 tokenizer to match sglang server-side behavior. + + Root cause: transformers v5 (>= 5.0) changed how tokenizers are loaded. + Specifically, LlamaTokenizerFast.__init__ in v5 rebuilds the pre_tokenizer + and decoder from scratch using class-specific components, discarding the + originals from tokenizer.json. For models like DeepSeek-R1 that declare + LlamaTokenizerFast but actually use a ByteLevel/Sequence tokenizer + architecture, v5 incorrectly replaces the original Sequence pre_tokenizer + with Metaspace, and the original ByteLevel decoder with Sequence. + See: https://github.com/sgl-project/sglang/blob/9238bd08a2895fa3b7ec79ea567e5c27ac951343/python/sglang/srt/utils/hf_transformers_utils.py#L836 + + The sglang server applies fixes for this in hf_transformers_utils.py + (_fix_v5_tokenizer_components and _fix_v5_add_bos_eos_token), but the + benchmark client loads the tokenizer directly via AutoTokenizer without + these fixes. This mismatch causes the client to encode text differently + from the server -- e.g. a 7000-token prompt on the client becomes ~35000 + tokens on the server, leading to ~5x TTFT inflation and false performance + regressions in benchmarks. + + This function replicates the same fixes so the benchmark client tokenizes + identically to the sglang server. It is a no-op on transformers v4. + """ + import json + from pathlib import Path + + backend = getattr(tokenizer, "_tokenizer", None) + if backend is not None: + try: + from tokenizers import Tokenizer as RawTokenizer + tok_file = Path(model_path) / "tokenizer.json" + if tok_file.is_file(): + raw = RawTokenizer.from_file(str(tok_file)) + raw_pre = type(raw.pre_tokenizer).__name__ if raw.pre_tokenizer else None + loaded_pre = type(backend.pre_tokenizer).__name__ if backend.pre_tokenizer else None + if raw_pre and loaded_pre and raw_pre != loaded_pre: + backend.pre_tokenizer = raw.pre_tokenizer + backend.decoder = raw.decoder + except Exception: + pass + + try: + config_file = Path(model_path) / "tokenizer_config.json" + if config_file.is_file(): + with open(config_file) as f: + config = json.load(f) + tok_class = config.get("tokenizer_class", "") + bos_eos_classes = { + "LlamaTokenizer", "LlamaTokenizerFast", + "CodeLlamaTokenizer", "CodeLlamaTokenizerFast", + "GemmaTokenizer", "GemmaTokenizerFast", "CohereTokenizerFast", + } + if tok_class in bos_eos_classes: + defaults = {"add_bos_token": True, "add_eos_token": False} + changed = False + for attr in ("add_bos_token", "add_eos_token"): + val = config.get(attr) + if val is None: + val = defaults.get(attr, False) + if getattr(tokenizer, attr, None) != val: + setattr(tokenizer, f"_{attr}", val) + changed = True + if changed and hasattr(tokenizer, "update_post_processor"): + tokenizer.update_post_processor() + except Exception: + pass + + return tokenizer + + def get_tokenizer( pretrained_model_name_or_path: str, tokenizer_mode: str = "auto", @@ -464,11 +534,12 @@ def get_tokenizer( return MistralTokenizer.from_pretrained( str(pretrained_model_name_or_path)) else: - return AutoTokenizer.from_pretrained( + tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs, ) + return _fix_tokenizer_for_sglang(tokenizer, pretrained_model_name_or_path) ASYNC_REQUEST_FUNCS = { @@ -481,4 +552,4 @@ def get_tokenizer( "tensorrt-llm": async_request_trt_llm, "scalellm": async_request_openai_completions, "sglang": async_request_openai_completions, -} \ No newline at end of file +}