Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ ScopedValues = "1"
julia = "1"

[extras]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Random", "Test"]
60 changes: 60 additions & 0 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,27 @@ MLXMatrix(array::AbstractMatrix{T}) where {T} = MLXMatrix{T}(array)

const MLXVecOrMat{T} = Union{MLXVector{T}, MLXMatrix{T}}

# UndefInitializer

function MLXArray{T, N}(::UndefInitializer, dims::Dims{N}) where {T, N}
stream = get_stream()
result_ref = Ref(Wrapper.mlx_array_new())
shape = collect(Cint.(dims))
dtype = convert(Wrapper.mlx_dtype, T)
Wrapper.mlx_zeros(result_ref, pointer(shape), Cint(N), dtype, stream.mlx_stream)
return MLXArray{T, N}(result_ref[])
end

function MLXArray{T}(::UndefInitializer, dims::Dims{N}) where {T, N}
return MLXArray{T, N}(undef, dims)
end

# BitArray

MLXArray{Bool, N}(array::BitArray{N}) where {N} = MLXArray(Array{Bool}(array))

MLXArray(array::BitArray{N}) where {N} = MLXArray{Bool, N}(array)

# AbstractArray interface, cf. https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-array

function Base.size(array::MLXArray)
Expand All @@ -89,6 +110,13 @@ function Base.setindex!(array::MLXArray{T, N}, v::T, i::Int) where {T, N}
return array
end

function Base.similar(array::MLXArray{T, N}, ::Type{T}, ::Dims{N}) where {T, N}
stream = get_stream()
result_ref = Ref(Wrapper.mlx_array_new())
Wrapper.mlx_zeros_like(result_ref, array.mlx_array, stream.mlx_stream)
return MLXArray{T, N}(result_ref[])
end

# Strided array interface, cf. https://docs.julialang.org/en/v1/manual/interfaces/#man-interface-strided-arrays

function Base.strides(array::MLXArray)
Expand Down Expand Up @@ -164,3 +192,35 @@ function Base.unsafe_wrap(array::MLXArray{T, N}) where {T, N}
return PermutedDimsArray(wrapped_array, reverse(1:ndims(array)))
end
end

# Broadcasting interface, cf. https://docs.julialang.org/en/v1/manual/interfaces/#man-interfaces-broadcasting

Base.BroadcastStyle(::Type{<:MLXArray}) = Broadcast.ArrayStyle{MLXArray}()

function Base.similar(
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MLXArray}}, ::Type{TElement}
) where {TElement}
first_mlx_array(bc::Broadcast.Broadcasted) = first_mlx_array(bc.args)
function first_mlx_array(args::Tuple)
return first_mlx_array(first_mlx_array(args[1]), Base.tail(args))
end
first_mlx_array(x) = x
first_mlx_array(::Tuple{}) = nothing
first_mlx_array(a::MLXArray, _) = a
first_mlx_array(::Any, rest) = first_mlx_array(rest)
mlx_array = first_mlx_array(bc)
if isnothing(mlx_array)
return similar(MLXArray{TElement}, ())
end
return similar(mlx_array)
end

function Base.Broadcast.materialize(
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MLXArray}}
)
result = copy(Broadcast.instantiate(bc))
if iszero(ndims(result)) # Drop 0-dim arrays to scalars, cf. https://github.com/JuliaLang/julia/issues/28866
return result[]
end
return result
end
86 changes: 86 additions & 0 deletions test/array_tests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
@static if VERSION < v"1.11"
using ScopedValues
else
using Base.ScopedValues
end

using MLX
using Random
using Test

@testset "MLXArray" begin
Random.seed!(42)

device_types = [MLX.DeviceTypeCPU]
if MLX.metal_is_available()
push!(device_types, MLX.DeviceTypeGPU)
end

@test IndexStyle(MLXArray) == IndexLinear()

array_sizes = [(), (1,), (2,), (1, 1), (2, 1), (3, 2), (4, 3, 2)]
Expand Down Expand Up @@ -31,6 +45,22 @@ using Test
array[1] = T(1)
@test setindex!(mlx_array, T(1), 1) == array
end

@testset "similar(::$MLXArray{$T, $N}), array_size=$array_size" begin
for device_type in device_types
if T ∉ MLX.supported_number_types(device_type)
continue
end
@testset "similar(::$MLXArray{$T, $N}), with array_size=$array_size, $device_type" begin
with(MLX.device => MLX.Device(; device_type)) do
similar_mlx_array = similar(mlx_array)
@test typeof(similar_mlx_array) == typeof(mlx_array)
@test size(similar_mlx_array) == size(mlx_array)
@test similar_mlx_array !== mlx_array
end
end
end
end
end
end
end
Expand Down Expand Up @@ -62,7 +92,63 @@ using Test
@test Base.elsize(MLXArray{T, 0}) == Base.elsize(Array{T, 0})
end
end

@testset "Unsupported Number types" begin
@test_throws ArgumentError convert(MLX.Wrapper.mlx_dtype, Rational{Int})
end

@testset "BitArray" begin
for array_size in array_sizes
N = length(array_size)
@testset "$MLXArray{Bool, $N}(::BitArray), array_size=$array_size" begin
array = BitArray(rand(Bool, array_size))
if N > 2 || N == 0
mlx_array = MLXArray(array)
elseif N > 1
mlx_array = MLXMatrix(array)
else
mlx_array = MLXVector(array)
end
@test array == mlx_array
end
end
end

@testset "Broadcasting interface" begin
@testset "broadcast over tuple with no MLXArray" begin
result = similar(
Broadcast.Broadcasted{Broadcast.ArrayStyle{MLXArray}}(identity, ()), Bool
)
@test result isa MLXArray{Bool, 0}
end

test_cases(T, array_size) = [
(fn = identity, args = ()),
(fn = T == Bool ? xor : +, args = (T == Bool ? true : T(2),)),
(fn = T == Bool ? xor : +, args = (MLXArray(rand(T, array_size)),)),
]
for device_type in device_types,
T in MLX.supported_number_types(device_type),
array_size in array_sizes,
test_case in test_cases(T, array_size)

compare_op = T <: Integer ? (==) : ≈
N = length(array_size)
@testset "broadcast($(repr(test_case.fn)), $(join(map(arg -> "::$(typeof(arg))", [MLXArray{T, N}, test_case.args...]), ", "))), array_size=$array_size, $device_type" begin
array = rand(T, array_size)
mlx_array = MLXArray(array)

with(MLX.device => MLX.Device(; device_type)) do
actual = broadcast(test_case.fn, mlx_array, test_case.args...)
expected = broadcast(test_case.fn, array, test_case.args...)
if N == 0
@test actual isa T
else
@test actual isa MLXArray{T, N}
end
@test compare_op(actual, expected)
end
end
end
end
end
1 change: 1 addition & 0 deletions test/device_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
else
using Base.ScopedValues
end

using MLX
using Test

Expand Down
1 change: 1 addition & 0 deletions test/stream_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
else
using Base.ScopedValues
end

using MLX
using Test

Expand Down
Loading