Skip to content

Codegen: Optional Expression Rewriter#385

Open
eugenebokhan wants to merge 9 commits intomainfrom
type-aware-optional-expression-rewriter
Open

Codegen: Optional Expression Rewriter#385
eugenebokhan wants to merge 9 commits intomainfrom
type-aware-optional-expression-rewriter

Conversation

@eugenebokhan
Copy link
Copy Markdown
Contributor

@eugenebokhan eugenebokhan commented May 4, 2026

What changed

  • OPTIONAL(...) conditions like OPTIONAL(quant_method == QuantizationMethod::AWQ) now are available to use.
  • Replaced use_zero_points / use_mlx_quant bool pair in Qmv, QmvFast, QmmTransposed kernels with single quant_method: QuantizationMethod enum.
  • Split bindgen unreadable spaghetti into a typed sum-of-products module.

Bindgen decoupling

build/metal/bindgen/
├── mod.rs                       # orchestrator
├── arguments.rs                 # ArgumentEmission { Buffer | Constant | IndirectDispatch }
├── specialize.rs                # SpecializeEmission
├── variants.rs                  # VariantBind
├── variant_path_rewriter.rs     # variant ident → self.field rewriter
├── dispatch.rs                  # DispatchEmission (axis / direct / indirect)
└── trait_wiring.rs              # TraitWiring

Each kernel argument is parsed once into a typed variant; each emission stream is one iter().filter_map() pass:

pub enum ArgumentEmission {
    Buffer(BufferArgument), 
    Constant(ConstantArgument), 
    IndirectDispatch(IndirectDispatchArgument),
}

impl ArgumentEmission {
    pub fn struct_field(&self)               -> Option<TokenStream>;
    pub fn struct_initializer(&self)         -> Option<TokenStream>;
    pub fn encode_argument_definition(&self) -> TokenStream;
    pub fn encode_lifetime(&self)            -> Option<TokenStream>;
    pub fn encode_deconstruct(&self)         -> Option<TokenStream>;
    pub fn encode_access(&self)              -> Option<TokenStream>;
    pub fn encode_set(&self)                 -> TokenStream;
}

Replaced a .multiunzip() over a 7-tuple of Vec<Option<TokenStream>> and three mut accumulators (arg_count, indirect_flag, referenced_variants) with typed fields and a VariantPathRewriter consumed via .finish().

mod.rs orchestrator — parse → emit → assemble:

let variant_binds         = variants::parse(kernel)?;
let argument_emissions    = arguments::parse(kernel, enum_path_rewriter)?;
let specialize_emission   = specialize::parse(kernel, base_index, kernel_name)?;
let trait_wiring          = trait_wiring::build(kernel, &trait_name, &struct_name);
let mut variant_rewriter  = VariantPathRewriter::new(&variant_binds, kernel_name);
let dispatch_emission     = dispatch::parse(kernel, &mut variant_rewriter)?;
let referenced            = variant_rewriter.finish();

// One pass per output stream
let conditional_buffer_fields: Vec<_> =
    argument_emissions.iter().filter_map(|a| a.struct_field()).collect();
// …etc

quote! { /* assemble struct + impl from the named pieces */ }

@eugenebokhan eugenebokhan self-assigned this May 4, 2026
@eugenebokhan eugenebokhan changed the title Codegen: Type-aware Optional Expression Rewriter Codegen: Optional Expression Rewriter May 5, 2026
@eugenebokhan eugenebokhan marked this pull request as ready for review May 5, 2026 05:16
@eugenebokhan eugenebokhan removed request for CC-Yeh and LuckyIYI May 5, 2026 05:16
Comment on lines +34 to +38
Err(_) => return format!("static_cast<bool>({condition})"),
};
match emit_metal_expr(&parsed, self) {
Ok(body) => format!("static_cast<bool>({body})"),
Err(_) => format!("static_cast<bool>({condition})"),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we let it fail earlier instead of getting cryptic metal compile error?


namespace uzu::quantization_method {
enum class QuantizationMethod : uint32_t {
MLX = 0,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about ZeroPoint and Bias?

.enumerate()
.map(|(i, a)| {
let c_type = a.c_type.trim_start_matches("const ");
let c_type = specialize_constant_type(&a.c_type);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why we use #define/#undef here instead of having rewrite_for_metal substitute bare specialize names with their prefixed counterparts, wouldn't that keep all name resolution in one place?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants