diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1b46d876..6f4c3447 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -36,12 +36,24 @@ jobs: uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + - name: Install fftw libraries + run: | + case $(uname) in + Linux) + sudo apt update + sudo apt install -y -q libfftw3-dev + ;; + Darwin) + export HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK=1 + brew install fftw + ;; + esac - name: Installing Dependencies run: | python -m pip install --upgrade pip setuptools if [ -f requirements.txt ]; then pip install -r requirements-dev.txt; fi - name: Install pylops-mpi - run: pip install . + run: pip install .[all] - name: Testing using pytest-mpi run: | if [ "${{ matrix.mpi }}" = "openmpi" ]; then diff --git a/.github/workflows/deploy-docs.yml b/.github/workflows/deploy-docs.yml index 59e83239..c2a3777f 100644 --- a/.github/workflows/deploy-docs.yml +++ b/.github/workflows/deploy-docs.yml @@ -22,12 +22,15 @@ jobs: uses: actions/setup-python@v4 with: python-version: 3.11 - + - name: Install fftw libraries + run: | + sudo apt update + sudo apt install -y -q libfftw3-dev - name: Install dependencies run: | python -m pip install --upgrade pip if [ -f requirements.txt ]; then pip install -r requirements-dev.txt; fi - pip install . + pip install .[all] - name: Build docs run: | sphinx-build -b html ./docs/source ./docs/build diff --git a/Makefile b/Makefile index 1d824ee0..10d121bd 100644 --- a/Makefile +++ b/Makefile @@ -43,6 +43,10 @@ dev-install_conda: dev-install_conda_nccl: conda env create -f environment-dev.yml && conda activate pylops_mpi && conda install -c conda-forge cupy nccl && pip install -e . +dev-install_fft: + make pipcheck + $(PIP) install -r requirements-dev.txt && $(PIP) install -e ".[fft]" + lint: flake8 pylops_mpi/ tests/ examples/ tutorials/ diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 73d41469..8b74e40a 100644 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -71,7 +71,8 @@ Signal Processing :toctree: generated/ MPIFredholm1 - + MPIFFT2D + MPIFFTND Wave-Equation processing ~~~~~~~~~~~~~~~~~~~~~~~~ @@ -107,7 +108,7 @@ Basic cgls Sparsity -~~~~~ +~~~~~~~~ .. currentmodule:: pylops_mpi.optimization.cls_sparsity diff --git a/docs/source/conf.py b/docs/source/conf.py index 4e7c2c86..169d232e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -28,7 +28,8 @@ "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), "matplotlib": ("https://matplotlib.org/", None), "mpi4py": ("https://mpi4py.readthedocs.io/en/stable/", None), - "pylops": ("https://pylops.readthedocs.io/en/stable/", None) + "pylops": ("https://pylops.readthedocs.io/en/stable/", None), + "mpi4py_fft": ("https://mpi4py-fft.readthedocs.io/en/stable/", None) } # Generate autodoc stubs with summaries from code diff --git a/docs/source/installation.rst b/docs/source/installation.rst index e1d7faf3..5f7e7211 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -79,6 +79,13 @@ using `pip`: Replace `12x` with your CUDA version (e.g., `11x` for CUDA 11.x). +Install fft dependencies (optional) +=================================== +Similarly, to use the FFT classes with distributed arrays, install PyLops-MPI with the ``fft`` extra: + +.. code-block:: bash + + >> pip install pylops-mpi[fft] .. _UserInstall: @@ -94,6 +101,12 @@ command in your terminal to install the PyPI distribution: Note that when installing via `pip`, only *required* dependencies are installed. +Alternatively, optional dependencies can be installed individually via: + +.. code-block:: bash + + >> pip install pylops-mpi[all] + .. _DevInstall: @@ -169,6 +182,12 @@ Otherwise, you can change the command in `Makefile` to an appropriate CUDA versi i.e., If you use CUDA 11.x, change ``cupy-cuda12x`` and ``nvidia-nccl-cu12`` to ``cupy-cuda11x`` and ``nvidia-nccl-cu11`` and run the command. +If you want to be able to use FFT classes with distributed arrays, run: + +.. code-block:: bash + + >> make dev-install_fft + Run tests ========= To ensure that everything has been setup correctly, run tests: diff --git a/environment-dev.yml b/environment-dev.yml index 2a3f961f..101151e8 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -21,5 +21,6 @@ dependencies: - nbsphinx - pydata-sphinx-theme - flake8 + - mpi4py-fft - pip: - sphinx-gallery diff --git a/examples/plot_ffts.py b/examples/plot_ffts.py new file mode 100644 index 00000000..eb8df6d5 --- /dev/null +++ b/examples/plot_ffts.py @@ -0,0 +1,109 @@ +""" +Fourier Transform +================= +This example shows how to use the :py:class:`pylops_mpi.signalprocessing.MPIFFT2D` +and :py:class:`pylops_mpi.signalprocessing.MPIFFTND` operators to apply the Fourier +Transform to the model and the inverse Fourier Transform to the data. +""" + +import matplotlib.pyplot as plt +import numpy as np + +import pylops_mpi + +plt.close("all") + +############################################################################### +# We start by applying the two dimensional MPI-distributed FFT to a +# two-dimensional signal using :py:class:`pylops_mpi.signalprocessing.MPIFFT2D`. +# The input signal is a :py:class:`pylops_mpi.DistributedArray` which is +# distributed across MPI ranks before applying the transform. + +dt, dx = 0.005, 5 +nt, nx = 2**7, 2**8 +t = np.arange(nt) * dt +x = np.arange(nx) * dx +f0 = 10 + +d = np.outer(np.sin(2 * np.pi * f0 * t), np.arange(nx) + 1) +dist = pylops_mpi.DistributedArray.to_dist(x=d.ravel()) + +FFTop = pylops_mpi.signalprocessing.MPIFFT2D( + dims=(nt, nx), sampling=(dt, dx) +) + +D = FFTop * dist + +dinv = FFTop.H * D +dinv = np.real(dinv.asarray()).reshape(nt, nx) + +D_2d = D.asarray().reshape(nt, nx) + +fig, axs = plt.subplots(2, 2, figsize=(10, 6)) + +axs[0][0].imshow(d, vmin=-100, vmax=100, cmap="bwr") +axs[0][0].set_title("Signal") +axs[0][0].axis("tight") + +axs[0][1].imshow( + np.abs(np.fft.fftshift(D_2d, axes=1)[:nt // 2, :]), cmap="bwr" +) +axs[0][1].set_title("Fourier Transform") +axs[0][1].axis("tight") + +axs[1][0].imshow(dinv, vmin=-100, vmax=100, cmap="bwr") +axs[1][0].set_title("Inverted") +axs[1][0].axis("tight") + +axs[1][1].imshow(d - dinv, vmin=-100, vmax=100, cmap="bwr") +axs[1][1].set_title("Error") +axs[1][1].axis("tight") + +fig.tight_layout() + +############################################################################### +# We can also apply the three dimensional MPI-distributed FFT to a +# three-dimensional signal using :py:class:`pylops_mpi.signalprocessing.MPIFFTND`. + +dt, dx, dy = 0.005, 5, 3 +nt, nx, ny = 2**7, 2**6, 13 +t = np.arange(nt) * dt +x = np.arange(nx) * dx +y = np.arange(ny) * dy +f0 = 10 + +d = np.outer(np.sin(2 * np.pi * f0 * t), np.arange(nx) + 1) +d = np.tile(d[:, :, np.newaxis], [1, 1, ny]) +dist = pylops_mpi.DistributedArray.to_dist(x=d.ravel()) + +FFTop = pylops_mpi.signalprocessing.MPIFFTND( + dims=(nt, nx, ny), + sampling=(dt, dx, dy) +) + +D = FFTop * dist +dinv = FFTop.H * D +dinv = np.real(dinv.asarray()).reshape(nt, nx, ny) +D_3d = D.asarray().reshape(nt, nx, ny) # shape matches dims now + +fig, axs = plt.subplots(2, 2, figsize=(10, 6)) + +axs[0][0].imshow(d[:, :, ny // 2], vmin=-20, vmax=20, cmap="bwr") +axs[0][0].set_title("Signal") +axs[0][0].axis("tight") +axs[0][1].imshow( + np.abs(np.fft.fftshift(D_3d, axes=1)[:nx // 2, :, ny // 2]), + cmap="bwr" +) +axs[0][1].set_title("Fourier Transform") +axs[0][1].axis("tight") + +axs[1][0].imshow(dinv[:, :, ny // 2], vmin=-20, vmax=20, cmap="bwr") +axs[1][0].set_title("Inverted") +axs[1][0].axis("tight") + +axs[1][1].imshow(d[:, :, ny // 2] - dinv[:, :, ny // 2], vmin=-20, vmax=20, cmap="bwr") +axs[1][1].set_title("Error") +axs[1][1].axis("tight") + +fig.tight_layout() diff --git a/pylops_mpi/signalprocessing/FFT2D.py b/pylops_mpi/signalprocessing/FFT2D.py new file mode 100644 index 00000000..f6f7bb56 --- /dev/null +++ b/pylops_mpi/signalprocessing/FFT2D.py @@ -0,0 +1,163 @@ +from typing import Sequence + +from mpi4py import MPI + +from pylops.utils import DTypeLike, InputDimsLike + +from pylops_mpi.DistributedArray import DistributedArray +from pylops_mpi.signalprocessing.FFTND import MPIFFTND + + +class MPIFFT2D(MPIFFTND): + r"""Two-dimensional Fast-Fourier Transform. + + Apply two-dimensional Fast-Fourier Transform (FFT) to any pair of ``axes`` of a + multidimensional array. + + When using ``real=True``, the result of the forward is also multiplied by + :math:`\sqrt{2}` for all frequency bins except zero and Nyquist, and the input of + the adjoint is multiplied by :math:`1 / \sqrt{2}` for the same frequencies. + + For a real valued input signal, it is advised to use the flag ``real=True`` + as it stores the values of the Fourier transform of the last axis in ``axes`` at positive + frequencies only as values at negative frequencies are simply their complex conjugates. + + Parameters + ---------- + dims : :obj:`tuple` + Number of samples for each dimension + axes : :obj:`tuple`, optional + Pair of axes along which FFT2D is applied + sampling : :obj:`tuple` or :obj:`float`, optional + Sampling steps for each axis in ``axes``. When supplied a single value, it is used + for both axes. + norm : `{"none", "1/n"}`, optional + - "none": Does not scale the forward or the adjoint FFT transforms. Default is "none". + - "1/n": Scales both the forward and adjoint FFT transforms by + :math:`1/N_F`. + real : :obj:`bool`, optional + Model to which fft is applied has real numbers (``True``) or not + (``False``). Used to enforce that the output of adjoint of a real + model is real. Note that the real FFT is applied only to the first + dimension to which the FFT2D operator is applied (last element of + ``axes``) + ifftshift_before : :obj:`tuple` or :obj:`bool`, optional + Apply ifftshift (``True``) or not (``False``) to model vector (before FFT). + Consider using this option when the model vector's respective axis is symmetric + with respect to the zero value sample. This will shift the zero value sample to + coincide with the zero index sample. With such an arrangement, FFT will not + introduce a sample-dependent phase-shift when compared to the continuous Fourier + Transform. When passing a single value, the shift will the same for every direction. + Pass a tuple to specify which dimensions are shifted. + fftshift_after : :obj:`tuple` or :obj:`bool`, optional + Apply fftshift (``True``) or not (``False``) to data vector (after FFT). + Consider using this option when you require frequencies to be arranged + naturally, from negative to positive. When not applying fftshift after FFT, + frequencies are arranged from zero to largest positive, and then from negative + Nyquist to the frequency bin before zero. When passing a single value, the shift + will the same for every direction. Pass a tuple to specify which dimensions are shifted. + dtype : :obj:`str`, optional + Type of elements in input array. Note that the ``dtype`` of the operator + is the corresponding complex type even when a real type is provided. + In addition, note that the NumPy backend does not support returning ``dtype`` + different from ``complex128``. + base_comm : :obj:`mpi4py.MPI.Comm`, optional + MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``. + **kwargs_fft + Arbitrary keyword arguments to be passed to the selected fft method + + Attributes + ---------- + f1 : :obj:`numpy.ndarray` + Discrete Fourier Transform sample frequencies along ``axes[0]`` + f2 : :obj:`numpy.ndarray` + Discrete Fourier Transform sample frequencies along ``axes[1]`` + nffts : :obj:`tuple` or :obj:`int`, optional + Number of samples in Fourier Transform for each axis in ``axes``. + real : :obj:`bool` + When ``True``, uses real fast fourier transform. + rdtype : :obj:`bool` + Expected input type to the forward + cdtype : :obj:`bool` + Output type of the forward. Complex equivalent to ``rdtype``. + shape : :obj:`tuple` + Operator shape. + clinear : :obj:`bool` + Operator is complex-linear. Is false when either ``real=True`` or when + ``dtype`` is not a complex type. + fft : :obj:`mpi4py_fft.mpifft.PFFT` + Parallel FFT operator object handling the distributed transform across + MPI processes. Configured with the base communicator, dimension + decomposition, transform axes, and dtype. + + See Also + -------- + MPIFFTND: N-dimensional FFT + + Raises + ------ + ValueError + - If ``norm`` is not one of "none", or "1/n". + + Notes + ----- + The MPIFFT2D operator performs forward and adjoint passes on a + :class:`pylops_mpi.DistributedArray`, which is internally reshaped to the 2-dimensional layout + defined by ``dims``. The 2-dimensional FFT is then applied across MPI ranks using ``mpi4py_fft``'s + :class:`mpi4py_fft.mpifft.PFFT` class, with the global array decomposed via a pencil decomposition. + :class:`mpi4py_fft.pencil.Subcomm` selects the axis of distribution: ``axis=0`` by default, + shifting to ``axis=1`` if ``axes[-1] == 0`` to avoid a conflict between the transform and + decomposition axes. + + In the forward pass, :meth:`PFFT.forward` is called with ``normalize=False``, computing: + + .. math:: + D(k_y, k_x) = \mathscr{F} (d) = \iint\limits_{-\infty}^\infty d(y, x) e^{-j2\pi k_yy} + e^{-j2\pi k_xx} \,\mathrm{d}y \,\mathrm{d}x + + When ``norm="1/n"``, the result is additionally scaled by :math:`1/N_F`. + + In the adjoint pass, :meth:`PFFT.backward` is called with ``normalize=True``, so ``PFFT`` + internally divides by :math:`N_F = N_1 \cdot N_2`, computing: + + .. math:: + d(y,x) = \mathscr{F}^{-1} (D) = \frac{1}{N_F} \iint\limits_{-\infty}^\infty D(k_y, k_x) e^{j2\pi k_yy} + e^{j2\pi k_xx} \,\mathrm{d}k_y \,\mathrm{d}k_x + + When ``norm="none"``, the adjoint multiplies by :math:`N_F` to cancel this internal scaling, + returning a true unscaled adjoint. The result is then flattened back to a 1D + :class:`pylops_mpi.DistributedArray`. All inter-rank data movement is handled internally by + ``mpi4py_fft``. + """ + def __init__( + self, + dims: InputDimsLike, + axes: InputDimsLike = (0, 1), + sampling: float | Sequence[float] = 1.0, + norm: str = "none", + real: bool = False, + ifftshift_before: bool = False, + fftshift_after: bool = False, + dtype: DTypeLike = "complex128", + base_comm: MPI.Comm = MPI.COMM_WORLD + ) -> None: + # checks + if len(dims) < 2: + msg = "FFT2D requires at least two input dimensions" + raise ValueError(msg) + if len(axes) != 2: + msg = "FFT2D must be applied along exactly two dimensions" + raise ValueError(msg) + super().__init__(dims=dims, axes=axes, sampling=sampling, norm=norm, real=real, dtype=dtype, + ifftshift_before=ifftshift_before, fftshift_after=fftshift_after, base_comm=base_comm) + self.f1, self.f2 = self.fs + del self.fs + + def _matvec(self, x: DistributedArray) -> DistributedArray: + return super()._matvec(x) + + def _rmatvec(self, x: DistributedArray) -> DistributedArray: + return super()._rmatvec(x) + + def __truediv__(self, y: DistributedArray) -> DistributedArray: + return super().__truediv__(y) diff --git a/pylops_mpi/signalprocessing/FFTND.py b/pylops_mpi/signalprocessing/FFTND.py new file mode 100644 index 00000000..de4d8297 --- /dev/null +++ b/pylops_mpi/signalprocessing/FFTND.py @@ -0,0 +1,302 @@ +import warnings +from typing import Sequence + +from mpi4py import MPI +import numpy as np + +from pylops.signalprocessing._baseffts import _FFTNorms +from pylops.utils import DTypeLike, InputDimsLike, get_array_module + +from pylops_mpi.utils.decorators import reshaped +from pylops_mpi.DistributedArray import DistributedArray, Partition +from pylops_mpi.signalprocessing._baseffts import _MPIBaseFFTND +from pylops_mpi.utils import deps, fftshift_nd, ifftshift_nd + +mpi4py_fft_message = deps.mpi4py_fft_import("mpi4py_fft") + +if mpi4py_fft_message is None: + from mpi4py_fft import PFFT, newDistArray + from mpi4py_fft.pencil import Subcomm + + +class MPIFFTND(_MPIBaseFFTND): + r"""N-dimensional Fast-Fourier Transform. + + Apply N-dimensional Fast-Fourier Transform (FFT) to any n ``axes`` + of a multidimensional array. + + When using ``real=True``, the result of the forward is also multiplied by + :math:`\sqrt{2}` for all frequency bins except zero and Nyquist along the last + ``axes``, and the input of the adjoint is multiplied by + :math:`1 / \sqrt{2}` for the same frequencies. + + For a real valued input signal, it is advised to use the flag ``real=True`` + as it stores the values of the Fourier transform of the last axis in ``axes`` at positive + frequencies only as values at negative frequencies are simply their complex conjugates. + + Parameters + ---------- + dims : :obj:`tuple` + Number of samples for each dimension + axes : :obj:`tuple`, optional + Axes (or axis) along which FFTND is applied + sampling : :obj:`tuple` or :obj:`float`, optional + Sampling steps for each direction. When supplied a single value, it is used + for all directions. + norm : `{"none", "1/n"}`, optional + - "none": Does not scale the forward or the adjoint FFT transforms. Default is "none". + - "1/n": Scales both the forward and adjoint FFT transforms by + :math:`1/N_F`. + real : :obj:`bool`, optional + Model to which fft is applied has real numbers (``True``) or not + (``False``). Used to enforce that the output of adjoint of a real + model is real. Note that the real FFT is applied only to the first + dimension to which the FFTND operator is applied (last element of + ``axes``) + ifftshift_before : :obj:`tuple` or :obj:`bool`, optional + Apply ifftshift (``True``) or not (``False``) to model vector (before FFT). + Consider using this option when the model vector's respective axis is symmetric + with respect to the zero value sample. This will shift the zero value sample to + coincide with the zero index sample. With such an arrangement, FFT will not + introduce a sample-dependent phase-shift when compared to the continuous Fourier + Transform. When passing a single value, the shift will the same for every direction. + Pass a tuple to specify which dimensions are shifted. + fftshift_after : :obj:`tuple` or :obj:`bool`, optional + Apply fftshift (``True``) or not (``False``) to data vector (after FFT). + Consider using this option when you require frequencies to be arranged + naturally, from negative to positive. When not applying fftshift after FFT, + frequencies are arranged from zero to largest positive, and then from negative + Nyquist to the frequency bin before zero. When passing a single value, the shift + will the same for every direction. Pass a tuple to specify which dimensions are shifted. + dtype : :obj:`str`, optional + Type of elements in input array. Note that the ``dtype`` of the operator + is the corresponding complex type even when a real type is provided. + In addition, note that the NumPy backend does not support returning ``dtype`` + different from ``complex128``. + base_comm : :obj:`mpi4py.MPI.Comm`, optional + MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``. + + Attributes + ---------- + fs : :obj:`tuple` + Each element of the tuple corresponds to the Discrete Fourier Transform + sample frequencies along the respective direction given by ``axes``. + nffts : :obj:`tuple` or :obj:`int`, optional + Number of samples in Fourier Transform for each axis in ``axes``. + real : :obj:`bool` + When ``True``, uses real fast fourier transform + rdtype : :obj:`bool` + Expected input type to the forward + cdtype : :obj:`bool` + Output type of the forward. Complex equivalent to ``rdtype``. + shape : :obj:`tuple` + Operator shape. + clinear : :obj:`bool` + Operator is complex-linear. Is false when either ``real=True`` or when + ``dtype`` is not a complex type. + fft : :obj:`mpi4py_fft.mpifft.PFFT` + Parallel FFT operator object handling the distributed transform across + MPI processes. Configured with the base communicator, dimension + decomposition, transform axes, and dtype. + + See Also + -------- + MPIFFT2D: Two-dimensional FFT + + Raises + ------ + ValueError + - If ``norm`` is not one of "none", or "1/n". + + Notes + ----- + The MPIFFTND operator performs forward and adjoint passes on a + :class:`pylops_mpi.DistributedArray`, which is internally reshaped to the N-dimensional layout + defined by ``dims``. The N-dimensional FFT is then applied across MPI ranks using ``mpi4py_fft``'s + :class:`mpi4py_fft.mpifft.PFFT` class, with the global array decomposed via a pencil decomposition. + :class:`mpi4py_fft.pencil.Subcomm` selects the axis of distribution: ``axis=0`` by default, + shifting to ``axis=1`` if ``axes[-1] == 0`` to avoid a conflict between the transform and + decomposition axes. + + In the forward pass, :meth:`PFFT.forward` is called with ``normalize=False``, computing: + + .. math:: + D(k_1, \ldots, k_N) = \mathscr{F} (d) = + \int\limits_{-\infty}^\infty \cdots \int\limits_{-\infty}^\infty + d(x_1, \ldots, x_N) + e^{-j2\pi k_1 x_1} \cdots + e^{-j 2 \pi k_N x_N} \,\mathrm{d}x_1 \cdots \mathrm{d}x_N + + When ``norm="1/n"``, the result is additionally scaled by :math:`1/N_F`. + + In the adjoint pass, :meth:`PFFT.backward` is called with ``normalize=True``, so ``PFFT`` + internally divides by :math:`N_F = \prod_i N_i`, computing: + + .. math:: + d(x_1, \ldots, x_N) = \mathscr{F}^{-1} (D) = \frac{1}{N_F} + \int\limits_{-\infty}^\infty \cdots \int\limits_{-\infty}^\infty + D(k_1, \ldots, k_N) + e^{j2\pi k_1 x_1} \cdots + e^{j 2 \pi k_N x_N} \,\mathrm{d}k_1 \cdots \mathrm{d}k_N + + When ``norm="none"``, the adjoint multiplies by :math:`N_F` to cancel this internal scaling, + returning a true unscaled adjoint. The result is then flattened back to a 1D + :class:`pylops_mpi.DistributedArray`. All inter-rank data movement is handled internally by + ``mpi4py_fft``. + """ + def __init__( + self, + dims: InputDimsLike, + axes: InputDimsLike = (0, 1, 2), + sampling: float | Sequence[float] = 1.0, + norm: str = "none", + real: bool = False, + ifftshift_before: bool = False, + fftshift_after: bool = False, + dtype: DTypeLike = "complex128", + base_comm: MPI.Comm = MPI.COMM_WORLD + ) -> None: + super().__init__( + dims=dims, + axes=axes, + sampling=sampling, + norm=norm, + real=real, + fftshift_after=fftshift_after, + ifftshift_before=ifftshift_before, + dtype=dtype, + base_comm=base_comm + ) + if self.cdtype != np.complex128: + warnings.warn( + "numpy backend always returns complex128 dtype. " + "To respect the passed dtype, data will be cast to {self.cdtype}.", + stacklevel=2, + ) + if self.norm is _FFTNorms.NONE: + self._scale = np.prod(self.nffts) + elif self.norm is _FFTNorms.ONE_OVER_N: + self._scale = 1.0 / np.prod(self.nffts) + fft_dtype = self.rdtype if self.real else self.cdtype + subcomm_dims = np.ones(len(dims), dtype=int) + # axis=0 for the initial distribution by default, if the final axis is 0, distribute along axis 1 instead. + if axes[-1] == 0: + subcomm_dims[1] = 0 + else: + subcomm_dims[0] = 0 + subcomm = Subcomm(base_comm, subcomm_dims) + self.fft = PFFT(subcomm, self.dims, axes=self.axes, dtype=fft_dtype) + self._pfft_in_axis = next( + (i for i, s in enumerate(self.fft.pencil[False].subcomm) if s.Get_size() > 1), 0 + ) + self._pfft_out_axis = next( + (i for i, s in enumerate(self.fft.pencil[True].subcomm) if s.Get_size() > 1), 0 + ) + + @reshaped + def _matvec(self, x: DistributedArray) -> DistributedArray: + if x.engine == "cupy": + raise ValueError(f"x should be a numpy array with engine=numpy" + f"Got {x.engine} instead...") + if x.partition != Partition.SCATTER: + raise ValueError(f"x should have partition={Partition.SCATTER}" + f"Got {x.partition} instead...") + if self.ifftshift_before.any(): + x = ifftshift_nd(x, axes=self.axes[self.ifftshift_before]) + if not self.clinear: + x[:] = np.real(x.local_array) + x_dist_pfft = newDistArray(self.fft, forward_output=False) + y_dist_pfft = newDistArray(self.fft, forward_output=True) + # Redistribute input to match the input PFFT axis + x = x.redistribute(axis=self._pfft_in_axis) + x_dist_pfft[:] = x.local_array + # Perform the parallel forward FFT + self.fft.forward(x_dist_pfft, y_dist_pfft, normalize=False) + + y = DistributedArray(global_shape=self.dimsd, dtype=self.dtype, axis=self._pfft_out_axis, + base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, engine=x.engine) + y[:] = y_dist_pfft + if self.real: + self._scale_real_fft(y, inverse=False) + if self.norm is _FFTNorms.ONE_OVER_N: + y[:] *= self._scale + y[:] = y.local_array.astype(self.cdtype) + if self.fftshift_after.any(): + y = fftshift_nd(y, axes=self.axes[self.fftshift_after]) + return y + + @reshaped + def _rmatvec(self, x: DistributedArray) -> DistributedArray: + if x.engine == "cupy": + raise ValueError(f"x should be a numpy array with engine=numpy" + f"Got {x.engine} instead...") + if x.partition != Partition.SCATTER: + raise ValueError(f"x should have partition={Partition.SCATTER}, " + f"Got {x.partition} instead...") + np = get_array_module(x.local_array) + if self.fftshift_after.any(): + x = ifftshift_nd(x, axes=self.axes[self.fftshift_after]) + if self.real: + self._scale_real_fft(x, inverse=True) + # Allocate distributed arrays for input and output + y_dist_pfft = newDistArray(self.fft, forward_output=False) + x_dist_pfft = newDistArray(self.fft, forward_output=True) + # Redistribute input to match the PFFT axis + x = x.redistribute(axis=self._pfft_out_axis) + x_dist_pfft[:] = x.local_array + # Perform the parallel backward FFT + self.fft.backward(x_dist_pfft, y_dist_pfft, normalize=True) + + y = DistributedArray(global_shape=self.dims, dtype=self.dtype, axis=self._pfft_in_axis, + base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, engine=x.engine) + y[:] = y_dist_pfft + if self.norm is _FFTNorms.NONE: + y[:] *= self._scale + if self.nffts[0] > self.dims[self.axes[0]]: + y[:] = np.take(y.local_array, np.arange(self.dims[self.axes[0]]), axis=self.axes[0]) + if self.nffts[1] > self.dims[self.axes[1]]: + y[:] = np.take(y.local_array, np.arange(self.dims[self.axes[1]]), axis=self.axes[1]) + if not self.clinear: + y[:] = np.real(y.local_array) + y[:] = y.local_array.astype(self.rdtype) + if self.ifftshift_before.any(): + y = fftshift_nd(y, axes=self.axes[self.ifftshift_before]) + return y + + def _scale_real_fft(self, x: DistributedArray, inverse: bool = False) -> None: + """Apply scaling for real-valued FFTs. + + Scales the non-DC positive frequency components along the final FFT axis + by ``sqrt(2)`` in forward mode and ``1/sqrt(2)`` in inverse mode. + + When the final FFT axis is distributed across MPI ranks, only the local + portion overlapping with the global positive-frequency range is scaled. + + Parameters + ---------- + x : DistributedArray + Distributed FFT array to scale in-place. + inverse : bool, optional + Apply inverse scaling when ``True``. Default is ``False``. + """ + scale = 1 / np.sqrt(2) if inverse else np.sqrt(2) + if x.axis == self.axes[-1]: + sizes = [loc_shape[self.axes[-1]] for loc_shape in x.local_shapes] + local_start = sum(sizes[:self.base_comm.rank]) + local_stop = local_start + sizes[self.base_comm.rank] + freq_start, freq_stop = max(1, local_start), min(1 + (self.nffts[-1] - 1) // 2, local_stop) + # Local overlap with the global frequency slice [1:k] + if freq_stop > freq_start: + local_slice = [slice(None)] * x.ndim + local_slice[self.axes[-1]] = slice(freq_start - local_start, freq_stop - local_start) + x[tuple(local_slice)] *= scale + else: + # Axis is local on this rank, so direct slicing + freq_slice = [slice(None)] * x.ndim + freq_slice[self.axes[-1]] = slice(1, 1 + (self.nffts[-1] - 1) // 2) + x[tuple(freq_slice)] *= scale + + def __truediv__(self, y: DistributedArray) -> DistributedArray: + y_div = self._rmatvec(y) + y_div[:] = y_div.local_array / self._scale + return y_div diff --git a/pylops_mpi/signalprocessing/__init__.py b/pylops_mpi/signalprocessing/__init__.py index 3a0b83ab..209b64e4 100644 --- a/pylops_mpi/signalprocessing/__init__.py +++ b/pylops_mpi/signalprocessing/__init__.py @@ -8,12 +8,18 @@ A list of operators present in pylops_mpi.signalprocessing : MPIFredholm1 Fredholm integral of first kind. + MPIFFT2D Two-dimensional Fast-Fourier Transform + MPIFFTND N-dimensional Fast-Fourier Transform """ from .Fredholm1 import * +from .FFT2D import * +from .FFTND import * __all__ = [ "MPIFredholm1", + "MPIFFT2D", + "MPIFFTND", ] diff --git a/pylops_mpi/signalprocessing/_baseffts.py b/pylops_mpi/signalprocessing/_baseffts.py new file mode 100644 index 00000000..b0548a7a --- /dev/null +++ b/pylops_mpi/signalprocessing/_baseffts.py @@ -0,0 +1,169 @@ +import warnings +from typing import Sequence + +from mpi4py import MPI +import numpy as np + +from pylops.signalprocessing._baseffts import _FFTNorms +from pylops.utils import InputDimsLike, DTypeLike, get_normalize_axis_index, get_real_dtype, get_complex_dtype +from pylops.utils._internal import _value_or_sized_to_array, _raise_on_wrong_dtype, _value_or_sized_to_tuple + +from pylops_mpi.DistributedArray import DistributedArray +from pylops_mpi.LinearOperator import MPILinearOperator + + +class _MPIBaseFFTND(MPILinearOperator): + """Base class for N-dimensional fast Fourier Transform""" + + def __init__( + self, + dims: int | InputDimsLike, + axes: int | InputDimsLike | None = None, + nffts: int | InputDimsLike | None = None, + sampling: float | Sequence[float] = 1.0, + norm: str = "none", + real: bool = False, + ifftshift_before: bool = False, + fftshift_after: bool = False, + dtype: DTypeLike = "complex128", + base_comm: MPI.Comm = MPI.COMM_WORLD + ): + dims = _value_or_sized_to_array(dims) + _raise_on_wrong_dtype(dims, np.integer, "dims") + self.dims = tuple(dims) + self.ndim = len(dims) + + axes = _value_or_sized_to_array(axes) + _raise_on_wrong_dtype(axes, np.integer, "axes") + self.axes = np.array([get_normalize_axis_index()(d, self.ndim) for d in axes]) + self.naxes = len(self.axes) + if self.naxes != len(np.unique(self.axes)): + warnings.warn( + "At least one direction is repeated. This may cause unexpected results.", + stacklevel=2, + ) + + nffts = _value_or_sized_to_array(nffts, repeat=self.naxes) + if len(nffts[np.equal(nffts, None)]) > 0: # Found None(s) in nffts + nffts[np.equal(nffts, None)] = np.array( + [dims[d] for d, n in zip(axes, nffts, strict=True) if n is None] + ) + nffts = nffts.astype(np.array(dims).dtype) + _raise_on_wrong_dtype(nffts, np.integer, "nffts") + self.nffts = _value_or_sized_to_tuple( + nffts + ) # tuple is strictly needed for cupy + + sampling = _value_or_sized_to_array(sampling, repeat=self.naxes) + if np.issubdtype(sampling.dtype, np.integer): # Promote to float64 if integer + sampling = sampling.astype(np.float64) + self.sampling = sampling + _raise_on_wrong_dtype(self.sampling, np.floating, "sampling") + self.ifftshift_before = _value_or_sized_to_array( + ifftshift_before, repeat=self.naxes + ) + _raise_on_wrong_dtype(self.ifftshift_before, bool, "ifftshift_before") + + self.fftshift_after = _value_or_sized_to_array( + fftshift_after, repeat=self.naxes + ) + _raise_on_wrong_dtype(self.fftshift_after, bool, "fftshift_after") + if ( + self.naxes != len(self.nffts) + or self.naxes != len(self.sampling) + or self.naxes != len(self.ifftshift_before) + or self.naxes != len(self.fftshift_after) + ): + msg = ( + "`axes`, `nffts`, `sampling`, `ifftshift_before` and " + "`fftshift_after` must the have same number of elements. Received " + f"{self.naxes}, {len(self.nffts)}, {len(self.sampling)}, " + f"{len(self.ifftshift_before)} and {len(self.fftshift_after)}, " + "respectively." + ) + raise ValueError(msg) + + # Check if the user provided nfft smaller than n. See _BaseFFT for + # details + nfftshort = [ + nfft < dims[direction] + for direction, nfft in zip(self.axes, self.nffts, strict=True) + ] + self.doifftpad = any(nfftshort) + if self.doifftpad: + self.ifftpad = [(0, 0)] * self.ndim + for idir, (direction, nfshort) in enumerate( + zip(self.axes, nfftshort, strict=True) + ): + if nfshort: + self.ifftpad[direction] = ( + 0, + dims[direction] - self.nffts[idir], + ) + warnings.warn( + f"nffts in directions {np.where(nfftshort)[0]} have been selected to be smaller than the size of the original signal. " + "This is rarely intended behavior as the original signal will be truncated prior to applying fft, " + f"if this is the required behaviour ignore this message.", + stacklevel=2, + ) + + if norm == "none": + self.norm = _FFTNorms.NONE + elif norm.lower() == "1/n": + self.norm = _FFTNorms.ONE_OVER_N + elif norm == "backward": + msg = 'To use no scaling on the forward transform, use "none". Note that in this case, the adjoint transform will *not* have a 1/n scaling.' + raise ValueError(msg) + elif norm == "forward": + msg = 'To use 1/n scaling on the forward transform, use "1/n". Note that in this case, the adjoint transform will *also* have a 1/n scaling.' + raise ValueError(msg) + else: + msg = f"`norm`={norm} is not one of 'none' or '1/n'" + raise ValueError(msg) + + self.real = real + + fs = [ + np.fft.fftshift(np.fft.fftfreq(n, d=s)) + if fftshift + else np.fft.fftfreq(n, d=s) + for n, s, fftshift in zip( + self.nffts, self.sampling, self.fftshift_after, strict=True + ) + ] + if self.real: + fs[-1] = np.fft.rfftfreq(self.nffts[-1], d=self.sampling[-1]) + if self.fftshift_after[-1]: + warnings.warn( + "Using real=True and fftshift_after on the last direction. " + "fftshift should only be applied on directions with negative " + "and positive frequencies. When using FFTND with real=True, " + "are all directions except the last. If you wish to proceed " + "applying fftshift on a frequency axis with only positive " + "frequencies, ignore this message.", + stacklevel=2, + ) + fs[-1] = np.fft.fftshift(fs[-1]) + self.fs = tuple(fs) + dimsd = np.array(dims) + dimsd[self.axes] = self.nffts + if self.real: + dimsd[self.axes[-1]] = self.nffts[-1] // 2 + 1 + self.dimsd = dimsd + # Find types to enforce to forward and adjoint outputs. This is + # required as np.fft.fft always returns complex128 even if input is + # float32 or less. Moreover, when choosing real=True, the type of the + # adjoint output is forced to be real even if the provided dtype + # is complex. + self.rdtype = get_real_dtype(dtype) if self.real else np.dtype(dtype) + self.cdtype = get_complex_dtype(dtype) + self.clinear = False if self.real or np.issubdtype(dtype, np.floating) else True + super().__init__(dtype=self.cdtype, shape=(int(np.prod(dimsd)), int(np.prod(dims))), base_comm=base_comm) + + def _matvec(self, x: DistributedArray) -> DistributedArray: + msg = "_BaseFFT does not provide _matvec. It must be implemented separately." + raise NotImplementedError(msg) + + def _rmatvec(self, x: DistributedArray) -> DistributedArray: + msg = "_BaseFFT does not provide _rmatvec. It must be implemented separately." + raise NotImplementedError(msg) diff --git a/pylops_mpi/utils/__init__.py b/pylops_mpi/utils/__init__.py index 7f91a785..c2ff5d86 100644 --- a/pylops_mpi/utils/__init__.py +++ b/pylops_mpi/utils/__init__.py @@ -3,3 +3,4 @@ from .benchmark import * from .dottest import * from .deps import * +from .fft_helper import * diff --git a/pylops_mpi/utils/decorators.py b/pylops_mpi/utils/decorators.py index 21b16906..5988e2d7 100644 --- a/pylops_mpi/utils/decorators.py +++ b/pylops_mpi/utils/decorators.py @@ -51,8 +51,13 @@ def wrapper(self, x: DistributedArray): local_shapes = getattr(self, "local_shapes_n") global_shape = x.global_shape else: + fwd = ( + "rmat" not in f.__name__ + and f.__name__ != "div" + and f.__name__ != "__truediv__" + ) local_shapes = None - global_shape = getattr(self, "dims") + global_shape = getattr(self, "dims") if fwd else getattr(self, "dimsd", getattr(self, "dims")) arr = DistributedArray(global_shape=global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, @@ -72,7 +77,8 @@ def wrapper(self, x: DistributedArray): arr[:] = ghosted_array[index: arr_local_shapes[self.rank] + index].reshape(arr.local_shape) y: DistributedArray = f(self, arr) if len(y.global_shape) > 1: - y = y.ravel() + # Make sure y is distributed along axis=0 before applying ravel + y = y.redistribute(axis=0).ravel() return y return wrapper if func is not None: diff --git a/pylops_mpi/utils/deps.py b/pylops_mpi/utils/deps.py index eb639054..c6b4938a 100644 --- a/pylops_mpi/utils/deps.py +++ b/pylops_mpi/utils/deps.py @@ -3,7 +3,7 @@ ] import os -from importlib import util +from importlib import import_module, util from typing import Optional @@ -39,6 +39,22 @@ def nccl_import(message: Optional[str] = None) -> str: return nccl_message +def mpi4py_fft_import(message: str | None) -> str | None: + if mpi4py_fft: + try: + import_module("mpi4py_fft") # noqa: F401 + mpi4py_fft_message = None + except Exception as e: + mpi4py_fft_message = f"Failed to import mpi4py_fft (error:{e})." + else: + mpi4py_fft_message = ( + f"mpi4py_fft package not installed. In order to be able to use " + f"{message} run " + f'"pip install mpi4py_fft" or "conda install -c conda-forge mpi4py_fft".' + ) + return mpi4py_fft_message + + cuda_aware_mpi_enabled: bool = ( False if int(os.getenv("PYLOPS_MPI_CUDA_AWARE", 0)) == 0 else True ) @@ -46,3 +62,5 @@ def nccl_import(message: Optional[str] = None) -> str: nccl_enabled: bool = ( True if (nccl_import() is None and int(os.getenv("NCCL_PYLOPS_MPI", 1)) == 1) else False ) + +mpi4py_fft = util.find_spec("mpi4py_fft") is not None diff --git a/pylops_mpi/utils/fft_helper.py b/pylops_mpi/utils/fft_helper.py new file mode 100644 index 00000000..ea4ebc7a --- /dev/null +++ b/pylops_mpi/utils/fft_helper.py @@ -0,0 +1,105 @@ +all = [ + "fftshift_nd", + "ifftshift_nd" +] + +from pylops.utils import InputDimsLike, get_module + +from pylops_mpi import DistributedArray + + +def fftshift_nd(x: DistributedArray, axes: InputDimsLike = None): + """ + Shift the zero-frequency component to the center of the spectrum for a DistributedArray. + + This is the distributed equivalent of :func:`numpy.fft.fftshift`. For axes + that are local to each process, the shift is applied directly. For the + distributed axis, the array is first redistributed to a different axis so + the shift can be performed locally, then left in the redistributed state. + + .. note:: + This function only supports nd arrays (n >= 2). + + Parameters + ---------- + x : :obj: `pylops_mpi.DistributedArray` + Input array to shift. Modified in-place along each axis. + axes : tuple, optional + Axes over which to shift. Defaults to all axes if not specified. + + Returns + ------- + x : :obj: `pylops_mpi.DistributedArray` + The shifted array. May be distributed along a different axis than the + input if the original distributed axis was included in ``axes``. + """ + if x.ndim < 2: + raise ValueError( + f"fftshift_nd requires a 2D or higher array, but got ndim={x.ndim}. " + ) + ncp = get_module(x.engine) + if axes is None: + axes = tuple(range(x.ndim)) + elif ncp.isscalar(axes): + axes = (axes,) + local_axes = [ax for ax in axes if ax != x.axis] + remote_axes = [ax for ax in axes if ax == x.axis] + if local_axes: + shifts = [x.global_shape[ax] // 2 for ax in local_axes] + x[:] = ncp.roll(x.local_array, shift=shifts, axis=local_axes) + if remote_axes: + new_axis = 1 if x.axis == 0 else 0 + # Redistribute to a new axis for computation + x = x.redistribute(axis=new_axis) + shifts = [x.global_shape[ax] // 2 for ax in remote_axes] + x[:] = ncp.roll(x.local_array, shift=shifts, axis=remote_axes) + return x + + +def ifftshift_nd(x: DistributedArray, axes: InputDimsLike = None): + """ + Shift the zero-frequency component back to the beginning of the spectrum for a DistributedArray. + + This is the distributed equivalent of :func:`numpy.fft.ifftshift``. + Shifts are applied in the negative direction (i.e. ``-(n // 2)`` per axis) to undo + a prior :func:`pylops_mpi.utils.fftshift`. For axes that are local to each process, the shift is applied directly. + For the distributed axis, the array is first redistributed to a different + axis so the shift can be performed locally. + + .. note:: + This function only supports nd arrays (n >= 2). + + Parameters + ---------- + x : :obj: `pylops_mpi.DistributedArray` + Input array to shift. Modified in-place along each axis. + axes : int or sequence of int, optional + Axes over which to shift. Defaults to all axes if not specified. + + Returns + ------- + x : :obj: `pylops_mpi.DistributedArray` + The shifted array. May be distributed along a different axis than the + input if the original distributed axis was included in ``axes``. + """ + if x.ndim < 2: + raise ValueError( + f"ifftshift_nd requires a 2D or higher array, but got ndim={x.ndim}. " + ) + ncp = get_module(x.engine) + if axes is None: + axes = tuple(range(x.ndim)) + elif ncp.isscalar(axes): + axes = (axes,) + local_axes = [ax for ax in axes if ax != x.axis] + dist_axes = [ax for ax in axes if ax == x.axis] + if local_axes: + shifts = [-(x.global_shape[ax] // 2) for ax in local_axes] + x[:] = ncp.roll(x.local_array, shift=shifts, axis=local_axes) + if dist_axes: + new_axis = 1 if x.axis == 0 else 0 + # Redistribute to a new axis for computation + x = x.redistribute(axis=new_axis) + shifts = [-(x.global_shape[ax] // 2) for ax in dist_axes] + x[:] = ncp.roll(x.local_array, shift=shifts, axis=dist_axes) + return x diff --git a/pyproject.toml b/pyproject.toml index d70d32c4..e2df251d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,3 +42,7 @@ exclude = ["pytests"] [tool.setuptools_scm] version_file = "pylops_mpi/version.py" + +[project.optional-dependencies] +fft = ["mpi4py-fft"] +all = ["mpi4py-fft"] diff --git a/tests/test_ffts.py b/tests/test_ffts.py new file mode 100644 index 00000000..2b39313b --- /dev/null +++ b/tests/test_ffts.py @@ -0,0 +1,170 @@ +"""Test FFT classes + Designed to run with n processes + $ mpiexec -n 10 pytest test_ffts.py --with-mpi +""" +import os +import pytest + +if int(os.environ.get("TEST_CUPY_PYLOPS", 0)): + import cupy as np + from cupy.testing import assert_array_almost_equal + backend = "cupy" +else: + import numpy as np + from numpy.testing import assert_array_almost_equal + backend = "numpy" + +from mpi4py import MPI + +from pylops.signalprocessing import FFT2D, FFTND + +from pylops_mpi.signalprocessing import MPIFFT2D, MPIFFTND +from pylops_mpi.DistributedArray import DistributedArray + +par1 = { + "dims": (41, 51), + "axes": (0, 1), + "real": False, + "dtype": np.complex128, + "imag": 1j, + "norm": "none" +} +par2 = { + "dims": (50, 50), + "axes": (0, 1), + "real": False, + "dtype": np.complex128, + "imag": 1j, + "norm": "1/n" +} +par3 = { + "dims": (41, 51), + "axes": (0, 1), + "real": True, + "dtype": np.float64, + "imag": 0, + "norm": "1/n" +} +par4 = { + "dims": (50, 50), + "axes": (0, 1), + "real": True, + "dtype": np.float64, + "imag": 0, + "norm": "none" +} +par5 = { + "dims": (41, 51, 50), + "axes": (0, 1, 2), + "real": True, + "dtype": np.float64, + "imag": 0, + "norm": "none" +} +par6 = { + "dims": (41, 51, 50), + "axes": (0, 2, 1), + "real": True, + "dtype": np.float64, + "imag": 0, + "norm": "1/n" +} +par7 = { + "dims": (41, 51, 50), + "axes": (2, 1, 0), + "real": False, + "dtype": np.complex128, + "imag": 1j, + "norm": "none" +} +par8 = { + "dims": (41, 51, 50), + "axes": (2, 0, 1), + "real": False, + "dtype": np.complex128, + "imag": 1j, + "norm": "1/n" +} + +rank = MPI.COMM_WORLD.Get_rank() + + +@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4)]) +@pytest.mark.parametrize( + "ifftshift_before, fftshift_after", + [ + (False, False), + (True, False), + (False, True), + (True, True), + ], +) +def test_FFT2d(par, ifftshift_before, fftshift_after): + """MPIFFT2D Operator""" + if backend == "cupy": + pytest.skip("Skipping cupy backend") + np.random.seed(10) + ff2d_mpi = MPIFFT2D(dims=par['dims'], axes=par['axes'], norm=par['norm'], real=par['real'], + ifftshift_before=ifftshift_before, fftshift_after=fftshift_after, + dtype=par['dtype']) + x = DistributedArray(global_shape=ff2d_mpi.shape[1], dtype=par['dtype'], engine=backend) + x[:] = np.random.randn(*(x.local_shape)) + par['imag'] * np.random.randn(*(x.local_shape)) + x_global = x.asarray() + # Forward + y_dist = ff2d_mpi @ x + y = y_dist.asarray() + # Adjoint + y_adj_dist = ff2d_mpi.H @ y_dist + y_adj = y_adj_dist.asarray() + if rank == 0: + fft2d = FFT2D(dims=par['dims'], axes=par['axes'], norm=par['norm'], real=par['real'], + ifftshift_before=ifftshift_before, fftshift_after=fftshift_after, + dtype=par['dtype']) + assert ff2d_mpi.shape == fft2d.shape + y_np = fft2d @ x_global + y_adj_np = fft2d.H @ y_np + assert_array_almost_equal(y, y_np, decimal=7) + assert_array_almost_equal(y_adj, y_adj_np, decimal=7) + + +@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6), (par7), (par8)]) +@pytest.mark.parametrize( + "ifftshift_before, fftshift_after", + [ + (False, False), + (True, False), + (False, True), + (True, True), + ], +) +def test_FFTND(par, ifftshift_before, fftshift_after): + """MPIFFTND Operator""" + if backend == "cupy": + pytest.skip("Skipping cupy backend") + np.random.seed(10) + ffnd_mpi = MPIFFTND(dims=par['dims'], axes=par['axes'], norm=par['norm'], real=par['real'], + ifftshift_before=ifftshift_before, fftshift_after=fftshift_after, + dtype=par['dtype']) + x = DistributedArray(global_shape=ffnd_mpi.shape[1], dtype=par['dtype'], engine=backend) + x[:] = np.random.randn(*(x.local_shape)) + par['imag'] * np.random.randn(*(x.local_shape)) + x_global = x.asarray() + # Forward + y_dist = ffnd_mpi @ x + y = y_dist.asarray() + # Adjoint + y_adj_dist = ffnd_mpi.H @ y_dist + y_adj = y_adj_dist.asarray() + # Div + y_div_dist = ffnd_mpi / y_dist + y_div = y_div_dist.asarray() + if rank == 0: + fftnd = FFTND(dims=par['dims'], axes=par['axes'], norm=par['norm'], real=par['real'], + ifftshift_before=ifftshift_before, fftshift_after=fftshift_after, + dtype=par['dtype']) + assert ffnd_mpi.shape == fftnd.shape + y_np = fftnd @ x_global + y_adj_np = fftnd.H @ y_np + y_div_np = fftnd / y_np + assert_array_almost_equal(y, y_np, decimal=7) + assert_array_almost_equal(y_adj, y_adj_np, decimal=7) + assert_array_almost_equal(y_div, y_div_np, decimal=7)