Skip to content

Perf: 2-4x speedup for ShermanMorrison _solve_2D2#445

Open
jberg5 wants to merge 1 commit into
nanograv:devfrom
jberg5:solve-2D2-perf
Open

Perf: 2-4x speedup for ShermanMorrison _solve_2D2#445
jberg5 wants to merge 1 commit into
nanograv:devfrom
jberg5:solve-2D2-perf

Conversation

@jberg5
Copy link
Copy Markdown

@jberg5 jberg5 commented Apr 25, 2026

2-4x speedup for

def _solve_2D2(self, X, Z):

A single numpy matrix multiplication is much faster than outer product + subtraction in a loop. We can make the _solve_2D2 correction anywhere from 10-100x faster (end to end speedup will be lower because other things are unchanged) just by accumulating all zn and xn terms into matrices, since the sum of scaled outer products is itself a matrix product, as long as you decompose beta:

$$ \sum_{i=1}^{E} \beta_i ; \mathbf{z}_i , \mathbf{x}_i^T = \sum_{i=1}^{E} \left(\sqrt{\beta_i};\mathbf{z}_i\right)\left(\sqrt{\beta_i};\mathbf{x}_i\right)^T = V^T W $$

where we define:

$$V \in \mathbb{R}^{E \times N_\text{basis}}, \quad V_{i,:} = \sqrt{\beta_i};\mathbf{z}_i$$

$$W \in \mathbb{R}^{E \times N_\text{basis}}, \quad W_{i,:} = \sqrt{\beta_i};\mathbf{x}_i$$

Note that this isn't really an algorithmic improvement; the total number of flops is going to be roughly the same. The speedup all comes from being able to use the BLAS dgemm kernel (and the corresponding speedup you see will depend on your hardware). There is a negligible memory overhead from larger intermediate matrices.

Here's a synthetic benchmarking script that you can run standalone to see the boost (thanks Claude for writing it):

Click to expand script
"""
Microbenchmark: full _solve_2D2 original vs batched.
Sweeps realistic problem sizes. No enterprise dependencies.
"""
import time
import numpy as np

np.random.seed(42)


def _solve_2D2_original(nvec, jvec, idxs, X, Z):
  ZNX = np.dot(Z.T / nvec, X)
  for idx, jv in zip(idxs, jvec):
      if len(idx) > 1:
          niblock = 1 / nvec[idx]
          beta = 1.0 / (np.einsum("i->", niblock) + 1.0 / jv)
          zn = np.dot(niblock, Z[idx, :])
          xn = np.dot(niblock, X[idx, :])
          ZNX -= beta * np.outer(zn.T, xn)
  return ZNX


def _solve_2D2_batched(nvec, jvec, idxs, X, Z):
  ZNX = np.dot(Z.T / nvec, X)
  n_epochs = len(idxs)
  zn_all = np.zeros((n_epochs, Z.shape[1]))
  xn_all = np.zeros((n_epochs, X.shape[1]))
  beta_all = np.zeros(n_epochs)
  for i, (idx, jv) in enumerate(zip(idxs, jvec)):
      if len(idx) > 1:
          niblock = 1.0 / nvec[idx]
          beta_all[i] = 1.0 / (np.einsum("i->", niblock) + 1.0 / jv)
          zn_all[i] = np.dot(niblock, Z[idx, :])
          xn_all[i] = np.dot(niblock, X[idx, :])
  sqrt_beta = np.sqrt(beta_all)[:, None]
  ZNX -= np.dot((sqrt_beta * zn_all).T, sqrt_beta * xn_all)
  return ZNX


def make_synthetic(n_toas, n_basis, n_epochs, epoch_size):
  nvec = np.random.uniform(0.1, 10.0, n_toas)
  jvec = np.random.uniform(0.01, 1.0, n_epochs)
  idxs = []
  pos = 0
  for i in range(n_epochs):
      end = min(pos + epoch_size, n_toas)
      if end > pos:
          idxs.append(np.arange(pos, end))
      pos = end
      if pos >= n_toas:
          break
  idxs = idxs[:n_epochs]
  jvec = jvec[:len(idxs)]
  X = np.random.randn(n_toas, n_basis)
  Z = np.random.randn(n_toas, n_basis)
  return nvec, jvec, idxs, X, Z


def bench(nvec, jvec, idxs, X, Z, n_runs=300):
  for _ in range(30):
      _solve_2D2_original(nvec, jvec, idxs, X, Z)
      _solve_2D2_batched(nvec, jvec, idxs, X, Z)

  times_orig = []
  times_batch = []
  for _ in range(n_runs):
      t0 = time.perf_counter()
      _solve_2D2_original(nvec, jvec, idxs, X, Z)
      times_orig.append(time.perf_counter() - t0)

      t0 = time.perf_counter()
      _solve_2D2_batched(nvec, jvec, idxs, X, Z)
      times_batch.append(time.perf_counter() - t0)

  return np.median(times_orig), np.median(times_batch)


configs = [
  # (n_toas, n_basis, n_epochs, epoch_size, label)
  (1000,   80,   60,  17, "small"),
  (2000,  120,  120,  17, ""),
  (4005,  151,  235,  17, "B1855+09"),
  (6000,  160,  350,  17, ""),
  (9730,  178,  369,  26, "B1937+21"),
  (10259, 166,  307,  33, "J1909-3744"),
  (15000, 200,  880,  17, ""),
  (20000, 220, 1175,  17, ""),
  (30000, 260, 1765,  17, ""),
  (50000, 320, 2940,  17, "NANOGrav 15yr scale"),
]

print(f"{'N_toa':>7} {'N_basis':>7} {'E':>6} {'original':>12} {'batched':>12} {'speedup':>8}  {'':>5}")
print("-" * 65)

for n_toas, n_basis, n_epochs, epoch_size, label in configs:
  nvec, jvec, idxs, X, Z = make_synthetic(n_toas, n_basis, n_epochs, epoch_size)
  actual_e = len(idxs)
  n_runs = max(50, min(500, 3000000 // (n_toas * n_basis)))

  r1 = _solve_2D2_original(nvec, jvec, idxs, X, Z)
  r2 = _solve_2D2_batched(nvec, jvec, idxs, X, Z)
  assert np.allclose(r1, r2, atol=1e-10), f"Mismatch at {label or n_toas}!"

  t_orig, t_batch = bench(nvec, jvec, idxs, X, Z, n_runs=n_runs)
  speedup = t_orig / t_batch

  print(f"{n_toas:>7} {n_basis:>7} {actual_e:>6} {t_orig*1e3:>10.3f}ms {t_batch*1e3:>10.3f}ms {speedup:>7.2f}x  {label}")

On my macbook I see:

  N_toa N_basis      E     original      batched  speedup       
-----------------------------------------------------------------
   1000      80     59      0.990ms      0.481ms    2.06x  small
   2000     120    118      3.440ms      1.201ms    2.86x  
   4005     151    235      8.988ms      2.863ms    3.14x  B1855+09
   6000     160    350     14.527ms      4.444ms    3.27x  
   9730     178    369     29.083ms      6.465ms    4.50x  B1937+21
  10259     166    307     27.239ms      6.487ms    4.20x  J1909-3744
  15000     200    880     54.472ms     13.827ms    3.94x  
  20000     220   1175     81.170ms     19.975ms    4.06x  
  30000     260   1765    226.318ms     37.010ms    6.12x  
  50000     320   2940    808.887ms     73.055ms   11.07x  NANOGrav 15yr scale

On a gcloud c2-standard-4:

        N_toa N_basis      E     original      batched  speedup
     -----------------------------------------------------------------
        1000      80     59      2.984ms      1.742ms    1.71x  small
        2000     120    118      8.566ms      4.233ms    2.02x
        4005     151    235     22.757ms     10.025ms    2.27x  B1855+09
        6000     160    350     35.839ms     15.105ms    2.37x
        9730     178    369     51.758ms     22.793ms    2.27x  B1937+21
       10259     166    307     42.799ms     21.437ms    2.00x  J1909-3744
       15000     200    880    127.768ms     46.973ms    2.72x
       20000     220   1175    200.198ms     75.111ms    2.67x
       30000     260   1765    428.119ms    142.687ms    3.00x
       50000     320   2940   1040.927ms    274.145ms    3.80x  NANOGrav 15yr scale

@jberg5
Copy link
Copy Markdown
Author

jberg5 commented Apr 25, 2026

Just for fun I went ahead and benchmarked against fastshermanmorrison and this approach is a bit faster on apple silicon:

  ARM Mac (Apple Accelerate):                                                                           
                                                                                                        
  ┌────────────┬──────────┬──────────────┬─────────────────────────┐                                    
  │   Pulsar   │ Original │   PR #445    │ fastshermanmorrison (C) │                                    
  ├────────────┼──────────┼──────────────┼─────────────────────────┤                                    
  │ B1855+09   │ 15.3ms   │ 3.2ms (4.8x) │ 3.3ms (4.6x)            │                                    
  ├────────────┼──────────┼──────────────┼─────────────────────────┤
  │ B1937+21   │ 31.8ms   │ 8.0ms (4.0x) │ 10.1ms (3.2x)           │                                    
  ├────────────┼──────────┼──────────────┼─────────────────────────┤                                    
  │ J1909-3744 │ 24.5ms   │ 7.2ms (3.4x) │ 9.2ms (2.7x)            │                                    
  └────────────┴──────────┴──────────────┴─────────────────────────┘  

But slower on x86:

  x86 gcloud c2-standard-4 (OpenBLAS):                                                                  
                                                                          
  ┌────────────┬──────────┬───────────────┬─────────────────────────┐                                   
  │   Pulsar   │ Original │    PR #445    │ fastshermanmorrison (C) │
  ├────────────┼──────────┼───────────────┼─────────────────────────┤                                   
  │ B1855+09   │ 32.9ms   │ 13.3ms (2.5x) │ 5.7ms (5.8x)            │     
  ├────────────┼──────────┼───────────────┼─────────────────────────┤
  │ B1937+21   │ 68.3ms   │ 27.6ms (2.5x) │ 11.8ms (5.8x)           │                                   
  ├────────────┼──────────┼───────────────┼─────────────────────────┤
  │ J1909-3744 │ 52.4ms   │ 22.8ms (2.3x) │ 9.9ms (5.3x)            │                                   
  └────────────┴──────────┴───────────────┴─────────────────────────┘     

@jberg5
Copy link
Copy Markdown
Author

jberg5 commented Apr 26, 2026

Same approach for fastshermanmorrison here: nanograv/fastshermanmorrison#10, some decent speedups.

Also explains some of the mystery of why Python beat the C kernel on my Apple silicon.

@vhaasteren
Copy link
Copy Markdown
Member

Hi @jberg5 , it is great you are looking at this! Thanks. Can you modify the PR to merge into dev rather than master please?

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 27, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 71.69%. Comparing base (90dfa56) to head (4ea33d1).
⚠️ Report is 1 commits behind head on master.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #445      +/-   ##
==========================================
+ Coverage   71.58%   71.69%   +0.10%     
==========================================
  Files          13       13              
  Lines        3245     3243       -2     
==========================================
+ Hits         2323     2325       +2     
+ Misses        922      918       -4     
Files with missing lines Coverage Δ
enterprise/signals/signal_base.py 90.28% <100.00%> (+0.03%) ⬆️

... and 3 files with indirect coverage changes


Continue to review full report in Codecov by Sentry.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 6335ff7...4ea33d1. Read the comment docs.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@jberg5 jberg5 changed the base branch from master to dev April 30, 2026 12:07
@jberg5
Copy link
Copy Markdown
Author

jberg5 commented Apr 30, 2026

Thanks @vhaasteren ! Have switched to dev.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants