Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions src/Backends/DotCompute.Backends.CPU/CpuMemoryManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1236,13 +1236,20 @@ public ValueTask CopyToHostAsync<TDestination>(Memory<TDestination> destination,
public ValueTask DisposeAsync() => _view.DisposeAsync();

// Helper methods
private CpuMemoryBuffer GetParentBuffer()
{
// Access the parent buffer from the view
// This uses reflection to access the private field, which is not ideal
// but necessary for the current implementation
var field = typeof(CpuMemoryBufferView).GetField("_parent",
System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
return (CpuMemoryBuffer)(field?.GetValue(_view) ?? throw new InvalidOperationException("Could not access parent buffer"));
}
private CpuMemoryBuffer GetParentBuffer() => CpuMemoryBufferViewAccessor.GetParent(_view);
}

/// <summary>
/// AOT-friendly accessor for the private <c>_parent</c> field on
/// <see cref="CpuMemoryBufferView"/>. Replaces runtime <c>FieldInfo.GetValue</c>
/// reflection with a source-gen-resolved extern shim that the JIT inlines into a
/// direct field load — same speed as a normal field access, no trim/AOT warnings.
/// </summary>
internal static class CpuMemoryBufferViewAccessor
{
[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "_parent")]
private static extern ref CpuMemoryBuffer ParentRef(CpuMemoryBufferView view);

public static CpuMemoryBuffer GetParent(CpuMemoryBufferView view)
=> ParentRef(view) ?? throw new InvalidOperationException("Could not access parent buffer");
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,23 +90,21 @@ public void SetComputePipelineState(IntPtr encoder, object kernel)

ArgumentNullException.ThrowIfNull(kernel);

// Extract pipeline state from kernel using reflection (MetalCompiledKernel has private _pipelineState field)
var kernelType = kernel.GetType();
#pragma warning disable IL2075 // Reflection on kernel type is safe - Metal backend controls kernel types
var pipelineStateField = kernelType.GetField("_pipelineState",
System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
#pragma warning restore IL2075

if (pipelineStateField != null && pipelineStateField.GetValue(kernel) is IntPtr pipelineState && pipelineState != IntPtr.Zero)
// Production path: statically-typed access via UnsafeAccessor (AOT-safe, no reflection).
// Test path: anything that isn't a MetalCompiledKernel (e.g. mock) silently no-ops, matching
// the previous reflection-based behavior.
if (kernel is MetalCompiledKernel compiledKernel)
{
MetalNative.SetComputePipelineState(encoder, pipelineState);
_logger.LogTrace("Set pipeline state {Pipeline} on encoder {Encoder}", pipelineState, encoder);
}
else
{
// This might be a test scenario with a mock kernel
_logger.LogTrace("Skipping pipeline state for kernel type: {KernelType}", kernelType.Name);
var pipelineState = MetalCompiledKernelAccessor.PipelineState(compiledKernel);
if (pipelineState != IntPtr.Zero)
{
MetalNative.SetComputePipelineState(encoder, pipelineState);
_logger.LogTrace("Set pipeline state {Pipeline} on encoder {Encoder}", pipelineState, encoder);
return;
}
}

_logger.LogTrace("Skipping pipeline state for kernel type: {KernelType}", kernel.GetType().Name);
}

/// <inheritdoc/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,39 +257,18 @@ private static void ValidateGridDimensions(GridDimensions dimensions, string par
}

/// <summary>
/// Extracts the pipeline state handle from a compiled Metal kernel using reflection.
/// Extracts the pipeline state handle from a compiled Metal kernel.
/// </summary>
/// <param name="kernel">The compiled kernel.</param>
/// <returns>The pipeline state handle.</returns>
/// <exception cref="InvalidOperationException">Thrown when pipeline state cannot be extracted.</exception>
/// <remarks>
/// This method uses reflection to access the private _pipelineState field.
/// This is necessary because MetalCompiledKernel doesn't expose the pipeline state publicly.
/// Delegates to <see cref="MetalCompiledKernelAccessor"/>, which uses
/// <c>UnsafeAccessor</c> to access the private <c>_pipelineState</c> field on
/// <see cref="MetalCompiledKernel"/>. The accessor is resolved at source-gen time
/// (AOT-safe, no trim warnings) and inlined into a direct field load by the JIT.
/// </remarks>
private static IntPtr GetPipelineStateFromKernel(MetalCompiledKernel kernel)
{
var kernelType = typeof(MetalCompiledKernel);

#pragma warning disable IL2075 // Reflection on kernel type is safe - Metal backend controls kernel types
var pipelineStateField = kernelType.GetField("_pipelineState",
System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
#pragma warning restore IL2075

if (pipelineStateField == null)
{
throw new InvalidOperationException(
"Unable to access pipeline state from MetalCompiledKernel. Internal structure may have changed.");
}

var pipelineStateValue = pipelineStateField.GetValue(kernel);
if (pipelineStateValue is IntPtr pipelineState)
{
return pipelineState;
}

throw new InvalidOperationException(
"Pipeline state field exists but has unexpected type or null value.");
}
=> MetalCompiledKernelAccessor.PipelineState(kernel);

/// <summary>
/// Disposes the execution engine and releases native resources.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) 2025 Michael Ivertowski
// Licensed under the MIT License. See LICENSE file in the project root for license information.

using System.Runtime.CompilerServices;

namespace DotCompute.Backends.Metal.Kernels;

/// <summary>
/// AOT-friendly accessor for the private <c>_pipelineState</c> field on
/// <see cref="MetalCompiledKernel"/>.
///
/// <para>
/// The execution layer needs the raw <c>MTLComputePipelineState</c> handle to bind
/// to a command encoder, but <see cref="MetalCompiledKernel"/> deliberately keeps the
/// native handle private. Using <see cref="UnsafeAccessorAttribute"/> instead of
/// <c>FieldInfo.GetValue</c> resolves the field at source-gen time, removes the
/// runtime reflection call (and its associated trim/AOT warnings), and lets the JIT
/// inline the access into a single field load.
/// </para>
/// </summary>
internal static class MetalCompiledKernelAccessor
{
[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "_pipelineState")]
private static extern ref IntPtr PipelineStateRef(MetalCompiledKernel kernel);

/// <summary>
/// Returns the native <c>MTLComputePipelineState</c> handle for the given compiled kernel.
/// </summary>
public static IntPtr PipelineState(MetalCompiledKernel kernel)
{
ArgumentNullException.ThrowIfNull(kernel);
return PipelineStateRef(kernel);
}
}
38 changes: 26 additions & 12 deletions src/Core/DotCompute.Memory/UnifiedBufferMemory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -215,18 +215,11 @@ public void Resize(int newLength)
// Update length and recalculate size
var newSizeInBytes = newLength * Unsafe.SizeOf<T>();

// Use reflection to update readonly properties

var lengthField = typeof(UnifiedBuffer<T>).GetField("<Length>k__BackingField",

System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
lengthField?.SetValue(this, newLength);


var sizeField = typeof(UnifiedBuffer<T>).GetField("<SizeInBytes>k__BackingField",

System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
sizeField?.SetValue(this, newSizeInBytes);
// AOT-friendly: mutate the auto-property backing fields via UnsafeAccessor.
// Replaces FieldInfo.SetValue reflection with a source-gen-resolved extern shim
// that the JIT inlines into a direct field store.
UnifiedBufferBackingFields.LengthRef(this) = newLength;
UnifiedBufferBackingFields.SizeInBytesRef(this) = newSizeInBytes;

// Reallocate host memory
_hostArray = new T[newLength];
Expand Down Expand Up @@ -303,6 +296,27 @@ public async Task PrefetchToHostAsync()
}
}

/// <summary>
/// AOT-friendly accessors for the auto-property backing fields of <see cref="UnifiedBuffer{T}"/>.
/// Used by <see cref="UnifiedBuffer{T}.Resize(int)"/> to mutate the otherwise-readonly
/// <c>Length</c> and <c>SizeInBytes</c> properties without runtime reflection.
///
/// <para>
/// <see cref="UnsafeAccessorAttribute"/> resolves the field reference at source-gen time,
/// and the JIT inlines the accessor into a direct field store. This removes two
/// <c>FieldInfo.SetValue</c> calls from the Resize hot path and makes the code
/// trim/AOT-safe.
/// </para>
/// </summary>
internal static class UnifiedBufferBackingFields
{
[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "<Length>k__BackingField")]
public static extern ref int LengthRef<T>(UnifiedBuffer<T> buffer) where T : unmanaged;

[UnsafeAccessor(UnsafeAccessorKind.Field, Name = "<SizeInBytes>k__BackingField")]
public static extern ref long SizeInBytesRef<T>(UnifiedBuffer<T> buffer) where T : unmanaged;
}

/// <summary>
/// Information about buffer memory allocation and usage.
/// </summary>
Expand Down
Loading