Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,28 @@
# Electrolyte Foundation Model
Benchmarking RoBERTa model pre-training on molecular datasets.
# MIST: Molecular Insight SMILES Transformer

<div align="center" display="flex" >

![GitHub License](https://img.shields.io/github/license/BattModels/mist)
<a href="https://arxiv.org/abs/2510.18900">![arXiv:2409.15370](https://img.shields.io/badge/cs.LG-2409.15370-b31b1b?style=flat&amp;logo=arxiv&amp;logoColor=red)</a>
[![Model on HF](https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-sm.svg)](https://huggingface.co/mist-models)

</div>


MIST is a family of molecular foundation models for molecular property prediction.
The models were pre-trained on [smirk tokenized](https://github.com/BattModels/smirk) SMILES strings from the [Enamine REAL Space](https://enamine.net/compound-collections/real-compounds/real-space-navigator) dataset using the Masked Language Modeling (MLM) objective, then fine-tuned for downstream prediction tasks.

# Installation

The following provides installation instructions for the top-level package (`electrolyte_fm`), optional add-ons for our
various additional analysis and downstream applications (See `opt/`) may require additional configuration.
various additional analysis and downstream applications (See [`./opt`](./opt/) may require additional configuration.

1. Install [uv](https://docs.astral.sh/uv/getting-started/installation/) and [julia](https://julialang.org/downloads/) (only needed for `/opt` tasks)
2. Instantiate the environment: `uv sync`
3. Use [`submit/submit.py`](./submit/submit.py) to submit a training job or checkout one of our applications in [`./opt`](./opt)

> You may need to install [rust](https://www.rust-lang.org/tools/install) if pre-built wheels for [smirk](https://github.com/BattModels/smirk) are not available on [PyPi](https://pypi.org/project/smirk/).
> Feel free to [open an issue](https://github.com/BattModels/smirk/issues) to request additional pre-built wheels.

## Polaris

Expand Down Expand Up @@ -34,17 +52,22 @@ Same as above except:
1. Build the image `bash container/build.sh`, once build relocate the image `mv /tmp/mist.sif ./mist.sif`
2. Run training within the image `apptainer run --nv mist.sif python train.py ...`

> See `submit/dgx.j2` or `submit/delta.j2` for a more complete example of using the container
> See [`submit/dgx.j2`](./submit/dgx.j2) or [`submit/delta.j2`](./submit/delta.j2) for a more complete example of using the container

# Submitting Jobs

We use a python script ([`submit/submit.py`](./submit/submit.py)) to template training jobs for submission on HPC systems across multiple sites.
Templates may need to be modified for your particular HPC cluster, but should provide a starting point.

```shell
source ./activate # Activate Environment
./submit/submit.py ./submit/polaris.j2 --data ./submit/pretrain.yaml | qsub
```

See `submit/submit.py --help` for more info

> Note: [./activate](./activate) is used to activate the python virtual environment *and* set various environment variables.

# Development

## Pre-commit
Expand Down
1 change: 1 addition & 0 deletions opt/BayesianScaling/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
Expand Down
20 changes: 20 additions & 0 deletions opt/BayesianScaling/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# BayesianScaling

A Julia Package for fitting regression models using MCMC, that was used to fit penalized neural scaling laws.
To install:

- Install Julia: https://julialang.org/downloads/
- Instantiate the package: `julia --project -e 'using Pkg; Pkg.instantiate()`
- Download the wandb records or chains ([doi:10.5281/zenodo.17527149](https://doi.org/10.5281/zenodo.17527149))

## Code Organization

- `./scripts/` are used for fitting and analyzing the neural scaling laws.
- `./plots/` has plotting code for the paper and various conferences
- `./src` is the MCMC regression and analysis package powering this work
- [ppl.jl](./src/ppl.jl): Define a regression first interface for fitting MCMC models,
plus single-pass algorithms for working with posterior samples
- [scaling.jl](./src/scaling.jl): functional forms for neural scaling laws and derived qualities
- [analysis.jl](./src/analysis.jl.jl): Code for predicting the perform of models using fitted neural scaling laws
- `./test/` has the unit tests for the BayesianScaling.jl package
- `./benchmark/`: benchmark suite for evaluating different AD backends using [PkgJogger.jl](https://github.com/awadell1/PkgJogger.jl)
4 changes: 2 additions & 2 deletions opt/BayesianScaling/src/ppl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ function transform_samples(t::TransformVariables.AbstractTransform, x::Matrix{T}
end

function transform!(y::AbstractVector, tt::TransformVariables.TransformTuple, x::AbstractVector)
(; transformations) = tt
transformations = getfield(tt, :inner)
@assert TransformVariables.dimension(tt) == length(y) == length(x)
index = firstindex(y)
for t in transformations
Expand All @@ -242,7 +242,7 @@ transform!(y::AbstractVector, t::TransformVariables.AbstractTransform, x::Abstra
function transfrom_axis(tt::TransformVariables.TransformTuple{<:NamedTuple})
ax_tt = []
index = 1
for (k, t) in pairs(tt.transformations)
for (k, t) in pairs(getfield(tt, :inner))
ax = transfrom_axis(t)
n = TransformVariables.dimension(t)
if ax isa Union{ComponentArrays.ShapedAxis,ComponentArrays.Axis}
Expand Down
13 changes: 13 additions & 0 deletions opt/FeatureMiner/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Feature Miner

Code for evaluating fitted linear probes for their ability to predict various chemically meaningful features.

# Replication

1. Install [Julia](https://julialang.org/downloads/)
2. Instantiate the project: `julia --project -e 'using Pkg; Pkg.instantiate()'`
3. Train linear probes using [linear_probe.jsonnet](../../submit/linear_probe.jsonnet) and [submit/submit.py](../../submit/submit.py) on
MIST finetuned models.
4. Run `julia --project explore_probes.jl` to extract fitted probe weights from the checkpoints
5. Instantiate the plotting code: `julia --project=plots -e 'using Pkg; Pkg.instantiate()'`
6. Evaluate fitted probes: `julia --project=plots ./plots/lipinski_probes.jl`
3 changes: 3 additions & 0 deletions opt/MISTStyle/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# MISTStyle.jl

A collection of plotting utilities and themes for [Makie.jl](https://docs.makie.org/stable/) used through the codebase to generate high-quality plots for publication with a consistent visual theme.
11 changes: 10 additions & 1 deletion opt/TokenizerStats/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
# Analysis Code for "Smirk: An Atomically Complete Tokenizer for Molecular Foundation Models"
# Analysis Code for "Tokenization for Molecular Foundation Models"

<div align="center" display="flex" >

![GitHub License](https://img.shields.io/github/license/BattModels/smirk)
<a href="https://doi.org/10.1021/acs.jcim.5c01856">![paper](https://img.shields.io/badge/paper-10.1021%2Facs.jcim.5c01856-blue)</a>
<a href="https://doi.org/10.5281/zenodo.13761262">![data](https://img.shields.io/badge/data-10.5281%2Fzenodo.13761262-blue)</a>
<a href="https://arxiv.org/abs/2409.15370">![arXiv:2409.15370](https://img.shields.io/badge/cs.LG-2409.15370-b31b1b?style=flat&amp;logo=arxiv&amp;logoColor=red)</a>

</div>

## Installation

Expand Down
12 changes: 12 additions & 0 deletions opt/design/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Evaluating Chemical Trends with MIST

Source code for querying the MIST models on hydrocarbon and other templatable organic molecules.

## Installation

> All commands run from this directory

1. Install [Julia](https://julialang.org/downloads/) and [uv](https://docs.astral.sh/uv/getting-started/installation/)
2. Instantiate the project: `uv run julia --project -e 'using Pkg; Pkg.instantiate()'`
3. Download the mist models to `models/`
4. Recreate the plots `uv run julia --project plots.jl`
4 changes: 2 additions & 2 deletions opt/interp_embeddings/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ Scripts for exploring MIST's embeddings and generating relevant figures from the
## Reproducing Analysis

1. Install [julia](https://julialang.org/downloads/) and the base project (See [Project README](../../README.md))
2. Instantiate the environment `julia --project -e 'using Pkg; Pkg.instantiate()'`
2. Instantiate the environment `uv run julia --project -e 'using Pkg; Pkg.instantiate()'`
3. Obtain model files and place at the appropriate path (see `plots.jl`)
4. Run the script: `julia --project plots.jl`
4. Run the script: `uv run julia --project plots.jl`
17 changes: 17 additions & 0 deletions opt/mixtures/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Mixtures

Code for evaluating the MIST mixture models, exploring mixture space and optimizing mixture composition.

## Installation

> All commands run from this directory

1. Install [Julia](https://julialang.org/downloads/) and [uv](https://docs.astral.sh/uv/getting-started/installation/)
2. Instantiate the project: `uv run julia --project -e 'using Pkg; Pkg.instantiate()'`
3. Obtain the mixtures dataset from [doi:10.5281/zenodo.17527149](https://doi.org/10.5281/zenodo.17527149)

## Reproducing Plots

Once installed, most of the scripts in the current directory can be run with:
- python: `uv run <script>.py`
- julia: `uv run julia --project <script>.jl`
10 changes: 10 additions & 0 deletions opt/olfactory/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Olfactory

Scripts for exploring MIST's olfaction model and generating relevant figures from the paper

## Reproducing Analysis

1. Install [julia](https://julialang.org/downloads/) and the base project (See [Project README](../../README.md))
2. Instantiate the environment `uv run julia --project -e 'using Pkg; Pkg.instantiate()'`
3. Obtain model files and place at the appropriate path (see `discordance.jl` and `olfactory.jl`)
4. Run the script: `uv run julia --project discordance.jl`
1 change: 1 addition & 0 deletions opt/package/.python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12
File renamed without changes.
86 changes: 64 additions & 22 deletions opt/package/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
save_tokenizer,
)
from .write_model_class import write_modeling_module
from .channel_schema import resolve_dataset_channels

cli = typer.Typer()
logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -152,7 +153,7 @@ def pretrained(ckpt: Path, name: Optional[str] = None, safe: bool = True):
create_tar_gz(save_dir)


def export_finetuned(ckpt: Path) -> MISTFinetuned:
def export_finetuned(ckpt: Path, dataset: Optional[str] = None) -> MISTFinetuned:
bundle, best_ckpt = load_model(ckpt, model_class="MISTFinetuned")
train_cfg = read_training_config(ckpt)
tokenizer = train_cfg.get("data")
Expand All @@ -161,12 +162,11 @@ def export_finetuned(ckpt: Path) -> MISTFinetuned:
else:
tokenizer = load_tokenizer("smirk")

# Try to get channels from training config or packaged model config
try:
channels = train_cfg["data"]["init_args"].get("target_columns")
except KeyError:
# If loading from already-packaged model, channels are at top level
channels = train_cfg.get("channels")
channels = resolve_dataset_channels(
train_cfg,
dataset=dataset,
source_names=[ckpt.name, best_ckpt.name],
)

model = MISTFinetuned.from_components(
encoder=bundle.encoder,
Expand All @@ -179,11 +179,19 @@ def export_finetuned(ckpt: Path) -> MISTFinetuned:


@cli.command()
def finetuned(ckpt: Path, name: Optional[str] = None, safe: bool = True):
def finetuned(
ckpt: Path,
name: Optional[str] = None,
safe: bool = True,
dataset: Optional[str] = typer.Option(
None,
help="Override dataset name instead of autodetecting it from the checkpoint config",
),
):
"""
Export a finetuned model with embedded remote code.
"""
model, best_ckpt = export_finetuned(ckpt)
model, best_ckpt = export_finetuned(ckpt, dataset=dataset)

tag = name_model(
model,
Expand Down Expand Up @@ -344,7 +352,7 @@ def excess_physics(ckpt: Path, name: Optional[str] = None, safe: bool = True):
create_tar_gz(save_dir)


def export_mixtures(ckpt: Path) -> MISTMixtures:
def export_mixtures(ckpt: Path, dataset: Optional[str] = None) -> MISTMixtures:
bundle, best_ckpt = load_model(ckpt, model_class="MISTMixtures")

train_cfg = read_training_config(ckpt)
Expand All @@ -358,11 +366,11 @@ def export_mixtures(ckpt: Path) -> MISTMixtures:
if hasattr(temperature_condition, "value"):
temperature_condition = temperature_condition.value

# Try to get channels from training config
try:
channels = train_cfg["data"]["init_args"].get("target_col")
except KeyError:
channels = model_cfg.get("target_columns")
channels = resolve_dataset_channels(
train_cfg,
dataset=dataset,
source_names=[ckpt.name, best_ckpt.name],
)

model = MISTMixtures.from_components(
encoder=bundle.encoder,
Expand All @@ -378,11 +386,19 @@ def export_mixtures(ckpt: Path) -> MISTMixtures:


@cli.command()
def mixtures(ckpt: Path, name: Optional[str] = None, safe: bool = True):
def mixtures(
ckpt: Path,
name: Optional[str] = None,
safe: bool = True,
dataset: Optional[str] = typer.Option(
None,
help="Override dataset name instead of autodetecting it from the checkpoint config",
),
):
"""
Export a mixture property prediction model.
"""
model, best_ckpt = export_mixtures(ckpt)
model, best_ckpt = export_mixtures(ckpt, dataset=dataset)

tag = name_model(
model,
Expand Down Expand Up @@ -423,21 +439,35 @@ def mixtures(ckpt: Path, name: Optional[str] = None, safe: bool = True):
create_tar_gz(save_dir)


def export_multitask(encoder_ckpt: Path, task_ckpt: List[Path]) -> MISTMultiTask:
def export_multitask(
encoder_ckpt: Path,
task_ckpt: List[Path],
task_datasets: Optional[List[Optional[str]]] = None,
) -> MISTMultiTask:
encoder_ckpt = maybe_best_ckpt(encoder_ckpt)
if task_datasets is None:
task_datasets = [None] * len(task_ckpt)
if len(task_datasets) != len(task_ckpt):
raise ValueError("task_datasets must match task_ckpt length")
try:
# Try loading from training checkpoints
encoder_ckpt = maybe_best_ckpt(encoder_ckpt)
encoder = load_encoder(encoder_ckpt)
tokenizer = get_ckpt_tokenizer(encoder_ckpt)

task_networks, transforms, channels = [], [], []
for ckpt in task_ckpt:
for ckpt, dataset in zip(task_ckpt, task_datasets):
ckpt = maybe_best_ckpt(ckpt)
cfg = read_training_config(ckpt)
assert cfg["model"]["init_args"][
"freeze_encoder"
], f"Encoder not frozen for {ckpt}"
channels.extend(cfg["data"]["init_args"]["target_columns"])
channels.extend(
resolve_dataset_channels(
cfg,
dataset=dataset,
source_names=[ckpt.name],
)
)
bundle = SaveConfigWithCkpts.load(ckpt, strict=False)
task_networks.append(bundle.task_network)
transforms.append(bundle.transform)
Expand Down Expand Up @@ -498,6 +528,10 @@ def multitask(
task_ckpt: List[Path] = typer.Option(
[], help="Repeat to add multiple task checkpoints"
),
task_dataset: List[str] = typer.Option(
[],
help="Repeat to override the dataset name for each task checkpoint in order",
),
tasks_in_folder: bool = False,
name: Optional[str] = None,
safe: bool = True,
Expand All @@ -510,7 +544,15 @@ def multitask(
if d.is_dir() and d.name != "pretrained":
task_ckpt.append(get_best_ckpt(d))

model = export_multitask(encoder_ckpt, task_ckpt)
task_datasets: List[Optional[str]]
if task_dataset:
if len(task_dataset) != len(task_ckpt):
raise ValueError("task_dataset must be provided once per task_ckpt")
task_datasets = list(task_dataset)
else:
task_datasets = [None] * len(task_ckpt)

model = export_multitask(encoder_ckpt, task_ckpt, task_datasets=task_datasets)

tag = name_model(
model,
Expand Down
Loading
Loading