From 8c9e152ca2ce5e9fcdbe5f765e324d8bdb45c80b Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Tue, 5 May 2026 13:25:35 +0430 Subject: [PATCH] perf: optimize GDN performance on Metal --- aphrodite/metal/v1/cache_policy.py | 69 +++++++++-- aphrodite/metal/v1/model_adapter.py | 14 +-- aphrodite/metal/v1/model_runner.py | 173 +++++++++++++++++++++++++--- 3 files changed, 220 insertions(+), 36 deletions(-) diff --git a/aphrodite/metal/v1/cache_policy.py b/aphrodite/metal/v1/cache_policy.py index 24dbd30446..f8df7b7fe1 100644 --- a/aphrodite/metal/v1/cache_policy.py +++ b/aphrodite/metal/v1/cache_policy.py @@ -636,19 +636,46 @@ def determine_available_memory(self) -> int: "Paged attention backend not initialized for capacity reporting" ) block_size_bytes = self._worker.get_cache_block_size_bytes() - available = backend.num_blocks() * block_size_bytes - logger.info( - "Paged attention: reporting MPS cache capacity " - "(%d blocks × %d bytes = %.2f GB)", - backend.num_blocks(), - block_size_bytes, - available / 1e9, - ) + paged_available = backend.num_blocks() * block_size_bytes + linear_available = 0 + if self._worker.model_runner.is_hybrid: + linear_available = self._hybrid_scheduler_linear_capacity( + block_size_bytes + ) + available = paged_available + linear_available + if linear_available: + logger.info( + "Paged attention: reporting MPS cache capacity " + "(%d blocks × %d bytes + %.2f GB hybrid linear state " + "= %.2f GB)", + backend.num_blocks(), + block_size_bytes, + linear_available / 1e9, + available / 1e9, + ) + else: + logger.info( + "Paged attention: reporting MPS cache capacity " + "(%d blocks × %d bytes = %.2f GB)", + backend.num_blocks(), + block_size_bytes, + available / 1e9, + ) return available - one_sequence_bytes = self._worker._one_sequence_kv_bytes() max_num_seqs = self._worker.model_runner.scheduler_config.max_num_seqs - available = one_sequence_bytes * max_num_seqs + if self._worker.model_runner.is_hybrid: + block_size = self._worker.aphrodite_config.cache_config.block_size + max_model_len = self._worker.model_config.max_model_len + num_blocks = -(-max_model_len // block_size) + block_size_bytes = self._worker.get_cache_block_size_bytes() + available = ( + num_blocks * block_size_bytes * max_num_seqs + + self._hybrid_scheduler_linear_capacity(block_size_bytes) + ) + else: + one_sequence_bytes = self._worker._one_sequence_kv_bytes() + available = one_sequence_bytes * max_num_seqs logger.info( "MLX path: reporting %.2f GB for scheduler admission control " "(%d max-length sequence%s, max_model_len=%d)", @@ -659,6 +686,28 @@ def determine_available_memory(self) -> int: ) return available + def _hybrid_scheduler_linear_capacity(self, block_size_bytes: int) -> int: + """Return scheduler-visible bytes for hybrid linear state. + + The Metal GDN state cache stores the compact recurrent tensors, but the + v1 scheduler unifies hybrid page sizes so each linear layer is budgeted + as one padded page per active sequence. Capacity reporting must mirror + that padded scheduler view or startup can fail after the Metal cache has + already been allocated successfully. + """ + runner = self._worker.model_runner + if runner.num_sdpa_layers <= 0: + return ( + runner.linear_cache_bytes_per_slot() + * runner.scheduler_config.max_num_seqs + ) + attention_page_size = block_size_bytes // runner.num_sdpa_layers + return ( + runner.num_linear_layers + * attention_page_size + * runner.scheduler_config.max_num_seqs + ) + def _paged_attention_plan(self, *, overhead: int) -> _PagedAttentionPlan: block_size = self._worker.aphrodite_config.cache_config.block_size fraction = self._memory_fraction() diff --git a/aphrodite/metal/v1/model_adapter.py b/aphrodite/metal/v1/model_adapter.py index d2323864a9..17dc24c403 100644 --- a/aphrodite/metal/v1/model_adapter.py +++ b/aphrodite/metal/v1/model_adapter.py @@ -64,9 +64,10 @@ def build_sliding_window_per_layer( # gemma4: mlx_vlm forward path produces garbled output vs mlx_lm. _TEXT_BACKBONE_OVERRIDE_TYPES: frozenset[str] = frozenset({"gemma4"}) # Qwen3.5/Qwen3.6 conditional-generation wrappers expose a multimodal config, -# but vllm-metal only serves them in text-only mode. Their FP8 checkpoints ship -# `*_weight_scale_inv` tensors that the mlx_vlm qwen3_5 loader does not -# currently sanitize, while mlx_lm's qwen3_5 text loader handles them. +# but vllm-metal only serves them in text-only mode. Route them through +# mlx_lm's qwen3_5 text loader; the mlx_vlm wrapper adds multimodal processing +# overhead and some local text-only MLX checkpoints do not behave correctly +# through the VLM forward path. _TEXT_BACKBONE_OVERRIDE_ARCHITECTURES: frozenset[str] = frozenset( { "Qwen3_5ForConditionalGeneration", @@ -109,12 +110,7 @@ def _matches_auto_text_backbone_override(self, hf_config: Any) -> bool: ): return False - quantization_config = getattr(hf_config, "quantization_config", None) - if isinstance(quantization_config, dict): - quant_method = quantization_config.get("quant_method") - else: - quant_method = getattr(quantization_config, "quant_method", None) - return quant_method == "fp8" + return True def should_force_text_backbone(self, hf_config: Any) -> bool: """Whether the current serve mode should use the text-only path. diff --git a/aphrodite/metal/v1/model_runner.py b/aphrodite/metal/v1/model_runner.py index 9ed968eff9..283eb7ccf4 100644 --- a/aphrodite/metal/v1/model_runner.py +++ b/aphrodite/metal/v1/model_runner.py @@ -18,7 +18,7 @@ import mlx.core as mx import torch from mlx_lm import stream_generate -from mlx_lm.generate import generate_step +from mlx_lm.generate import generate_step, generation_stream from aphrodite.config import AphroditeConfig from aphrodite.logger import init_logger from aphrodite.sampling_params import SamplingParams @@ -251,6 +251,54 @@ def sampler(logprobs: mx.array) -> mx.array: return sampler +def _generate_greedy_step_no_logprobs( + prompt: mx.array, + model: Any, + *, + prompt_cache: list[Any], + prefill_step_size: int = 2048, +): + """MLX greedy generation without per-token logprob materialization.""" + if len(prompt) == 0: + raise ValueError("prompt must not be empty") + + def _model_call(input_tokens: mx.array) -> mx.array: + output = model(input_tokens[None], cache=prompt_cache) + return output.logits if hasattr(output, "logits") else output + + def _step(input_tokens: mx.array) -> mx.array: + with mx.stream(generation_stream): + logits = _model_call(input_tokens)[:, -1, :] + return mx.argmax(logits, axis=-1) + + with mx.stream(generation_stream): + total_prompt_tokens = len(prompt) + prompt_processed_tokens = 0 + while total_prompt_tokens - prompt_processed_tokens > 1: + remaining = (total_prompt_tokens - prompt_processed_tokens) - 1 + n_to_process = min(prefill_step_size, remaining) + _model_call(prompt[:n_to_process]) + mx.eval([c.state for c in prompt_cache]) + prompt_processed_tokens += n_to_process + prompt = prompt[n_to_process:] + mx.clear_cache() + + y = _step(prompt) + + mx.async_eval(y) + n = 0 + while True: + next_y = _step(y) + mx.async_eval(next_y) + if n == 0: + mx.eval(y) + yield y.item(), None + if n > 0 and n % 256 == 0: + mx.clear_cache() + y = next_y + n += 1 + + def _slice_logprobs_row( logprobs_tensors: LogprobsTensors | None, index: int, @@ -673,10 +721,25 @@ def _make_mlx_lm_generator( self, token_ids: list[int], sampling_params: SamplingParams, + *, + allow_paged_hybrid: bool = False, ) -> tuple[Any, list[AnyCache]] | None: """Create raw MLX generation iterator for the common single-request path.""" - if self._prefix_cache is not None or self._paged_attention_backend is not None: + if self._prefix_cache is not None: + if self.metal_config.debug: + logger.info("Metal raw MLX generator skipped: prefix cache enabled") return None + if self._paged_attention_backend is not None: + if not ( + allow_paged_hybrid + and self.is_hybrid + and self.scheduler_config.max_num_seqs == 1 + ): + if self.metal_config.debug: + logger.info( + "Metal raw MLX generator skipped: paged attention enabled" + ) + return None batch = SamplingBatch( [sampling_params], @@ -690,25 +753,75 @@ def _make_mlx_lm_generator( batch.can_use_native_greedy_for_batch() or batch.can_use_native_random_for_batch() ): + if self.metal_config.debug: + logger.info( + "Metal raw MLX generator skipped: unsupported sampling params %s", + sampling_params, + ) return None - sampler = _make_top_k_first_mlx_sampler( - temperature=sampling_params.temperature, - top_p=sampling_params.top_p, - top_k=sampling_params.top_k, - vocab_size=self._vocab_size, - ) - cache = contiguous_cache.make_prompt_cache(self._forward_model) - generator = generate_step( - mx.array(token_ids, dtype=mx.int32), - self._forward_model, - max_tokens=-1, - sampler=sampler, - prompt_cache=cache, - ) + prompt = mx.array(token_ids, dtype=mx.int32) + if batch.can_use_native_greedy_for_batch(): + generator = _generate_greedy_step_no_logprobs( + prompt, + self._forward_model, + prompt_cache=cache, + ) + else: + sampler = _make_top_k_first_mlx_sampler( + temperature=sampling_params.temperature, + top_p=sampling_params.top_p, + top_k=sampling_params.top_k, + vocab_size=self._vocab_size, + ) + generator = generate_step( + prompt, + self._forward_model, + max_tokens=-1, + sampler=sampler, + prompt_cache=cache, + ) + if self.metal_config.debug: + logger.info( + "Metal raw MLX generator enabled: prompt_tokens=%d temperature=%s " + "top_p=%s top_k=%s", + len(token_ids), + sampling_params.temperature, + sampling_params.top_p, + sampling_params.top_k, + ) return generator, cache + def _try_prefill_paged_hybrid_mlx_generator( + self, + req_id: str, + token_ids: list[int], + sampling_params: SamplingParams, + generator: torch.Generator | None, + ) -> tuple[int, list[AnyCache], LogprobsTensors | None] | None: + """Fast-path one hybrid request through mlx_lm while paged KV is active.""" + if ( + self._paged_attention_backend is None + or not self.is_hybrid + or self.scheduler_config.max_num_seqs != 1 + or generator is not None + ): + return None + + fast_generator = self._make_mlx_lm_generator( + token_ids, + sampling_params, + allow_paged_hybrid=True, + ) + if fast_generator is None: + return None + + generator_iter, cache = fast_generator + next_token, _logprobs = next(generator_iter) + self._pending_mlx_generators[req_id] = generator_iter + return int(next_token), cache, None + def _prefill_single( self, req_id: str, @@ -1202,6 +1315,27 @@ def _handle_new_requests( if self._paged_attention_backend is not None: sched_block_ids = list(new_req.block_ids[0]) + fast_result = self._try_prefill_paged_hybrid_mlx_generator( + req_id, + token_ids, + sampling_params, + generator, + ) + if fast_result is not None: + next_token, cache, logprobs_tensors = fast_result + batch.add_output(req_id, [next_token], logprobs_tensors) + self._request_states[req_id] = RequestState( + token_ids=list(token_ids) + [next_token], + prompt_len=len(token_ids), + cache=cache, + sampling_params=sampling_params, + generator=generator, + mlx_generator=self._pending_mlx_generators.pop(req_id, None), + generated_tokens=1, + block_ids=sched_block_ids, + ) + continue + scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] computed_tokens = new_req.num_computed_tokens prompt_len = len(token_ids) @@ -1319,6 +1453,11 @@ def _collect_cached_requests( batch.add_output(req_id, [0]) continue + if state.mlx_generator is not None: + batch.scheduled_cached_req_ids.append(req_id) + batch.valid_decode_reqs.append((req_id, state)) + continue + if state.generated_tokens == 0: computed_tokens = cached_reqs.num_computed_tokens[idx] scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] @@ -1581,7 +1720,7 @@ def execute_model( "invariant violated." ) - if self._paged_attention_backend is None: + if batch.valid_decode_reqs: self._run_non_paged_decode_batch(batch) # Non-paged path: complete synchronously