Motivation
JAX is widely used in ML research, particularly for:
DeepMind and Google Brain research projects
TPU training workloads
Differentiable programming and scientific computing
Graph neural networks and algorithmic learning
Adding JAX support would expand tropical-gemm's reach to these ecosystems.
Proposed API
from tropical_gemm .jax import (
tropical_maxplus_matmul ,
tropical_minplus_matmul ,
tropical_maxmul_matmul ,
)
import jax .numpy as jnp
a = jnp .array ([[1.0 , 2.0 ], [3.0 , 4.0 ]])
b = jnp .array ([[5.0 , 6.0 ], [7.0 , 8.0 ]])
c = tropical_maxplus_matmul (a , b ) # c[i,j] = max_k(a[i,k] + b[k,j])
Implementation Approach
1. Leverage Existing Infrastructure
Rust core : Already framework-agnostic
DLPack : Already implemented, JAX supports DLPack via jax.dlpack
2. JAX Custom Derivatives
from jax import custom_vjp
import jax .numpy as jnp
import tropical_gemm
@custom_vjp
def tropical_maxplus_matmul (a , b ):
"""JAX-compatible tropical max-plus matmul with autodiff support."""
c , _ = tropical_gemm .maxplus_matmul_with_argmax_dlpack (a , b )
return c
def _fwd (a , b ):
c , argmax = tropical_gemm .maxplus_matmul_with_argmax_dlpack (a , b )
return c , (argmax , a .shape [1 ]) # Save argmax and K for backward
def _bwd (res , g ):
argmax , k = res
m , n = g .shape
# Sparse gradient: only argmax indices contribute
grad_a = jnp .zeros ((m , k )).at [
jnp .arange (m )[:, None ], argmax
].add (g )
grad_b = jnp .zeros ((k , n )).at [
argmax , jnp .arange (n )
].add (g )
return grad_a , grad_b
tropical_maxplus_matmul .defvjp (_fwd , _bwd )
3. File Structure
tropical-gemm/
├── crates/tropical-gemm-python/python/tropical_gemm/
│ ├── __init__.py
│ ├── pytorch.py # Existing
│ └── jax.py # New
Benefits
Benefit
Description
TPU support
JAX has first-class TPU support
Research adoption
Popular in academic ML research
Minimal overhead
Share Rust core, thin JAX wrapper
DLPack reuse
Zero-copy tensor exchange already implemented
Considerations
JAX arrays are immutable - fits well with pure tropical operations
Need to handle JAX's tracing/JIT compilation
Consider jax.pure_callback for non-traceable Rust calls
Batched operations (see Add batched tropical matmul support #26 ) would also benefit JAX users
Related
Motivation
JAX is widely used in ML research, particularly for:
Adding JAX support would expand tropical-gemm's reach to these ecosystems.
Proposed API
Implementation Approach
1. Leverage Existing Infrastructure
jax.dlpack2. JAX Custom Derivatives
3. File Structure
Benefits
Considerations
jax.pure_callbackfor non-traceable Rust callsRelated