Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/NeuralOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ using AbstractFFTs: fft, rfft, ifft, irfft, fftshift
using ConcreteStructs: @concrete
using Random: Random, AbstractRNG

using Lux: Lux, Chain, Dense, Conv, Parallel, NoOpLayer, WrappedFunction, Scale
using Lux: Lux, Chain, Dense, Conv, Parallel, NoOpLayer, WrappedFunction, Scale,
Upsample, MeanPool, SamePad

using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxWrapperLayer
using LuxLib: fast_activation!!
using NNlib: NNlib, batched_mul, pad_constant, gelu
Expand All @@ -18,6 +20,8 @@ include("layers.jl")
include("models/fno.jl")
include("models/deeponet.jl")
include("models/nomad.jl")
include("models/cno.jl")


export FourierTransform
export SpectralConv, OperatorConv, SpectralKernel, OperatorKernel
Expand All @@ -26,5 +30,7 @@ export GridEmbedding, ComplexDecomposedLayer, SoftGating
export FourierNeuralOperator
export DeepONet
export NOMAD
export ConvolutionalNeuralOperator


end
6 changes: 5 additions & 1 deletion src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,14 +231,18 @@ function (layer::GridEmbedding)(x::AbstractArray{T, N}, ps, st) where {T, N}
grid = meshgrid(
map(enumerate(layer.grid_boundaries)) do (i, (min, max))
return range(T(min), T(max); length = size(x, i))
end...
end...,
)

grid = repeat(
Lux.Utils.contiguous(reshape(grid, size(grid)..., 1)),
ntuple(Returns(1), N - 1)...,
size(x, N),
)

# Move the CPU-built grid to the same device as x (fixes CUDA scalar indexing, #125)
grid = Lux.get_device(x)(grid)

return cat(grid, x; dims = N - 1), st
end

Expand Down
151 changes: 151 additions & 0 deletions src/models/cno.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""
CNOBlock(
in_channels::Integer,
out_channels::Integer,
modes::Dims{N},
activation = gelu;
upsample_factor::Integer = 2,
) where {N}

A single Convolutional Neural Operator (CNO) block.

Each block:
1. **Upsamples** the input by `upsample_factor` using bilinear/trilinear interpolation.
2. Applies a **3×(…×3) convolution** in the higher-resolution space with `SamePad`.
3. Applies the **activation function** pointwise.
4. **Downsamples** back to the original resolution via average pooling.

This design ensures the discrete operator converges to a continuous limit as resolution
increases, unlike standard CNNs which only approximate finite-dimensional maps.

## Arguments

- `in_channels`: Number of input channels.
- `out_channels`: Number of output channels.
- `modes`: Spatial dimensions tuple (length = data dimensionality). Only the length is
used to set the kernel dimensionality.
- `activation`: Pointwise activation applied after convolution.

## Keyword Arguments

- `upsample_factor`: Integer upsampling factor. Default is `2`.

## References

[1] Raonic et al., "Convolutional Neural Operators for robust and accurate learning of
PDEs," NeurIPS 2023. https://arxiv.org/abs/2302.01178
"""
@concrete struct CNOBlock <: AbstractLuxWrapperLayer{:model}
model
end

function CNOBlock(
in_channels::Integer,
out_channels::Integer,
modes::Dims{N},
activation = gelu;
upsample_factor::Integer = 2,
) where {N}
spatial_dims = ntuple(identity, N)
kernel = ntuple(Returns(3), N)
pool_window = ntuple(Returns(upsample_factor), N)

return CNOBlock(
Chain(
# 1. Upsample to higher resolution
Upsample(:bilinear; scale = upsample_factor),
# 2. Convolution at high resolution (preserves spatial size via SamePad)
Conv(kernel, in_channels => out_channels, activation; pad = SamePad()),
# 3. Downsample back via average pooling (paper Section 3.1)
MeanPool(pool_window),
),
)
end

"""
ConvolutionalNeuralOperator(
modes::Dims{N},
in_channels::Integer,
out_channels::Integer,
hidden_channels::Integer;
num_layers::Integer = 4,
activation = gelu,
upsample_factor::Integer = 2,
) where {N}

Convolutional Neural Operator (CNO) for learning PDE solution operators.

CNO applies a sequence of resolution-preserving continuous convolutional blocks. Each
block upsamples the input, applies a convolution in the higher-resolution space, and
downsamples back. This design is proven to converge to a well-defined continuous operator
as resolution increases, making CNO resolution-invariant by construction.

**Architecture**:
1. **Lifting** `Conv(1×…×1)`: maps `in_channels → hidden_channels`
2. **CNO blocks** × `num_layers`: each is upsample → conv → activation → avgpool
3. **Projection**: `Conv(1×…×1, act)` → `Conv(1×…×1)` maps to `out_channels`

## Arguments

- `modes`: Spatial size tuple (length `d` for d-dimensional data). Only its length
matters — kept consistent with the FNO API.
- `in_channels`: Number of input channels.
- `out_channels`: Number of output channels.
- `hidden_channels`: Number of channels inside the CNO blocks.

## Keyword Arguments

- `num_layers`: Number of `CNOBlock` layers. Default is `4`.
- `activation`: Activation function used inside each block. Default is `gelu`.
- `upsample_factor`: Spatial upsampling factor inside each block. Default is `2`.

## References

[1] Raonic et al., "Convolutional Neural Operators for robust and accurate learning of
PDEs," NeurIPS 2023. https://arxiv.org/abs/2302.01178

## Example

```jldoctest
julia> cno = ConvolutionalNeuralOperator((16,), 1, 1, 32; num_layers=3);

julia> ps, st = Lux.setup(Xoshiro(), cno);

julia> u = rand(Float32, 64, 1, 5);

julia> size(first(cno(u, ps, st)))
(64, 1, 5)
```
"""
@concrete struct ConvolutionalNeuralOperator <: AbstractLuxWrapperLayer{:model}
model <: AbstractLuxLayer
end

function ConvolutionalNeuralOperator(
modes::Dims{N},
in_channels::Integer,
out_channels::Integer,
hidden_channels::Integer;
num_layers::Integer = 4,
activation = gelu,
upsample_factor::Integer = 2,
) where {N}
ones_kernel = ntuple(Returns(1), N)

lifting = Conv(ones_kernel, in_channels => hidden_channels)

cno_blocks = Chain(
[
CNOBlock(
hidden_channels, hidden_channels, modes, activation; upsample_factor,
) for _ in 1:num_layers
]...,
)

projection = Chain(
Conv(ones_kernel, hidden_channels => hidden_channels, activation),
Conv(ones_kernel, hidden_channels => out_channels),
)

return ConvolutionalNeuralOperator(Chain(; lifting, cno_blocks, projection))
end
59 changes: 59 additions & 0 deletions test/models/cno_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
using NeuralOperators, Test

include("../shared_testsetup.jl")

@testset "Convolutional Neural Operator" begin
rng = StableRNG(12345)

setups = [
(
modes = (4,),
in_channels = 1,
out_channels = 1,
hidden_channels = 8,
num_layers = 2,
x_size = (16, 1, 4),
y_size = (16, 1, 4),
),
(
modes = (4, 4),
in_channels = 2,
out_channels = 1,
hidden_channels = 8,
num_layers = 2,
x_size = (16, 16, 2, 4),
y_size = (16, 16, 1, 4),
),
]

xdev = reactant_device(; force = true)

@testset "$(length(setup.modes))D" for setup in setups
cno = ConvolutionalNeuralOperator(
setup.modes, setup.in_channels, setup.out_channels, setup.hidden_channels;
num_layers = setup.num_layers,
)
display(cno)
ps, st = Lux.setup(rng, cno)

x = rand(rng, Float32, setup.x_size...)
y = rand(rng, Float32, setup.y_size...)

@test size(first(cno(x, ps, st))) == setup.y_size

ps_ra, st_ra = (ps, st) |> xdev
x_ra, y_ra = (x, y) |> xdev

res = first(cno(x, ps, st))
res_ra, _ = @jit cno(x_ra, ps_ra, st_ra)
@test res_ra ≈ res atol = 1.0f-2 rtol = 1.0f-2

@testset "check gradients" begin
∂x_fd, ∂ps_fd = ∇sumabs2_reactant_fd(cno, x_ra, ps_ra, st_ra)
∂x_ra, ∂ps_ra = ∇sumabs2_reactant(cno, x_ra, ps_ra, st_ra)

@test ∂x_fd ≈ ∂x_ra atol = 1.0f-1 rtol = 1.0f-1
@test check_approx(∂ps_fd, ∂ps_ra; atol = 1.0f-1, rtol = 1.0f-1)
end
end
end