Skip to content

fix: transfer GridEmbedding grid to input device (fixes CUDA scalar indexing #125)#132

Open
jitendravjh wants to merge 1 commit into
SciML:mainfrom
jitendravjh:fix/grid-embedding-gpu-scalar-indexing
Open

fix: transfer GridEmbedding grid to input device (fixes CUDA scalar indexing #125)#132
jitendravjh wants to merge 1 commit into
SciML:mainfrom
jitendravjh:fix/grid-embedding-gpu-scalar-indexing

Conversation

@jitendravjh
Copy link
Copy Markdown

What this fixes

Calling FourierNeuralOperator on a GPU (CuArray) raised:

ERROR: Scalar indexing is disallowed. Invocation of getindex resulted in scalar indexing of a GPU array.

Reported in #125.

Root cause

GridEmbedding builds its positional coordinate grid from CPU range vectors via meshgrid. The resulting grid is a plain Array{T} on the CPU. When: cat(grid, x; dims = N - 1) is called and x is a CuArray, Julia's fallback cat path reads from the GPU array element-by-element on the CPU — triggering the scalar indexing error.

The second stacktrace in the issue points directly to the offending line:

[20] (::GridEmbedding{...})(x::CuArray{...}) @ NeuralOperators src/layers.jl:240

Fix

After building the grid, move it to the same device as x before the cat:

# Move the CPU-built grid to the same device as x (fixes CUDA scalar indexing, #125)
grid = Lux.get_device(x)(grid)

Lux.get_device is already available via the existing Lux import and:

  • Is a no-op on CPU - returns a CPUDevice functor that just calls Array(grid) which is already a no-op
  • Transparently transfers to GPU - returns a CUDADevice functor that calls cu(grid)
  • Works equally for Metal, ROCm, and any other MLDataDevices-supported backend

Checklist

  • Appropriate tests were added (manually verified on CPU; GPU path requires CUDA hardware)
  • Any code changes were done in a way that does not break public API
  • All documentation related to code changes were updated (N/A — behaviour fix only)
  • The new code follows the contributor guidelines and SciML Style Guide
  • Any new documentation only uses public API

…ndexing SciML#125)

GridEmbedding built the positional grid using CPU range/meshgrid, then
called cat(grid, x) where x may be a CuArray. This caused:

  ERROR: Scalar indexing is disallowed. Invocation of getindex resulted
  in scalar indexing of a GPU array.

Fix: call Lux.get_device(x)(grid) immediately after building the grid,
so the array is moved to the same device as the input before the cat.
This is a no-op on CPU and transparently transfers to GPU/Metal/etc.

Fixes SciML#125
@jitendravjh
Copy link
Copy Markdown
Author

The CI failures are pre-existing and unrelated to this change:

  • CUDA GPU Tests: Malt.TerminatedWorkerException() in spectral_conv_tests, spectral_kernel_tests, and fno_tests - these are Reactant worker process crashes, not GridEmbedding failures. GridEmbedding and scalar indexing do not appear anywhere in the error log.
  • Documentation: SIGABRT (exit 134) - runner OOM during doc build.

The same failures appear on main in run #25599450582. This PR only modifies the GridEmbedding forward method and passes spell check + format check cleanly.

@ChrisRackauckas
Copy link
Copy Markdown
Member

We'd need to get master working again to get merges here then.

@jitendravjh
Copy link
Copy Markdown
Author

I just opened a quick PR (#135) to fix the broken Windows tests on main (the Float32 finite difference tolerance was too strict for XLA on Windows). Once that's in, we should be unblocked!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants