Skip to content

Implement MPIFFT2D and MPIFFTND#195

Open
rohanbabbar04 wants to merge 12 commits into
PyLops:mainfrom
rohanbabbar04:fft
Open

Implement MPIFFT2D and MPIFFTND#195
rohanbabbar04 wants to merge 12 commits into
PyLops:mainfrom
rohanbabbar04:fft

Conversation

@rohanbabbar04
Copy link
Copy Markdown
Collaborator

@rohanbabbar04 rohanbabbar04 commented May 8, 2026

  • Implement FFT using mpi4py-fft
  • Introduce MPIFFTND, MPIFFT2D and _MPIBaseFFTND
  • Basic Test to compare with pylops operators.
  • Working example in examples\
  • Update GA to include fftw libraries.
  • Add mpi implementations of fftshift and ifftshift.

Note: mpi4py-fft works only with numpy arrays, and PFFT works with multi-dimensional arrays.

@rohanbabbar04
Copy link
Copy Markdown
Collaborator Author

Worked on implementing fft using mpi4py-fft. Overall, the implementation is straightforward — getting the data into PFFT and retrieving it back into our Distributed Array. There are some key considerations to keep in mind.

  • Since we have axes as a parameter in the class, we need to ensure that x (the data) is redistributed across the axes[0] axis before pushing it into the DistArray (mpi4py-fft). This is a key step which acts as the starting point.
  • I tried implementing fftshift and ifftshift. I think the best approach is: if the distribution axis differs from the required axis, we can compute it directly; otherwise, we redistribute to a new axis first and then compute.
  • PFFT doesn't support 1-D arrays, which makes sense since there is no free axis available for computation. We can explore a workaround for this.
  • The nffts parameter from the PyLops version is currently missing. PFFT does have a padding parameter, but we will need to investigate how it can be mapped to nffts.
  • mpi4py-fft works with NumPy arrays, since DistArray is built on top of np.ndarray. For CuPy/NCCL support, we should look into using nvidia.distributed.fft.

Once we agree on the implementation, I will go ahead and update the documentation including the mpi4py-fft.

@rohanbabbar04 rohanbabbar04 marked this pull request as ready for review May 13, 2026 05:40
@rohanbabbar04 rohanbabbar04 requested a review from mrava87 May 13, 2026 05:40
Copy link
Copy Markdown
Contributor

@mrava87 mrava87 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rohanbabbar04 this is a great addition to pylops-mpi!

I have gone through the PR mostly trying to understand the rationale of your code and the decision made in the implementation - left some comments and questions?

Once we agree on the right approach, I'll do a more focused review on the actual code 😄

Comment thread examples/plot_ffts.py Outdated
Comment thread examples/plot_ffts.py
Comment thread pylops_mpi/signalprocessing/FFTND.py Outdated

Notes
-----
The MPIFFTND operator (using ``norm="none"``) applies the N-dimensional forward
Copy link
Copy Markdown
Contributor

@mrava87 mrava87 May 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a quick check of the documentation and in general our approach to operators we re-implement in pylops-mpi has been to focus the notes on the aspects related to the distributed version of the algorithm used rather than the general operator... this is really a copy-paste from the PyLops operator that is not very useful because if one wants to understand what is FFT2D they can just look at PyLops's documentation. What really matters is to explain how its MPIFFT2D version differ for it.

This leads me to the main question of this PR: could you explain the overall rationale you took in the current implementation. So far this

# Distribute only along axes[0]; all other axes are non-distributed (0=distributed, 1=not)
subcomm_dims = (np.arange(len(axes)) != axes[0]).astype(int)

makes me think we can only distribute axis=0 and so to some extent (I guess) only pass 1D DistributedArray of ND but with axis=0?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better, but I still left some comments

Copy link
Copy Markdown
Collaborator Author

@rohanbabbar04 rohanbabbar04 Jun 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right...
In the future, If the person provides Nd arrays along axis=0, we wouldn't need to redistribute and straightaway we can perform the fft...but if the user let's say has axis=-1, as the distribution axis, a minor redistribution before we perform fft (this will make the function more flexible).

I would suggest we keep the self.fft fixed rather than changing it with every matvec or rmatvec call (so we initialise it once inside the init itself.)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes makes sense to fix the self.fft object 😄 which basically means (like for many PyLops operators) once a operator is defined one can't change the approach of for example how the input distributed array is distributed....

Comment thread environment-dev.yml
Comment thread pylops_mpi/signalprocessing/FFTND.py Outdated
Comment thread pylops_mpi/signalprocessing/FFTND.py Outdated
fft_dtype = self.rdtype if self.real else self.cdtype
# Distribute only along axes[0]; all other axes are non-distributed (0=distributed, 1=not)
subcomm_dims = (np.arange(len(axes)) != axes[0]).astype(int)
subcomm = Subcomm(self.base_comm, dims=np.resize(subcomm_dims, len(dims)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit lost about these two lines... what are trying to achieve?

If I understand your comment, I would think this is a much easier way to achieve it:

subcomm_dims = np.ones(len(dims))
subcomm_dims[axes[0]] = 1

Apart from simplifying the code, can you explain why you want to only distribute along axes[0] instead of for example along the axis that the input is actually distributed... just trying to think if we can minimize any re-distribution. Of course, this way you would need to ask upfront which axis the input that the operator expects will be distributed, but if that is the price to pay I would still prefer it to doing more re-distributions than needed?

Copy link
Copy Markdown
Collaborator Author

@rohanbabbar04 rohanbabbar04 May 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrava87, this is what I thought just to make the function pretty flexible.

This subcomm parameter controls which axis should be distributed. By default, PFFT determines this automatically based on the number of ranks (and the distribution type: slab or non-slab). Also, our Distributed Array supports Slab(distribution along one axis).

I changed this to subcomm_dim[0] = 0, meaning only the first axis is distributed — which is what our reshaped method does. However, if the last axis provided by the user (via the axes param) is 0, we need to override this, as PFFT raises an error in that case(changing this to subcomm_dim[1] = 0).

The distribution axis for the forward and backward functions depends on the axes argument provided by the user. To handle this, I introduced two internal parameters — _pfft_in_axis and _pfft_out_axis — which track the distribution axis of the input and output respectively. These differ because PFFT may internally transpose the axis during FFT computation, meaning the output of forward (and similarly backward) can be distributed across different axes.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I think I get it... we apply reshaped so we always create internally a Nd distributed array with scattering over axes=0 but PFFT does not like the first axis over which the FFT is computed to be distributed, so if this is the case we redistribute over another axis?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread pylops_mpi/signalprocessing/FFTND.py Outdated
Comment thread pylops_mpi/signalprocessing/FFTND.py Outdated
Comment thread requirements-fft.txt Outdated
Comment thread .github/workflows/build.yml
Comment thread pylops_mpi/signalprocessing/FFTND.py Outdated

# Axis along which PFFT decomposes the output array across MPI processes
dist_axis = [i for i, s in enumerate(u_hat.subcomm) if s.Get_size() > 1]
y = DistributedArray(global_shape=self.dimsd, dtype=self.dtype, axis=dist_axis[0] if dist_axis else 0,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean that the output could have a different distribution of the input?

Copy link
Copy Markdown
Collaborator Author

@rohanbabbar04 rohanbabbar04 Jun 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it could be, as it depends on the axes parameter.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rohanbabbar04 ok, this is not ideal as it is not really predictable.

A user having a global 2d array chunked over axis=0 and flattened and then put into a 1D distributed array may get back a new 1D distributed array that contains local arrays that if reshaped to 2d have chunks over axis=1 and entire axis 0?

If so, how do you see a user knowing this, is there at least a parameter in the FFT operator they can consult?

Comment thread pylops_mpi/signalprocessing/FFTND.py Outdated
@mrava87
Copy link
Copy Markdown
Contributor

mrava87 commented Jun 1, 2026

@rohanbabbar04 thanks for the updates!

I left a few additional comments (and unresolved some conversations to make sure I fully undestand some of your design choices... should be clear now but please confirm).

Once you have handled these last few points, I think this is good to go!

Next, we can look into enabling N-D Distributed arrays into MPI Linear Operators as in some cases like FFT we could benefit from not having to always distribute the first axis... I had a little play and it seems doable with minor code changes, I may create a draft PR to document what I did so far

@rohanbabbar04
Copy link
Copy Markdown
Collaborator Author

@rohanbabbar04 thanks for the updates!

I left a few additional comments (and unresolved some conversations to make sure I fully undestand some of your design choices... should be clear now but please confirm).

Once you have handled these last few points, I think this is good to go!

Next, we can look into enabling N-D Distributed arrays into MPI Linear Operators as in some cases like FFT we could benefit from not having to always distribute the first axis... I had a little play and it seems doable with minor code changes, I may create a draft PR to document what I did so far

Thanks @mrava87 for the review, I have added the comments to the respective questions. Do let me know if you need any more information 🙂 ?

@mrava87
Copy link
Copy Markdown
Contributor

mrava87 commented Jun 3, 2026

@rohanbabbar04 thanks for the updates!
I left a few additional comments (and unresolved some conversations to make sure I fully undestand some of your design choices... should be clear now but please confirm).
Once you have handled these last few points, I think this is good to go!
Next, we can look into enabling N-D Distributed arrays into MPI Linear Operators as in some cases like FFT we could benefit from not having to always distribute the first axis... I had a little play and it seems doable with minor code changes, I may create a draft PR to document what I did so far

Thanks @mrava87 for the review, I have added the comments to the respective questions. Do let me know if you need any more information 🙂 ?

Quite a few comments unanswered (you may have missed completely my comments from the second review I did... I pinged you in each of them, so you can easily find them 😄

def __init__(
self,
dims: InputDimsLike,
axes: InputDimsLike = (0, 1, 2),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized you don't pass nffts, so this is eventually always equal None and so set to be equal to dims, any reason why?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raises
------
ValueError
- If ``norm`` is not one of "none", or "1/n".
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove - if ValueError is raised only for one case

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But looking at the init, there are at least 2 raises... fix this.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


Notes
-----
The MPIFFT2D operator performs forward and adjoint passes on a
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MPIFFT2D operator performs forward and adjoint passes on a -> The MPIFFT2D operator applies the forward and inverse 2-dimensional Fast Fourier transform to a

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)
raise ValueError(msg)

# Check if the user provided nfft smaller than n. See _BaseFFT for
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment does not make sense right now, as a user is not even allowed to pass nffts (in the public API classes)...

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also what is _BaseFFT here?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:class:`pylops_mpi.DistributedArray`, which is internally reshaped to the 2-dimensional layout
defined by ``dims``. The 2-dimensional FFT is then applied across MPI ranks using ``mpi4py_fft``'s
:class:`mpi4py_fft.mpifft.PFFT` class, with the global array decomposed via a pencil decomposition.
:class:`mpi4py_fft.pencil.Subcomm` selects the axis of distribution: ``axis=0`` by default,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:class:`mpi4py_fft.pencil.Subcomm` selects the axis of distribution: ``axis=0`` by default,
shifting to ``axis=1`` if ``axes[-1] == 0`` to avoid a conflict between the transform and decomposition axes.

This is still unclear. Since MPI operators can only be applied to 1D DistributedArray, i guess we always assume that the first axis of the underlying 2d array is distributed, so when internally the local array is reshaped, the only axis over which a straighforward fft isnt allow is axis=0.

Now if axes=(0,1) this is all fine, right? But if axes=(1,0), the Subcomm is defined such that axis=1 is assumed to be distributed but then we need to use .redistribute to make this happen? And then at the end after the FFT is applied we let @Reshaped to bring it back to 1D?

Not asking to write in such detail but things like to avoid a conflict between the transform and decomposition axes. are very unclear and a user would not gain much from it

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread .github/workflows/build.yml
Comment thread examples/plot_ffts.py
Comment thread pylops_mpi/signalprocessing/FFTND.py Outdated

Notes
-----
The MPIFFTND operator (using ``norm="none"``) applies the N-dimensional forward
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes makes sense to fix the self.fft object 😄 which basically means (like for many PyLops operators) once a operator is defined one can't change the approach of for example how the input distributed array is distributed....

Comment thread pylops_mpi/signalprocessing/FFTND.py Outdated

# Axis along which PFFT decomposes the output array across MPI processes
dist_axis = [i for i, s in enumerate(u_hat.subcomm) if s.Get_size() > 1]
y = DistributedArray(global_shape=self.dimsd, dtype=self.dtype, axis=dist_axis[0] if dist_axis else 0,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rohanbabbar04 ok, this is not ideal as it is not really predictable.

A user having a global 2d array chunked over axis=0 and flattened and then put into a 1D distributed array may get back a new 1D distributed array that contains local arrays that if reshaped to 2d have chunks over axis=1 and entire axis 0?

If so, how do you see a user knowing this, is there at least a parameter in the FFT operator they can consult?

Comment thread pylops_mpi/signalprocessing/FFTND.py Outdated
fft_dtype = self.rdtype if self.real else self.cdtype
# Distribute only along axes[0]; all other axes are non-distributed (0=distributed, 1=not)
subcomm_dims = (np.arange(len(axes)) != axes[0]).astype(int)
subcomm = Subcomm(self.base_comm, dims=np.resize(subcomm_dims, len(dims)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants