From 1ebf46d97314393e5164dacf62aeab3c24f4dbda Mon Sep 17 00:00:00 2001 From: rohanbabbar04 Date: Sat, 9 May 2026 00:18:01 +0530 Subject: [PATCH 01/12] Implement MPIFFT2D and MPIFFTND --- environment-dev.yml | 1 + pylops_mpi/signalprocessing/FFT2D.py | 139 +++++++++++++ pylops_mpi/signalprocessing/FFTND.py | 251 +++++++++++++++++++++++ pylops_mpi/signalprocessing/__init__.py | 6 + pylops_mpi/signalprocessing/_baseffts.py | 137 +++++++++++++ pylops_mpi/utils/decorators.py | 10 +- pylops_mpi/utils/deps.py | 20 +- requirements-dev.txt | 1 + tests/test_ffts.py | 139 +++++++++++++ 9 files changed, 701 insertions(+), 3 deletions(-) create mode 100644 pylops_mpi/signalprocessing/FFT2D.py create mode 100644 pylops_mpi/signalprocessing/FFTND.py create mode 100644 pylops_mpi/signalprocessing/_baseffts.py create mode 100644 tests/test_ffts.py diff --git a/environment-dev.yml b/environment-dev.yml index 2a3f961f..101151e8 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -21,5 +21,6 @@ dependencies: - nbsphinx - pydata-sphinx-theme - flake8 + - mpi4py-fft - pip: - sphinx-gallery diff --git a/pylops_mpi/signalprocessing/FFT2D.py b/pylops_mpi/signalprocessing/FFT2D.py new file mode 100644 index 00000000..5613fdb8 --- /dev/null +++ b/pylops_mpi/signalprocessing/FFT2D.py @@ -0,0 +1,139 @@ +from typing import Sequence + +from mpi4py import MPI + +from pylops.utils import DTypeLike, InputDimsLike + +from pylops_mpi.DistributedArray import DistributedArray +from pylops_mpi.signalprocessing.FFTND import MPIFFTND + + +class MPIFFT2D(MPIFFTND): + r"""Two-dimensional Fast-Fourier Transform. + + Apply two-dimensional Fast-Fourier Transform (FFT) to any pair of ``axes`` of a + multidimensional array. + + When using ``real=True``, the result of the forward is also multiplied by + :math:`\sqrt{2}` for all frequency bins except zero and Nyquist, and the input of + the adjoint is multiplied by :math:`1 / \sqrt{2}` for the same frequencies. + + For a real valued input signal, it is advised to use the flag ``real=True`` + as it stores the values of the Fourier transform of the last axis in ``axes`` at positive + frequencies only as values at negative frequencies are simply their complex conjugates. + + Parameters + ---------- + dims : :obj:`tuple` + Number of samples for each dimension + axes : :obj:`tuple`, optional + Pair of axes along which FFT2D is applied + sampling : :obj:`tuple` or :obj:`float`, optional + Sampling steps for each axis in ``axes``. When supplied a single value, it is used + for both axes. + norm : `{"none", "1/n"}`, optional + - "none": Does not scale the forward or the adjoint FFT transforms. Default is "none". + - "1/n": Scales both the forward and adjoint FFT transforms by + :math:`1/N_F`. + real : :obj:`bool`, optional + Model to which fft is applied has real numbers (``True``) or not + (``False``). Used to enforce that the output of adjoint of a real + model is real. Note that the real FFT is applied only to the first + dimension to which the FFT2D operator is applied (last element of + ``axes``) + dtype : :obj:`str`, optional + Type of elements in input array. Note that the ``dtype`` of the operator + is the corresponding complex type even when a real type is provided. + In addition, note that the NumPy backend does not support returning ``dtype`` + different from ``complex128``. + base_comm : :obj:`mpi4py.MPI.Comm`, optional + MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``. + **kwargs_fft + Arbitrary keyword arguments to be passed to the selected fft method + + Attributes + ---------- + f1 : :obj:`numpy.ndarray` + Discrete Fourier Transform sample frequencies along ``axes[0]`` + f2 : :obj:`numpy.ndarray` + Discrete Fourier Transform sample frequencies along ``axes[1]`` + nffts : :obj:`tuple` or :obj:`int`, optional + Number of samples in Fourier Transform for each axis in ``axes``. + real : :obj:`bool` + When ``True``, uses real fast fourier transform. + rdtype : :obj:`bool` + Expected input type to the forward + cdtype : :obj:`bool` + Output type of the forward. Complex equivalent to ``rdtype``. + shape : :obj:`tuple` + Operator shape. + clinear : :obj:`bool` + Operator is complex-linear. Is false when either ``real=True`` or when + ``dtype`` is not a complex type. + fft : :obj:`mpi4py_fft.PFFT` + Parallel FFT operator object handling the distributed transform across + MPI processes. Configured with the base communicator, dimension + decomposition, transform axes, and dtype. + + See Also + -------- + MPIFFTND: N-dimensional FFT + + Raises + ------ + ValueError + - If ``norm`` is not one of "none", or "1/n". + + Notes + ----- + The FFT2D operator (using ``norm="none"``) applies the two-dimensional forward + Fourier transform to a signal :math:`d(y, x)` in forward mode: + + .. math:: + D(k_y, k_x) = \mathscr{F} (d) = \iint\limits_{-\infty}^\infty d(y, x) e^{-j2\pi k_yy} + e^{-j2\pi k_xx} \,\mathrm{d}y \,\mathrm{d}x + + Similarly, the two-dimensional inverse Fourier transform is applied to + the Fourier spectrum :math:`D(k_y, k_x)` in adjoint mode: + + .. math:: + d(y,x) = \mathscr{F}^{-1} (D) = \frac{1}{N_F} \iint\limits_{-\infty}^\infty D(k_y, k_x) e^{j2\pi k_yy} + e^{j2\pi k_xx} \,\mathrm{d}k_y \,\mathrm{d}k_x + + where :math:`N_F` is the number of samples in the Fourier domain given by the + product of the elements of ``nffts``. + + Both operators are effectively discretized and solved by a fast iterative + algorithm known as Fast Fourier Transform. Note that when using ``norm="none"``, + the adjoint is **not** the inverse of the forward mode; instead, the inverse + requires an explicit :math:`1/N_F` scaling factor (applied in the adjoint/inverse). + + """ + def __init__( + self, + dims: InputDimsLike, + axes: InputDimsLike = (0, 1), + sampling: float | Sequence[float] = 1.0, + norm: str = "none", + real: bool = False, + dtype: DTypeLike = "complex128", + base_comm: MPI.Comm = MPI.COMM_WORLD, + **kwargs_fft, + ) -> None: + # checks + if len(dims) < 2: + msg = "FFT2D requires at least two input dimensions" + raise ValueError(msg) + if len(axes) != 2: + msg = "FFT2D must be applied along exactly two dimensions" + raise ValueError(msg) + super().__init__(dims=dims, axes=axes, sampling=sampling, norm=norm, real=real, + dtype=dtype, base_comm=base_comm, **kwargs_fft) + self.f1, self.f2 = self.fs + del self.fs + + def _matvec(self, x: DistributedArray) -> DistributedArray: + return super()._matvec(x) + + def _rmatvec(self, x: DistributedArray) -> DistributedArray: + return super()._rmatvec(x) diff --git a/pylops_mpi/signalprocessing/FFTND.py b/pylops_mpi/signalprocessing/FFTND.py new file mode 100644 index 00000000..f3159ec5 --- /dev/null +++ b/pylops_mpi/signalprocessing/FFTND.py @@ -0,0 +1,251 @@ +import warnings +from typing import Sequence + +from mpi4py import MPI +import numpy as np + +from pylops.signalprocessing._baseffts import _FFTNorms +from pylops.utils import DTypeLike, InputDimsLike, get_array_module + +from pylops_mpi.utils.decorators import reshaped +from pylops_mpi.DistributedArray import DistributedArray, Partition +from pylops_mpi.signalprocessing._baseffts import _MPIBaseFFTND +from pylops_mpi.utils import deps + +mpi4py_fft_message = deps.mpi4py_fft_import("mpi4py_fft") + +if mpi4py_fft_message is None: + from mpi4py_fft import PFFT, newDistArray + from mpi4py_fft.pencil import Subcomm + + +class MPIFFTND(_MPIBaseFFTND): + r"""N-dimensional Fast-Fourier Transform. + + Apply N-dimensional Fast-Fourier Transform (FFT) to any n ``axes`` + of a multidimensional array. + + When using ``real=True``, the result of the forward is also multiplied by + :math:`\sqrt{2}` for all frequency bins except zero and Nyquist along the last + ``axes``, and the input of the adjoint is multiplied by + :math:`1 / \sqrt{2}` for the same frequencies. + + For a real valued input signal, it is advised to use the flag ``real=True`` + as it stores the values of the Fourier transform of the last axis in ``axes`` at positive + frequencies only as values at negative frequencies are simply their complex conjugates. + + Parameters + ---------- + dims : :obj:`tuple` + Number of samples for each dimension + axes : :obj:`tuple`, optional + Axes (or axis) along which FFTND is applied + sampling : :obj:`tuple` or :obj:`float`, optional + Sampling steps for each direction. When supplied a single value, it is used + for all directions. + norm : `{"none", "1/n"}`, optional + - "none": Does not scale the forward or the adjoint FFT transforms. Default is "none". + - "1/n": Scales both the forward and adjoint FFT transforms by + :math:`1/N_F`. + real : :obj:`bool`, optional + Model to which fft is applied has real numbers (``True``) or not + (``False``). Used to enforce that the output of adjoint of a real + model is real. Note that the real FFT is applied only to the first + dimension to which the FFTND operator is applied (last element of + ``axes``) + dtype : :obj:`str`, optional + Type of elements in input array. Note that the ``dtype`` of the operator + is the corresponding complex type even when a real type is provided. + In addition, note that the NumPy backend does not support returning ``dtype`` + different from ``complex128``. + base_comm : :obj:`mpi4py.MPI.Comm`, optional + MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``. + **kwargs_fft + Arbitrary keyword arguments to be passed to the selected fft method + + Attributes + ---------- + fs : :obj:`tuple` + Each element of the tuple corresponds to the Discrete Fourier Transform + sample frequencies along the respective direction given by ``axes``. + nffts : :obj:`tuple` or :obj:`int`, optional + Number of samples in Fourier Transform for each axis in ``axes``. + real : :obj:`bool` + When ``True``, uses real fast fourier transform + rdtype : :obj:`bool` + Expected input type to the forward + cdtype : :obj:`bool` + Output type of the forward. Complex equivalent to ``rdtype``. + shape : :obj:`tuple` + Operator shape. + clinear : :obj:`bool` + Operator is complex-linear. Is false when either ``real=True`` or when + ``dtype`` is not a complex type. + fft : :obj:`mpi4py_fft.PFFT` + Parallel FFT operator object handling the distributed transform across + MPI processes. Configured with the base communicator, dimension + decomposition, transform axes, and dtype. + + See Also + -------- + MPIFFT2D: Two-dimensional FFT + + Raises + ------ + ValueError + - If ``norm`` is not one of "none", or "1/n". + + Notes + ----- + The MPIFFTND operator (using ``norm="none"``) applies the N-dimensional forward + Fourier transform to a multidimensional array. Considering an N-dimensional + signal :math:`d(x_1, \ldots, x_N)`. The MPIFFTND in forward mode is: + + .. math:: + D(k_1, \ldots, k_N) = \mathscr{F} (d) = + \int\limits_{-\infty}^\infty \cdots \int\limits_{-\infty}^\infty + d(x_1, \ldots, x_N) + e^{-j2\pi k_1 x_1} \cdots + e^{-j 2 \pi k_N x_N} \,\mathrm{d}x_1 \cdots \mathrm{d}x_N + + Similarly, the N-dimensional inverse Fourier transform is applied to + the Fourier spectrum :math:`D(k_1, \ldots, k_N)` in adjoint mode: + + .. math:: + d(x_1, \ldots, x_N) = \mathscr{F}^{-1} (D) = \frac{1}{N_F} + \int\limits_{-\infty}^\infty \cdots \int\limits_{-\infty}^\infty + D(k_1, \ldots, k_N) + e^{j2\pi k_1 x_1} \cdots + e^{j 2 \pi k_N x_N} \,\mathrm{d}k_1 \cdots \mathrm{d}k_N + + where :math:`N_F` is the number of samples in the Fourier domain given by the + product of the elements of ``nffts``. + + Both operators are effectively discretized and solved by a fast iterative + algorithm known as Fast Fourier Transform. Note that when using ``norm="none"``, + the adjoint is **not** the inverse of the forward mode; instead, the inverse + requires an explicit :math:`1/N_F` scaling factor (applied in the adjoint/inverse). + + """ + def __init__( + self, + dims: InputDimsLike, + axes: InputDimsLike = (0, 1, 2), + sampling: float | Sequence[float] = 1.0, + norm: str = "none", + real: bool = False, + dtype: DTypeLike = "complex128", + base_comm: MPI.Comm = MPI.COMM_WORLD, + **kwargs_fft, + ) -> None: + super().__init__( + dims=dims, + axes=axes, + sampling=sampling, + norm=norm, + real=real, + dtype=dtype, + base_comm=base_comm + ) + if self.cdtype != np.complex128: + warnings.warn( + "numpy backend always returns complex128 dtype. " + "To respect the passed dtype, data will be cast to {self.cdtype}.", + stacklevel=2, + ) + + self._kwargs_fft = kwargs_fft + if self.norm is _FFTNorms.NONE: + self._scale = np.prod(self.nffts) + elif self.norm is _FFTNorms.ONE_OVER_N: + self._scale = 1.0 / np.prod(self.nffts) + fft_dtype = self.rdtype if self.real else self.cdtype + # Distribute only along axes[0]; all other axes are non-distributed (0=distributed, 1=not) + subcomm_dims = (np.arange(len(axes)) != axes[0]).astype(int) + subcomm = Subcomm(self.base_comm, dims=np.resize(subcomm_dims, len(dims))) + self.fft = PFFT(subcomm, self.dims, axes=self.axes, dtype=fft_dtype, collapse=True, **self._kwargs_fft) + + @reshaped + def _matvec(self, x: DistributedArray) -> DistributedArray: + if x.engine == "cupy": + raise ValueError(f"x should be a numpy array with engine=numpy" + f"Got {x.engine} instead...") + if x.partition != Partition.SCATTER: + raise ValueError(f"x should have partition={Partition.SCATTER}" + f"Got {x.partition} instead...") + ncp = get_array_module(x.local_array) + if not self.clinear: + x[:] = ncp.real(x.local_array) + # Allocate distributed arrays for input and output + u_dist = newDistArray(self.fft, forward_output=False) + u_hat = newDistArray(self.fft, forward_output=True) + # Redistribute input to match the axis decomposed by PFFT + x = x.redistribute(axis=self.axes[0]) + u_dist[:] = x.local_array + # Perform the parallel forward FFT + self.fft.forward(u_dist, u_hat, normalize=False) + + # Axis along which PFFT decomposes the output array across MPI processes + dist_axis = [i for i, s in enumerate(u_hat.subcomm) if s.Get_size() > 1][0] + y = DistributedArray(global_shape=self.dimsd, dtype=self.dtype, axis=dist_axis, + base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, engine=x.engine) + y[:] = u_hat + if self.real: + # Redistribute so that self.axes[-1] is not the one sliced + safe_axis = next(i for i in range(len(self.dims)) if i != self.axes[-1]) + y = y.redistribute(axis=safe_axis) + y_local = y.local_array + # Apply scaling to obtain a correct adjoint for this operator + y_local = ncp.swapaxes(y_local, -1, self.axes[-1]) + y_local[..., 1: 1 + (self.nffts[-1] - 1) // 2] *= ncp.sqrt(2) + y_local = ncp.swapaxes(y_local, self.axes[-1], -1) + y[:] = y_local + if self.norm is _FFTNorms.ONE_OVER_N: + y[:] *= self._scale + y[:] = y.local_array.astype(self.cdtype) + return y + + @reshaped + def _rmatvec(self, x: DistributedArray) -> DistributedArray: + if x.engine == "cupy": + raise ValueError(f"x should be a numpy array with engine=numpy" + f"Got {x.engine} instead...") + if x.partition != Partition.SCATTER: + raise ValueError(f"x should have partition={Partition.SCATTER}, " + f"Got {x.partition} instead...") + ncp = get_array_module(x.local_array) + if self.real: + # Redistribute so that self.axes[-1] is not the one sliced + safe_axis = next(i for i in range(len(self.dims)) if i != self.axes[-1]) + x = x.redistribute(axis=safe_axis) + # Apply scaling to obtain a correct adjoint for this operator + x_local = x.local_array + x_local = ncp.swapaxes(x_local, -1, self.axes[-1]) + x_local[..., 1: 1 + (self.nffts[-1] - 1) // 2] /= ncp.sqrt(2) + x_local = ncp.swapaxes(x_local, self.axes[-1], -1) + x[:] = x_local + # Allocate distributed arrays for input and output + u_dist = newDistArray(self.fft, forward_output=False) + u_hat = newDistArray(self.fft, forward_output=True) + # Redistribute input to match the axis decomposed by PFFT + x_axis = [i for i, s in enumerate(u_hat.subcomm) if s.Get_size() > 1][0] + x = x.redistribute(axis=x_axis) + u_hat[:] = x.local_array + # Perform the parallel backward FFT + self.fft.backward(u_hat, u_dist, normalize=True) + + # Axis along which PFFT decomposes the output array across MPI processes + dist_axis = [i for i, s in enumerate(u_dist.subcomm) if s.Get_size() > 1][0] + y = DistributedArray(global_shape=self.dims, dtype=self.dtype, axis=dist_axis, + base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, engine=x.engine) + y[:] = u_dist + if self.norm is _FFTNorms.NONE: + y[:] *= self._scale + if self.nffts[0] > self.dims[self.axes[0]]: + y[:] = ncp.take(y.local_array, ncp.arange(self.dims[self.axes[0]]), axis=self.axes[0]) + if self.nffts[1] > self.dims[self.axes[1]]: + y[:] = ncp.take(y.local_array, ncp.arange(self.dims[self.axes[1]]), axis=self.axes[1]) + if not self.clinear: + y[:] = ncp.real(y.local_array) + y[:] = y.local_array.astype(self.rdtype) + return y diff --git a/pylops_mpi/signalprocessing/__init__.py b/pylops_mpi/signalprocessing/__init__.py index 3a0b83ab..209b64e4 100644 --- a/pylops_mpi/signalprocessing/__init__.py +++ b/pylops_mpi/signalprocessing/__init__.py @@ -8,12 +8,18 @@ A list of operators present in pylops_mpi.signalprocessing : MPIFredholm1 Fredholm integral of first kind. + MPIFFT2D Two-dimensional Fast-Fourier Transform + MPIFFTND N-dimensional Fast-Fourier Transform """ from .Fredholm1 import * +from .FFT2D import * +from .FFTND import * __all__ = [ "MPIFredholm1", + "MPIFFT2D", + "MPIFFTND", ] diff --git a/pylops_mpi/signalprocessing/_baseffts.py b/pylops_mpi/signalprocessing/_baseffts.py new file mode 100644 index 00000000..78a67bd2 --- /dev/null +++ b/pylops_mpi/signalprocessing/_baseffts.py @@ -0,0 +1,137 @@ +import warnings +from typing import Sequence + +from mpi4py import MPI +import numpy as np + +from pylops.signalprocessing._baseffts import _FFTNorms +from pylops.utils import InputDimsLike, DTypeLike, get_normalize_axis_index, get_real_dtype, get_complex_dtype +from pylops.utils._internal import _value_or_sized_to_array, _raise_on_wrong_dtype, _value_or_sized_to_tuple + +from pylops_mpi.DistributedArray import DistributedArray +from pylops_mpi.LinearOperator import MPILinearOperator + + +class _MPIBaseFFTND(MPILinearOperator): + """Base class for N-dimensional fast Fourier Transform""" + + def __init__( + self, + dims: int | InputDimsLike, + axes: int | InputDimsLike | None = None, + nffts: int | InputDimsLike | None = None, + sampling: float | Sequence[float] = 1.0, + norm: str = "none", + real: bool = False, + dtype: DTypeLike = "complex128", + base_comm: MPI.Comm = MPI.COMM_WORLD + ): + dims = _value_or_sized_to_array(dims) + _raise_on_wrong_dtype(dims, np.integer, "dims") + self.dims = tuple(dims) + self.ndim = len(dims) + + axes = _value_or_sized_to_array(axes) + _raise_on_wrong_dtype(axes, np.integer, "axes") + self.axes = np.array([get_normalize_axis_index()(d, self.ndim) for d in axes]) + self.naxes = len(self.axes) + if self.naxes != len(np.unique(self.axes)): + warnings.warn( + "At least one direction is repeated. This may cause unexpected results.", + stacklevel=2, + ) + + nffts = _value_or_sized_to_array(nffts, repeat=self.naxes) + if len(nffts[np.equal(nffts, None)]) > 0: # Found None(s) in nffts + nffts[np.equal(nffts, None)] = np.array( + [dims[d] for d, n in zip(axes, nffts, strict=True) if n is None] + ) + nffts = nffts.astype(np.array(dims).dtype) + _raise_on_wrong_dtype(nffts, np.integer, "nffts") + self.nffts = _value_or_sized_to_tuple( + nffts + ) # tuple is strictly needed for cupy + + sampling = _value_or_sized_to_array(sampling, repeat=self.naxes) + if np.issubdtype(sampling.dtype, np.integer): # Promote to float64 if integer + sampling = sampling.astype(np.float64) + self.sampling = sampling + _raise_on_wrong_dtype(self.sampling, np.floating, "sampling") + + if ( + self.naxes != len(self.nffts) + or self.naxes != len(self.sampling) + ): + msg = ( + "`axes`, `nffts`, `sampling` must the have same number of elements. Received " + f"{self.naxes}, {len(self.nffts)}, {len(self.sampling)}, " + "respectively." + ) + raise ValueError(msg) + + # Check if the user provided nfft smaller than n. See _BaseFFT for + # details + nfftshort = [ + nfft < dims[direction] + for direction, nfft in zip(self.axes, self.nffts, strict=True) + ] + self.doifftpad = any(nfftshort) + if self.doifftpad: + self.ifftpad = [(0, 0)] * self.ndim + for idir, (direction, nfshort) in enumerate( + zip(self.axes, nfftshort, strict=True) + ): + if nfshort: + self.ifftpad[direction] = ( + 0, + dims[direction] - self.nffts[idir], + ) + warnings.warn( + f"nffts in directions {np.where(nfftshort)[0]} have been selected to be smaller than the size of the original signal. " + "This is rarely intended behavior as the original signal will be truncated prior to applying fft, " + f"if this is the required behaviour ignore this message.", + stacklevel=2, + ) + + if norm == "none": + self.norm = _FFTNorms.NONE + elif norm.lower() == "1/n": + self.norm = _FFTNorms.ONE_OVER_N + elif norm == "backward": + msg = 'To use no scaling on the forward transform, use "none". Note that in this case, the adjoint transform will *not* have a 1/n scaling.' + raise ValueError(msg) + elif norm == "forward": + msg = 'To use 1/n scaling on the forward transform, use "1/n". Note that in this case, the adjoint transform will *also* have a 1/n scaling.' + raise ValueError(msg) + else: + msg = f"`norm`={norm} is not one of 'none' or '1/n'" + raise ValueError(msg) + + self.real = real + + fs = [np.fft.fftfreq(n, d=s) for n, s in zip(self.nffts, self.sampling, strict=True)] + if self.real: + fs[-1] = np.fft.rfftfreq(self.nffts[-1], d=self.sampling[-1]) + self.fs = tuple(fs) + dimsd = np.array(dims) + dimsd[self.axes] = self.nffts + if self.real: + dimsd[self.axes[-1]] = self.nffts[-1] // 2 + 1 + self.dimsd = dimsd + # Find types to enforce to forward and adjoint outputs. This is + # required as np.fft.fft always returns complex128 even if input is + # float32 or less. Moreover, when choosing real=True, the type of the + # adjoint output is forced to be real even if the provided dtype + # is complex. + self.rdtype = get_real_dtype(dtype) if self.real else np.dtype(dtype) + self.cdtype = get_complex_dtype(dtype) + self.clinear = False if self.real or np.issubdtype(dtype, np.floating) else True + super().__init__(dtype=self.cdtype, shape=(int(np.prod(dimsd)), int(np.prod(dims))), base_comm=base_comm) + + def _matvec(self, x: DistributedArray) -> DistributedArray: + msg = "_BaseFFT does not provide _matvec. It must be implemented separately." + raise NotImplementedError(msg) + + def _rmatvec(self, x: DistributedArray) -> DistributedArray: + msg = "_BaseFFT does not provide _rmatvec. It must be implemented separately." + raise NotImplementedError(msg) diff --git a/pylops_mpi/utils/decorators.py b/pylops_mpi/utils/decorators.py index 21b16906..f216dc6d 100644 --- a/pylops_mpi/utils/decorators.py +++ b/pylops_mpi/utils/decorators.py @@ -51,8 +51,13 @@ def wrapper(self, x: DistributedArray): local_shapes = getattr(self, "local_shapes_n") global_shape = x.global_shape else: + fwd = ( + "rmat" not in f.__name__ + and f.__name__ != "div" + and f.__name__ != "__truediv__" + ) local_shapes = None - global_shape = getattr(self, "dims") + global_shape = getattr(self, "dims") if fwd else getattr(self, "dimsd") arr = DistributedArray(global_shape=global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, @@ -72,7 +77,8 @@ def wrapper(self, x: DistributedArray): arr[:] = ghosted_array[index: arr_local_shapes[self.rank] + index].reshape(arr.local_shape) y: DistributedArray = f(self, arr) if len(y.global_shape) > 1: - y = y.ravel() + # Make sure y is distributed along axis=0 before applying ravel + y = y.redistribute(axis=0).ravel() return y return wrapper if func is not None: diff --git a/pylops_mpi/utils/deps.py b/pylops_mpi/utils/deps.py index eb639054..c6b4938a 100644 --- a/pylops_mpi/utils/deps.py +++ b/pylops_mpi/utils/deps.py @@ -3,7 +3,7 @@ ] import os -from importlib import util +from importlib import import_module, util from typing import Optional @@ -39,6 +39,22 @@ def nccl_import(message: Optional[str] = None) -> str: return nccl_message +def mpi4py_fft_import(message: str | None) -> str | None: + if mpi4py_fft: + try: + import_module("mpi4py_fft") # noqa: F401 + mpi4py_fft_message = None + except Exception as e: + mpi4py_fft_message = f"Failed to import mpi4py_fft (error:{e})." + else: + mpi4py_fft_message = ( + f"mpi4py_fft package not installed. In order to be able to use " + f"{message} run " + f'"pip install mpi4py_fft" or "conda install -c conda-forge mpi4py_fft".' + ) + return mpi4py_fft_message + + cuda_aware_mpi_enabled: bool = ( False if int(os.getenv("PYLOPS_MPI_CUDA_AWARE", 0)) == 0 else True ) @@ -46,3 +62,5 @@ def nccl_import(message: Optional[str] = None) -> str: nccl_enabled: bool = ( True if (nccl_import() is None and int(os.getenv("NCCL_PYLOPS_MPI", 1)) == 1) else False ) + +mpi4py_fft = util.find_spec("mpi4py_fft") is not None diff --git a/requirements-dev.txt b/requirements-dev.txt index 24d12f34..af56d1fb 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -12,3 +12,4 @@ numba nbsphinx sphinx_gallery flake8 +mpi4py-fft \ No newline at end of file diff --git a/tests/test_ffts.py b/tests/test_ffts.py new file mode 100644 index 00000000..975a87af --- /dev/null +++ b/tests/test_ffts.py @@ -0,0 +1,139 @@ +"""Test FFT classes + Designed to run with n processes + $ mpiexec -n 10 pytest test_ffts.py --with-mpi +""" +import os +import pytest + +if int(os.environ.get("TEST_CUPY_PYLOPS", 0)): + import cupy as np + backend = "cupy" +else: + import numpy as np + + backend = "numpy" + +from mpi4py import MPI +from numpy.testing import assert_allclose + +from pylops.signalprocessing import FFT2D, FFTND + +from pylops_mpi.signalprocessing import MPIFFT2D, MPIFFTND +from pylops_mpi.DistributedArray import DistributedArray + +par1 = { + "dims": (41, 51), + "axes": (0, 1), + "real": False, + "dtype": np.complex128, + "imag": 1j, + "norm": "none" +} +par2 = { + "dims": (50, 50), + "axes": (0, 1), + "real": False, + "dtype": np.complex128, + "imag": 1j, + "norm": "1/n" +} +par3 = { + "dims": (41, 51), + "axes": (0, 1), + "real": True, + "dtype": np.float64, + "imag": 0, + "norm": "1/n" +} +par4 = { + "dims": (50, 50), + "axes": (0, 1), + "real": True, + "dtype": np.float64, + "imag": 0, + "norm": "none" +} +par5 = { + "dims": (41, 51, 50), + "axes": (0, 1, 2), + "real": True, + "dtype": np.float64, + "imag": 0, + "norm": "none" +} +par6 = { + "dims": (41, 51, 50), + "axes": (0, 2, 1), + "real": True, + "dtype": np.float64, + "imag": 0, + "norm": "1/n" +} +par7 = { + "dims": (41, 51, 50), + "axes": (2, 1, 0), + "real": False, + "dtype": np.complex128, + "imag": 1j, + "norm": "none" +} +par8 = { + "dims": (41, 51, 50), + "axes": (2, 0, 1), + "real": False, + "dtype": np.complex128, + "imag": 1j, + "norm": "1/n" +} + +rank = MPI.COMM_WORLD.Get_rank() + + +@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4)]) +def test_FFT2d(par): + """MPIFFT2D Operator""" + if backend == "cupy": + pytest.skip("Skipping cupy backend") + np.random.seed(10) + ff2d_mpi = MPIFFT2D(dims=par['dims'], axes=par['axes'], norm=par['norm'], real=par['real'], dtype=par['dtype']) + x = DistributedArray(global_shape=ff2d_mpi.shape[1], dtype=par['dtype']) + x[:] = np.random.randn(*(x.local_shape)) + par['imag'] * np.random.randn(*(x.local_shape)) + x_global = x.asarray() + # Forward + y_dist = ff2d_mpi.matvec(x) + y = y_dist.asarray() + # Adjoint + y_adj_dist = ff2d_mpi.rmatvec(y_dist) + y_adj = y_adj_dist.asarray() + if rank == 0: + fft2d = FFT2D(dims=par['dims'], axes=par['axes'], norm=par['norm'], real=par['real'], dtype=par['dtype']) + assert ff2d_mpi.shape == fft2d.shape + y_np = fft2d.matvec(x_global) + y_adj_np = fft2d.rmatvec(y_np) + assert_allclose(y, y_np, rtol=1e-5, atol=1e-8) + assert_allclose(y_adj, y_adj_np, rtol=1e-5, atol=1e-8) + + +@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6), (par7), (par8)]) +def test_FFTND(par): + """MPIFFTND Operator""" + if backend == "cupy": + pytest.skip("Skipping cupy backend") + np.random.seed(10) + ffnd_mpi = MPIFFTND(dims=par['dims'], axes=par['axes'], norm=par['norm'], real=par['real'], dtype=par['dtype']) + x = DistributedArray(global_shape=ffnd_mpi.shape[1], dtype=par['dtype']) + x[:] = np.random.randn(*(x.local_shape)) + par['imag'] * np.random.randn(*(x.local_shape)) + x_global = x.asarray() + # Forward + y_dist = ffnd_mpi.matvec(x) + y = y_dist.asarray() + # Adjoint + y_adj_dist = ffnd_mpi.rmatvec(y_dist) + y_adj = y_adj_dist.asarray() + if rank == 0: + fftnd = FFTND(dims=par['dims'], axes=par['axes'], norm=par['norm'], real=par['real'], dtype=par['dtype']) + assert ffnd_mpi.shape == fftnd.shape + y_np = fftnd.matvec(x_global) + y_adj_np = fftnd.rmatvec(y_np) + assert_allclose(y, y_np, rtol=1e-5, atol=1e-8) + assert_allclose(y_adj, y_adj_np, rtol=1e-5, atol=1e-8) From 074489ceffe84029936a1ecd4ac81e04a00a9a27 Mon Sep 17 00:00:00 2001 From: rohanbabbar04 Date: Sat, 9 May 2026 10:29:28 +0530 Subject: [PATCH 02/12] Update GA to include fftw libraries --- .github/workflows/build.yml | 13 +++++++++++++ requirements-dev.txt | 1 - requirements-fft.txt | 1 + 3 files changed, 14 insertions(+), 1 deletion(-) create mode 100644 requirements-fft.txt diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1b46d876..78e0c8e0 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -36,10 +36,23 @@ jobs: uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + - name: Install fftw libraries + run: | + case $(uname) in + Linux) + sudo apt update + sudo apt install -y -q libfftw3-dev + ;; + Darwin) + export HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK=1 + brew install fftw + ;; + esac - name: Installing Dependencies run: | python -m pip install --upgrade pip setuptools if [ -f requirements.txt ]; then pip install -r requirements-dev.txt; fi + if [ -f requirements-fft.txt ]; then pip install -r requirements-fft.txt; fi - name: Install pylops-mpi run: pip install . - name: Testing using pytest-mpi diff --git a/requirements-dev.txt b/requirements-dev.txt index af56d1fb..24d12f34 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -12,4 +12,3 @@ numba nbsphinx sphinx_gallery flake8 -mpi4py-fft \ No newline at end of file diff --git a/requirements-fft.txt b/requirements-fft.txt new file mode 100644 index 00000000..ca5fb8be --- /dev/null +++ b/requirements-fft.txt @@ -0,0 +1 @@ +mpi4py-fft From a7743ca1aec61546cee58bf0f1919bd34061dc3d Mon Sep 17 00:00:00 2001 From: rohanbabbar04 Date: Sat, 9 May 2026 10:40:14 +0530 Subject: [PATCH 03/12] Fallback to dims if dimsd does not exist --- pylops_mpi/utils/decorators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pylops_mpi/utils/decorators.py b/pylops_mpi/utils/decorators.py index f216dc6d..5988e2d7 100644 --- a/pylops_mpi/utils/decorators.py +++ b/pylops_mpi/utils/decorators.py @@ -57,7 +57,7 @@ def wrapper(self, x: DistributedArray): and f.__name__ != "__truediv__" ) local_shapes = None - global_shape = getattr(self, "dims") if fwd else getattr(self, "dimsd") + global_shape = getattr(self, "dims") if fwd else getattr(self, "dimsd", getattr(self, "dims")) arr = DistributedArray(global_shape=global_shape, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, From 13d35a8a6e56353ae098cf725945e11428e505bf Mon Sep 17 00:00:00 2001 From: rohanbabbar04 Date: Sat, 9 May 2026 13:49:58 +0530 Subject: [PATCH 04/12] Add __truediv__ --- pylops_mpi/signalprocessing/FFT2D.py | 3 +++ pylops_mpi/signalprocessing/FFTND.py | 5 +++++ tests/test_ffts.py | 21 +++++++++++++-------- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/pylops_mpi/signalprocessing/FFT2D.py b/pylops_mpi/signalprocessing/FFT2D.py index 5613fdb8..f08e5bae 100644 --- a/pylops_mpi/signalprocessing/FFT2D.py +++ b/pylops_mpi/signalprocessing/FFT2D.py @@ -137,3 +137,6 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: def _rmatvec(self, x: DistributedArray) -> DistributedArray: return super()._rmatvec(x) + + def __truediv__(self, y: DistributedArray) -> DistributedArray: + return super().__truediv__(y) diff --git a/pylops_mpi/signalprocessing/FFTND.py b/pylops_mpi/signalprocessing/FFTND.py index f3159ec5..b81cbc45 100644 --- a/pylops_mpi/signalprocessing/FFTND.py +++ b/pylops_mpi/signalprocessing/FFTND.py @@ -249,3 +249,8 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: y[:] = ncp.real(y.local_array) y[:] = y.local_array.astype(self.rdtype) return y + + def __truediv__(self, y: DistributedArray) -> DistributedArray: + y_div = self._rmatvec(y) + y_div[:] = y_div.local_array / self._scale + return y_div diff --git a/tests/test_ffts.py b/tests/test_ffts.py index 975a87af..f5e6facf 100644 --- a/tests/test_ffts.py +++ b/tests/test_ffts.py @@ -100,16 +100,16 @@ def test_FFT2d(par): x[:] = np.random.randn(*(x.local_shape)) + par['imag'] * np.random.randn(*(x.local_shape)) x_global = x.asarray() # Forward - y_dist = ff2d_mpi.matvec(x) + y_dist = ff2d_mpi @ x y = y_dist.asarray() # Adjoint - y_adj_dist = ff2d_mpi.rmatvec(y_dist) + y_adj_dist = ff2d_mpi.H @ y_dist y_adj = y_adj_dist.asarray() if rank == 0: fft2d = FFT2D(dims=par['dims'], axes=par['axes'], norm=par['norm'], real=par['real'], dtype=par['dtype']) assert ff2d_mpi.shape == fft2d.shape - y_np = fft2d.matvec(x_global) - y_adj_np = fft2d.rmatvec(y_np) + y_np = fft2d @ x_global + y_adj_np = fft2d.H @ y_np assert_allclose(y, y_np, rtol=1e-5, atol=1e-8) assert_allclose(y_adj, y_adj_np, rtol=1e-5, atol=1e-8) @@ -125,15 +125,20 @@ def test_FFTND(par): x[:] = np.random.randn(*(x.local_shape)) + par['imag'] * np.random.randn(*(x.local_shape)) x_global = x.asarray() # Forward - y_dist = ffnd_mpi.matvec(x) + y_dist = ffnd_mpi @ x y = y_dist.asarray() # Adjoint - y_adj_dist = ffnd_mpi.rmatvec(y_dist) + y_adj_dist = ffnd_mpi.H @ y_dist y_adj = y_adj_dist.asarray() + # Div + y_div_dist = ffnd_mpi / y_dist + y_div = y_div_dist.asarray() if rank == 0: fftnd = FFTND(dims=par['dims'], axes=par['axes'], norm=par['norm'], real=par['real'], dtype=par['dtype']) assert ffnd_mpi.shape == fftnd.shape - y_np = fftnd.matvec(x_global) - y_adj_np = fftnd.rmatvec(y_np) + y_np = fftnd @ x_global + y_adj_np = fftnd.H @ y_np + y_div_np = fftnd / y_np assert_allclose(y, y_np, rtol=1e-5, atol=1e-8) assert_allclose(y_adj, y_adj_np, rtol=1e-5, atol=1e-8) + assert_allclose(y_div, y_div_np, rtol=1e-5, atol=1e-8) From de364da5301087f40c2c131b402c78c2a0f1fad8 Mon Sep 17 00:00:00 2001 From: rohanbabbar04 Date: Mon, 11 May 2026 13:40:32 +0530 Subject: [PATCH 05/12] Add example and fftshifts --- docs/source/api/index.rst | 2 +- examples/plot_ffts.py | 111 +++++++++++++++++++++++ pylops_mpi/signalprocessing/FFT2D.py | 22 ++++- pylops_mpi/signalprocessing/FFTND.py | 41 +++++++-- pylops_mpi/signalprocessing/_baseffts.py | 38 +++++++- pylops_mpi/utils/__init__.py | 1 + pylops_mpi/utils/fft_helper.py | 91 +++++++++++++++++++ tests/test_ffts.py | 38 ++++++-- 8 files changed, 325 insertions(+), 19 deletions(-) create mode 100644 examples/plot_ffts.py create mode 100644 pylops_mpi/utils/fft_helper.py diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 73d41469..9913547c 100644 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -107,7 +107,7 @@ Basic cgls Sparsity -~~~~~ +~~~~~~~~ .. currentmodule:: pylops_mpi.optimization.cls_sparsity diff --git a/examples/plot_ffts.py b/examples/plot_ffts.py new file mode 100644 index 00000000..06b697c9 --- /dev/null +++ b/examples/plot_ffts.py @@ -0,0 +1,111 @@ +""" +Fourier Transform +================= +This example shows how to use the :py:class:`pylops_mpi.signalprocessing.MPIFFT2D` +and :py:class:`pylops_mpi.signalprocessing.MPIFFTND` operators to apply the Fourier +Transform to the model and the inverse Fourier Transform to the data. +""" + +import matplotlib.pyplot as plt +import numpy as np + +import pylops_mpi + +plt.close("all") + +############################################################################### +# We start by applying the two dimensional MPI-distributed FFT to a +# two-dimensional signal using :py:class:`pylops_mpi.signalprocessing.MPIFFT2D`. +# The input signal is a :py:class:`pylops_mpi.DistributedArray` which is +# distributed across MPI ranks before applying the transform. + +dt, dx = 0.005, 5 +nt, nx = 2**7, 2**8 +t = np.arange(nt) * dt +x = np.arange(nx) * dx +f0 = 10 + +d = np.outer(np.sin(2 * np.pi * f0 * t), np.arange(nx) + 1) +dist = pylops_mpi.DistributedArray.to_dist(x=d.ravel()) + +FFTop = pylops_mpi.signalprocessing.MPIFFT2D( + dims=(nt, nx), sampling=(dt, dx) +) + +D = FFTop * dist + +dinv = FFTop.H * D +dinv = FFTop / D +dinv = np.real(dinv.asarray()).reshape(nt, nx) + +D_2d = D.asarray().reshape(nt, nx) + +fig, axs = plt.subplots(2, 2, figsize=(10, 6)) + +axs[0][0].imshow(d, vmin=-100, vmax=100, cmap="bwr") +axs[0][0].set_title("Signal") +axs[0][0].axis("tight") + +axs[0][1].imshow( + np.abs(np.fft.fftshift(D_2d, axes=1)[:nt // 2, :]), cmap="bwr" +) +axs[0][1].set_title("Fourier Transform") +axs[0][1].axis("tight") + +axs[1][0].imshow(dinv, vmin=-100, vmax=100, cmap="bwr") +axs[1][0].set_title("Inverted") +axs[1][0].axis("tight") + +axs[1][1].imshow(d - dinv, vmin=-100, vmax=100, cmap="bwr") +axs[1][1].set_title("Error") +axs[1][1].axis("tight") + +fig.tight_layout() + +############################################################################### +# We can also apply the three dimensional MPI-distributed FFT to a +# three-dimensional signal using :py:class:`pylops_mpi.signalprocessing.MPIFFTND`. + +dt, dx, dy = 0.005, 5, 3 +nt, nx, ny = 2**7, 2**6, 13 +t = np.arange(nt) * dt +x = np.arange(nx) * dx +y = np.arange(ny) * dy +f0 = 10 + +d = np.outer(np.sin(2 * np.pi * f0 * t), np.arange(nx) + 1) +d = np.tile(d[:, :, np.newaxis], [1, 1, ny]) +dist = pylops_mpi.DistributedArray.to_dist(x=d.ravel()) + +FFTop = pylops_mpi.signalprocessing.MPIFFTND( + dims=(nt, nx, ny), + sampling=(dt, dx, dy) +) + +D = FFTop * dist +dinv = FFTop.H * D +dinv = FFTop / D +dinv = np.real(dinv.asarray()).reshape(nt, nx, ny) +D_3d = D.asarray().reshape(nt, nx, ny) # shape matches dims now + +fig, axs = plt.subplots(2, 2, figsize=(10, 6)) + +axs[0][0].imshow(d[:, :, ny // 2], vmin=-20, vmax=20, cmap="bwr") +axs[0][0].set_title("Signal") +axs[0][0].axis("tight") +axs[0][1].imshow( + np.abs(np.fft.fftshift(D_3d, axes=1)[:nx // 2, :, ny // 2]), + cmap="bwr" +) +axs[0][1].set_title("Fourier Transform") +axs[0][1].axis("tight") + +axs[1][0].imshow(dinv[:, :, ny // 2], vmin=-20, vmax=20, cmap="bwr") +axs[1][0].set_title("Inverted") +axs[1][0].axis("tight") + +axs[1][1].imshow(d[:, :, ny // 2] - dinv[:, :, ny // 2], vmin=-20, vmax=20, cmap="bwr") +axs[1][1].set_title("Error") +axs[1][1].axis("tight") + +fig.tight_layout() diff --git a/pylops_mpi/signalprocessing/FFT2D.py b/pylops_mpi/signalprocessing/FFT2D.py index f08e5bae..0fdb36b2 100644 --- a/pylops_mpi/signalprocessing/FFT2D.py +++ b/pylops_mpi/signalprocessing/FFT2D.py @@ -41,6 +41,21 @@ class MPIFFT2D(MPIFFTND): model is real. Note that the real FFT is applied only to the first dimension to which the FFT2D operator is applied (last element of ``axes``) + ifftshift_before : :obj:`tuple` or :obj:`bool`, optional + Apply ifftshift (``True``) or not (``False``) to model vector (before FFT). + Consider using this option when the model vector's respective axis is symmetric + with respect to the zero value sample. This will shift the zero value sample to + coincide with the zero index sample. With such an arrangement, FFT will not + introduce a sample-dependent phase-shift when compared to the continuous Fourier + Transform. When passing a single value, the shift will the same for every direction. + Pass a tuple to specify which dimensions are shifted. + fftshift_after : :obj:`tuple` or :obj:`bool`, optional + Apply fftshift (``True``) or not (``False``) to data vector (after FFT). + Consider using this option when you require frequencies to be arranged + naturally, from negative to positive. When not applying fftshift after FFT, + frequencies are arranged from zero to largest positive, and then from negative + Nyquist to the frequency bin before zero. When passing a single value, the shift + will the same for every direction. Pass a tuple to specify which dimensions are shifted. dtype : :obj:`str`, optional Type of elements in input array. Note that the ``dtype`` of the operator is the corresponding complex type even when a real type is provided. @@ -116,6 +131,8 @@ def __init__( sampling: float | Sequence[float] = 1.0, norm: str = "none", real: bool = False, + ifftshift_before: bool = False, + fftshift_after: bool = False, dtype: DTypeLike = "complex128", base_comm: MPI.Comm = MPI.COMM_WORLD, **kwargs_fft, @@ -127,8 +144,9 @@ def __init__( if len(axes) != 2: msg = "FFT2D must be applied along exactly two dimensions" raise ValueError(msg) - super().__init__(dims=dims, axes=axes, sampling=sampling, norm=norm, real=real, - dtype=dtype, base_comm=base_comm, **kwargs_fft) + super().__init__(dims=dims, axes=axes, sampling=sampling, norm=norm, real=real, dtype=dtype, + ifftshift_before=ifftshift_before, fftshift_after=fftshift_after, base_comm=base_comm, + **kwargs_fft) self.f1, self.f2 = self.fs del self.fs diff --git a/pylops_mpi/signalprocessing/FFTND.py b/pylops_mpi/signalprocessing/FFTND.py index b81cbc45..4d36eaa3 100644 --- a/pylops_mpi/signalprocessing/FFTND.py +++ b/pylops_mpi/signalprocessing/FFTND.py @@ -10,7 +10,7 @@ from pylops_mpi.utils.decorators import reshaped from pylops_mpi.DistributedArray import DistributedArray, Partition from pylops_mpi.signalprocessing._baseffts import _MPIBaseFFTND -from pylops_mpi.utils import deps +from pylops_mpi.utils import deps, fftshift, ifftshift mpi4py_fft_message = deps.mpi4py_fft_import("mpi4py_fft") @@ -53,6 +53,21 @@ class MPIFFTND(_MPIBaseFFTND): model is real. Note that the real FFT is applied only to the first dimension to which the FFTND operator is applied (last element of ``axes``) + ifftshift_before : :obj:`tuple` or :obj:`bool`, optional + Apply ifftshift (``True``) or not (``False``) to model vector (before FFT). + Consider using this option when the model vector's respective axis is symmetric + with respect to the zero value sample. This will shift the zero value sample to + coincide with the zero index sample. With such an arrangement, FFT will not + introduce a sample-dependent phase-shift when compared to the continuous Fourier + Transform. When passing a single value, the shift will the same for every direction. + Pass a tuple to specify which dimensions are shifted. + fftshift_after : :obj:`tuple` or :obj:`bool`, optional + Apply fftshift (``True``) or not (``False``) to data vector (after FFT). + Consider using this option when you require frequencies to be arranged + naturally, from negative to positive. When not applying fftshift after FFT, + frequencies are arranged from zero to largest positive, and then from negative + Nyquist to the frequency bin before zero. When passing a single value, the shift + will the same for every direction. Pass a tuple to specify which dimensions are shifted. dtype : :obj:`str`, optional Type of elements in input array. Note that the ``dtype`` of the operator is the corresponding complex type even when a real type is provided. @@ -134,6 +149,8 @@ def __init__( sampling: float | Sequence[float] = 1.0, norm: str = "none", real: bool = False, + ifftshift_before: bool = False, + fftshift_after: bool = False, dtype: DTypeLike = "complex128", base_comm: MPI.Comm = MPI.COMM_WORLD, **kwargs_fft, @@ -144,6 +161,8 @@ def __init__( sampling=sampling, norm=norm, real=real, + fftshift_after=fftshift_after, + ifftshift_before=ifftshift_before, dtype=dtype, base_comm=base_comm ) @@ -174,6 +193,8 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: raise ValueError(f"x should have partition={Partition.SCATTER}" f"Got {x.partition} instead...") ncp = get_array_module(x.local_array) + if self.ifftshift_before.any(): + x = ifftshift(x, axes=self.axes[self.ifftshift_before]) if not self.clinear: x[:] = ncp.real(x.local_array) # Allocate distributed arrays for input and output @@ -186,8 +207,8 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: self.fft.forward(u_dist, u_hat, normalize=False) # Axis along which PFFT decomposes the output array across MPI processes - dist_axis = [i for i, s in enumerate(u_hat.subcomm) if s.Get_size() > 1][0] - y = DistributedArray(global_shape=self.dimsd, dtype=self.dtype, axis=dist_axis, + dist_axis = [i for i, s in enumerate(u_hat.subcomm) if s.Get_size() > 1] + y = DistributedArray(global_shape=self.dimsd, dtype=self.dtype, axis=dist_axis[0] if dist_axis else 0, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, engine=x.engine) y[:] = u_hat if self.real: @@ -203,6 +224,8 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: if self.norm is _FFTNorms.ONE_OVER_N: y[:] *= self._scale y[:] = y.local_array.astype(self.cdtype) + if self.fftshift_after.any(): + y = fftshift(y, axes=self.axes[self.fftshift_after]) return y @reshaped @@ -214,6 +237,8 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: raise ValueError(f"x should have partition={Partition.SCATTER}, " f"Got {x.partition} instead...") ncp = get_array_module(x.local_array) + if self.fftshift_after.any(): + x = ifftshift(x, axes=self.axes[self.fftshift_after]) if self.real: # Redistribute so that self.axes[-1] is not the one sliced safe_axis = next(i for i in range(len(self.dims)) if i != self.axes[-1]) @@ -228,15 +253,15 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: u_dist = newDistArray(self.fft, forward_output=False) u_hat = newDistArray(self.fft, forward_output=True) # Redistribute input to match the axis decomposed by PFFT - x_axis = [i for i, s in enumerate(u_hat.subcomm) if s.Get_size() > 1][0] - x = x.redistribute(axis=x_axis) + x_axis = [i for i, s in enumerate(u_hat.subcomm) if s.Get_size() > 1] + x = x.redistribute(axis=x_axis[0] if x_axis else 0) u_hat[:] = x.local_array # Perform the parallel backward FFT self.fft.backward(u_hat, u_dist, normalize=True) # Axis along which PFFT decomposes the output array across MPI processes - dist_axis = [i for i, s in enumerate(u_dist.subcomm) if s.Get_size() > 1][0] - y = DistributedArray(global_shape=self.dims, dtype=self.dtype, axis=dist_axis, + dist_axis = [i for i, s in enumerate(u_dist.subcomm) if s.Get_size() > 1] + y = DistributedArray(global_shape=self.dims, dtype=self.dtype, axis=dist_axis[0] if dist_axis else 0, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, engine=x.engine) y[:] = u_dist if self.norm is _FFTNorms.NONE: @@ -248,6 +273,8 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: if not self.clinear: y[:] = ncp.real(y.local_array) y[:] = y.local_array.astype(self.rdtype) + if self.ifftshift_before.any(): + y = fftshift(y, axes=self.axes[self.ifftshift_before]) return y def __truediv__(self, y: DistributedArray) -> DistributedArray: diff --git a/pylops_mpi/signalprocessing/_baseffts.py b/pylops_mpi/signalprocessing/_baseffts.py index 78a67bd2..b0548a7a 100644 --- a/pylops_mpi/signalprocessing/_baseffts.py +++ b/pylops_mpi/signalprocessing/_baseffts.py @@ -23,6 +23,8 @@ def __init__( sampling: float | Sequence[float] = 1.0, norm: str = "none", real: bool = False, + ifftshift_before: bool = False, + fftshift_after: bool = False, dtype: DTypeLike = "complex128", base_comm: MPI.Comm = MPI.COMM_WORLD ): @@ -57,14 +59,26 @@ def __init__( sampling = sampling.astype(np.float64) self.sampling = sampling _raise_on_wrong_dtype(self.sampling, np.floating, "sampling") - + self.ifftshift_before = _value_or_sized_to_array( + ifftshift_before, repeat=self.naxes + ) + _raise_on_wrong_dtype(self.ifftshift_before, bool, "ifftshift_before") + + self.fftshift_after = _value_or_sized_to_array( + fftshift_after, repeat=self.naxes + ) + _raise_on_wrong_dtype(self.fftshift_after, bool, "fftshift_after") if ( self.naxes != len(self.nffts) or self.naxes != len(self.sampling) + or self.naxes != len(self.ifftshift_before) + or self.naxes != len(self.fftshift_after) ): msg = ( - "`axes`, `nffts`, `sampling` must the have same number of elements. Received " + "`axes`, `nffts`, `sampling`, `ifftshift_before` and " + "`fftshift_after` must the have same number of elements. Received " f"{self.naxes}, {len(self.nffts)}, {len(self.sampling)}, " + f"{len(self.ifftshift_before)} and {len(self.fftshift_after)}, " "respectively." ) raise ValueError(msg) @@ -109,9 +123,27 @@ def __init__( self.real = real - fs = [np.fft.fftfreq(n, d=s) for n, s in zip(self.nffts, self.sampling, strict=True)] + fs = [ + np.fft.fftshift(np.fft.fftfreq(n, d=s)) + if fftshift + else np.fft.fftfreq(n, d=s) + for n, s, fftshift in zip( + self.nffts, self.sampling, self.fftshift_after, strict=True + ) + ] if self.real: fs[-1] = np.fft.rfftfreq(self.nffts[-1], d=self.sampling[-1]) + if self.fftshift_after[-1]: + warnings.warn( + "Using real=True and fftshift_after on the last direction. " + "fftshift should only be applied on directions with negative " + "and positive frequencies. When using FFTND with real=True, " + "are all directions except the last. If you wish to proceed " + "applying fftshift on a frequency axis with only positive " + "frequencies, ignore this message.", + stacklevel=2, + ) + fs[-1] = np.fft.fftshift(fs[-1]) self.fs = tuple(fs) dimsd = np.array(dims) dimsd[self.axes] = self.nffts diff --git a/pylops_mpi/utils/__init__.py b/pylops_mpi/utils/__init__.py index 7f91a785..c2ff5d86 100644 --- a/pylops_mpi/utils/__init__.py +++ b/pylops_mpi/utils/__init__.py @@ -3,3 +3,4 @@ from .benchmark import * from .dottest import * from .deps import * +from .fft_helper import * diff --git a/pylops_mpi/utils/fft_helper.py b/pylops_mpi/utils/fft_helper.py new file mode 100644 index 00000000..9e5a60e4 --- /dev/null +++ b/pylops_mpi/utils/fft_helper.py @@ -0,0 +1,91 @@ +all = [ + "fftshift", + "ifftshift" +] + +import numpy as np + +from pylops.utils import InputDimsLike + +from pylops_mpi import DistributedArray + + +def fftshift(x: DistributedArray, axes: InputDimsLike = None): + """ + Shift the zero-frequency component to the center of the spectrum for a DistributedArray. + + This is the distributed equivalent of :func:`numpy.fft.fftshift`. For axes + that are local to each process, the shift is applied directly. For the + distributed axis, the array is first redistributed to a different axis so + the shift can be performed locally, then left in the redistributed state. + + Parameters + ---------- + x : :obj: `pylops_mpi.DistributedArray` + Input array to shift. Modified in-place along each axis. + axes : tuple, optional + Axes over which to shift. Defaults to all axes if not specified. + + Returns + ------- + x : :obj: `pylops_mpi.DistributedArray` + The shifted array. May be distributed along a different axis than the + input if the original distributed axis was included in ``axes``. + """ + if axes is None: + axes = tuple(range(x.ndim)) + elif np.isscalar(axes): + axes = (axes,) + local_axes = [ax for ax in axes if ax != x.axis] + remote_axes = [ax for ax in axes if ax == x.axis] + if local_axes: + shifts = [x.global_shape[ax] // 2 for ax in local_axes] + x[:] = np.roll(x.local_array, shift=shifts, axis=local_axes) + if remote_axes: + new_axis = 1 if x.axis == 0 else 0 + # Redistribute to a new axis for computation + x = x.redistribute(axis=new_axis) + shifts = [x.global_shape[ax] // 2 for ax in remote_axes] + x[:] = np.roll(x.local_array, shift=shifts, axis=remote_axes) + return x + + +def ifftshift(x: DistributedArray, axes: InputDimsLike = None): + """ + Shift the zero-frequency component back to the beginning of the spectrum for a DistributedArray. + + This is the distributed equivalent of :func:`numpy.fft.ifftshift``. + Shifts are applied in the negative direction (i.e. ``-(n // 2)`` per axis) to undo + a prior :func:`pylops_mpi.utils.fftshift`. For axes that are local to each process, the shift is applied directly. + For the distributed axis, the array is first redistributed to a different + axis so the shift can be performed locally. + + Parameters + ---------- + x : :obj: `pylops_mpi.DistributedArray` + Input array to shift. Modified in-place along each axis. + axes : int or sequence of int, optional + Axes over which to shift. Defaults to all axes if not specified. + + Returns + ------- + x : :obj: `pylops_mpi.DistributedArray` + The shifted array. May be distributed along a different axis than the + input if the original distributed axis was included in ``axes``. + """ + if axes is None: + axes = tuple(range(x.ndim)) + elif np.isscalar(axes): + axes = (axes,) + local_axes = [ax for ax in axes if ax != x.axis] + dist_axes = [ax for ax in axes if ax == x.axis] + if local_axes: + shifts = [-(x.global_shape[ax] // 2) for ax in local_axes] + x[:] = np.roll(x.local_array, shift=shifts, axis=local_axes) + if dist_axes: + new_axis = 1 if x.axis == 0 else 0 + # Redistribute to a new axis for computation + x = x.redistribute(axis=new_axis) + shifts = [-(x.global_shape[ax] // 2) for ax in dist_axes] + x[:] = np.roll(x.local_array, shift=shifts, axis=dist_axes) + return x diff --git a/tests/test_ffts.py b/tests/test_ffts.py index f5e6facf..794186f7 100644 --- a/tests/test_ffts.py +++ b/tests/test_ffts.py @@ -90,12 +90,23 @@ @pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4)]) -def test_FFT2d(par): +@pytest.mark.parametrize( + "ifftshift_before, fftshift_after", + [ + (False, False), + (True, False), + (False, True), + (True, True), + ], +) +def test_FFT2d(par, ifftshift_before, fftshift_after): """MPIFFT2D Operator""" if backend == "cupy": pytest.skip("Skipping cupy backend") np.random.seed(10) - ff2d_mpi = MPIFFT2D(dims=par['dims'], axes=par['axes'], norm=par['norm'], real=par['real'], dtype=par['dtype']) + ff2d_mpi = MPIFFT2D(dims=par['dims'], axes=par['axes'], norm=par['norm'], real=par['real'], + ifftshift_before=ifftshift_before, fftshift_after=fftshift_after, + dtype=par['dtype']) x = DistributedArray(global_shape=ff2d_mpi.shape[1], dtype=par['dtype']) x[:] = np.random.randn(*(x.local_shape)) + par['imag'] * np.random.randn(*(x.local_shape)) x_global = x.asarray() @@ -106,7 +117,9 @@ def test_FFT2d(par): y_adj_dist = ff2d_mpi.H @ y_dist y_adj = y_adj_dist.asarray() if rank == 0: - fft2d = FFT2D(dims=par['dims'], axes=par['axes'], norm=par['norm'], real=par['real'], dtype=par['dtype']) + fft2d = FFT2D(dims=par['dims'], axes=par['axes'], norm=par['norm'], real=par['real'], + ifftshift_before=ifftshift_before, fftshift_after=fftshift_after, + dtype=par['dtype']) assert ff2d_mpi.shape == fft2d.shape y_np = fft2d @ x_global y_adj_np = fft2d.H @ y_np @@ -115,12 +128,23 @@ def test_FFT2d(par): @pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6), (par7), (par8)]) -def test_FFTND(par): +@pytest.mark.parametrize( + "ifftshift_before, fftshift_after", + [ + (False, False), + (True, False), + (False, True), + (True, True), + ], +) +def test_FFTND(par, ifftshift_before, fftshift_after): """MPIFFTND Operator""" if backend == "cupy": pytest.skip("Skipping cupy backend") np.random.seed(10) - ffnd_mpi = MPIFFTND(dims=par['dims'], axes=par['axes'], norm=par['norm'], real=par['real'], dtype=par['dtype']) + ffnd_mpi = MPIFFTND(dims=par['dims'], axes=par['axes'], norm=par['norm'], real=par['real'], + ifftshift_before=ifftshift_before, fftshift_after=fftshift_after, + dtype=par['dtype']) x = DistributedArray(global_shape=ffnd_mpi.shape[1], dtype=par['dtype']) x[:] = np.random.randn(*(x.local_shape)) + par['imag'] * np.random.randn(*(x.local_shape)) x_global = x.asarray() @@ -134,7 +158,9 @@ def test_FFTND(par): y_div_dist = ffnd_mpi / y_dist y_div = y_div_dist.asarray() if rank == 0: - fftnd = FFTND(dims=par['dims'], axes=par['axes'], norm=par['norm'], real=par['real'], dtype=par['dtype']) + fftnd = FFTND(dims=par['dims'], axes=par['axes'], norm=par['norm'], real=par['real'], + ifftshift_before=ifftshift_before, fftshift_after=fftshift_after, + dtype=par['dtype']) assert ffnd_mpi.shape == fftnd.shape y_np = fftnd @ x_global y_adj_np = fftnd.H @ y_np From bcdd74f65da5ab8a2c036303bf5d00b6db514c11 Mon Sep 17 00:00:00 2001 From: rohanbabbar04 Date: Mon, 11 May 2026 14:26:15 +0530 Subject: [PATCH 06/12] Add FFTND and FFT2D to docs --- docs/source/api/index.rst | 3 ++- pylops_mpi/signalprocessing/FFTND.py | 10 +++++----- pylops_mpi/utils/fft_helper.py | 22 ++++++++++++++++++---- 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 9913547c..8b74e40a 100644 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -71,7 +71,8 @@ Signal Processing :toctree: generated/ MPIFredholm1 - + MPIFFT2D + MPIFFTND Wave-Equation processing ~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/pylops_mpi/signalprocessing/FFTND.py b/pylops_mpi/signalprocessing/FFTND.py index 4d36eaa3..fc4a891c 100644 --- a/pylops_mpi/signalprocessing/FFTND.py +++ b/pylops_mpi/signalprocessing/FFTND.py @@ -10,7 +10,7 @@ from pylops_mpi.utils.decorators import reshaped from pylops_mpi.DistributedArray import DistributedArray, Partition from pylops_mpi.signalprocessing._baseffts import _MPIBaseFFTND -from pylops_mpi.utils import deps, fftshift, ifftshift +from pylops_mpi.utils import deps, fftshift_nd, ifftshift_nd mpi4py_fft_message = deps.mpi4py_fft_import("mpi4py_fft") @@ -194,7 +194,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: f"Got {x.partition} instead...") ncp = get_array_module(x.local_array) if self.ifftshift_before.any(): - x = ifftshift(x, axes=self.axes[self.ifftshift_before]) + x = ifftshift_nd(x, axes=self.axes[self.ifftshift_before]) if not self.clinear: x[:] = ncp.real(x.local_array) # Allocate distributed arrays for input and output @@ -225,7 +225,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: y[:] *= self._scale y[:] = y.local_array.astype(self.cdtype) if self.fftshift_after.any(): - y = fftshift(y, axes=self.axes[self.fftshift_after]) + y = fftshift_nd(y, axes=self.axes[self.fftshift_after]) return y @reshaped @@ -238,7 +238,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: f"Got {x.partition} instead...") ncp = get_array_module(x.local_array) if self.fftshift_after.any(): - x = ifftshift(x, axes=self.axes[self.fftshift_after]) + x = ifftshift_nd(x, axes=self.axes[self.fftshift_after]) if self.real: # Redistribute so that self.axes[-1] is not the one sliced safe_axis = next(i for i in range(len(self.dims)) if i != self.axes[-1]) @@ -274,7 +274,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: y[:] = ncp.real(y.local_array) y[:] = y.local_array.astype(self.rdtype) if self.ifftshift_before.any(): - y = fftshift(y, axes=self.axes[self.ifftshift_before]) + y = fftshift_nd(y, axes=self.axes[self.ifftshift_before]) return y def __truediv__(self, y: DistributedArray) -> DistributedArray: diff --git a/pylops_mpi/utils/fft_helper.py b/pylops_mpi/utils/fft_helper.py index 9e5a60e4..b91bcbfa 100644 --- a/pylops_mpi/utils/fft_helper.py +++ b/pylops_mpi/utils/fft_helper.py @@ -1,6 +1,6 @@ all = [ - "fftshift", - "ifftshift" + "fftshift_nd", + "ifftshift_nd" ] import numpy as np @@ -10,7 +10,7 @@ from pylops_mpi import DistributedArray -def fftshift(x: DistributedArray, axes: InputDimsLike = None): +def fftshift_nd(x: DistributedArray, axes: InputDimsLike = None): """ Shift the zero-frequency component to the center of the spectrum for a DistributedArray. @@ -19,6 +19,9 @@ def fftshift(x: DistributedArray, axes: InputDimsLike = None): distributed axis, the array is first redistributed to a different axis so the shift can be performed locally, then left in the redistributed state. + .. note:: + This function only supports nd arrays (n >= 2). + Parameters ---------- x : :obj: `pylops_mpi.DistributedArray` @@ -32,6 +35,10 @@ def fftshift(x: DistributedArray, axes: InputDimsLike = None): The shifted array. May be distributed along a different axis than the input if the original distributed axis was included in ``axes``. """ + if x.ndim < 2: + raise ValueError( + f"fftshift_nd requires a 2D or higher array, but got ndim={x.ndim}. " + ) if axes is None: axes = tuple(range(x.ndim)) elif np.isscalar(axes): @@ -50,7 +57,7 @@ def fftshift(x: DistributedArray, axes: InputDimsLike = None): return x -def ifftshift(x: DistributedArray, axes: InputDimsLike = None): +def ifftshift_nd(x: DistributedArray, axes: InputDimsLike = None): """ Shift the zero-frequency component back to the beginning of the spectrum for a DistributedArray. @@ -60,6 +67,9 @@ def ifftshift(x: DistributedArray, axes: InputDimsLike = None): For the distributed axis, the array is first redistributed to a different axis so the shift can be performed locally. + .. note:: + This function only supports nd arrays (n >= 2). + Parameters ---------- x : :obj: `pylops_mpi.DistributedArray` @@ -73,6 +83,10 @@ def ifftshift(x: DistributedArray, axes: InputDimsLike = None): The shifted array. May be distributed along a different axis than the input if the original distributed axis was included in ``axes``. """ + if x.ndim < 2: + raise ValueError( + f"ifftshift_nd requires a 2D or higher array, but got ndim={x.ndim}. " + ) if axes is None: axes = tuple(range(x.ndim)) elif np.isscalar(axes): From c3253de56f1bbdeea1466ae5c06985055cc83f85 Mon Sep 17 00:00:00 2001 From: rohanbabbar04 Date: Tue, 12 May 2026 19:36:41 +0530 Subject: [PATCH 07/12] Use assert_array_allmost_equal and minor changes --- pylops_mpi/utils/fft_helper.py | 18 +++++++++--------- tests/test_ffts.py | 18 +++++++++--------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/pylops_mpi/utils/fft_helper.py b/pylops_mpi/utils/fft_helper.py index b91bcbfa..ea4ebc7a 100644 --- a/pylops_mpi/utils/fft_helper.py +++ b/pylops_mpi/utils/fft_helper.py @@ -3,9 +3,7 @@ "ifftshift_nd" ] -import numpy as np - -from pylops.utils import InputDimsLike +from pylops.utils import InputDimsLike, get_module from pylops_mpi import DistributedArray @@ -39,21 +37,22 @@ def fftshift_nd(x: DistributedArray, axes: InputDimsLike = None): raise ValueError( f"fftshift_nd requires a 2D or higher array, but got ndim={x.ndim}. " ) + ncp = get_module(x.engine) if axes is None: axes = tuple(range(x.ndim)) - elif np.isscalar(axes): + elif ncp.isscalar(axes): axes = (axes,) local_axes = [ax for ax in axes if ax != x.axis] remote_axes = [ax for ax in axes if ax == x.axis] if local_axes: shifts = [x.global_shape[ax] // 2 for ax in local_axes] - x[:] = np.roll(x.local_array, shift=shifts, axis=local_axes) + x[:] = ncp.roll(x.local_array, shift=shifts, axis=local_axes) if remote_axes: new_axis = 1 if x.axis == 0 else 0 # Redistribute to a new axis for computation x = x.redistribute(axis=new_axis) shifts = [x.global_shape[ax] // 2 for ax in remote_axes] - x[:] = np.roll(x.local_array, shift=shifts, axis=remote_axes) + x[:] = ncp.roll(x.local_array, shift=shifts, axis=remote_axes) return x @@ -87,19 +86,20 @@ def ifftshift_nd(x: DistributedArray, axes: InputDimsLike = None): raise ValueError( f"ifftshift_nd requires a 2D or higher array, but got ndim={x.ndim}. " ) + ncp = get_module(x.engine) if axes is None: axes = tuple(range(x.ndim)) - elif np.isscalar(axes): + elif ncp.isscalar(axes): axes = (axes,) local_axes = [ax for ax in axes if ax != x.axis] dist_axes = [ax for ax in axes if ax == x.axis] if local_axes: shifts = [-(x.global_shape[ax] // 2) for ax in local_axes] - x[:] = np.roll(x.local_array, shift=shifts, axis=local_axes) + x[:] = ncp.roll(x.local_array, shift=shifts, axis=local_axes) if dist_axes: new_axis = 1 if x.axis == 0 else 0 # Redistribute to a new axis for computation x = x.redistribute(axis=new_axis) shifts = [-(x.global_shape[ax] // 2) for ax in dist_axes] - x[:] = np.roll(x.local_array, shift=shifts, axis=dist_axes) + x[:] = ncp.roll(x.local_array, shift=shifts, axis=dist_axes) return x diff --git a/tests/test_ffts.py b/tests/test_ffts.py index 794186f7..2b39313b 100644 --- a/tests/test_ffts.py +++ b/tests/test_ffts.py @@ -7,14 +7,14 @@ if int(os.environ.get("TEST_CUPY_PYLOPS", 0)): import cupy as np + from cupy.testing import assert_array_almost_equal backend = "cupy" else: import numpy as np - + from numpy.testing import assert_array_almost_equal backend = "numpy" from mpi4py import MPI -from numpy.testing import assert_allclose from pylops.signalprocessing import FFT2D, FFTND @@ -107,7 +107,7 @@ def test_FFT2d(par, ifftshift_before, fftshift_after): ff2d_mpi = MPIFFT2D(dims=par['dims'], axes=par['axes'], norm=par['norm'], real=par['real'], ifftshift_before=ifftshift_before, fftshift_after=fftshift_after, dtype=par['dtype']) - x = DistributedArray(global_shape=ff2d_mpi.shape[1], dtype=par['dtype']) + x = DistributedArray(global_shape=ff2d_mpi.shape[1], dtype=par['dtype'], engine=backend) x[:] = np.random.randn(*(x.local_shape)) + par['imag'] * np.random.randn(*(x.local_shape)) x_global = x.asarray() # Forward @@ -123,8 +123,8 @@ def test_FFT2d(par, ifftshift_before, fftshift_after): assert ff2d_mpi.shape == fft2d.shape y_np = fft2d @ x_global y_adj_np = fft2d.H @ y_np - assert_allclose(y, y_np, rtol=1e-5, atol=1e-8) - assert_allclose(y_adj, y_adj_np, rtol=1e-5, atol=1e-8) + assert_array_almost_equal(y, y_np, decimal=7) + assert_array_almost_equal(y_adj, y_adj_np, decimal=7) @pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6), (par7), (par8)]) @@ -145,7 +145,7 @@ def test_FFTND(par, ifftshift_before, fftshift_after): ffnd_mpi = MPIFFTND(dims=par['dims'], axes=par['axes'], norm=par['norm'], real=par['real'], ifftshift_before=ifftshift_before, fftshift_after=fftshift_after, dtype=par['dtype']) - x = DistributedArray(global_shape=ffnd_mpi.shape[1], dtype=par['dtype']) + x = DistributedArray(global_shape=ffnd_mpi.shape[1], dtype=par['dtype'], engine=backend) x[:] = np.random.randn(*(x.local_shape)) + par['imag'] * np.random.randn(*(x.local_shape)) x_global = x.asarray() # Forward @@ -165,6 +165,6 @@ def test_FFTND(par, ifftshift_before, fftshift_after): y_np = fftnd @ x_global y_adj_np = fftnd.H @ y_np y_div_np = fftnd / y_np - assert_allclose(y, y_np, rtol=1e-5, atol=1e-8) - assert_allclose(y_adj, y_adj_np, rtol=1e-5, atol=1e-8) - assert_allclose(y_div, y_div_np, rtol=1e-5, atol=1e-8) + assert_array_almost_equal(y, y_np, decimal=7) + assert_array_almost_equal(y_adj, y_adj_np, decimal=7) + assert_array_almost_equal(y_div, y_div_np, decimal=7) From 7b7e8926d902b01303207f93470ea3f4fd6a655b Mon Sep 17 00:00:00 2001 From: rohanbabbar04 Date: Thu, 14 May 2026 23:52:51 +0530 Subject: [PATCH 08/12] Add optional dependencies and update code for self.real --- .github/workflows/build.yml | 3 +- .github/workflows/deploy-docs.yml | 7 +- pylops_mpi/signalprocessing/FFT2D.py | 6 +- pylops_mpi/signalprocessing/FFTND.py | 100 ++++++++++++++------------- pyproject.toml | 4 ++ requirements-fft.txt | 1 - 6 files changed, 65 insertions(+), 56 deletions(-) delete mode 100644 requirements-fft.txt diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 78e0c8e0..6f4c3447 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -52,9 +52,8 @@ jobs: run: | python -m pip install --upgrade pip setuptools if [ -f requirements.txt ]; then pip install -r requirements-dev.txt; fi - if [ -f requirements-fft.txt ]; then pip install -r requirements-fft.txt; fi - name: Install pylops-mpi - run: pip install . + run: pip install .[all] - name: Testing using pytest-mpi run: | if [ "${{ matrix.mpi }}" = "openmpi" ]; then diff --git a/.github/workflows/deploy-docs.yml b/.github/workflows/deploy-docs.yml index 59e83239..c2a3777f 100644 --- a/.github/workflows/deploy-docs.yml +++ b/.github/workflows/deploy-docs.yml @@ -22,12 +22,15 @@ jobs: uses: actions/setup-python@v4 with: python-version: 3.11 - + - name: Install fftw libraries + run: | + sudo apt update + sudo apt install -y -q libfftw3-dev - name: Install dependencies run: | python -m pip install --upgrade pip if [ -f requirements.txt ]; then pip install -r requirements-dev.txt; fi - pip install . + pip install .[all] - name: Build docs run: | sphinx-build -b html ./docs/source ./docs/build diff --git a/pylops_mpi/signalprocessing/FFT2D.py b/pylops_mpi/signalprocessing/FFT2D.py index 0fdb36b2..2c3dfc77 100644 --- a/pylops_mpi/signalprocessing/FFT2D.py +++ b/pylops_mpi/signalprocessing/FFT2D.py @@ -134,8 +134,7 @@ def __init__( ifftshift_before: bool = False, fftshift_after: bool = False, dtype: DTypeLike = "complex128", - base_comm: MPI.Comm = MPI.COMM_WORLD, - **kwargs_fft, + base_comm: MPI.Comm = MPI.COMM_WORLD ) -> None: # checks if len(dims) < 2: @@ -145,8 +144,7 @@ def __init__( msg = "FFT2D must be applied along exactly two dimensions" raise ValueError(msg) super().__init__(dims=dims, axes=axes, sampling=sampling, norm=norm, real=real, dtype=dtype, - ifftshift_before=ifftshift_before, fftshift_after=fftshift_after, base_comm=base_comm, - **kwargs_fft) + ifftshift_before=ifftshift_before, fftshift_after=fftshift_after, base_comm=base_comm) self.f1, self.f2 = self.fs del self.fs diff --git a/pylops_mpi/signalprocessing/FFTND.py b/pylops_mpi/signalprocessing/FFTND.py index fc4a891c..011db4f6 100644 --- a/pylops_mpi/signalprocessing/FFTND.py +++ b/pylops_mpi/signalprocessing/FFTND.py @@ -75,8 +75,6 @@ class MPIFFTND(_MPIBaseFFTND): different from ``complex128``. base_comm : :obj:`mpi4py.MPI.Comm`, optional MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``. - **kwargs_fft - Arbitrary keyword arguments to be passed to the selected fft method Attributes ---------- @@ -152,8 +150,7 @@ def __init__( ifftshift_before: bool = False, fftshift_after: bool = False, dtype: DTypeLike = "complex128", - base_comm: MPI.Comm = MPI.COMM_WORLD, - **kwargs_fft, + base_comm: MPI.Comm = MPI.COMM_WORLD ) -> None: super().__init__( dims=dims, @@ -172,17 +169,16 @@ def __init__( "To respect the passed dtype, data will be cast to {self.cdtype}.", stacklevel=2, ) - - self._kwargs_fft = kwargs_fft if self.norm is _FFTNorms.NONE: self._scale = np.prod(self.nffts) elif self.norm is _FFTNorms.ONE_OVER_N: self._scale = 1.0 / np.prod(self.nffts) fft_dtype = self.rdtype if self.real else self.cdtype # Distribute only along axes[0]; all other axes are non-distributed (0=distributed, 1=not) - subcomm_dims = (np.arange(len(axes)) != axes[0]).astype(int) - subcomm = Subcomm(self.base_comm, dims=np.resize(subcomm_dims, len(dims))) - self.fft = PFFT(subcomm, self.dims, axes=self.axes, dtype=fft_dtype, collapse=True, **self._kwargs_fft) + subcomm_dims = np.ones(len(dims), dtype=int) + subcomm_dims[axes[0]] = 0 + subcomm = Subcomm(base_comm, subcomm_dims) + self.fft = PFFT(subcomm, self.dims, axes=self.axes, dtype=fft_dtype, collapse=True) @reshaped def _matvec(self, x: DistributedArray) -> DistributedArray: @@ -192,35 +188,39 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: if x.partition != Partition.SCATTER: raise ValueError(f"x should have partition={Partition.SCATTER}" f"Got {x.partition} instead...") - ncp = get_array_module(x.local_array) if self.ifftshift_before.any(): x = ifftshift_nd(x, axes=self.axes[self.ifftshift_before]) if not self.clinear: - x[:] = ncp.real(x.local_array) - # Allocate distributed arrays for input and output - u_dist = newDistArray(self.fft, forward_output=False) - u_hat = newDistArray(self.fft, forward_output=True) + x[:] = np.real(x.local_array) + x_dist_pfft = newDistArray(self.fft, forward_output=False) + y_dist_pfft = newDistArray(self.fft, forward_output=True) # Redistribute input to match the axis decomposed by PFFT x = x.redistribute(axis=self.axes[0]) - u_dist[:] = x.local_array + x_dist_pfft[:] = x.local_array # Perform the parallel forward FFT - self.fft.forward(u_dist, u_hat, normalize=False) + self.fft.forward(x_dist_pfft, y_dist_pfft, normalize=False) # Axis along which PFFT decomposes the output array across MPI processes - dist_axis = [i for i, s in enumerate(u_hat.subcomm) if s.Get_size() > 1] + dist_axis = [i for i, s in enumerate(y_dist_pfft.subcomm) if s.Get_size() > 1] y = DistributedArray(global_shape=self.dimsd, dtype=self.dtype, axis=dist_axis[0] if dist_axis else 0, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, engine=x.engine) - y[:] = u_hat + y[:] = y_dist_pfft if self.real: - # Redistribute so that self.axes[-1] is not the one sliced - safe_axis = next(i for i in range(len(self.dims)) if i != self.axes[-1]) - y = y.redistribute(axis=safe_axis) - y_local = y.local_array - # Apply scaling to obtain a correct adjoint for this operator - y_local = ncp.swapaxes(y_local, -1, self.axes[-1]) - y_local[..., 1: 1 + (self.nffts[-1] - 1) // 2] *= ncp.sqrt(2) - y_local = ncp.swapaxes(y_local, self.axes[-1], -1) - y[:] = y_local + if y.axis == self.axes[-1]: + sizes = [loc_shape[self.axes[-1]] for loc_shape in y.local_shapes] + start = sum(sizes[:self.base_comm.rank]) + stop = start + sizes[self.base_comm.rank] + # Local overlap with the global frequency slice [1:k] + g0, g1 = max(1, start), min(1 + (self.nffts[-1] - 1) // 2, stop) + if g1 > g0: + sl = [slice(None)] * y.ndim + sl[self.axes[-1]] = slice(g0 - start, g1 - start) + y[tuple(sl)] *= np.sqrt(2) + else: + # Axis is local on this rank, so direct slicing + sl = [slice(None)] * y.ndim + sl[self.axes[-1]] = slice(1, 1 + (self.nffts[-1] - 1) // 2) + y[tuple(sl)] *= np.sqrt(2) if self.norm is _FFTNorms.ONE_OVER_N: y[:] *= self._scale y[:] = y.local_array.astype(self.cdtype) @@ -236,42 +236,48 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: if x.partition != Partition.SCATTER: raise ValueError(f"x should have partition={Partition.SCATTER}, " f"Got {x.partition} instead...") - ncp = get_array_module(x.local_array) + np = get_array_module(x.local_array) if self.fftshift_after.any(): x = ifftshift_nd(x, axes=self.axes[self.fftshift_after]) if self.real: - # Redistribute so that self.axes[-1] is not the one sliced - safe_axis = next(i for i in range(len(self.dims)) if i != self.axes[-1]) - x = x.redistribute(axis=safe_axis) - # Apply scaling to obtain a correct adjoint for this operator - x_local = x.local_array - x_local = ncp.swapaxes(x_local, -1, self.axes[-1]) - x_local[..., 1: 1 + (self.nffts[-1] - 1) // 2] /= ncp.sqrt(2) - x_local = ncp.swapaxes(x_local, self.axes[-1], -1) - x[:] = x_local + if x.axis == self.axes[-1]: + sizes = [loc_shape[self.axes[-1]] for loc_shape in x.local_shapes] + start = sum(sizes[:self.base_comm.rank]) + stop = start + sizes[self.base_comm.rank] + # Local overlap with the global frequency slice [1:k] + g0, g1 = max(1, start), min(1 + (self.nffts[-1] - 1) // 2, stop) + if g1 > g0: + sl = [slice(None)] * x.ndim + sl[self.axes[-1]] = slice(g0 - start, g1 - start) + x[tuple(sl)] /= np.sqrt(2) + else: + # Axis is local on this rank, so direct slicing + sl = [slice(None)] * x.ndim + sl[self.axes[-1]] = slice(1, 1 + (self.nffts[-1] - 1) // 2) + x[tuple(sl)] /= np.sqrt(2) # Allocate distributed arrays for input and output - u_dist = newDistArray(self.fft, forward_output=False) - u_hat = newDistArray(self.fft, forward_output=True) + y_dist_pfft = newDistArray(self.fft, forward_output=False) + x_dist_pfft = newDistArray(self.fft, forward_output=True) # Redistribute input to match the axis decomposed by PFFT - x_axis = [i for i, s in enumerate(u_hat.subcomm) if s.Get_size() > 1] + x_axis = [i for i, s in enumerate(x_dist_pfft.subcomm) if s.Get_size() > 1] x = x.redistribute(axis=x_axis[0] if x_axis else 0) - u_hat[:] = x.local_array + x_dist_pfft[:] = x.local_array # Perform the parallel backward FFT - self.fft.backward(u_hat, u_dist, normalize=True) + self.fft.backward(x_dist_pfft, y_dist_pfft, normalize=True) # Axis along which PFFT decomposes the output array across MPI processes - dist_axis = [i for i, s in enumerate(u_dist.subcomm) if s.Get_size() > 1] + dist_axis = [i for i, s in enumerate(y_dist_pfft.subcomm) if s.Get_size() > 1] y = DistributedArray(global_shape=self.dims, dtype=self.dtype, axis=dist_axis[0] if dist_axis else 0, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, engine=x.engine) - y[:] = u_dist + y[:] = y_dist_pfft if self.norm is _FFTNorms.NONE: y[:] *= self._scale if self.nffts[0] > self.dims[self.axes[0]]: - y[:] = ncp.take(y.local_array, ncp.arange(self.dims[self.axes[0]]), axis=self.axes[0]) + y[:] = np.take(y.local_array, np.arange(self.dims[self.axes[0]]), axis=self.axes[0]) if self.nffts[1] > self.dims[self.axes[1]]: - y[:] = ncp.take(y.local_array, ncp.arange(self.dims[self.axes[1]]), axis=self.axes[1]) + y[:] = np.take(y.local_array, np.arange(self.dims[self.axes[1]]), axis=self.axes[1]) if not self.clinear: - y[:] = ncp.real(y.local_array) + y[:] = np.real(y.local_array) y[:] = y.local_array.astype(self.rdtype) if self.ifftshift_before.any(): y = fftshift_nd(y, axes=self.axes[self.ifftshift_before]) diff --git a/pyproject.toml b/pyproject.toml index d70d32c4..e2df251d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,3 +42,7 @@ exclude = ["pytests"] [tool.setuptools_scm] version_file = "pylops_mpi/version.py" + +[project.optional-dependencies] +fft = ["mpi4py-fft"] +all = ["mpi4py-fft"] diff --git a/requirements-fft.txt b/requirements-fft.txt deleted file mode 100644 index ca5fb8be..00000000 --- a/requirements-fft.txt +++ /dev/null @@ -1 +0,0 @@ -mpi4py-fft From f87b5b7bc3cdcaf2c8c42db09b17125d4e922744 Mon Sep 17 00:00:00 2001 From: rohanbabbar04 Date: Fri, 15 May 2026 21:25:49 +0530 Subject: [PATCH 09/12] Add scale_real_fft and set subcomm_axis=0 --- pylops_mpi/signalprocessing/FFTND.py | 97 +++++++++++++++------------- 1 file changed, 53 insertions(+), 44 deletions(-) diff --git a/pylops_mpi/signalprocessing/FFTND.py b/pylops_mpi/signalprocessing/FFTND.py index 011db4f6..493abe50 100644 --- a/pylops_mpi/signalprocessing/FFTND.py +++ b/pylops_mpi/signalprocessing/FFTND.py @@ -174,11 +174,20 @@ def __init__( elif self.norm is _FFTNorms.ONE_OVER_N: self._scale = 1.0 / np.prod(self.nffts) fft_dtype = self.rdtype if self.real else self.cdtype - # Distribute only along axes[0]; all other axes are non-distributed (0=distributed, 1=not) subcomm_dims = np.ones(len(dims), dtype=int) - subcomm_dims[axes[0]] = 0 + # axis=0 for the initial distribution by default, if the final axis is 0, distribute along axis 1 instead. + if axes[-1] == 0: + subcomm_dims[1] = 0 + else: + subcomm_dims[0] = 0 subcomm = Subcomm(base_comm, subcomm_dims) - self.fft = PFFT(subcomm, self.dims, axes=self.axes, dtype=fft_dtype, collapse=True) + self.fft = PFFT(subcomm, self.dims, axes=self.axes, dtype=fft_dtype) + self._pfft_in_axis = next( + i for i, s in enumerate(self.fft.pencil[False].subcomm) if s.Get_size() > 1 + ) + self._pfft_out_axis = next( + i for i, s in enumerate(self.fft.pencil[True].subcomm) if s.Get_size() > 1 + ) @reshaped def _matvec(self, x: DistributedArray) -> DistributedArray: @@ -194,33 +203,17 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: x[:] = np.real(x.local_array) x_dist_pfft = newDistArray(self.fft, forward_output=False) y_dist_pfft = newDistArray(self.fft, forward_output=True) - # Redistribute input to match the axis decomposed by PFFT - x = x.redistribute(axis=self.axes[0]) + # Redistribute input to match the input PFFT axis + x = x.redistribute(axis=self._pfft_in_axis) x_dist_pfft[:] = x.local_array # Perform the parallel forward FFT self.fft.forward(x_dist_pfft, y_dist_pfft, normalize=False) - # Axis along which PFFT decomposes the output array across MPI processes - dist_axis = [i for i, s in enumerate(y_dist_pfft.subcomm) if s.Get_size() > 1] - y = DistributedArray(global_shape=self.dimsd, dtype=self.dtype, axis=dist_axis[0] if dist_axis else 0, + y = DistributedArray(global_shape=self.dimsd, dtype=self.dtype, axis=self._pfft_out_axis, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, engine=x.engine) y[:] = y_dist_pfft if self.real: - if y.axis == self.axes[-1]: - sizes = [loc_shape[self.axes[-1]] for loc_shape in y.local_shapes] - start = sum(sizes[:self.base_comm.rank]) - stop = start + sizes[self.base_comm.rank] - # Local overlap with the global frequency slice [1:k] - g0, g1 = max(1, start), min(1 + (self.nffts[-1] - 1) // 2, stop) - if g1 > g0: - sl = [slice(None)] * y.ndim - sl[self.axes[-1]] = slice(g0 - start, g1 - start) - y[tuple(sl)] *= np.sqrt(2) - else: - # Axis is local on this rank, so direct slicing - sl = [slice(None)] * y.ndim - sl[self.axes[-1]] = slice(1, 1 + (self.nffts[-1] - 1) // 2) - y[tuple(sl)] *= np.sqrt(2) + self._scale_real_fft(y, inverse=False) if self.norm is _FFTNorms.ONE_OVER_N: y[:] *= self._scale y[:] = y.local_array.astype(self.cdtype) @@ -240,34 +233,17 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: if self.fftshift_after.any(): x = ifftshift_nd(x, axes=self.axes[self.fftshift_after]) if self.real: - if x.axis == self.axes[-1]: - sizes = [loc_shape[self.axes[-1]] for loc_shape in x.local_shapes] - start = sum(sizes[:self.base_comm.rank]) - stop = start + sizes[self.base_comm.rank] - # Local overlap with the global frequency slice [1:k] - g0, g1 = max(1, start), min(1 + (self.nffts[-1] - 1) // 2, stop) - if g1 > g0: - sl = [slice(None)] * x.ndim - sl[self.axes[-1]] = slice(g0 - start, g1 - start) - x[tuple(sl)] /= np.sqrt(2) - else: - # Axis is local on this rank, so direct slicing - sl = [slice(None)] * x.ndim - sl[self.axes[-1]] = slice(1, 1 + (self.nffts[-1] - 1) // 2) - x[tuple(sl)] /= np.sqrt(2) + self._scale_real_fft(x, inverse=True) # Allocate distributed arrays for input and output y_dist_pfft = newDistArray(self.fft, forward_output=False) x_dist_pfft = newDistArray(self.fft, forward_output=True) - # Redistribute input to match the axis decomposed by PFFT - x_axis = [i for i, s in enumerate(x_dist_pfft.subcomm) if s.Get_size() > 1] - x = x.redistribute(axis=x_axis[0] if x_axis else 0) + # Redistribute input to match the PFFT axis + x = x.redistribute(axis=self._pfft_out_axis) x_dist_pfft[:] = x.local_array # Perform the parallel backward FFT self.fft.backward(x_dist_pfft, y_dist_pfft, normalize=True) - # Axis along which PFFT decomposes the output array across MPI processes - dist_axis = [i for i, s in enumerate(y_dist_pfft.subcomm) if s.Get_size() > 1] - y = DistributedArray(global_shape=self.dims, dtype=self.dtype, axis=dist_axis[0] if dist_axis else 0, + y = DistributedArray(global_shape=self.dims, dtype=self.dtype, axis=self._pfft_in_axis, base_comm=x.base_comm, base_comm_nccl=x.base_comm_nccl, engine=x.engine) y[:] = y_dist_pfft if self.norm is _FFTNorms.NONE: @@ -283,6 +259,39 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: y = fftshift_nd(y, axes=self.axes[self.ifftshift_before]) return y + def _scale_real_fft(self, x: DistributedArray, inverse: bool = False) -> None: + """Apply scaling for real-valued FFTs. + + Scales the non-DC positive frequency components along the final FFT axis + by ``sqrt(2)`` in forward mode and ``1/sqrt(2)`` in inverse mode. + + When the final FFT axis is distributed across MPI ranks, only the local + portion overlapping with the global positive-frequency range is scaled. + + Parameters + ---------- + x : DistributedArray + Distributed FFT array to scale in-place. + inverse : bool, optional + Apply inverse scaling when ``True``. Default is ``False``. + """ + scale = 1 / np.sqrt(2) if inverse else np.sqrt(2) + if x.axis == self.axes[-1]: + sizes = [loc_shape[self.axes[-1]] for loc_shape in x.local_shapes] + local_start = sum(sizes[:self.base_comm.rank]) + local_stop = local_start + sizes[self.base_comm.rank] + freq_start, freq_stop = max(1, local_start), min(1 + (self.nffts[-1] - 1) // 2, local_stop) + # Local overlap with the global frequency slice [1:k] + if freq_stop > freq_start: + local_slice = [slice(None)] * x.ndim + local_slice[self.axes[-1]] = slice(freq_start - local_start, freq_stop - local_start) + x[tuple(local_slice)] *= scale + else: + # Axis is local on this rank, so direct slicing + freq_slice = [slice(None)] * x.ndim + freq_slice[self.axes[-1]] = slice(1, 1 + (self.nffts[-1] - 1) // 2) + x[tuple(freq_slice)] *= scale + def __truediv__(self, y: DistributedArray) -> DistributedArray: y_div = self._rmatvec(y) y_div[:] = y_div.local_array / self._scale From 4d76502f7951860ee0efcf35fb32d03917b8355a Mon Sep 17 00:00:00 2001 From: rohanbabbar04 Date: Sat, 16 May 2026 13:33:02 +0530 Subject: [PATCH 10/12] Update documentation --- Makefile | 4 ++++ docs/source/conf.py | 3 ++- docs/source/installation.rst | 19 +++++++++++++++++++ examples/plot_ffts.py | 2 -- pylops_mpi/signalprocessing/FFT2D.py | 21 ++++++++++++++++++++- pylops_mpi/signalprocessing/FFTND.py | 24 +++++++++++++++++++++--- 6 files changed, 66 insertions(+), 7 deletions(-) diff --git a/Makefile b/Makefile index 1d824ee0..10d121bd 100644 --- a/Makefile +++ b/Makefile @@ -43,6 +43,10 @@ dev-install_conda: dev-install_conda_nccl: conda env create -f environment-dev.yml && conda activate pylops_mpi && conda install -c conda-forge cupy nccl && pip install -e . +dev-install_fft: + make pipcheck + $(PIP) install -r requirements-dev.txt && $(PIP) install -e ".[fft]" + lint: flake8 pylops_mpi/ tests/ examples/ tutorials/ diff --git a/docs/source/conf.py b/docs/source/conf.py index 4e7c2c86..169d232e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -28,7 +28,8 @@ "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), "matplotlib": ("https://matplotlib.org/", None), "mpi4py": ("https://mpi4py.readthedocs.io/en/stable/", None), - "pylops": ("https://pylops.readthedocs.io/en/stable/", None) + "pylops": ("https://pylops.readthedocs.io/en/stable/", None), + "mpi4py_fft": ("https://mpi4py-fft.readthedocs.io/en/stable/", None) } # Generate autodoc stubs with summaries from code diff --git a/docs/source/installation.rst b/docs/source/installation.rst index e1d7faf3..5f7e7211 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -79,6 +79,13 @@ using `pip`: Replace `12x` with your CUDA version (e.g., `11x` for CUDA 11.x). +Install fft dependencies (optional) +=================================== +Similarly, to use the FFT classes with distributed arrays, install PyLops-MPI with the ``fft`` extra: + +.. code-block:: bash + + >> pip install pylops-mpi[fft] .. _UserInstall: @@ -94,6 +101,12 @@ command in your terminal to install the PyPI distribution: Note that when installing via `pip`, only *required* dependencies are installed. +Alternatively, optional dependencies can be installed individually via: + +.. code-block:: bash + + >> pip install pylops-mpi[all] + .. _DevInstall: @@ -169,6 +182,12 @@ Otherwise, you can change the command in `Makefile` to an appropriate CUDA versi i.e., If you use CUDA 11.x, change ``cupy-cuda12x`` and ``nvidia-nccl-cu12`` to ``cupy-cuda11x`` and ``nvidia-nccl-cu11`` and run the command. +If you want to be able to use FFT classes with distributed arrays, run: + +.. code-block:: bash + + >> make dev-install_fft + Run tests ========= To ensure that everything has been setup correctly, run tests: diff --git a/examples/plot_ffts.py b/examples/plot_ffts.py index 06b697c9..56b86b48 100644 --- a/examples/plot_ffts.py +++ b/examples/plot_ffts.py @@ -34,7 +34,6 @@ D = FFTop * dist -dinv = FFTop.H * D dinv = FFTop / D dinv = np.real(dinv.asarray()).reshape(nt, nx) @@ -83,7 +82,6 @@ ) D = FFTop * dist -dinv = FFTop.H * D dinv = FFTop / D dinv = np.real(dinv.asarray()).reshape(nt, nx, ny) D_3d = D.asarray().reshape(nt, nx, ny) # shape matches dims now diff --git a/pylops_mpi/signalprocessing/FFT2D.py b/pylops_mpi/signalprocessing/FFT2D.py index 2c3dfc77..d360e86c 100644 --- a/pylops_mpi/signalprocessing/FFT2D.py +++ b/pylops_mpi/signalprocessing/FFT2D.py @@ -85,7 +85,7 @@ class MPIFFT2D(MPIFFTND): clinear : :obj:`bool` Operator is complex-linear. Is false when either ``real=True`` or when ``dtype`` is not a complex type. - fft : :obj:`mpi4py_fft.PFFT` + fft : :obj:`mpi4py_fft.mpifft.PFFT` Parallel FFT operator object handling the distributed transform across MPI processes. Configured with the base communicator, dimension decomposition, transform axes, and dtype. @@ -123,6 +123,25 @@ class MPIFFT2D(MPIFFTND): the adjoint is **not** the inverse of the forward mode; instead, the inverse requires an explicit :math:`1/N_F` scaling factor (applied in the adjoint/inverse). + **MPI Parallelization** + + The distributed 2-D FFT relies on ``mpi4py_fft``'s + :class:`mpi4py_fft.mpifft.PFFT` (Parallel FFT) class. The global 2-D array is + decomposed across MPI ranks using a *slab decomposition* managed by + :class:`mpi4py_fft.pencil.Subcomm`, which distributes along a single + axis. By default, the input domain is distributed along + ``axis=0``; if ``axes[-1] == 0``, distribution shifts to ``axis=1`` + to avoid a conflict between the transform and decomposition axes. + + In the forward pass, the input is redistributed to match the axis + along which :attr:`fft` (a :class:`mpi4py_fft.mpifft.PFFT` instance) + expects its input, and :meth:`PFFT.forward` is called with + ``normalize=False``. In the adjoint pass, :meth:`PFFT.backward` is + called with ``normalize=True``, meaning ``PFFT`` divides by + :math:`N_x N_y` internally. All inter-rank data movement is + handled internally by ``mpi4py_fft``. + + """ def __init__( self, diff --git a/pylops_mpi/signalprocessing/FFTND.py b/pylops_mpi/signalprocessing/FFTND.py index 493abe50..869f36dc 100644 --- a/pylops_mpi/signalprocessing/FFTND.py +++ b/pylops_mpi/signalprocessing/FFTND.py @@ -94,7 +94,7 @@ class MPIFFTND(_MPIBaseFFTND): clinear : :obj:`bool` Operator is complex-linear. Is false when either ``real=True`` or when ``dtype`` is not a complex type. - fft : :obj:`mpi4py_fft.PFFT` + fft : :obj:`mpi4py_fft.mpifft.PFFT` Parallel FFT operator object handling the distributed transform across MPI processes. Configured with the base communicator, dimension decomposition, transform axes, and dtype. @@ -139,6 +139,24 @@ class MPIFFTND(_MPIBaseFFTND): the adjoint is **not** the inverse of the forward mode; instead, the inverse requires an explicit :math:`1/N_F` scaling factor (applied in the adjoint/inverse). + **MPI Parallelization** + + The distributed N-dimensional FFT relies on ``mpi4py_fft``'s + :class:`mpi4py_fft.mpifft.PFFT` (Parallel FFT) class. The global array is + decomposed across MPI ranks using a *pencil decomposition* managed by + :class:`mpi4py_fft.pencil.Subcomm`, which distributes along a single + axis at a time. By default, the input domain is distributed along + ``axis=0``; if ``axes[-1] == 0``, distribution shifts to ``axis=1`` + to avoid a conflict between the transform and decomposition axes. + + In the forward pass, the input is redistributed to match the axis + along which :attr:`fft` (a :class:`mpi4py_fft.mpifft.PFFT` instance) + expects its input, and :meth:`PFFT.forward` is called with + ``normalize=False``. In the adjoint pass, :meth:`PFFT.backward` is + called with ``normalize=True``, meaning ``PFFT`` divides by + :math:`N_F` internally. All inter-rank data movement (pencil transfer) is + handled internally by ``mpi4py_fft``. + """ def __init__( self, @@ -183,10 +201,10 @@ def __init__( subcomm = Subcomm(base_comm, subcomm_dims) self.fft = PFFT(subcomm, self.dims, axes=self.axes, dtype=fft_dtype) self._pfft_in_axis = next( - i for i, s in enumerate(self.fft.pencil[False].subcomm) if s.Get_size() > 1 + (i for i, s in enumerate(self.fft.pencil[False].subcomm) if s.Get_size() > 1), 0 ) self._pfft_out_axis = next( - i for i, s in enumerate(self.fft.pencil[True].subcomm) if s.Get_size() > 1 + (i for i, s in enumerate(self.fft.pencil[True].subcomm) if s.Get_size() > 1), 0 ) @reshaped From d78311b28194db7b742c1aa9471dbf96d9396364 Mon Sep 17 00:00:00 2001 From: rohanbabbar04 Date: Sat, 16 May 2026 20:36:13 +0530 Subject: [PATCH 11/12] Minor change in plot_ffts.py --- examples/plot_ffts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/plot_ffts.py b/examples/plot_ffts.py index 56b86b48..eb8df6d5 100644 --- a/examples/plot_ffts.py +++ b/examples/plot_ffts.py @@ -34,7 +34,7 @@ D = FFTop * dist -dinv = FFTop / D +dinv = FFTop.H * D dinv = np.real(dinv.asarray()).reshape(nt, nx) D_2d = D.asarray().reshape(nt, nx) @@ -82,7 +82,7 @@ ) D = FFTop * dist -dinv = FFTop / D +dinv = FFTop.H * D dinv = np.real(dinv.asarray()).reshape(nt, nx, ny) D_3d = D.asarray().reshape(nt, nx, ny) # shape matches dims now From 887dad6ba18056d95a5d3d202c6ace022053af09 Mon Sep 17 00:00:00 2001 From: rohanbabbar04 Date: Sun, 24 May 2026 23:24:23 +0530 Subject: [PATCH 12/12] Update notes section --- pylops_mpi/signalprocessing/FFT2D.py | 48 ++++++++++------------------ pylops_mpi/signalprocessing/FFTND.py | 48 ++++++++++------------------ 2 files changed, 34 insertions(+), 62 deletions(-) diff --git a/pylops_mpi/signalprocessing/FFT2D.py b/pylops_mpi/signalprocessing/FFT2D.py index d360e86c..f6f7bb56 100644 --- a/pylops_mpi/signalprocessing/FFT2D.py +++ b/pylops_mpi/signalprocessing/FFT2D.py @@ -101,47 +101,33 @@ class MPIFFT2D(MPIFFTND): Notes ----- - The FFT2D operator (using ``norm="none"``) applies the two-dimensional forward - Fourier transform to a signal :math:`d(y, x)` in forward mode: + The MPIFFT2D operator performs forward and adjoint passes on a + :class:`pylops_mpi.DistributedArray`, which is internally reshaped to the 2-dimensional layout + defined by ``dims``. The 2-dimensional FFT is then applied across MPI ranks using ``mpi4py_fft``'s + :class:`mpi4py_fft.mpifft.PFFT` class, with the global array decomposed via a pencil decomposition. + :class:`mpi4py_fft.pencil.Subcomm` selects the axis of distribution: ``axis=0`` by default, + shifting to ``axis=1`` if ``axes[-1] == 0`` to avoid a conflict between the transform and + decomposition axes. + + In the forward pass, :meth:`PFFT.forward` is called with ``normalize=False``, computing: .. math:: D(k_y, k_x) = \mathscr{F} (d) = \iint\limits_{-\infty}^\infty d(y, x) e^{-j2\pi k_yy} e^{-j2\pi k_xx} \,\mathrm{d}y \,\mathrm{d}x - Similarly, the two-dimensional inverse Fourier transform is applied to - the Fourier spectrum :math:`D(k_y, k_x)` in adjoint mode: + When ``norm="1/n"``, the result is additionally scaled by :math:`1/N_F`. + + In the adjoint pass, :meth:`PFFT.backward` is called with ``normalize=True``, so ``PFFT`` + internally divides by :math:`N_F = N_1 \cdot N_2`, computing: .. math:: d(y,x) = \mathscr{F}^{-1} (D) = \frac{1}{N_F} \iint\limits_{-\infty}^\infty D(k_y, k_x) e^{j2\pi k_yy} e^{j2\pi k_xx} \,\mathrm{d}k_y \,\mathrm{d}k_x - where :math:`N_F` is the number of samples in the Fourier domain given by the - product of the elements of ``nffts``. - - Both operators are effectively discretized and solved by a fast iterative - algorithm known as Fast Fourier Transform. Note that when using ``norm="none"``, - the adjoint is **not** the inverse of the forward mode; instead, the inverse - requires an explicit :math:`1/N_F` scaling factor (applied in the adjoint/inverse). - - **MPI Parallelization** - - The distributed 2-D FFT relies on ``mpi4py_fft``'s - :class:`mpi4py_fft.mpifft.PFFT` (Parallel FFT) class. The global 2-D array is - decomposed across MPI ranks using a *slab decomposition* managed by - :class:`mpi4py_fft.pencil.Subcomm`, which distributes along a single - axis. By default, the input domain is distributed along - ``axis=0``; if ``axes[-1] == 0``, distribution shifts to ``axis=1`` - to avoid a conflict between the transform and decomposition axes. - - In the forward pass, the input is redistributed to match the axis - along which :attr:`fft` (a :class:`mpi4py_fft.mpifft.PFFT` instance) - expects its input, and :meth:`PFFT.forward` is called with - ``normalize=False``. In the adjoint pass, :meth:`PFFT.backward` is - called with ``normalize=True``, meaning ``PFFT`` divides by - :math:`N_x N_y` internally. All inter-rank data movement is - handled internally by ``mpi4py_fft``. - - + When ``norm="none"``, the adjoint multiplies by :math:`N_F` to cancel this internal scaling, + returning a true unscaled adjoint. The result is then flattened back to a 1D + :class:`pylops_mpi.DistributedArray`. All inter-rank data movement is handled internally by + ``mpi4py_fft``. """ def __init__( self, diff --git a/pylops_mpi/signalprocessing/FFTND.py b/pylops_mpi/signalprocessing/FFTND.py index 869f36dc..de4d8297 100644 --- a/pylops_mpi/signalprocessing/FFTND.py +++ b/pylops_mpi/signalprocessing/FFTND.py @@ -110,9 +110,15 @@ class MPIFFTND(_MPIBaseFFTND): Notes ----- - The MPIFFTND operator (using ``norm="none"``) applies the N-dimensional forward - Fourier transform to a multidimensional array. Considering an N-dimensional - signal :math:`d(x_1, \ldots, x_N)`. The MPIFFTND in forward mode is: + The MPIFFTND operator performs forward and adjoint passes on a + :class:`pylops_mpi.DistributedArray`, which is internally reshaped to the N-dimensional layout + defined by ``dims``. The N-dimensional FFT is then applied across MPI ranks using ``mpi4py_fft``'s + :class:`mpi4py_fft.mpifft.PFFT` class, with the global array decomposed via a pencil decomposition. + :class:`mpi4py_fft.pencil.Subcomm` selects the axis of distribution: ``axis=0`` by default, + shifting to ``axis=1`` if ``axes[-1] == 0`` to avoid a conflict between the transform and + decomposition axes. + + In the forward pass, :meth:`PFFT.forward` is called with ``normalize=False``, computing: .. math:: D(k_1, \ldots, k_N) = \mathscr{F} (d) = @@ -121,8 +127,10 @@ class MPIFFTND(_MPIBaseFFTND): e^{-j2\pi k_1 x_1} \cdots e^{-j 2 \pi k_N x_N} \,\mathrm{d}x_1 \cdots \mathrm{d}x_N - Similarly, the N-dimensional inverse Fourier transform is applied to - the Fourier spectrum :math:`D(k_1, \ldots, k_N)` in adjoint mode: + When ``norm="1/n"``, the result is additionally scaled by :math:`1/N_F`. + + In the adjoint pass, :meth:`PFFT.backward` is called with ``normalize=True``, so ``PFFT`` + internally divides by :math:`N_F = \prod_i N_i`, computing: .. math:: d(x_1, \ldots, x_N) = \mathscr{F}^{-1} (D) = \frac{1}{N_F} @@ -131,32 +139,10 @@ class MPIFFTND(_MPIBaseFFTND): e^{j2\pi k_1 x_1} \cdots e^{j 2 \pi k_N x_N} \,\mathrm{d}k_1 \cdots \mathrm{d}k_N - where :math:`N_F` is the number of samples in the Fourier domain given by the - product of the elements of ``nffts``. - - Both operators are effectively discretized and solved by a fast iterative - algorithm known as Fast Fourier Transform. Note that when using ``norm="none"``, - the adjoint is **not** the inverse of the forward mode; instead, the inverse - requires an explicit :math:`1/N_F` scaling factor (applied in the adjoint/inverse). - - **MPI Parallelization** - - The distributed N-dimensional FFT relies on ``mpi4py_fft``'s - :class:`mpi4py_fft.mpifft.PFFT` (Parallel FFT) class. The global array is - decomposed across MPI ranks using a *pencil decomposition* managed by - :class:`mpi4py_fft.pencil.Subcomm`, which distributes along a single - axis at a time. By default, the input domain is distributed along - ``axis=0``; if ``axes[-1] == 0``, distribution shifts to ``axis=1`` - to avoid a conflict between the transform and decomposition axes. - - In the forward pass, the input is redistributed to match the axis - along which :attr:`fft` (a :class:`mpi4py_fft.mpifft.PFFT` instance) - expects its input, and :meth:`PFFT.forward` is called with - ``normalize=False``. In the adjoint pass, :meth:`PFFT.backward` is - called with ``normalize=True``, meaning ``PFFT`` divides by - :math:`N_F` internally. All inter-rank data movement (pencil transfer) is - handled internally by ``mpi4py_fft``. - + When ``norm="none"``, the adjoint multiplies by :math:`N_F` to cancel this internal scaling, + returning a true unscaled adjoint. The result is then flattened back to a 1D + :class:`pylops_mpi.DistributedArray`. All inter-rank data movement is handled internally by + ``mpi4py_fft``. """ def __init__( self,