Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 59 additions & 10 deletions aphrodite/metal/v1/cache_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,19 +636,46 @@ def determine_available_memory(self) -> int:
"Paged attention backend not initialized for capacity reporting"
)
block_size_bytes = self._worker.get_cache_block_size_bytes()
available = backend.num_blocks() * block_size_bytes
logger.info(
"Paged attention: reporting MPS cache capacity "
"(%d blocks × %d bytes = %.2f GB)",
backend.num_blocks(),
block_size_bytes,
available / 1e9,
)
paged_available = backend.num_blocks() * block_size_bytes
linear_available = 0
if self._worker.model_runner.is_hybrid:
linear_available = self._hybrid_scheduler_linear_capacity(
block_size_bytes
)
available = paged_available + linear_available
if linear_available:
logger.info(
"Paged attention: reporting MPS cache capacity "
"(%d blocks × %d bytes + %.2f GB hybrid linear state "
"= %.2f GB)",
backend.num_blocks(),
block_size_bytes,
linear_available / 1e9,
available / 1e9,
)
else:
logger.info(
"Paged attention: reporting MPS cache capacity "
"(%d blocks × %d bytes = %.2f GB)",
backend.num_blocks(),
block_size_bytes,
available / 1e9,
)
return available

one_sequence_bytes = self._worker._one_sequence_kv_bytes()
max_num_seqs = self._worker.model_runner.scheduler_config.max_num_seqs
available = one_sequence_bytes * max_num_seqs
if self._worker.model_runner.is_hybrid:
block_size = self._worker.aphrodite_config.cache_config.block_size
max_model_len = self._worker.model_config.max_model_len
num_blocks = -(-max_model_len // block_size)
block_size_bytes = self._worker.get_cache_block_size_bytes()
available = (
num_blocks * block_size_bytes * max_num_seqs
+ self._hybrid_scheduler_linear_capacity(block_size_bytes)
)
else:
one_sequence_bytes = self._worker._one_sequence_kv_bytes()
available = one_sequence_bytes * max_num_seqs
logger.info(
"MLX path: reporting %.2f GB for scheduler admission control "
"(%d max-length sequence%s, max_model_len=%d)",
Expand All @@ -659,6 +686,28 @@ def determine_available_memory(self) -> int:
)
return available

def _hybrid_scheduler_linear_capacity(self, block_size_bytes: int) -> int:
"""Return scheduler-visible bytes for hybrid linear state.

The Metal GDN state cache stores the compact recurrent tensors, but the
v1 scheduler unifies hybrid page sizes so each linear layer is budgeted
as one padded page per active sequence. Capacity reporting must mirror
that padded scheduler view or startup can fail after the Metal cache has
already been allocated successfully.
"""
runner = self._worker.model_runner
if runner.num_sdpa_layers <= 0:
return (
runner.linear_cache_bytes_per_slot()
* runner.scheduler_config.max_num_seqs
)
attention_page_size = block_size_bytes // runner.num_sdpa_layers
return (
runner.num_linear_layers
* attention_page_size
* runner.scheduler_config.max_num_seqs
)

def _paged_attention_plan(self, *, overhead: int) -> _PagedAttentionPlan:
block_size = self._worker.aphrodite_config.cache_config.block_size
fraction = self._memory_fraction()
Expand Down
14 changes: 5 additions & 9 deletions aphrodite/metal/v1/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ def build_sliding_window_per_layer(
# gemma4: mlx_vlm forward path produces garbled output vs mlx_lm.
_TEXT_BACKBONE_OVERRIDE_TYPES: frozenset[str] = frozenset({"gemma4"})
# Qwen3.5/Qwen3.6 conditional-generation wrappers expose a multimodal config,
# but vllm-metal only serves them in text-only mode. Their FP8 checkpoints ship
# `*_weight_scale_inv` tensors that the mlx_vlm qwen3_5 loader does not
# currently sanitize, while mlx_lm's qwen3_5 text loader handles them.
# but vllm-metal only serves them in text-only mode. Route them through
# mlx_lm's qwen3_5 text loader; the mlx_vlm wrapper adds multimodal processing
# overhead and some local text-only MLX checkpoints do not behave correctly
# through the VLM forward path.
_TEXT_BACKBONE_OVERRIDE_ARCHITECTURES: frozenset[str] = frozenset(
{
"Qwen3_5ForConditionalGeneration",
Expand Down Expand Up @@ -109,12 +110,7 @@ def _matches_auto_text_backbone_override(self, hf_config: Any) -> bool:
):
return False

quantization_config = getattr(hf_config, "quantization_config", None)
if isinstance(quantization_config, dict):
quant_method = quantization_config.get("quant_method")
else:
quant_method = getattr(quantization_config, "quant_method", None)
return quant_method == "fp8"
return True

def should_force_text_backbone(self, hf_config: Any) -> bool:
"""Whether the current serve mode should use the text-only path.
Expand Down
173 changes: 156 additions & 17 deletions aphrodite/metal/v1/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import mlx.core as mx
import torch
from mlx_lm import stream_generate
from mlx_lm.generate import generate_step
from mlx_lm.generate import generate_step, generation_stream
from aphrodite.config import AphroditeConfig
from aphrodite.logger import init_logger
from aphrodite.sampling_params import SamplingParams
Expand Down Expand Up @@ -251,6 +251,54 @@ def sampler(logprobs: mx.array) -> mx.array:
return sampler


def _generate_greedy_step_no_logprobs(
prompt: mx.array,
model: Any,
*,
prompt_cache: list[Any],
prefill_step_size: int = 2048,
):
"""MLX greedy generation without per-token logprob materialization."""
if len(prompt) == 0:
raise ValueError("prompt must not be empty")

def _model_call(input_tokens: mx.array) -> mx.array:
output = model(input_tokens[None], cache=prompt_cache)
return output.logits if hasattr(output, "logits") else output

def _step(input_tokens: mx.array) -> mx.array:
with mx.stream(generation_stream):
logits = _model_call(input_tokens)[:, -1, :]
return mx.argmax(logits, axis=-1)

with mx.stream(generation_stream):
total_prompt_tokens = len(prompt)
prompt_processed_tokens = 0
while total_prompt_tokens - prompt_processed_tokens > 1:
remaining = (total_prompt_tokens - prompt_processed_tokens) - 1
n_to_process = min(prefill_step_size, remaining)
_model_call(prompt[:n_to_process])
mx.eval([c.state for c in prompt_cache])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The mx.eval call here might fail if prompt_cache contains ArraysCache entries with None states, which can happen in hybrid models before full initialization. It is safer to filter out None values before evaluation to prevent runtime errors.

Suggested change
mx.eval([c.state for c in prompt_cache])
mx.eval([s for c in prompt_cache for s in (c.state if isinstance(c.state, list) else [c.state]) if s is not None])

prompt_processed_tokens += n_to_process
prompt = prompt[n_to_process:]
mx.clear_cache()

y = _step(prompt)

mx.async_eval(y)
n = 0
while True:
next_y = _step(y)
mx.async_eval(next_y)
if n == 0:
mx.eval(y)
yield y.item(), None
if n > 0 and n % 256 == 0:
mx.clear_cache()
y = next_y
n += 1


def _slice_logprobs_row(
logprobs_tensors: LogprobsTensors | None,
index: int,
Expand Down Expand Up @@ -673,10 +721,25 @@ def _make_mlx_lm_generator(
self,
token_ids: list[int],
sampling_params: SamplingParams,
*,
allow_paged_hybrid: bool = False,
) -> tuple[Any, list[AnyCache]] | None:
"""Create raw MLX generation iterator for the common single-request path."""
if self._prefix_cache is not None or self._paged_attention_backend is not None:
if self._prefix_cache is not None:
if self.metal_config.debug:
logger.info("Metal raw MLX generator skipped: prefix cache enabled")
return None
if self._paged_attention_backend is not None:
if not (
allow_paged_hybrid
and self.is_hybrid
and self.scheduler_config.max_num_seqs == 1
):
if self.metal_config.debug:
logger.info(
"Metal raw MLX generator skipped: paged attention enabled"
)
return None

batch = SamplingBatch(
[sampling_params],
Expand All @@ -690,25 +753,75 @@ def _make_mlx_lm_generator(
batch.can_use_native_greedy_for_batch()
or batch.can_use_native_random_for_batch()
):
if self.metal_config.debug:
logger.info(
"Metal raw MLX generator skipped: unsupported sampling params %s",
sampling_params,
)
return None

sampler = _make_top_k_first_mlx_sampler(
temperature=sampling_params.temperature,
top_p=sampling_params.top_p,
top_k=sampling_params.top_k,
vocab_size=self._vocab_size,
)

cache = contiguous_cache.make_prompt_cache(self._forward_model)
generator = generate_step(
mx.array(token_ids, dtype=mx.int32),
self._forward_model,
max_tokens=-1,
sampler=sampler,
prompt_cache=cache,
)
prompt = mx.array(token_ids, dtype=mx.int32)
if batch.can_use_native_greedy_for_batch():
generator = _generate_greedy_step_no_logprobs(
prompt,
self._forward_model,
prompt_cache=cache,
)
else:
sampler = _make_top_k_first_mlx_sampler(
temperature=sampling_params.temperature,
top_p=sampling_params.top_p,
top_k=sampling_params.top_k,
vocab_size=self._vocab_size,
)
generator = generate_step(
prompt,
self._forward_model,
max_tokens=-1,
sampler=sampler,
prompt_cache=cache,
)
if self.metal_config.debug:
logger.info(
"Metal raw MLX generator enabled: prompt_tokens=%d temperature=%s "
"top_p=%s top_k=%s",
len(token_ids),
sampling_params.temperature,
sampling_params.top_p,
sampling_params.top_k,
)
return generator, cache

def _try_prefill_paged_hybrid_mlx_generator(
self,
req_id: str,
token_ids: list[int],
sampling_params: SamplingParams,
generator: torch.Generator | None,
) -> tuple[int, list[AnyCache], LogprobsTensors | None] | None:
"""Fast-path one hybrid request through mlx_lm while paged KV is active."""
if (
self._paged_attention_backend is None
or not self.is_hybrid
or self.scheduler_config.max_num_seqs != 1
or generator is not None
):
return None

fast_generator = self._make_mlx_lm_generator(
token_ids,
sampling_params,
allow_paged_hybrid=True,
)
if fast_generator is None:
return None

generator_iter, cache = fast_generator
next_token, _logprobs = next(generator_iter)
self._pending_mlx_generators[req_id] = generator_iter
return int(next_token), cache, None

def _prefill_single(
self,
req_id: str,
Expand Down Expand Up @@ -1202,6 +1315,27 @@ def _handle_new_requests(

if self._paged_attention_backend is not None:
sched_block_ids = list(new_req.block_ids[0])
fast_result = self._try_prefill_paged_hybrid_mlx_generator(
req_id,
token_ids,
sampling_params,
generator,
)
if fast_result is not None:
next_token, cache, logprobs_tensors = fast_result
batch.add_output(req_id, [next_token], logprobs_tensors)
self._request_states[req_id] = RequestState(
token_ids=list(token_ids) + [next_token],
prompt_len=len(token_ids),
cache=cache,
sampling_params=sampling_params,
generator=generator,
mlx_generator=self._pending_mlx_generators.pop(req_id, None),
generated_tokens=1,
block_ids=sched_block_ids,
)
continue

scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
computed_tokens = new_req.num_computed_tokens
prompt_len = len(token_ids)
Expand Down Expand Up @@ -1319,6 +1453,11 @@ def _collect_cached_requests(
batch.add_output(req_id, [0])
continue

if state.mlx_generator is not None:
batch.scheduled_cached_req_ids.append(req_id)
batch.valid_decode_reqs.append((req_id, state))
continue

if state.generated_tokens == 0:
computed_tokens = cached_reqs.num_computed_tokens[idx]
scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
Expand Down Expand Up @@ -1581,7 +1720,7 @@ def execute_model(
"invariant violated."
)

if self._paged_attention_backend is None:
if batch.valid_decode_reqs:
self._run_non_paged_decode_batch(batch)
Comment on lines +1723 to 1724
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _run_non_paged_decode_batch call is currently skipped if has_paged_work() is true. While the fast-path for hybrid models is restricted to max_num_seqs=1, this logic could lead to silent failures or missing outputs if multiple requests (one paged, one MLX-native) are ever scheduled together in future iterations. Moving this call before the paged-work check ensures all scheduled requests are processed.

Suggested change
if batch.valid_decode_reqs:
self._run_non_paged_decode_batch(batch)
if batch.valid_decode_reqs:
self._run_non_paged_decode_batch(batch)
if self._paged_attention_backend is not None and batch.has_paged_work():


# Non-paged path: complete synchronously
Expand Down
Loading