Skip to content

mlx_array_dtype returns wrong value for MLX_COMPLEX64 (ComplexF32) arrays #19

@stemann

Description

@stemann
a = [1.2f0 + 3.4f0im]
a_mlx = MLXArray(a)
UInt32(MLX.Wrapper.mlx_array_dtype(a_mlx.mlx_array)) # Returns `0x3374534e` instead of `0x0000000c` / `MLX.Wrapper.MLX_COMPLEX64`

Despite mlx_array_new_data / mlx_array_set_data being called with proper arguments:

array = [1.2f0 + 3.4f0im]
N = ndims(array) # 1
T = eltype(array) # ComplexF32
shape = collect(Cint.(reverse(size(array)))) # Int32[1]
dtype = convert(MLX.Wrapper.mlx_dtype, T) # MLX_COMPLEX64
array_mlx = MLX.Wrapper.mlx_array_new_data(pointer(array), pointer(shape), N, dtype)

UInt32(MLX.Wrapper.mlx_array_dtype(array_mlx)) # 0x3374534e

Seems to not reproduce in C:

  float v[] = {1.2f, 3.4f};
  int v_shape[] = {1};
  mlx_array v_arr = mlx_array_new_data(&v, v_shape, 1, MLX_COMPLEX64);
  printf("mlx_array_dtype: %u\n", mlx_array_dtype(v_arr)); // Prints `mlx_array_dtype: 12` / MLX_COMPLEX64
  mlx_array_free(v_arr);

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions