-
-
Notifications
You must be signed in to change notification settings - Fork 197
perf: optimize GDN performance on Metal #1670
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -18,7 +18,7 @@ | |||||||||||||
| import mlx.core as mx | ||||||||||||||
| import torch | ||||||||||||||
| from mlx_lm import stream_generate | ||||||||||||||
| from mlx_lm.generate import generate_step | ||||||||||||||
| from mlx_lm.generate import generate_step, generation_stream | ||||||||||||||
| from aphrodite.config import AphroditeConfig | ||||||||||||||
| from aphrodite.logger import init_logger | ||||||||||||||
| from aphrodite.sampling_params import SamplingParams | ||||||||||||||
|
|
@@ -251,6 +251,54 @@ def sampler(logprobs: mx.array) -> mx.array: | |||||||||||||
| return sampler | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| def _generate_greedy_step_no_logprobs( | ||||||||||||||
| prompt: mx.array, | ||||||||||||||
| model: Any, | ||||||||||||||
| *, | ||||||||||||||
| prompt_cache: list[Any], | ||||||||||||||
| prefill_step_size: int = 2048, | ||||||||||||||
| ): | ||||||||||||||
| """MLX greedy generation without per-token logprob materialization.""" | ||||||||||||||
| if len(prompt) == 0: | ||||||||||||||
| raise ValueError("prompt must not be empty") | ||||||||||||||
|
|
||||||||||||||
| def _model_call(input_tokens: mx.array) -> mx.array: | ||||||||||||||
| output = model(input_tokens[None], cache=prompt_cache) | ||||||||||||||
| return output.logits if hasattr(output, "logits") else output | ||||||||||||||
|
|
||||||||||||||
| def _step(input_tokens: mx.array) -> mx.array: | ||||||||||||||
| with mx.stream(generation_stream): | ||||||||||||||
| logits = _model_call(input_tokens)[:, -1, :] | ||||||||||||||
| return mx.argmax(logits, axis=-1) | ||||||||||||||
|
|
||||||||||||||
| with mx.stream(generation_stream): | ||||||||||||||
| total_prompt_tokens = len(prompt) | ||||||||||||||
| prompt_processed_tokens = 0 | ||||||||||||||
| while total_prompt_tokens - prompt_processed_tokens > 1: | ||||||||||||||
| remaining = (total_prompt_tokens - prompt_processed_tokens) - 1 | ||||||||||||||
| n_to_process = min(prefill_step_size, remaining) | ||||||||||||||
| _model_call(prompt[:n_to_process]) | ||||||||||||||
| mx.eval([c.state for c in prompt_cache]) | ||||||||||||||
| prompt_processed_tokens += n_to_process | ||||||||||||||
| prompt = prompt[n_to_process:] | ||||||||||||||
| mx.clear_cache() | ||||||||||||||
|
|
||||||||||||||
| y = _step(prompt) | ||||||||||||||
|
|
||||||||||||||
| mx.async_eval(y) | ||||||||||||||
| n = 0 | ||||||||||||||
| while True: | ||||||||||||||
| next_y = _step(y) | ||||||||||||||
| mx.async_eval(next_y) | ||||||||||||||
| if n == 0: | ||||||||||||||
| mx.eval(y) | ||||||||||||||
| yield y.item(), None | ||||||||||||||
| if n > 0 and n % 256 == 0: | ||||||||||||||
| mx.clear_cache() | ||||||||||||||
| y = next_y | ||||||||||||||
| n += 1 | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| def _slice_logprobs_row( | ||||||||||||||
| logprobs_tensors: LogprobsTensors | None, | ||||||||||||||
| index: int, | ||||||||||||||
|
|
@@ -673,10 +721,25 @@ def _make_mlx_lm_generator( | |||||||||||||
| self, | ||||||||||||||
| token_ids: list[int], | ||||||||||||||
| sampling_params: SamplingParams, | ||||||||||||||
| *, | ||||||||||||||
| allow_paged_hybrid: bool = False, | ||||||||||||||
| ) -> tuple[Any, list[AnyCache]] | None: | ||||||||||||||
| """Create raw MLX generation iterator for the common single-request path.""" | ||||||||||||||
| if self._prefix_cache is not None or self._paged_attention_backend is not None: | ||||||||||||||
| if self._prefix_cache is not None: | ||||||||||||||
| if self.metal_config.debug: | ||||||||||||||
| logger.info("Metal raw MLX generator skipped: prefix cache enabled") | ||||||||||||||
| return None | ||||||||||||||
| if self._paged_attention_backend is not None: | ||||||||||||||
| if not ( | ||||||||||||||
| allow_paged_hybrid | ||||||||||||||
| and self.is_hybrid | ||||||||||||||
| and self.scheduler_config.max_num_seqs == 1 | ||||||||||||||
| ): | ||||||||||||||
| if self.metal_config.debug: | ||||||||||||||
| logger.info( | ||||||||||||||
| "Metal raw MLX generator skipped: paged attention enabled" | ||||||||||||||
| ) | ||||||||||||||
| return None | ||||||||||||||
|
|
||||||||||||||
| batch = SamplingBatch( | ||||||||||||||
| [sampling_params], | ||||||||||||||
|
|
@@ -690,25 +753,75 @@ def _make_mlx_lm_generator( | |||||||||||||
| batch.can_use_native_greedy_for_batch() | ||||||||||||||
| or batch.can_use_native_random_for_batch() | ||||||||||||||
| ): | ||||||||||||||
| if self.metal_config.debug: | ||||||||||||||
| logger.info( | ||||||||||||||
| "Metal raw MLX generator skipped: unsupported sampling params %s", | ||||||||||||||
| sampling_params, | ||||||||||||||
| ) | ||||||||||||||
| return None | ||||||||||||||
|
|
||||||||||||||
| sampler = _make_top_k_first_mlx_sampler( | ||||||||||||||
| temperature=sampling_params.temperature, | ||||||||||||||
| top_p=sampling_params.top_p, | ||||||||||||||
| top_k=sampling_params.top_k, | ||||||||||||||
| vocab_size=self._vocab_size, | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| cache = contiguous_cache.make_prompt_cache(self._forward_model) | ||||||||||||||
| generator = generate_step( | ||||||||||||||
| mx.array(token_ids, dtype=mx.int32), | ||||||||||||||
| self._forward_model, | ||||||||||||||
| max_tokens=-1, | ||||||||||||||
| sampler=sampler, | ||||||||||||||
| prompt_cache=cache, | ||||||||||||||
| ) | ||||||||||||||
| prompt = mx.array(token_ids, dtype=mx.int32) | ||||||||||||||
| if batch.can_use_native_greedy_for_batch(): | ||||||||||||||
| generator = _generate_greedy_step_no_logprobs( | ||||||||||||||
| prompt, | ||||||||||||||
| self._forward_model, | ||||||||||||||
| prompt_cache=cache, | ||||||||||||||
| ) | ||||||||||||||
| else: | ||||||||||||||
| sampler = _make_top_k_first_mlx_sampler( | ||||||||||||||
| temperature=sampling_params.temperature, | ||||||||||||||
| top_p=sampling_params.top_p, | ||||||||||||||
| top_k=sampling_params.top_k, | ||||||||||||||
| vocab_size=self._vocab_size, | ||||||||||||||
| ) | ||||||||||||||
| generator = generate_step( | ||||||||||||||
| prompt, | ||||||||||||||
| self._forward_model, | ||||||||||||||
| max_tokens=-1, | ||||||||||||||
| sampler=sampler, | ||||||||||||||
| prompt_cache=cache, | ||||||||||||||
| ) | ||||||||||||||
| if self.metal_config.debug: | ||||||||||||||
| logger.info( | ||||||||||||||
| "Metal raw MLX generator enabled: prompt_tokens=%d temperature=%s " | ||||||||||||||
| "top_p=%s top_k=%s", | ||||||||||||||
| len(token_ids), | ||||||||||||||
| sampling_params.temperature, | ||||||||||||||
| sampling_params.top_p, | ||||||||||||||
| sampling_params.top_k, | ||||||||||||||
| ) | ||||||||||||||
| return generator, cache | ||||||||||||||
|
|
||||||||||||||
| def _try_prefill_paged_hybrid_mlx_generator( | ||||||||||||||
| self, | ||||||||||||||
| req_id: str, | ||||||||||||||
| token_ids: list[int], | ||||||||||||||
| sampling_params: SamplingParams, | ||||||||||||||
| generator: torch.Generator | None, | ||||||||||||||
| ) -> tuple[int, list[AnyCache], LogprobsTensors | None] | None: | ||||||||||||||
| """Fast-path one hybrid request through mlx_lm while paged KV is active.""" | ||||||||||||||
| if ( | ||||||||||||||
| self._paged_attention_backend is None | ||||||||||||||
| or not self.is_hybrid | ||||||||||||||
| or self.scheduler_config.max_num_seqs != 1 | ||||||||||||||
| or generator is not None | ||||||||||||||
| ): | ||||||||||||||
| return None | ||||||||||||||
|
|
||||||||||||||
| fast_generator = self._make_mlx_lm_generator( | ||||||||||||||
| token_ids, | ||||||||||||||
| sampling_params, | ||||||||||||||
| allow_paged_hybrid=True, | ||||||||||||||
| ) | ||||||||||||||
| if fast_generator is None: | ||||||||||||||
| return None | ||||||||||||||
|
|
||||||||||||||
| generator_iter, cache = fast_generator | ||||||||||||||
| next_token, _logprobs = next(generator_iter) | ||||||||||||||
| self._pending_mlx_generators[req_id] = generator_iter | ||||||||||||||
| return int(next_token), cache, None | ||||||||||||||
|
|
||||||||||||||
| def _prefill_single( | ||||||||||||||
| self, | ||||||||||||||
| req_id: str, | ||||||||||||||
|
|
@@ -1202,6 +1315,27 @@ def _handle_new_requests( | |||||||||||||
|
|
||||||||||||||
| if self._paged_attention_backend is not None: | ||||||||||||||
| sched_block_ids = list(new_req.block_ids[0]) | ||||||||||||||
| fast_result = self._try_prefill_paged_hybrid_mlx_generator( | ||||||||||||||
| req_id, | ||||||||||||||
| token_ids, | ||||||||||||||
| sampling_params, | ||||||||||||||
| generator, | ||||||||||||||
| ) | ||||||||||||||
| if fast_result is not None: | ||||||||||||||
| next_token, cache, logprobs_tensors = fast_result | ||||||||||||||
| batch.add_output(req_id, [next_token], logprobs_tensors) | ||||||||||||||
| self._request_states[req_id] = RequestState( | ||||||||||||||
| token_ids=list(token_ids) + [next_token], | ||||||||||||||
| prompt_len=len(token_ids), | ||||||||||||||
| cache=cache, | ||||||||||||||
| sampling_params=sampling_params, | ||||||||||||||
| generator=generator, | ||||||||||||||
| mlx_generator=self._pending_mlx_generators.pop(req_id, None), | ||||||||||||||
| generated_tokens=1, | ||||||||||||||
| block_ids=sched_block_ids, | ||||||||||||||
| ) | ||||||||||||||
| continue | ||||||||||||||
|
|
||||||||||||||
| scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] | ||||||||||||||
| computed_tokens = new_req.num_computed_tokens | ||||||||||||||
| prompt_len = len(token_ids) | ||||||||||||||
|
|
@@ -1319,6 +1453,11 @@ def _collect_cached_requests( | |||||||||||||
| batch.add_output(req_id, [0]) | ||||||||||||||
| continue | ||||||||||||||
|
|
||||||||||||||
| if state.mlx_generator is not None: | ||||||||||||||
| batch.scheduled_cached_req_ids.append(req_id) | ||||||||||||||
| batch.valid_decode_reqs.append((req_id, state)) | ||||||||||||||
| continue | ||||||||||||||
|
|
||||||||||||||
| if state.generated_tokens == 0: | ||||||||||||||
| computed_tokens = cached_reqs.num_computed_tokens[idx] | ||||||||||||||
| scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] | ||||||||||||||
|
|
@@ -1581,7 +1720,7 @@ def execute_model( | |||||||||||||
| "invariant violated." | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| if self._paged_attention_backend is None: | ||||||||||||||
| if batch.valid_decode_reqs: | ||||||||||||||
| self._run_non_paged_decode_batch(batch) | ||||||||||||||
|
Comment on lines
+1723
to
1724
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||
|
|
||||||||||||||
| # Non-paged path: complete synchronously | ||||||||||||||
|
|
||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
mx.evalcall here might fail ifprompt_cachecontainsArraysCacheentries withNonestates, which can happen in hybrid models before full initialization. It is safer to filter outNonevalues before evaluation to prevent runtime errors.