Skip to content

[Bug]: Parameter count in FNO model is too low when following specification #67

@t-muser

Description

@t-muser

Describe the issue:

I think torchinfo miscounts the parameters due to some ModuleList shenanigans. If one uses a simple for-loop + torch, the number of parameters is much lower, with the difference being exactly the 7.7mio assigned to the total ModuleList. A simple fix would be to increase hidden_channels to 180, which gives 19,056,783 mio total parameters.

===================================================================================================================
Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
===================================================================================================================
FNO                                      [4, 3, 64, 64]            [4, 3, 64, 64]            --
├─ChannelMLP: 1-1                        [4, 3, 64, 64]            [4, 128, 64, 64]          --
│    └─ModuleList: 2-1                   --                        --                        --
│    │    └─Conv1d: 3-1                  [4, 3, 4096]              [4, 256, 4096]            1,024
│    │    └─Conv1d: 3-2                  [4, 256, 4096]            [4, 128, 4096]            32,896
├─FNOBlocks: 1-2                         [4, 128, 64, 64]          [4, 128, 64, 64]          7,177,536
│    └─ModuleList: 2-14                  --                        --                        (recursive)
│    │    └─Flattened1dConv: 3-3         [4, 128, 64, 64]          [4, 128, 64, 64]          --
│    │    │    └─Conv1d: 4-1             [4, 128, 4096]            [4, 128, 4096]            16,384
│    └─ModuleList: 2-15                  --                        --                        (recursive)
│    │    └─SoftGating: 3-4              [4, 128, 64, 64]          [4, 128, 64, 64]          128
│    └─ModuleList: 2-16                  --                        --                        (recursive)
│    │    └─SpectralConv: 3-5            [4, 128, 64, 64]          [4, 128, 64, 64]          2,359,424
│    └─ModuleList: 2-17                  --                        --                        (recursive)
│    │    └─ChannelMLP: 3-6              [4, 128, 64, 64]          [4, 128, 64, 64]          --
│    │    │    └─ModuleList: 4-2         --                        --                        16,576
├─FNOBlocks: 1-3                         [4, 128, 64, 64]          [4, 128, 64, 64]          (recursive)
│    └─ModuleList: 2-14                  --                        --                        (recursive)
│    │    └─Flattened1dConv: 3-7         [4, 128, 64, 64]          [4, 128, 64, 64]          --
│    │    │    └─Conv1d: 4-3             [4, 128, 4096]            [4, 128, 4096]            16,384
│    └─ModuleList: 2-15                  --                        --                        (recursive)
│    │    └─SoftGating: 3-8              [4, 128, 64, 64]          [4, 128, 64, 64]          128
│    └─ModuleList: 2-16                  --                        --                        (recursive)
│    │    └─SpectralConv: 3-9            [4, 128, 64, 64]          [4, 128, 64, 64]          2,359,424
│    └─ModuleList: 2-17                  --                        --                        (recursive)
│    │    └─ChannelMLP: 3-10             [4, 128, 64, 64]          [4, 128, 64, 64]          --
│    │    │    └─ModuleList: 4-4         --                        --                        16,576
├─FNOBlocks: 1-4                         [4, 128, 64, 64]          [4, 128, 64, 64]          (recursive)
│    └─ModuleList: 2-14                  --                        --                        (recursive)
│    │    └─Flattened1dConv: 3-11        [4, 128, 64, 64]          [4, 128, 64, 64]          --
│    │    │    └─Conv1d: 4-5             [4, 128, 4096]            [4, 128, 4096]            16,384
│    └─ModuleList: 2-15                  --                        --                        (recursive)
│    │    └─SoftGating: 3-12             [4, 128, 64, 64]          [4, 128, 64, 64]          128
│    └─ModuleList: 2-16                  --                        --                        (recursive)
│    │    └─SpectralConv: 3-13           [4, 128, 64, 64]          [4, 128, 64, 64]          2,359,424
│    └─ModuleList: 2-17                  --                        --                        (recursive)
│    │    └─ChannelMLP: 3-14             [4, 128, 64, 64]          [4, 128, 64, 64]          --
│    │    │    └─ModuleList: 4-6         --                        --                        16,576
├─FNOBlocks: 1-5                         [4, 128, 64, 64]          [4, 128, 64, 64]          (recursive)
│    └─ModuleList: 2-14                  --                        --                        (recursive)
│    │    └─Flattened1dConv: 3-15        [4, 128, 64, 64]          [4, 128, 64, 64]          --
│    │    │    └─Conv1d: 4-7             [4, 128, 4096]            [4, 128, 4096]            16,384
│    └─ModuleList: 2-15                  --                        --                        (recursive)
│    │    └─SoftGating: 3-16             [4, 128, 64, 64]          [4, 128, 64, 64]          128
│    └─ModuleList: 2-16                  --                        --                        (recursive)
│    │    └─SpectralConv: 3-17           [4, 128, 64, 64]          [4, 128, 64, 64]          2,359,424
│    └─ModuleList: 2-17                  --                        --                        (recursive)
│    │    └─ChannelMLP: 3-18             [4, 128, 64, 64]          [4, 128, 64, 64]          --
│    │    │    └─ModuleList: 4-8         --                        --                        16,576
├─ChannelMLP: 1-6                        [4, 128, 64, 64]          [4, 3, 64, 64]            --
│    └─ModuleList: 2-18                  --                        --                        --
│    │    └─Conv1d: 3-19                 [4, 128, 4096]            [4, 256, 4096]            33,024
│    │    └─Conv1d: 3-20                 [4, 256, 4096]            [4, 3, 4096]              771
===================================================================================================================
Total params: 16,815,299
Trainable params: 16,815,299
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 3.27
===================================================================================================================
Input size (MB): 0.20
Forward/backward pass size (MB): 319.16
Params size (MB): 0.80
Estimated Total Size (MB): 320.16
===================================================================================================================
Torch counting:
        9,637,763

Code to reproduce the issue:

import torch
from torchinfo import summary

from neuralop.models import FNO

in_channels = 3
out_channels = 3
modes = 16
hidden_channels = 128
n_layers = 4
n_spatial_dims = 2

batch_size = 4
height, width = 64, 64

device = torch.device('cpu')

dummy_input = torch.randn(batch_size, in_channels, height, width, device=device)

model = FNO(
    n_modes=(modes, modes),
    in_channels=in_channels,
    out_channels=out_channels,
    hidden_channels=hidden_channels,
    n_layers=n_layers,
    positional_embedding=None,
).to(device)

summary(
    model,
    input_size=(batch_size, in_channels, height, width),
    depth=4,
    col_names=["input_size", "output_size", "num_params"],
    device=device,
)

num_params = sum(param.numel() for param in model.parameters() if param.requires_grad)
print(f'{num_params:,}')

Version

1.1

Environment

dependencies = [
    "torch >= 2.0.1",
    "torchvision >= 0.15.2",
    "numpy < 2.0.0",
    "transformers == 4.55.0",
    "matplotlib",
    "accelerate>=0.32.0",
    "wandb==0.22.2",
    "h5py",
    "pandas",
    "pyyaml",
    "netcdf4>=1.7.2",
    "einops>=0.8.1",
    "scipy>=1.16.1",
    "pytorch-lightning>=2.3.3",
    "ninja>=1.13.0",
    "ipykernel>=6.30.1",
    "seaborn>=0.13.2",
    "huggingface-hub[cli]>=0.34.4",
    "xarray>=2025.8.0",
    "torchinfo>=1.8.0",
    "the-well>=1.1.0",
    "hydra-core>=1.3.2",
    "denoising-diffusion-pytorch>=2.2.5",
    "timm>=1.0.20",
    "neuraloperator>=2.0.0",
    "triton>=3.4.0",
    "ruff>=0.14.5",
]

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions