-
Notifications
You must be signed in to change notification settings - Fork 75
feat: moe kernel tuning #482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
78c6657 to
5764274
Compare
|
This PR has been inactive for 10 days and is now marked as stale. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment @cursor review or bugbot run to trigger another review on this PR
src/pruna/engine/model_checks.py
Outdated
| bool | ||
| True if the model is a MoE LM, False otherwise. | ||
| """ | ||
| return hasattr(model, "num_experts") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function checks wrong attribute for MoE model detection
High Severity
The is_moe_lm function checks hasattr(model, "num_experts") but the code in _apply uses model.config.num_experts. For transformer models, num_experts is typically on the config object, not the model itself. This mismatch causes is_moe_lm to incorrectly return False for valid MoE models, leading to wrong branch being taken at line 195 where it chooses between num_experts_per_tok and moe_topk[0], likely causing an AttributeError.
Additional Locations (1)
| pruna_logger.error( | ||
| "MoE kernel tuner artifacts not found in SmashConfig. " | ||
| "Ensure the tuner ran successfully before saving/loading." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing return value when payload is absent
High Severity
The load_moe_kernel_tuner function logs an error when payload is falsy but doesn't return anything, causing the function to implicitly return None. Since this function is called by the model loading flow (which expects a model object), returning None will cause downstream failures when the caller tries to use the returned model.
| intermediate_size = ( | ||
| model_config.moe_intermediate_size | ||
| if model_config.moe_intermediate_size is not None | ||
| else model_config.intermediate_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Attribute access may fail for missing moe_intermediate_size
Medium Severity
The code directly accesses model_config.moe_intermediate_size before checking if it's None. If the attribute doesn't exist on the config (as may be the case for some MoE models like Mixtral which use intermediate_size directly), this will raise an AttributeError before the None check is evaluated. Using getattr(model_config, "moe_intermediate_size", None) would prevent this issue.
| batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] | ||
|
|
||
| # use ray to parallelize the tuning | ||
| ray.init() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ray cluster never shut down after tuning completes
Medium Severity
The ray.init() call starts a Ray cluster for parallel tuning but there is no corresponding ray.shutdown() after ray.get(outputs) completes. This leaves Ray processes and resources allocated after the tuning finishes, which can cause resource leaks, interfere with subsequent Ray usage, or cause issues if the function is called multiple times.
|
|
||
| now = datetime.now() | ||
| pruna_logger.info(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") | ||
| assert best_config is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assertion fails if all kernel configs exceed resources
Medium Severity
The tune remote function asserts best_config is not None at the end, but if every configuration in search_space throws OutOfResources and is caught by the try/except block, best_config remains None. This causes an AssertionError to propagate through ray.get(), crashing the tuning process without a meaningful error message. This could occur on GPUs with limited resources.
| "device", | ||
| "device_map", | ||
| "cache_dir", | ||
| "artifacts", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Artifacts not reset when configuration is flushed
Low Severity
The new artifacts attribute is added to ADDITIONAL_ARGS and initialized in __init__, but the flush_configuration() method doesn't reset it. This is inconsistent with how save_fns, load_fns, and reapply_after_load are cleared during flush. After calling flush_configuration(), stale artifacts from a previous smash operation would persist, potentially causing unexpected behavior if the config is reused.
Description
This PR is inspired from vLLM benchmarks (the benchmark_config fn is copied from here) and enable one to tune the MoE (triton) kernel used in vllm.
This new algorithm
MoeKernelTunerdoes not modify the model. It generates a tuned configuration that is saved in:kernelslib will make use of the optimized config);The core modifications are in:
rayfor parallelization; (iv) the best configurations are saved in hf, vllm, and in the smash_config);Related Issue
Fixes #(issue number)
Type of Change
How Has This Been Tested?
Checklist
Additional Notes
Notebook for testing with vllm is available here. On H100 for qwen3Coder, latency goes from 6.43ms (before tuning) to 5.83 ms (after tuning) while using
vllm.