Skip to content

jax deseq2#441

Open
adamgayoso wants to merge 1 commit intoscverse:mainfrom
adamgayoso:jax
Open

jax deseq2#441
adamgayoso wants to merge 1 commit intoscverse:mainfrom
adamgayoso:jax

Conversation

@adamgayoso
Copy link
Member

@adamgayoso adamgayoso commented Feb 25, 2026

Implement a JaxInference backend for DESeq2 inference.

This backend allows hardware acceleration via GPU/TPU. Please see the class docstring for a detailed description of the methodology.

A test file has been added to test all the same use cases as the default inference class and all tests pass.

Based on the import structure, the class can only be imported if jax is installed, so no added logic was implemented for handling missing imports.

Copy link

@ilan-gold ilan-gold left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spent some time with this, seems like I would have some broad questions:

  1. Is it worth it to abstract the "math" from the "jax" via array-api? My comments are along these lines. My software engineering instincts are "yes" but I don't know Jax or the math well enough to say either way and also it's a massive lift as it would require unifying with the current implementation. I'll try to spend some time reading up
  2. In general, I see a few uses of things that have multiple implementations in jax, specifically where one implementation lives in jax.spicy and the other in jax.numpy - It seems there are good reasons for this usually like slightly different features. But that raises the question further from the the above of how much of this could be refactored to be reused (modulo which scipy module is called for a numpy or jax array).

Not sure I'm the one most qualified to review, but figured I'd get the ball rolling! Maybe we should discuss among core what to do.

return beta_, mu_, converged_


def _compute_hat_matrix_diagonal(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)


def _compute_mu(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

w = mu / (1 + mu * alpha)
mat = (design_matrix.T * w) @ design_matrix

# There are a few ways to compute the log determinant. Here we use an approach

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return n * dispersion_neg1 * jnp.log(dispersion) - logbinom.sum()


def _nb_loss_variable_terms(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants