Skip to content

Add support of float atomics and generic dtypes to shared memory on Vulkan and Apple Metal.#432

Open
duburcqa wants to merge 15 commits intomainfrom
duburcqa/fix_metal_atomics_shared_mem
Open

Add support of float atomics and generic dtypes to shared memory on Vulkan and Apple Metal.#432
duburcqa wants to merge 15 commits intomainfrom
duburcqa/fix_metal_atomics_shared_mem

Conversation

@duburcqa
Copy link
Copy Markdown
Contributor

@duburcqa duburcqa commented Mar 29, 2026

Brief Summary

  1. Fix float atomics on shared memory for Metal and Vulkan:
    • allocate shared float arrays as uint, bitcast at load/store
    • only retype arrays targeted by atomics (pre-scan via scan_shared_atomic_allocs)
  2. Add f16 shared memory float atomics support:
    • back f16 arrays with u32 (Metal/Vulkan lack 16-bit atomics), with width conversion at load/store/CAS boundaries
  3. Add support of shared memory of arbitrary dtype on Metal and Vulkan:
    • flatten nested tensor types (vec3 etc.), handle OpPtrAccessChain for component access
  4. Add official support of multiple shared arrays on Metal:
    • enable existing test on Metal
  5. Fix shared memory float atomics on GPUs with native float atomic support:
    • handle dest_is_ptr before at_buffer, disable native float atomics for uint-retyped shared arrays

Accompanying PR: Genesis-Embodied-AI/SPIRV-Cross#1

@duburcqa duburcqa force-pushed the duburcqa/fix_metal_atomics_shared_mem branch 4 times, most recently from af08498 to 02fb882 Compare March 30, 2026 07:05
@duburcqa duburcqa changed the title Fix support of float atomics on shared memory for Apple Metal. Fix support of float atomics and generic dtypes on shared memory for Apple Metal. Mar 30, 2026
@duburcqa duburcqa changed the title Fix support of float atomics and generic dtypes on shared memory for Apple Metal. Add support of float atomics and generic dtypes to shared memory for Apple Metal. Mar 30, 2026
@duburcqa duburcqa changed the title Add support of float atomics and generic dtypes to shared memory for Apple Metal. Add support of float atomics and generic dtypes to shared memory on Vulkan and Apple Metal. Mar 30, 2026
@duburcqa
Copy link
Copy Markdown
Contributor Author

I was assisted by Claude Opus to write this PR. I have read every line added in this PR, and reviewed the lines. I take full responsibility for the lines added and removed in this PR. I won't blame any issue on Claude Opus.

@duburcqa duburcqa force-pushed the duburcqa/fix_metal_atomics_shared_mem branch 4 times, most recently from c7cbb8b to 4a42abe Compare March 30, 2026 08:43
auto elem_num = tensor_type->get_num_elements();
spirv::SType elem_type =
ir_->get_primitive_type(tensor_type->get_element_type());
DataType elem_dt = tensor_type->get_element_type();
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

elem_dt and elem_type are very confusing. Could we either give more intuitive names, or at least add a comment on what is the difference between them?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be better now.

// float atomics).
if (alloca->is_shared && is_real(elem_dt)) {
elem_type =
ir_->get_primitive_type(ir_->get_quadrants_uint_type(elem_dt));
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not clear to me from the name what get_quadrants_uint_type does. specifically around nubmer of bits. Could we add a comment to clarify what is happening in this line, specifically around nubmer of bits?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should be better now.

spirv::Value offset_val = ir_->query_value(stmt->offset->raw_name());
auto dt = stmt->element_type().ptr_removed();
// Flatten nested tensor types to scalar (e.g., vec3 to f32)
if (auto nested = dt->cast<TensorType>()) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems very similar to what happens above. Could this be factorized into a helper function?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

ir_->get_primitive_type(dt), origin_val.stype.storage_class);
auto elem_type = ir_->get_primitive_type(dt);
if (shared_float_retyped_.count(stmt->origin)) {
elem_type = ir_->get_primitive_type(ir_->get_quadrants_uint_type(dt));
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto for questoin about helper function.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

std::unordered_map<int, GetRootStmt *>
root_stmts_; // maps root id to get root stmt
std::unordered_map<const Stmt *, BufferInfo> ptr_to_buffers_;
// Shared float arrays retyped to uint (Metal lacks threadgroup float atomics)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we give a bit more detail about what we are storing here, and why.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

spirv::SType ptr_type = ir_->get_pointer_type(
ir_->get_primitive_type(dt), origin_val.stype.storage_class);
auto elem_type = ir_->get_primitive_type(dt);
if (shared_float_retyped_.count(stmt->origin)) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a comment about what this if statement is checking for intuitively

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

spirv::Value offset_bytes = ir_->mul(dt_bytes, offset_val);
ptr_val = ir_->add(origin_val, offset_bytes);
ptr_to_buffers_[stmt] = ptr_to_buffers_[stmt->origin];
} else if (origin_val.stype.flag == TypeKind::kPtr) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a comment about what this new else if block is checking for

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

stmt->op_type == AtomicOpType::add) {
addr_ptr = at_buffer(stmt->dest, dt);
} else {
addr_ptr = dest_is_ptr
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just refactorizing right? seems like a nice refactorization, if I've undrestood correctly.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

}

// Shared float arrays are retyped to uint, so native float atomics
// (which require a float pointer) cannot be used on them.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow. I thought the purpose of changing the backing type to uint was to enable the spirv atomics? Could you give a little more clarification (in the comments) about this point please.

Also, how do we know if we ar edealing with a shared array here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comment on the PR itself to clarify this.


@test_utils.test(arch=[qd.cuda])
@pytest.mark.parametrize("op", ["add", "sub", "min", "max"])
@test_utils.test(arch=[qd.cuda, qd.vulkan, qd.metal, qd.amdgpu])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have a simpler way of doing this like conceptually:

  • not cpu?, or
  • gpu?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

def test_shared_array_float_atomics(op):
N = 256
block_dim = 32
total = block_dim * (block_dim - 1) / 2.0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

total what? total_threads? Why are we dividing by 2.0? Oh, perhaps we are doing some kind of arithmetic progression or similar, and this is the expected_sum of that progression? Could we update the name to make the meaning more intuitive please. By the way, this calculation could be done using ints. Could we make this something that needs actual floats? Like, e.g. multiply each term in the progression by 0.333, which is pretty incompatible with binary representation.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

sharr = qd.simt.block.SharedArray((block_dim,), qd.f32)
sharr[tid] = qd.f32(tid)
qd.simt.block.sync()
atomic_fn(sharr[0], qd.f32(tid))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets multiply qd.f32(tid) by some fractional float, like 0.3333, or maybe something arbitary like 0.1523f

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.



@test_utils.test(arch=[qd.cuda], debug=True)
@test_utils.test(arch=[qd.cuda, qd.vulkan, qd.metal])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this excluding amdgpu?

again, can we use something simpler?

  • 'not cpu', or
  • 'gpu'
    ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

@hughperkins
Copy link
Copy Markdown
Collaborator

Opus review:

Thoughts

What's good:
• The approach of retyping float shared arrays to uint and bitcasting at the boundaries is a well-known pattern for working around missing float atomics. It's clean and
localized.
• The taint-tracking via shared_float_retyped_ is a reasonable way to propagate the retyping through the IR without modifying the IR itself.
• Test coverage looks solid — the new test_shared_array_float_atomics is parametrized over all four atomic ops.

Things I'd flag:

  1. The shared_float_retyped_ set is a std::unordered_set<const Stmt*> tracking raw pointers. This is fine as long as stmts outlive the codegen pass, but it's worth
    confirming there's no stmt lifetime issue. Given this is a visitor pass over existing IR, it should be safe.
  2. The retyping applies to all real types on shared memory, not just f32. If someone uses f16 or f64 shared arrays, the bitcast to the corresponding uint type should work,
    but there's no explicit test for those. The get_quadrants_uint_type utility presumably handles the width mapping, but worth verifying for f16 especially.
  3. The early dest_is_ptr branch in AtomicOpStmt unconditionally uses dest_val for shared memory, which means it never goes through at_buffer. This is correct (shared
    memory isn't in a buffer), but the original code had some interleaving of the dest_is_ptr check inside each float type branch. The refactor is cleaner, but it's a behavior
    change — previously a shared f32 array on a GPU that does report spirv_has_atomic_float_add would have used dest_val directly (correct), and now it still does (correct),
    but the native-atomics flag is also explicitly suppressed. The comment explains why (shared arrays are uint-retyped so native float atomics can't work), which is good.
  4. The test_shared_array_float_atomics test initializes sharr[tid] = qd.f32(tid) then atomically accumulates into sharr[0]. The initial value of sharr[0] is 0.0 (from
    tid=0), then all threads do atomic_op(sharr[0], tid). For sub, the expected result is -total which assumes the initial value 0.0 stays as-is and only the atomic subtracts
    are applied. That's correct since atomic_sub(sharr[0], 0) is a no-op for tid=0.
  5. Minor: test_shared_array_matrix dropped debug=True — intentional? Could be that debug mode was causing issues on Vulkan/Metal or just wasn't needed.

Overall, this is a well-structured branch. The Metal/Vulkan shared memory work is the kind of backend plumbing that's easy to get wrong, but the approach here is principled and
the test coverage backs it up.

@hughperkins
Copy link
Copy Markdown
Collaborator

from the AI review, pelase could we address at least:
2. The retyping applies to all real types on shared memory, not just f32. If someone uses f16 or f64 shared arrays, the bitcast to the corresponding uint type should work,
but there's no explicit test for those. The get_quadrants_uint_type utility presumably handles the width mapping, but worth verifying for f16 especially.
5. Minor: test_shared_array_matrix dropped debug=True — intentional? Could be that debug mode was causing issues on Vulkan/Metal or just wasn't needed.

@hughperkins
Copy link
Copy Markdown
Collaborator

(so AI and myself concur about the ambiguity over what get_quadrants_uint_type does)

@duburcqa duburcqa force-pushed the duburcqa/fix_metal_atomics_shared_mem branch from 4a42abe to 4b0ea62 Compare March 30, 2026 15:42
// type
DataType get_quadrants_uint_type(const DataType &dt) const;
// Return the SPIR-V uint type with the same bit-width as dt (e.g. f32->u32).
SType get_bitcast_uint_stype(const DataType &dt) const;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stype does at least give some clue, but I would prefer the more explciit _spirv_dtype I feel. (or _spirv_dt is ok for me too, or potentially _spv_dt, if you want it really short.)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stype is already at many places before this PR. I think it is better to be consistent with the existing naming conventions.

// Convert a value from float dt to shared-memory uint backing.
Value float_to_shared_uint(Value val, const DataType &dt);
// Get the pointer type that points to value_type
SType get_storage_pointer_type(const SType &value_type);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does the SType here relaet to the stype name above?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SType is the SPIR-V type struct (defined in spirv_ir_builder.h:48). The _stype suffix in get_bitcast_uint_stype indicates it returns an SType, as opposed to _dtype which returns a Quadrants DataType. get_storage_pointer_type is pre-existing code, not part of this PR.



@test_utils.test(arch=[qd.cuda, qd.vulkan, qd.amdgpu])
@test_utils.test(arch=[qd.cuda, qd.vulkan, qd.metal, qd.amdgpu])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my question about can we simply exclude cpu, or say to run only on gpu seems not to have been addressed? (or I didnt see the response perhaps)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can! Done.


@test_utils.test(arch=[qd.cuda])
@pytest.mark.parametrize("op", ["add", "sub", "min", "max"])
@pytest.mark.parametrize("dtype", [qd.f16, qd.f32])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f64? (seems more common than f16 tbh)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

( i mean, we can test both)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i guess f64 doesnt work on metal right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not, but it is easy to skip if unsupported in a clean way.

rtol = 1e-3 if dtype == qd.f16 else 1e-6
arr = qd.ndarray(qd.f32, (N))
make_kernel(atomic_op)(arr)
qd.sync()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need a sync? Is this the existing metal bug you've mentioned recenlty?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not actually. I was being extra cautious for no reason.

make_kernel(atomic_op)(arr)
qd.sync()
assert arr[0] == test_utils.approx(expected[op], rel=rtol)
assert arr[32] == test_utils.approx(expected[op], rel=rtol)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we check also 31 and 255?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

@duburcqa
Copy link
Copy Markdown
Contributor Author

For the record, the original SPIRV-Cross fix makes SPIRV-Cross emit atomic_uint instead of atomic_float for threadgroup pointers when the CAS result type is integer but the pointee is float. This would let us keep f32 shared arrays as float-typed and avoid the uint retyping + bitcast overhead on every load/store. However, it only partially helps:

  1. f32 shared atomics - yes, the SPIRV-Cross fix would be sufficient. No need to retype the array; the CAS loop would use uint result type on a float pointer, and SPIRV-Cross would emit the correct Metal code.
  2. f16 shared atomics - no. The SPIRV-Cross fix doesn't help here because the problem is deeper: OpAtomicCompareExchange needs at least 32-bit operands (Metal/Vulkan lack 16-bit atomics). We'd still need u32 backing for f16, with width
    conversion at load/store/CAS boundaries.
  3. Spec compliance - the uint retyping approach produces valid SPIR-V (uint pointer + uint atomic = matching types). The SPIRV-Cross approach relies on mismatched types (uint atomic on float pointer), which may not be spec-compliant and could
    break with other SPIR-V consumers.
  4. Vulkan - the SPIRV-Cross fix only patches spirv_msl.cpp (Metal backend). MoltenVK would benefit (it uses SPIRV-Cross), but native Vulkan drivers wouldn't.

So the SPIRV-Cross fix could simplify the f32 case, but the uint retyping approach is still needed for f16 and is more robust overall. Given that the pre-scan now limits the retyping to only arrays with atomics, the overhead is minimal.

// Propagated from shared_atomic_allocs_ to derived MatrixPtrStmt nodes
// during codegen, so that load/store/atomic visitors know to bitcast.
// Example: if `sharr` (AllocaStmt) is in shared_atomic_allocs_, then
// `sharr[0]` (MatrixPtrStmt) is added here during visit(MatrixPtrStmt).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice explanation thanks. It would reduce my cognitive load if you could give me a case study showing some code (perhaps LLVM IR, or python, no strong preference), and how that maps to these sets, and to the resulting spir-v instructions. Doesn't have to be as comments in the code; could be in the PR description, or in some slides or similar.

Copy link
Copy Markdown
Contributor Author

@duburcqa duburcqa Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python kernel:

  sharr = qd.simt.block.SharedArray((32,), qd.f32)
  sharr[tid] = qd.f32(tid)
  qd.simt.block.sync()
  qd.atomic_add(sharr[0], qd.f32(tid))
  qd.simt.block.sync()
  out[i] = sharr[0]

IR statements and set membership:

  %sharr = AllocaStmt(shared, array<f32, 32>)     # in shared_atomic_allocs_ (pre-scan found atomic_add targets it)
                                                    # in shared_float_retyped_ (visit(AllocaStmt) retypes to u32)
  %ptr0  = MatrixPtrStmt(%sharr, tid)              # in shared_float_retyped_ (propagated from %sharr)
  %ptr1  = MatrixPtrStmt(%sharr, 0)                # in shared_float_retyped_ (propagated from %sharr)
           LocalStoreStmt(%ptr0, f32(tid))          # sees %ptr0 in shared_float_retyped_ -> float_to_shared_uint
           AtomicOpStmt(add, %ptr1, f32(tid))       # dest_is_ptr=true -> CAS with u32 atomics
  %val   = LocalLoadStmt(%ptr1)                    # sees %ptr1 in shared_float_retyped_ -> shared_uint_to_float

  Generated SPIR-V (simplified):
  ; Allocation: u32 array instead of f32 (retyped)
  %sharr = OpVariable Workgroup array<u32, 32>

  ; Store: f32 -> bitcast to u32 -> store
  %u_tid = OpBitcast u32 %f_tid
           OpStore %sharr[tid] %u_tid

  ; Atomic add (CAS loop):
  %old   = OpAtomicLoad u32 %sharr[0]
  %old_f = OpBitcast f32 %old              ; u32 -> f32
  %new_f = OpFAdd f32 %old_f %f_tid        ; float add
  %new   = OpBitcast u32 %new_f            ; f32 -> u32
  %loaded = OpAtomicCompareExchange u32 %sharr[0] %new %old
           ; (loop until %loaded == %old)

  ; Load: load u32 -> bitcast to f32
  %raw   = OpLoad u32 %sharr[0]
  %val   = OpBitcast f32 %raw

  For f16, the only difference is the array is still u32-backed (not u16, since Metal/Vulkan lack 16-bit atomics), with OpUConvert inserted between the bitcast and the atomic:
  ; Store f16: bitcast f16->u16, widen u16->u32, store
  ; Load u32: narrow u32->u16, bitcast u16->f16
  ; CAS: OpAtomicLoad u32, narrow->bitcast->FAdd->bitcast->widen, OpAtomicCompareExchange u32

ir_->register_value(const_stmt->raw_name(), val);
}

const AllocaStmt *TaskCodegen::trace_to_alloca(const Stmt *s) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we add a comment to this function saying what it does, and providing an example.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

return nullptr;
}

void TaskCodegen::scan_shared_atomic_allocs(Block *block) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this scans a single block, but writes the results to a class level collection? I wonder if it would be more intuitive/re-usable/testable if we pased that collection in as a function parameter somehow?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there tests for this function? Could we add some?

How do we know that this function is complete? It seems fairly complex.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this scans a single block, but writes the results to a class level collection? I wonder if it would be more intuitive/re-usable/testable if we pased that collection in as a function parameter somehow?

Done.

Are there tests for this function? Could we add some?

A hole would cause a shader compilation error, not a silent bug. If the pre-scan misses an atomic target, the array stays float-typed, but the CAS emulation expects a uint pointer. The resulting type mismatch (OpAtomicLoad(u32, ptr_to_f32)) is invalid SPIR-V and gets rejected at compile time. So I don't think more testing is necessary at this point.

@duburcqa duburcqa force-pushed the duburcqa/fix_metal_atomics_shared_mem branch from 2890be3 to a806667 Compare March 31, 2026 06:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants