This package is a JAX-based implementation of Conditional Flow Matching
(CFM) - an approach for generative modelling based on continuous normalizing
flows. The API design of this package is closely tied to that of the
TorchCFM library to
allow users used to TorchCFM who want to migrate to JAX an easy transition.
This repository is currently under construction and thus may not be bug-free or complete at this point.
To install JAX-CFM clone this repository and run pip install . in an
environment with a python version >= 3.10.
If you intend to contribute or run examples, please consider installing
with optional packages as well (e.g. pip install .[dev] or
pip install .[examples]).
Eventually, the goal is to make the package available on PyPi.
JAX-CFM relies on ott-jax for Optimal
Transport-related tasks and uses
equinox and
jaxtyping for API
design and type annotations and checking.