CrystalFormer-CSP is an autoregressive transformer designed for crystal structure prediction.
Generating Cs2ZnFe(CN)6 Crystal (mp-570545)
- Contents
- Model Card
- Status
- Get Started
- Installation
- Available Weights
- Crystal Structure Preduction
- Advanced Usage
- How to cite
The model is an autoregressive transformer for the formula conditioned crystal probability distribution P(C|f) = P(g|f) P (W_1 | ... ) P ( A_1 | ... ) P(X_1| ...) P(W_2|...) ... P(L| ...), where
f: chemical formula, e.g.Cu12Sb4S13g: space group number 1-230W: Wyckoff letter ('a', 'b', ...,'A')A: atom type ('H', 'He', ..., 'Og') in the chemical formulaX: factional coordinatesL: lattice vector [a, b, c, alpha, beta, gamma]P(W_i| ...)andP(A_i| ...)are categorical distributions.P(X_i| ...)is the mixture of von Mises distribution.P(L| ...)is the mixture of Gaussian distribution.
We only consider symmetry inequivalent atoms in the crystal representation. The remaining atoms are restored based on the information of space group and Wyckoff letters. There is a natural alphabetical ordering for the Wyckoff letters, starting with 'a' for a position with the site-symmetry group of maximal order and ending with the highest letter for the general position. The sampling procedure starts from higher symmetry sites (with smaller multiplicities) and then goes on to lower symmetry ones (with larger multiplicities). Only for the cases where the Wyckoff letter can not fully determine the structure, one needs to further consider factional coordinates in the loss or sampling.
Major milestones are summarized below.
- v0.1 : Initial release developed from CrystalFormer v0.4.2
Notebooks: The quickest way to get started with CrystalFormer-CSP is our notebooks in the Google Colab platform:
Create a new environment and install the required packages, we recommend using python 3.10.* and conda to create the environment:
conda create -n crystalgpt python=3.10
conda activate crystalgptBefore installing the required packages, you need to install jax and jaxlib first.
pip install -U "jax[cpu]"If you intend to use CUDA (GPU) to speed up the training, it is important to install the appropriate version of jax and jaxlib. It is recommended to check the jax docs for the installation guide. The basic installation command is given below:
pip install --upgrade pip
# NVIDIA CUDA 12 installation
# Note: wheels only available on linux.
pip install -U "jax[cuda12]"pip install -r requirements.txtTo use the command line tools, you need to install the crystalformer package. You can use the following command to install the package:
pip install .We release the weights of the model trained on the Alex20s dataset. More details can be seen in the model card.
python ./main.py --optimizer none --restore_path RESTORE_PATH --K 40 --num_samples 1000 --formula Cu12Sb4S13 --save_path SAVE_PATHoptimizer: the optimizer to use,nonemeans no training, only samplingrestore_path: the path to the model weightsK: the top-K number of space groups will be sampled uniformly.num_samples: the number of samples to generateformula: the chemical formulasave_path: [Optional] the path to save the generated structures, if not provided, the structures will be saved in theRESTORE_PATHfolder.
Instead of providing K for top-K sampling, you may directly provide your favorite space group number
spacegroup: the space group number [1-230]
The sampled structure will be saved in the SAVE_PATH/output_Cu12Sb4S13.csv file. To transform the generated structure from g, W, A, X, L to the cif format, you can use the following command
python ./scripts/awl2struct.py --output_path SAVE_PATH --formula FORMULA output_path: the path to read the generatedL, W, A, Xand save theciffilesformula: the chemical formula constrained in the structure
This will save the generated structures in the cif format to a output_Cu12Sb4S13_struct.csv file.
python scripts/mlff_relax.py \
--restore_path SAVE_PATH \
--filename output_Cu12Sb4S13_struct.csv \
--model orb-v3-conservative-inf-mpa \
--model_path path/to/orb-v3.ckpt \
--relaxationThis will produce relaxed structure in relaxed_structures with predicted energies.
Compute Ehull for all relaxed structures:
python scripts/e_above_hull_alex.py \
--convex_path convex_hull_pbe.json.bz2 \
--restore_path SAVE_PATH \
--filename relaxed_structures.csvRun sampling → CIF conversion → relaxation → Ehull ranking:
./postprocess.sh \
-r RESTORE_PATH \
-k 40 \
--relaxation true \
-n 1000 \
-f Cu12Sb4S13 \
-s SAVE_PATHIn case you are curious about the parameters, run:
./postprocess.sh -h Important
Before running the reinforcement fine-tuning, please make sure you have installed the corresponding machine learning force field model or property prediction model. The mlff_model and mlff_path arguments in the command line should be set according to the model you are using. Now we only support theorb for the BatchRelaxer is also needed for batch structure relaxation during the fine-tuning.
train_ppo --folder ./data/\
--restore_path YOUR_PATH\
--reward ehull\
--convex_path YOUR_PATH/convex_hull_pbe.json.bz2\
--mlff_model orb-v3-conservative-inf-mpa\
--mlff_path YOUR_PATH/orb-v3-conservative-inf-mpa-20250404.ckpt \
--lr 1e-05 \
--dropout_rate 0.0 \
--K 40 \
--batchsize 500 \
--formula LiPH2O4 where
-
folder: the folder to save the model and logs -
restore_path: the path to the pre-trained model weights -
reward: the reward function to use,ehullmeans the energy above the convex hull -
convex_path: the path to the convex hull data, which is used to calculate the$E_{hull}$ . Only used when the reward isehull -
mlff_model: the machine learning force field model to predict the total energy. We supportorbmodel for the$E_{hull}$ reward -
mlff_path: the path to load the checkpoint of the machine learning force field model
python ./main.py --folder ./data/ --train_path YOUR_PATH/alex20s/train.csv --valid_path YOUR_PATH/alx20s/val.csv where
folder: the folder to save the model and logstrain_path: the path to the training datasetvalid_path: the path to the validation dataset
Test the prediction accuracy of space groups on the test dataset
python scripts/predict_g.py --restore_path /home/user_wanglei/private/datafile/crystalgpt/csp/alex20s/csp-99fcd/adam_bs_8000_lr_0.0001_decay_0_clip_1_A_119_W_28_N_21_a_1_w_1_l_1_Nf_5_Kx_16_Kl_4_h0_256_l_16_H_8_k_32_m_256_e_256_drop_0.1_0.1/epoch_046000.pkl --valid_path /opt/data/bcmdata/ZONES/data/PROJECTS/datafile/PRIVATE/zdcao/crystal_gpt/dataset/alex/PBE_20241204/test.lmdb --Nf 5 --Kx 16 --Kl 4 --h0_size 256 --transformer_layers 16 --num_heads 8 --key_size 32 --model_size 256 --embed_size 256 --batchsize 1000s@article{cao2024space,
title={Space Group Informed Transformer for Crystalline Materials Generation},
author={Zhendong Cao and Xiaoshan Luo and Jian Lv and Lei Wang},
year={2024},
eprint={2403.15734},
archivePrefix={arXiv},
primaryClass={cond-mat.mtrl-sci}
}@article{cao2025crystalformerrl,
title={CrystalFormer-RL: Reinforcement Fine-Tuning for Materials Design},
author={Zhendong Cao and Lei Wang},
year={2025},
eprint={2504.02367},
archivePrefix={arXiv},
primaryClass={cond-mat.mtrl-sci},
url={https://arxiv.org/abs/2504.02367},
}@misc{cao2025crystalformercsp,
title={CrystalFormer-CSP: Thinking Fast and Slow for Crystal Structure Prediction},
author={Zhendong Cao and Shigang Ou and Lei Wang},
year={2025},
eprint={2512.18251},
archivePrefix={arXiv},
primaryClass={cond-mat.mtrl-sci},
url={https://arxiv.org/abs/2512.18251},
}Note: This project is unrelated to https://github.com/omron-sinicx/crystalformer with the same name.
