Skip to content

Conversation

@llcnt
Copy link
Collaborator

@llcnt llcnt commented Dec 23, 2025

Description

This PR is inspired from vLLM benchmarks (the benchmark_config fn is copied from here) and enable one to tune the MoE (triton) kernel used in vllm.
This new algorithm MoeKernelTuner does not modify the model. It generates a tuned configuration that is saved in:

  • the vllm configs folder (so that using the model on the same gpu afterward makes vllm use this optimized config);
  • the RedhatAI kernel folder in the hf hub (so that using the moe kernels from the kernels lib will make use of the optimized config);
  • the smash_config (to be saved and later re-used without waiting for tuning).

The core modifications are in:

  • the new moe_kernel_tuner.py file ((i) it does not modify the model, so it is compatible with every other algorithm before/after; (ii) the user can select dtypes but also size of parameters gridsearch; (iii) the kernel is tuned for batch sizes(ie the input dimension M) from 1 to 8192 using ray for parallelization; (iv) the best configurations are saved in hf, vllm, and in the smash_config);
  • the smash_config.py file (adding new artifacts for saving any additional dict into the smashconfig);
  • the load.py file (for re-saving the tuned config inside vllm/hf cache when loading a smashed model).

Related Issue

Fixes #(issue number)

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

How Has This Been Tested?

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Additional Notes

Notebook for testing with vllm is available here. On H100 for qwen3Coder, latency goes from 6.43ms (before tuning) to 5.83 ms (after tuning) while using vllm.

@llcnt llcnt force-pushed the feat/moe_kernel_tuning branch from 78c6657 to 5764274 Compare December 23, 2025 17:02
@github-actions
Copy link

github-actions bot commented Jan 6, 2026

This PR has been inactive for 10 days and is now marked as stale.

@github-actions github-actions bot added the stale label Jan 6, 2026
@llcnt llcnt removed the stale label Jan 7, 2026
@llcnt llcnt marked this pull request as ready for review January 7, 2026 14:21
Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment @cursor review or bugbot run to trigger another review on this PR

bool
True if the model is a MoE LM, False otherwise.
"""
return hasattr(model, "num_experts")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function checks wrong attribute for MoE model detection

High Severity

The is_moe_lm function checks hasattr(model, "num_experts") but the code in _apply uses model.config.num_experts. For transformer models, num_experts is typically on the config object, not the model itself. This mismatch causes is_moe_lm to incorrectly return False for valid MoE models, leading to wrong branch being taken at line 195 where it chooses between num_experts_per_tok and moe_topk[0], likely causing an AttributeError.

Additional Locations (1)

Fix in Cursor Fix in Web

pruna_logger.error(
"MoE kernel tuner artifacts not found in SmashConfig. "
"Ensure the tuner ran successfully before saving/loading."
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing return value when payload is absent

High Severity

The load_moe_kernel_tuner function logs an error when payload is falsy but doesn't return anything, causing the function to implicitly return None. Since this function is called by the model loading flow (which expects a model object), returning None will cause downstream failures when the caller tries to use the returned model.

Fix in Cursor Fix in Web

intermediate_size = (
model_config.moe_intermediate_size
if model_config.moe_intermediate_size is not None
else model_config.intermediate_size
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attribute access may fail for missing moe_intermediate_size

Medium Severity

The code directly accesses model_config.moe_intermediate_size before checking if it's None. If the attribute doesn't exist on the config (as may be the case for some MoE models like Mixtral which use intermediate_size directly), this will raise an AttributeError before the None check is evaluated. Using getattr(model_config, "moe_intermediate_size", None) would prevent this issue.

Fix in Cursor Fix in Web

batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]

# use ray to parallelize the tuning
ray.init()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ray cluster never shut down after tuning completes

Medium Severity

The ray.init() call starts a Ray cluster for parallel tuning but there is no corresponding ray.shutdown() after ray.get(outputs) completes. This leaves Ray processes and resources allocated after the tuning finishes, which can cause resource leaks, interfere with subsequent Ray usage, or cause issues if the function is called multiple times.

Fix in Cursor Fix in Web


now = datetime.now()
pruna_logger.info(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
assert best_config is not None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assertion fails if all kernel configs exceed resources

Medium Severity

The tune remote function asserts best_config is not None at the end, but if every configuration in search_space throws OutOfResources and is caught by the try/except block, best_config remains None. This causes an AssertionError to propagate through ray.get(), crashing the tuning process without a meaningful error message. This could occur on GPUs with limited resources.

Fix in Cursor Fix in Web

"device",
"device_map",
"cache_dir",
"artifacts",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Artifacts not reset when configuration is flushed

Low Severity

The new artifacts attribute is added to ADDITIONAL_ARGS and initialized in __init__, but the flush_configuration() method doesn't reset it. This is inconsistent with how save_fns, load_fns, and reapply_after_load are cleared during flush. After calling flush_configuration(), stale artifacts from a previous smash operation would persist, potentially causing unexpected behavior if the config is reused.

Additional Locations (1)

Fix in Cursor Fix in Web

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants