diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml new file mode 100644 index 00000000..0ce4cfa6 --- /dev/null +++ b/.github/workflows/python-app.yml @@ -0,0 +1,47 @@ +name: Test NeuralForceField package + +on: [push] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + matrix: + # python-version: ["pypy3.10", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.10"] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Display Python version + run: python -c "import sys; print(sys.version)" + - name: Install basics + run: python -m pip install --upgrade pip setuptools wheel + - name: Install package + run: python -m pip install . + # - name: Install linters + # run: python -m pip install flake8 mypy pylint + # - name: Install documentation requirements + # run: python -m pip install -r docs/requirements.txt + # - name: Test with flake8 + # run: flake8 polymethod + # - name: Test with mypy + # run: mypy polymethod + # - name: Test with pylint + # run: pylint polymethod + - name: Test with pytest + run: | + pip install pytest pytest-cov + pytest nff/tests --doctest-modules --junitxml=junit/test-results-${{ matrix.python-version }}.xml --cov=nff --cov-report=xml --cov-report=html + - name: Upload pytest test results + uses: actions/upload-artifact@v4 + with: + name: pytest-results-${{ matrix.python-version }} + path: junit/test-results-${{ matrix.python-version }}.xml + if: ${{ always() }} + # - name: Test documentation + # run: sphinx-build docs/source docs/build diff --git a/.gitignore b/.gitignore index 6181965e..4b6cfd09 100644 --- a/.gitignore +++ b/.gitignore @@ -66,5 +66,17 @@ dist/ sandbox_excited/ build/ +# Editor files +# vim +*.swp +*.swo + +# pycharm +.idea/ + +# coverage and tests +junit +.coverage + # required exceptions !tutorials/models/ammonia/Ammonia.xyz diff --git a/nff/analysis/attribution.py b/nff/analysis/attribution.py index f8c6631f..43b0a271 100644 --- a/nff/analysis/attribution.py +++ b/nff/analysis/attribution.py @@ -1,18 +1,20 @@ +from typing import Dict, List, Optional, Union + +import numpy as np import torch -from ase.io import Trajectory, write from ase import Atoms -import numpy as np +from ase.io import Trajectory, write +from tqdm import tqdm -from nff.io.ase_calcs import EnsembleNFF from nff.io.ase import AtomsBatch -from nff.utils.scatter import compute_grad +from nff.io.ase_calcs import EnsembleNFF from nff.utils.cuda import batch_to -from typing import Union - -from tqdm import tqdm +from nff.utils.scatter import compute_grad -def get_molecules(atom: AtomsBatch, bond_length: dict = None, mode: str = "bond", **kwargs) -> list[np.array]: +def get_molecules( + atom: AtomsBatch, bond_length: Optional[Dict[str, float]] = None, mode: str = "bond", **kwargs +) -> List[np.array]: """ find molecules in periodic or non-periodic system. bond mode finds molecules within bond length. Must pass bond_length dict: e.g bond_length=dict() @@ -29,7 +31,8 @@ def get_molecules(atom: AtomsBatch, bond_length: dict = None, mode: str = "bond" give extra cutoff = 6 e.g input output: - list of array of atom indices in molecules. e.g: if there is a H2O molecule, you will get a list with the atom indices + list of array of atom indices in molecules. e.g: if there is a H2O molecule, + you will get a list with the atom indices """ types = list(set(atom.numbers)) @@ -50,15 +53,18 @@ def get_molecules(atom: AtomsBatch, bond_length: dict = None, mode: str = "bond" oxy_neighbors = [] if mode == "bond": for t in types: - if bond_length.get("%s-%s" % (ty, t)) != None: + if bond_length.get(f"{ty}-{t}") is not None: oxy_neighbors.extend( list( np.where(atom.numbers == t)[0][ - np.where(dis_sq[i, np.where(atom.numbers == t)[0]] <= bond_length["%s-%s" % (ty, t)])[0] + np.where(dis_sq[i, np.where(atom.numbers == t)[0]] <= bond_length[f"{ty}-{t}"])[0] ] ) ) elif mode == "cutoff": + if "cutoff" not in kwargs: + raise ValueError("Specifying mode 'cutoff' requires passing a cutoff value as a keyword argument") + cutoff = kwargs["cutoff"] oxy_neighbors.extend(list(np.where(dis_sq[i] <= cutoff)[0])) # cutoff input extra argument oxy_neighbors = np.array(oxy_neighbors) if len(oxy_neighbors) == 0: @@ -69,10 +75,10 @@ def get_molecules(atom: AtomsBatch, bond_length: dict = None, mode: str = "bond" elif (clusters[oxy_neighbors] == 0).all() and clusters[i] == 0: clusters[oxy_neighbors] = mm + 1 clusters[i] = mm + 1 - elif (clusters[oxy_neighbors] == 0).all() == False and clusters[i] == 0: + elif not (clusters[oxy_neighbors] == 0).all() and clusters[i] == 0: clusters[i] = min(clusters[oxy_neighbors][clusters[oxy_neighbors] != 0]) clusters[oxy_neighbors] = min(clusters[oxy_neighbors][clusters[oxy_neighbors] != 0]) - elif (clusters[oxy_neighbors] == 0).all() == False and clusters[i] != 0: + elif not (clusters[oxy_neighbors] == 0).all() and clusters[i] != 0: tmp = clusters[oxy_neighbors][clusters[oxy_neighbors] != 0][ clusters[oxy_neighbors][clusters[oxy_neighbors] != 0] != min(clusters[oxy_neighbors][clusters[oxy_neighbors] != 0]) @@ -91,17 +97,17 @@ def get_molecules(atom: AtomsBatch, bond_length: dict = None, mode: str = "bond" return molecules -def reconstruct_atoms(atomsobject: AtomsBatch, mol_idx: list[np.array], centre: int = None): +def reconstruct_atoms(atomsobject: AtomsBatch, mol_idx: List[np.array], centre: Optional[int] = None): """ Function to shift atoms when we create non-periodic system from periodic. inputs: atomsobject: Atomsbatch object from NFF mol_idx: list of array of atom indices in molecules or atoms you want to keep together when changing to non-periodic system - centre: by default the atoms in a molecule or set of close atoms are shifted so as to get them close to the centre which - is by default the first atom index in the array. For reconstructing molecules this is fine. However, for attribution, - we may have to shift a whole molecule to come closer to the atoms with high attribution. In that case, we manually assign - the atom index. + centre: by default the atoms in a molecule or set of close atoms are shifted so as to get them close + to the centre which is by default the first atom index in the array. For reconstructing molecules this is fine. + However, for attribution, we may have to shift a whole molecule to come closer to the atoms with high attribution. + In that case, we manually assign the atom index. """ sys_xyz = torch.Tensor(atomsobject.get_positions(wrap=True)) @@ -111,11 +117,11 @@ def reconstruct_atoms(atomsobject: AtomsBatch, mol_idx: list[np.array], centre: mol_xyz = sys_xyz[idx] if any(atomsobject.pbc): center = mol_xyz.shape[0] // 2 - if centre != None: + if centre is not None: center = centre # changes the central atom to atom in focus intra_dmat = (mol_xyz[None, :, ...] - mol_xyz[:, None, ...])[center] if np.count_nonzero(atomsobject.cell.T - np.diag(np.diagonal(atomsobject.cell.T))) != 0: - M, N = intra_dmat.shape[0], intra_dmat.shape[1] + M, _ = intra_dmat.shape[0], intra_dmat.shape[1] f = torch.linalg.solve(torch.Tensor(atomsobject.cell.T), (intra_dmat.view(-1, 3).T)).T g = f - torch.floor(f + 0.5) intra_dmat = torch.matmul(g, torch.Tensor(atomsobject.cell)) @@ -123,14 +129,13 @@ def reconstruct_atoms(atomsobject: AtomsBatch, mol_idx: list[np.array], centre: offsets = -torch.floor(f + 0.5).view(M, 3) traj_unwrap = mol_xyz + torch.matmul(offsets, torch.Tensor(atomsobject.cell)) else: - sub = (intra_dmat > 0.5 * box_len).to(torch.float) * box_len - add = (intra_dmat <= -0.5 * box_len).to(torch.float) * box_len + (intra_dmat > 0.5 * box_len).to(torch.float) * box_len + (intra_dmat <= -0.5 * box_len).to(torch.float) * box_len shift = torch.round(torch.divide(intra_dmat, box_len)) offsets = -shift traj_unwrap = mol_xyz + offsets * box_len else: traj_unwrap = mol_xyz - # traj_unwrap=mol_xyz+add-sub sys_xyz[idx] = traj_unwrap new_pos = sys_xyz.numpy() @@ -138,11 +143,8 @@ def reconstruct_atoms(atomsobject: AtomsBatch, mol_idx: list[np.array], centre: return new_pos -# - - - class Attribution: - def __init__(self, ensemble: EnsembleNFF, save_file: str = None): + def __init__(self, ensemble: EnsembleNFF, save_file: Optional[str] = None): self.ensemble = ensemble self.save_file = save_file @@ -197,7 +199,7 @@ def calc_attribution_file( step: int = 1, progress_bar: bool = True, to_chemiscope: bool = False, - bond_length: dict = None, + bond_length: Optional[dict] = None, ) -> list: attributions = [] atoms_list = [] @@ -205,9 +207,7 @@ def calc_attribution_file( energy_stds = [] grads = [] grad_stds = [] - with tqdm( - range(skip, len(traj), step), disable=True if progress_bar == False else False - ) as pbar: # , postfix={"fbest":"?",}) as pbar: + with tqdm(range(skip, len(traj), step), disable=not progress_bar) as pbar: # , postfix={"fbest":"?",}) as pbar: # for i in range(skip,len(traj),step): for i in pbar: # create atoms batch object @@ -269,8 +269,7 @@ def calc_attribution_file( }, } return atoms_list, properties - else: - return attributions + return attributions def activelearning( self, @@ -281,12 +280,10 @@ def activelearning( skip: int = 0, step: int = 1, progress_bar: bool = True, - bond_length: dict = None, + bond_length: Optional[dict] = None, ): atom_list = [] - with tqdm( - range(skip, len(traj), step), disable=True if progress_bar == False else False - ) as pbar: # , postfix={"fbest":"?",}) as pbar: + with tqdm(range(skip, len(traj), step), disable=not progress_bar) as pbar: # , postfix={"fbest":"?",}) as pbar: # for i in range(skip,len(traj),step): for i in pbar: # create atoms batch object @@ -337,15 +334,15 @@ def activelearning( neighs = np.append(neighs, a) for n in neighs: atomstocare = np.append(atomstocare, molecules[np.where(balanced_mols == n)[0][0]]) - atomstocare = np.array((list(set(atomstocare)))) + atomstocare = np.array(list(set(atomstocare))) atomstocare = np.int64(atomstocare) atoms1 = atoms[atomstocare] index = np.where(atoms1.positions == atoms.positions[a])[0][0] xyz = reconstruct_atoms(atoms1, [np.arange(0, len(atoms1))], centre=index) atoms1.positions = xyz is_repeated = False - for Atoms in atom_list: - if atoms1.__eq__(Atoms): + for at in atom_list: + if atoms1 == at: is_repeated = True break if not is_repeated: diff --git a/nff/analysis/attribution_deprecate.py b/nff/analysis/attribution_deprecate.py index f8c6631f..8cd9a9f6 100644 --- a/nff/analysis/attribution_deprecate.py +++ b/nff/analysis/attribution_deprecate.py @@ -1,15 +1,16 @@ +# ruff: noqa +from typing import Union + +import numpy as np import torch -from ase.io import Trajectory, write from ase import Atoms -import numpy as np +from ase.io import Trajectory, write +from tqdm import tqdm -from nff.io.ase_calcs import EnsembleNFF from nff.io.ase import AtomsBatch -from nff.utils.scatter import compute_grad +from nff.io.ase_calcs import EnsembleNFF from nff.utils.cuda import batch_to -from typing import Union - -from tqdm import tqdm +from nff.utils.scatter import compute_grad def get_molecules(atom: AtomsBatch, bond_length: dict = None, mode: str = "bond", **kwargs) -> list[np.array]: @@ -29,7 +30,8 @@ def get_molecules(atom: AtomsBatch, bond_length: dict = None, mode: str = "bond" give extra cutoff = 6 e.g input output: - list of array of atom indices in molecules. e.g: if there is a H2O molecule, you will get a list with the atom indices + list of array of atom indices in molecules. + e.g: if there is a H2O molecule, you will get a list with the atom indices """ types = list(set(atom.numbers)) @@ -50,7 +52,7 @@ def get_molecules(atom: AtomsBatch, bond_length: dict = None, mode: str = "bond" oxy_neighbors = [] if mode == "bond": for t in types: - if bond_length.get("%s-%s" % (ty, t)) != None: + if bond_length.get("%s-%s" % (ty, t)) is not None: oxy_neighbors.extend( list( np.where(atom.numbers == t)[0][ @@ -69,10 +71,10 @@ def get_molecules(atom: AtomsBatch, bond_length: dict = None, mode: str = "bond" elif (clusters[oxy_neighbors] == 0).all() and clusters[i] == 0: clusters[oxy_neighbors] = mm + 1 clusters[i] = mm + 1 - elif (clusters[oxy_neighbors] == 0).all() == False and clusters[i] == 0: + elif (clusters[oxy_neighbors] == 0).all() is False and clusters[i] == 0: clusters[i] = min(clusters[oxy_neighbors][clusters[oxy_neighbors] != 0]) clusters[oxy_neighbors] = min(clusters[oxy_neighbors][clusters[oxy_neighbors] != 0]) - elif (clusters[oxy_neighbors] == 0).all() == False and clusters[i] != 0: + elif (clusters[oxy_neighbors] == 0).all() is False and clusters[i] != 0: tmp = clusters[oxy_neighbors][clusters[oxy_neighbors] != 0][ clusters[oxy_neighbors][clusters[oxy_neighbors] != 0] != min(clusters[oxy_neighbors][clusters[oxy_neighbors] != 0]) @@ -98,10 +100,10 @@ def reconstruct_atoms(atomsobject: AtomsBatch, mol_idx: list[np.array], centre: atomsobject: Atomsbatch object from NFF mol_idx: list of array of atom indices in molecules or atoms you want to keep together when changing to non-periodic system - centre: by default the atoms in a molecule or set of close atoms are shifted so as to get them close to the centre which - is by default the first atom index in the array. For reconstructing molecules this is fine. However, for attribution, - we may have to shift a whole molecule to come closer to the atoms with high attribution. In that case, we manually assign - the atom index. + centre: by default the atoms in a molecule or set of close atoms are shifted so as to + get them close to the centre which is by default the first atom index in the array. + For reconstructing molecules this is fine. However, for attribution, we may have to shift a whole molecule + to come closer to the atoms with high attribution. In that case, we manually assign the atom index. """ sys_xyz = torch.Tensor(atomsobject.get_positions(wrap=True)) @@ -111,7 +113,7 @@ def reconstruct_atoms(atomsobject: AtomsBatch, mol_idx: list[np.array], centre: mol_xyz = sys_xyz[idx] if any(atomsobject.pbc): center = mol_xyz.shape[0] // 2 - if centre != None: + if centre is not None: center = centre # changes the central atom to atom in focus intra_dmat = (mol_xyz[None, :, ...] - mol_xyz[:, None, ...])[center] if np.count_nonzero(atomsobject.cell.T - np.diag(np.diagonal(atomsobject.cell.T))) != 0: @@ -123,7 +125,7 @@ def reconstruct_atoms(atomsobject: AtomsBatch, mol_idx: list[np.array], centre: offsets = -torch.floor(f + 0.5).view(M, 3) traj_unwrap = mol_xyz + torch.matmul(offsets, torch.Tensor(atomsobject.cell)) else: - sub = (intra_dmat > 0.5 * box_len).to(torch.float) * box_len + (intra_dmat > 0.5 * box_len).to(torch.float) * box_len add = (intra_dmat <= -0.5 * box_len).to(torch.float) * box_len shift = torch.round(torch.divide(intra_dmat, box_len)) offsets = -shift @@ -206,7 +208,7 @@ def calc_attribution_file( grads = [] grad_stds = [] with tqdm( - range(skip, len(traj), step), disable=True if progress_bar == False else False + range(skip, len(traj), step), disable=True if not progress_bar else False ) as pbar: # , postfix={"fbest":"?",}) as pbar: # for i in range(skip,len(traj),step): for i in pbar: @@ -269,8 +271,7 @@ def calc_attribution_file( }, } return atoms_list, properties - else: - return attributions + return attributions def activelearning( self, @@ -285,7 +286,7 @@ def activelearning( ): atom_list = [] with tqdm( - range(skip, len(traj), step), disable=True if progress_bar == False else False + range(skip, len(traj), step), disable=True if not progress_bar else False ) as pbar: # , postfix={"fbest":"?",}) as pbar: # for i in range(skip,len(traj),step): for i in pbar: @@ -337,15 +338,15 @@ def activelearning( neighs = np.append(neighs, a) for n in neighs: atomstocare = np.append(atomstocare, molecules[np.where(balanced_mols == n)[0][0]]) - atomstocare = np.array((list(set(atomstocare)))) + atomstocare = np.array(list(set(atomstocare))) atomstocare = np.int64(atomstocare) atoms1 = atoms[atomstocare] index = np.where(atoms1.positions == atoms.positions[a])[0][0] xyz = reconstruct_atoms(atoms1, [np.arange(0, len(atoms1))], centre=index) atoms1.positions = xyz is_repeated = False - for Atoms in atom_list: - if atoms1.__eq__(Atoms): + for at in atom_list: + if atoms1 == at: is_repeated = True break if not is_repeated: diff --git a/nff/analysis/cp3d.py b/nff/analysis/cp3d.py index e81b206d..19984526 100644 --- a/nff/analysis/cp3d.py +++ b/nff/analysis/cp3d.py @@ -2,22 +2,21 @@ Tools for analyzing conformer-based model predictions. """ +import json +import logging import os import pickle import random -import logging -import json import numpy as np import torch -from tqdm import tqdm -from sklearn.metrics import roc_auc_score, auc, precision_recall_curve -from sklearn.metrics.pairwise import cosine_similarity as cos_sim from rdkit import Chem +from sklearn.metrics import auc, precision_recall_curve, roc_auc_score +from sklearn.metrics.pairwise import cosine_similarity as cos_sim +from tqdm import tqdm - -from nff.utils import fprint from nff.data.features import get_e3fp +from nff.utils import fprint LOGGER = logging.getLogger() LOGGER.disabled = True @@ -41,10 +40,8 @@ def get_pred_files(model_path): # should have the form _pred_.pickle # or pred_.pickle splits = ["train", "val", "test"] - starts_split = any([file.startswith(f"{split}_pred") - for split in splits]) - starts_pred = any([file.startswith(f"pred") - for split in splits]) + starts_split = any(file.startswith(f"{split}_pred") for split in splits) + starts_pred = file.startswith("pred") if (not starts_split) and (not starts_pred): continue if not file.endswith("pickle"): @@ -89,8 +86,8 @@ def get_att_type(dic): num_confs_list = [] for sub_dic in dic.values(): - num_learned_weights = sub_dic['learned_weights'].shape[0] - num_confs = sub_dic['boltz_weights'].shape[0] + num_learned_weights = sub_dic["learned_weights"].shape[0] + num_confs = sub_dic["boltz_weights"].shape[0] if num_learned_weights in num_weights_list: continue @@ -104,12 +101,11 @@ def get_att_type(dic): if len(num_confs_list) == 2: break - is_linear = ((num_weights_list[1] / num_weights_list[0]) - == (num_confs_list[1] / num_confs_list[0])) + is_linear = (num_weights_list[1] / num_weights_list[0]) == (num_confs_list[1] / num_confs_list[0]) if is_linear: num_heads = int(num_weights_list[0] / num_confs_list[0]) else: - num_heads = int((num_weights_list[0] / num_confs_list[0] ** 2)) + num_heads = int(num_weights_list[0] / num_confs_list[0] ** 2) return num_heads, is_linear @@ -128,34 +124,27 @@ def annotate_confs(dic): """ num_heads, is_linear = get_att_type(dic) for sub_dic in dic.values(): - num_confs = sub_dic['boltz_weights'].shape[0] - if is_linear: - split_sizes = [num_confs] * num_heads - else: - split_sizes = [num_confs ** 2] * num_heads + num_confs = sub_dic["boltz_weights"].shape[0] + split_sizes = [num_confs] * num_heads if is_linear else [num_confs**2] * num_heads - learned = torch.Tensor(sub_dic['learned_weights']) + learned = torch.Tensor(sub_dic["learned_weights"]) head_weights = torch.split(learned, split_sizes) # if it's not linear, sum over conformer pairs to # get the average importance of each conformer if not is_linear: - head_weights = [i.reshape(num_confs, num_confs).sum(0) - for i in head_weights] + head_weights = [i.reshape(num_confs, num_confs).sum(0) for i in head_weights] # the conformers with the highest weight, according to each # head - max_weight_confs = [head_weight.argmax().item() - for head_weight in head_weights] + max_weight_confs = [head_weight.argmax().item() for head_weight in head_weights] # the highest conformer weight assigned by each head - max_weights = [head_weight.max() - for head_weight in head_weights] + max_weights = [head_weight.max() for head_weight in head_weights] # the head that gave out the highest weight max_weight_head = np.argmax(max_weights) # the conformer with the highest of all weights max_weight_conf = max_weight_confs[max_weight_head] - sub_dic["head_weights"] = {i: weights.tolist() for i, weights in - enumerate(head_weights)} + sub_dic["head_weights"] = {i: weights.tolist() for i, weights in enumerate(head_weights)} sub_dic["max_weight_conf"] = max_weight_conf sub_dic["max_weight_head"] = max_weight_head @@ -181,9 +170,7 @@ def choices_from_pickle(paths): return fps_choices -def funcs_for_external(external_fp_fn, - summary_path, - rd_path): +def funcs_for_external(external_fp_fn, summary_path, rd_path): """ If requesting an external method to get and compare fingerprints, then use this function to get a dictionary @@ -231,38 +218,24 @@ def sample_species(dic, classifier, max_samples): # if it's not a classifier, you'll just randomly sample # different species pairs and compare their fingerprints keys = list(dic.keys()) - samples = [np.random.choice(keys, max_samples), - np.random.choice(keys, max_samples)] + samples = [np.random.choice(keys, max_samples), np.random.choice(keys, max_samples)] sample_dics = {"random_mols": samples} else: # if it is a classifier, you'll want to compare species # that are both hits, both misses, or one hit and one miss - pos_keys = [smiles for smiles, sub_dic in dic.items() - if sub_dic['true'] == 1] - neg_keys = [smiles for smiles, sub_dic in dic.items() - if sub_dic['true'] == 0] - - intra_pos = [np.random.choice(pos_keys, max_samples), - np.random.choice(pos_keys, max_samples)] - intra_neg = [np.random.choice(neg_keys, max_samples), - np.random.choice(neg_keys, max_samples)] - inter = [np.random.choice(pos_keys, max_samples), - np.random.choice(neg_keys, max_samples)] - - sample_dics = {"intra_pos": intra_pos, - "intra_neg": intra_neg, - "inter": inter} + pos_keys = [smiles for smiles, sub_dic in dic.items() if sub_dic["true"] == 1] + neg_keys = [smiles for smiles, sub_dic in dic.items() if sub_dic["true"] == 0] + + intra_pos = [np.random.choice(pos_keys, max_samples), np.random.choice(pos_keys, max_samples)] + intra_neg = [np.random.choice(neg_keys, max_samples), np.random.choice(neg_keys, max_samples)] + inter = [np.random.choice(pos_keys, max_samples), np.random.choice(neg_keys, max_samples)] + + sample_dics = {"intra_pos": intra_pos, "intra_neg": intra_neg, "inter": inter} return sample_dics -def calc_sim(dic, - smiles_0, - smiles_1, - func, - pickle_dic, - conf_type, - fp_kwargs): +def calc_sim(dic, smiles_0, smiles_1, func, pickle_dic, conf_type, fp_kwargs): """ Calculate the similatiy between conformers of two different species. Args: @@ -294,7 +267,6 @@ def calc_sim(dic, fp_1_choices = sub_dic_1["conf_fps"] if conf_type == "att": - conf_0_idx = sub_dic_0["max_weight_conf"] conf_1_idx = sub_dic_1["max_weight_conf"] @@ -312,20 +284,14 @@ def calc_sim(dic, if isinstance(fp, Chem.rdchem.Mol): fps[j] = func(fp, **fp_kwargs) - sim = cos_sim(fps[0].reshape(1, -1), - fps[1].reshape(1, -1)).item() + sim = cos_sim(fps[0].reshape(1, -1), fps[1].reshape(1, -1)).item() return sim -def attention_sim(dic, - max_samples, - classifier, - seed, - external_fp_fn=None, - summary_path=None, - rd_path=None, - fp_kwargs=None): +def attention_sim( + dic, max_samples, classifier, seed, external_fp_fn=None, summary_path=None, rd_path=None, fp_kwargs=None +): """ Calculate similarities of the conformer fingerprints of different pairs of species. @@ -359,9 +325,7 @@ def attention_sim(dic, # get an external fingeprinting function if asked if external_fp_fn is not None: - pickle_dic, func = funcs_for_external(external_fp_fn, - summary_path, - rd_path) + pickle_dic, func = funcs_for_external(external_fp_fn, summary_path, rd_path) else: pickle_dic = None func = None @@ -373,21 +337,22 @@ def attention_sim(dic, # conformer similarities for key, samples in sample_dics.items(): - fp_dics[key] = {} - conf_types = ['att', 'random'] + conf_types = ["att", "random"] for conf_type in conf_types: fp_sims = [] for i in tqdm(range(len(samples[0]))): smiles_0 = samples[0][i] smiles_1 = samples[1][i] - sim = calc_sim(dic=dic, - smiles_0=smiles_0, - smiles_1=smiles_1, - func=func, - pickle_dic=pickle_dic, - conf_type=conf_type, - fp_kwargs=fp_kwargs) + sim = calc_sim( + dic=dic, + smiles_0=smiles_0, + smiles_1=smiles_1, + func=func, + pickle_dic=pickle_dic, + conf_type=conf_type, + fp_kwargs=fp_kwargs, + ) fp_sims.append(sim) fp_dics[key][conf_type] = np.array(fp_sims) @@ -409,10 +374,11 @@ def analyze_data(bare_data, analysis): """ for key, val in bare_data.items(): if isinstance(val, np.ndarray): - analysis[key] = {"mean": np.mean(val), - "std": np.std(val), - "std_of_mean": (np.std(val) - / val.shape[0] ** 0.5)} + analysis[key] = { + "mean": np.mean(val), + "std": np.std(val), + "std_of_mean": (np.std(val) / val.shape[0] ** 0.5), + } else: if key not in analysis: analysis[key] = {} @@ -433,8 +399,8 @@ def report_delta(bare_dic): fprint("+/- indicates standard deviation of the mean") # attention and random differences in similarity - delta_att = dic['intra_pos']['att'] - dic['inter']['att'] - delta_rand = dic['intra_pos']['random'] - dic['inter']['random'] + delta_att = dic["intra_pos"]["att"] - dic["inter"]["att"] + delta_rand = dic["intra_pos"]["random"] - dic["inter"]["random"] # compute mean for attention delta_att_mean = np.mean(delta_att) @@ -449,24 +415,17 @@ def report_delta(bare_dic): # a measure of how much attention is learning delta_delta_mean = delta_att_mean - delta_rand_mean - delta_delta_std = ((np.var(delta_att) + np.var(delta_rand)) ** 0.5 - / (len(delta_att)) ** 0.5) + delta_delta_std = (np.var(delta_att) + np.var(delta_rand)) ** 0.5 / (len(delta_att)) ** 0.5 - fprint("Delta att: %.4f +/- %.4f" % (delta_att_mean, delta_att_std)) - fprint("Delta rand: %.4f +/- %.4f" % (delta_rand_mean, delta_rand_std)) - fprint("Delta delta: %.4f +/- %.4f" % - (delta_delta_mean, delta_delta_std)) + fprint(f"Delta att: {delta_att_mean:.4f} +/- {delta_att_std:.4f}") + fprint(f"Delta rand: {delta_rand_mean:.4f} +/- {delta_rand_std:.4f}") + fprint(f"Delta delta: {delta_delta_mean:.4f} +/- {delta_delta_std:.4f}") fprint("\n") -def conf_sims_from_files(model_path, - max_samples, - classifier, - seed, - external_fp_fn=None, - summary_path=None, - rd_path=None, - fp_kwargs=None): +def conf_sims_from_files( + model_path, max_samples, classifier, seed, external_fp_fn=None, summary_path=None, rd_path=None, fp_kwargs=None +): """ Get similarity among species according to predictions of different models, given a folder with all of the prediction pickles. @@ -504,14 +463,16 @@ def conf_sims_from_files(model_path, for key in tqdm(pred): dic = pred[key] annotate_confs(dic) - fp_dics = attention_sim(dic=dic, - max_samples=max_samples, - classifier=classifier, - seed=seed, - external_fp_fn=external_fp_fn, - summary_path=summary_path, - rd_path=rd_path, - fp_kwargs=fp_kwargs) + fp_dics = attention_sim( + dic=dic, + max_samples=max_samples, + classifier=classifier, + seed=seed, + external_fp_fn=external_fp_fn, + summary_path=summary_path, + rd_path=rd_path, + fp_kwargs=fp_kwargs, + ) bare_data[key] = fp_dics # analyze the bare data @@ -524,7 +485,7 @@ def conf_sims_from_files(model_path, return analysis, bare_data -def get_scores(path, avg_metrics=['auc', 'prc-auc']): +def get_scores(path, avg_metrics=["auc", "prc-auc"]): """ Load pickle files that contain predictions and actual values, using models evaluated by different validation metrics, and use the predictions @@ -538,10 +499,9 @@ def get_scores(path, avg_metrics=['auc', 'prc-auc']): used, the validation metric used to get the model, and the PRC and AUC scores. """ - files = [i for i in os.listdir(path) if i.endswith(".pickle") - and i.startswith("pred")] + files = [i for i in os.listdir(path) if i.endswith(".pickle") and i.startswith("pred")] if not files: - return + return None scores = [] for file in files: with open(os.path.join(path, file), "rb") as f: @@ -549,38 +509,27 @@ def get_scores(path, avg_metrics=['auc', 'prc-auc']): split = file.split(".pickle")[0].split("_")[-1] from_metric = file.split("pred_")[-1].split(f"_{split}")[0] - pred = [sub_dic['pred'] for sub_dic in dic.values()] - true = [sub_dic['true'] for sub_dic in dic.values()] + pred = [sub_dic["pred"] for sub_dic in dic.values()] + true = [sub_dic["true"] for sub_dic in dic.values()] # then it's not a binary classification problem - if any([i not in [0, 1] for i in true]): - return + if any(i not in [0, 1] for i in true): + return None auc_score = roc_auc_score(y_true=true, y_score=pred) - precision, recall, thresholds = precision_recall_curve( - y_true=true, probas_pred=pred) + precision, recall, thresholds = precision_recall_curve(y_true=true, probas_pred=pred) prc_score = auc(recall, precision) - scores.append({"split": split, - "from_metric": from_metric, - "auc": auc_score, - "prc": prc_score}) + scores.append({"split": split, "from_metric": from_metric, "auc": auc_score, "prc": prc_score}) if avg_metrics is None: avg_metrics = [score["from_metric"] for score in scores] - all_auc = [score["auc"] for score in scores if score['from_metric'] - in avg_metrics] - all_prc = [score["prc"] for score in scores if score['from_metric'] - in avg_metrics] - avg_auc = {"mean": np.mean(all_auc), - "std": np.std(all_auc)} - avg_prc = {"mean": np.mean(all_prc), - "std": np.std(all_prc)} - scores.append({"from_metric": "average", - "auc": avg_auc, - "prc": avg_prc, - "avg_metrics": avg_metrics}) + all_auc = [score["auc"] for score in scores if score["from_metric"] in avg_metrics] + all_prc = [score["prc"] for score in scores if score["from_metric"] in avg_metrics] + avg_auc = {"mean": np.mean(all_auc), "std": np.std(all_auc)} + avg_prc = {"mean": np.mean(all_prc), "std": np.std(all_prc)} + scores.append({"from_metric": "average", "auc": avg_auc, "prc": avg_prc, "avg_metrics": avg_metrics}) save_path = os.path.join(path, "scores_from_metrics.json") with open(save_path, "w") as f: @@ -589,7 +538,7 @@ def get_scores(path, avg_metrics=['auc', 'prc-auc']): return scores -def recursive_scoring(base_path, avg_metrics=['auc', 'prc-auc']): +def recursive_scoring(base_path, avg_metrics=["auc", "prc-auc"]): """ Recursively search in a base directory to find sub-folders that have pickle files that can be used for scoring. Apply `get_scores` @@ -602,8 +551,7 @@ def recursive_scoring(base_path, avg_metrics=['auc', 'prc-auc']): None """ - files = [i for i in os.listdir(base_path) if i.endswith(".pickle") - and i.startswith("pred")] + files = [i for i in os.listdir(base_path) if i.endswith(".pickle") and i.startswith("pred")] if files: print(f"Analyzing {base_path}") get_scores(base_path, avg_metrics) @@ -612,15 +560,13 @@ def recursive_scoring(base_path, avg_metrics=['auc', 'prc-auc']): direc_path = os.path.join(base_path, direc) if not os.path.isdir(direc_path): continue - files = [i for i in os.listdir(direc_path) if i.endswith(".pickle") - and i.startswith("pred")] + files = [i for i in os.listdir(direc_path) if i.endswith(".pickle") and i.startswith("pred")] if files: print(f"Analyzing {direc_path}") get_scores(direc_path, avg_metrics) continue - folders = [os.path.join(direc_path, i) for i in - os.listdir(direc_path)] + folders = [os.path.join(direc_path, i) for i in os.listdir(direc_path)] folders = [i for i in folders if os.path.isdir(i)] if not folders: diff --git a/nff/analysis/loss_plot.py b/nff/analysis/loss_plot.py index d7458572..6286e2cd 100644 --- a/nff/analysis/loss_plot.py +++ b/nff/analysis/loss_plot.py @@ -4,15 +4,21 @@ from . import mpl_settings -def plot_loss(energy_history, forces_history, figname, train_key="train", val_key="val"): +def plot_loss( + energy_history: dict, + forces_history: dict, + figname: str, + train_key: str = "train", + val_key: str = "val", +) -> None: """Plot the loss history of the model. - Args: - energy_history (dict): energy loss history of the model for training and validation - forces_history (dict): forces loss history of the model for training and validation - figname (str): name of the figure - Returns: - None + Args: + energy_history: energy loss history of the model for training and validation + forces_history: forces loss history of the model for training and validation + figname: name of the figure + train_key: key for training data in the history dictionary + val_key: key for validation data in the history dictionary """ epochs = np.arange(1, len(energy_history[train_key]) + 1) fig, ax_fig = plt.subplots(1, 2, figsize=(5, 2.5), dpi=mpl_settings.DPI) diff --git a/nff/analysis/mpl_settings.py b/nff/analysis/mpl_settings.py index c61308ec..d3f23b80 100644 --- a/nff/analysis/mpl_settings.py +++ b/nff/analysis/mpl_settings.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from pathlib import Path from typing import List @@ -61,11 +63,11 @@ plt.rcParams.update(custom_settings) -def update_custom_settings(custom_settings: dict = custom_settings) -> None: +def update_custom_settings(custom_settings: dict | None = custom_settings) -> None: """Update the custom settings for Matplotlib. Args: - custom_settings (dict, optional): Custom settings for Matplotlib. Defaults to + custom_settings: Custom settings for Matplotlib. Defaults to custom_settings. """ current_settings = plt.rcParams.copy() @@ -77,10 +79,7 @@ def hex_to_rgb(value: str) -> list[float]: """Converts hex to rgb colors. Args: - value (str): string of 6 characters representing a hex colour. - - Returns: - list: length 3 of RGB values + value: string of 6 characters representing a hex color. """ value = value.strip("#") # removes hash symbol if present lv = len(value) @@ -91,7 +90,7 @@ def rgb_to_dec(value: list[float]) -> list[float]: """Converts rgb to decimal colors (i.e. divides each value by 256). Args: - value (list[float]): string of 6 characters representing a hex colour. + value: string of 6 characters representing a hex color. Returns: list: length 3 of RGB values @@ -107,12 +106,9 @@ def get_continuous_cmap( each color in hex_list is mapped to the respective location in float_list. Args: - hex_list (list[str]): list of hex code strings - float_list (list[float]): list of floats between 0 and 1, same length as hex_list. Must - start with 0 and end with 1. - - Returns: - matplotlib.colors.LinearSegmentedColormap: continuous + hex_list: list of hex code strings + float_list: list of floats between 0 and 1, same length as hex_list. + Must start with 0 and end with 1. """ rgb_list = [rgb_to_dec(hex_to_rgb(i)) for i in hex_list] if float_list: diff --git a/nff/analysis/parity_plot.py b/nff/analysis/parity_plot.py index 59bb3287..5a44d9d6 100644 --- a/nff/analysis/parity_plot.py +++ b/nff/analysis/parity_plot.py @@ -1,13 +1,16 @@ -from typing import Dict, Literal, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Literal import matplotlib.pyplot as plt import numpy as np import pandas as pd -import torch from matplotlib.lines import Line2D from scipy import stats from scipy.stats import gaussian_kde +if TYPE_CHECKING: + from torch import Tensor from nff.data import to_tensor from nff.utils import cuda @@ -18,8 +21,8 @@ def plot_parity( - results: Dict[str, Union[list, torch.Tensor]], - targets: Dict[str, Union[list, torch.Tensor]], + results: Dict[str, list | Tensor], + targets: Dict[str, list | Tensor], figname: str, plot_type: Literal["hexbin", "scatter"] = "hexbin", energy_key: str = "energy", @@ -29,13 +32,13 @@ def plot_parity( """Perform a parity plot between the results and the targets. Args: - results (dict): dictionary containing the results - targets (dict): dictionary containing the targets - figname (str): name of the figure - plot_type (str): type of plot to use, either "hexbin" or "scatter" - energy_key (str): key for the energy - force_key (str): key for the forces - units (dict): dictionary containing the units of the keys + results: dictionary containing the results + targets: dictionary containing the targets + figname: name of the figure + plot_type: type of plot to use, either "hexbin" or "scatter" + energy_key: key for the energy + force_key: key for the forces + units: dictionary containing the units of the keys Returns: float: MAE of the energy @@ -94,13 +97,13 @@ def plot_parity( label = key ax.set_title(label.upper()) - ax.set_xlabel("Predicted %s [%s]" % (label, units[key])) - ax.set_ylabel("Target %s [%s]" % (label, units[key])) + ax.set_xlabel(f"Predicted {label} [{units[key]}]") + ax.set_ylabel(f"Target {label} [{units[key]}]") ax.text( 0.1, 0.9, - "MAE: %.3f %s" % (mae, units[key]), + f"MAE: {mae:.2f} {units[key]}", transform=ax.transAxes, ) @@ -113,8 +116,8 @@ def plot_parity( def plot_err_var( - err: Union[torch.Tensor, np.ndarray], - var: Union[torch.Tensor, np.ndarray], + err: Tensor | np.ndarray, + var: Tensor | np.ndarray, figname: str, units: str = "eV/Å", x_min: float = 0.0, @@ -128,27 +131,24 @@ def plot_err_var( """Plot the error vs variance of the forces. Args: - err (torch.Tensor): error of the forces - var (torch.Tensor): variance of the forces - figname (str): name of the figure - units (str): units of the error and variance - x_min (float): minimum value of the x-axis - x_max (float): maximum value of the x-axis - y_min (float): minimum value of the y-axis - y_max (float): maximum value of the y-axis - sample_frac (float): fraction of the data to sample for the plot - num_bins (int): number of bins to use for binning - cb_format (str): format of the colorbar - - Returns: - None + err: error of the forces + var: variance of the forces + figname: name of the figure + units: units of the error and variance + x_min: minimum value of the x-axis + x_max: maximum value of the x-axis + y_min: minimum value of the y-axis + y_max: maximum value of the y-axis + sample_frac: fraction of the data to sample for the plot + num_bins: number of bins to use for binning + cb_format: format of the colorbar """ fig, ax = plt.subplots(1, 1, figsize=(6, 6), dpi=mpl_settings.DPI) idx = np.arange(len(var)) np.random.seed(2) sample_idx = np.random.choice(idx, size=int(len(idx) * sample_frac), replace=False) - n_samples = len(sample_idx) + len(sample_idx) var = var.flatten()[sample_idx] err = err.flatten()[sample_idx] @@ -194,10 +194,10 @@ def plot_err_var( label="Avg. best fit", zorder=1, ) - min_text = ax.text( + ax.text( 0.6, 0.9, - r"$R^2$: {:.3f}".format(res.rvalue**2), + rf"$R^2$: {res.rvalue**2:.3f}", transform=ax.transAxes, ) diff --git a/nff/analysis/roce.py b/nff/analysis/roce.py index 54d2fc05..361c5c48 100644 --- a/nff/analysis/roce.py +++ b/nff/analysis/roce.py @@ -3,20 +3,18 @@ at different enrichment factors. """ +import argparse import copy -import pickle import json import math +import pickle -import argparse +import numpy as np from matplotlib import pyplot as plt from matplotlib import rcParams -import numpy as np - from nff.utils import read_csv - # height of each ROCE bar slice in the plots, normalized # to max value of all bars in the plot BAR_HEIGHT = 0.02 @@ -25,17 +23,17 @@ DELTA = 0.2 # keys for specifying text attributes -TEXT_KEYS = ['fontsize'] +TEXT_KEYS = ["fontsize"] # use the same defaults as in iPython notebooks # to avoid an unhappy surprise after testing your plots # in a notebook -rcParams['figure.figsize'] = (6.0, 4.0) -rcParams['font.size'] = 10 -rcParams['savefig.dpi'] = 72 -rcParams['figure.subplot.bottom'] = 0.125 +rcParams["figure.figsize"] = (6.0, 4.0) +rcParams["font.size"] = 10 +rcParams["savefig.dpi"] = 72 +rcParams["figure.subplot.bottom"] = 0.125 def compute_roce(fpr, preds, real): @@ -213,9 +211,7 @@ def remove_overlap(scores, height): # the model that are ordered before this one for other_val in vals: # calculate how much you have to change this value - change = get_change(this_val=new_vals[model, i], - other_val=other_val, - height=height) + change = get_change(this_val=new_vals[model, i], other_val=other_val, height=height) # if you've changed it at all after comparing # to another value, break if abs(change) > eps: @@ -231,9 +227,7 @@ def remove_overlap(scores, height): return new_vals -def parse_csv(pred_path, - true_path, - target): +def parse_csv(pred_path, true_path, target): """ Get the list of predicted and real values from a csv file. Running `predict.sh` on the results of a ChemProp calculation @@ -263,9 +257,7 @@ def parse_csv(pred_path, return [pred], [real] -def parse_json(pred_path, - target, - split): +def parse_json(pred_path, target, split): """ Get the list of predicted and real values from a JSON file. Running `predict.sh` on the results of a ChemProp calculation @@ -297,12 +289,11 @@ def parse_json(pred_path, values. """ - with open(pred_path, 'r') as f_open: + with open(pred_path, "r") as f_open: pred_dic = json.load(f_open) # int keys for different seeds - int_keys = list([i for i in pred_dic.keys() - if i.isdigit()]) + int_keys = [i for i in pred_dic if i.isdigit()] # get the predictions of each seed preds = [] @@ -336,16 +327,13 @@ def parse_pickle(path): with open(path, "rb") as f_open: dic = pickle.load(f_open) - real = np.array([sub_dic['true'] for sub_dic in dic.values()]) - pred = np.array([sub_dic['pred'] for sub_dic in dic.values()]) + real = np.array([sub_dic["true"] for sub_dic in dic.values()]) + pred = np.array([sub_dic["pred"] for sub_dic in dic.values()]) return [pred], [real] -def get_all_preds(true_path, - pred_paths, - target, - split): +def get_all_preds(true_path, pred_paths, target, split): """ Get all predictions from various different versions of a model (e.g. different seeds of a ChemProp model). @@ -368,14 +356,10 @@ def get_all_preds(true_path, for path in pred_paths: if path.endswith("csv"): - these_preds, these_real = parse_csv(pred_path=path, - true_path=true_path, - target=target) + these_preds, these_real = parse_csv(pred_path=path, true_path=true_path, target=target) elif path.endswith("json"): - these_preds, these_real = parse_json(pred_path=path, - target=target, - split=split) + these_preds, these_real = parse_json(pred_path=path, target=target, split=split) elif path.endswith("pickle"): these_preds, these_real = parse_pickle(path) else: @@ -390,11 +374,7 @@ def get_all_preds(true_path, return preds, reals -def get_mean_roce(true_path, - pred_paths, - target, - split, - fpr_vals): +def get_mean_roce(true_path, pred_paths, target, split, fpr_vals): """ Get mean ROCE scores from various different versions of a model (e.g. different seeds of a ChemProp model). @@ -410,14 +390,10 @@ def get_mean_roce(true_path, averaged over the different versions of the model. """ - all_preds, all_reals = get_all_preds(true_path=true_path, - pred_paths=pred_paths, - target=target, - split=split) + all_preds, all_reals = get_all_preds(true_path=true_path, pred_paths=pred_paths, target=target, split=split) roces = [] for pred, real in zip(all_preds, all_reals): - roce = np.array([compute_roce(fpr, pred, real) - for fpr in fpr_vals]) + roce = np.array([compute_roce(fpr, pred, real) for fpr in fpr_vals]) roces.append(roce) mean_roce = np.stack(roces).mean(axis=0) @@ -446,11 +422,8 @@ def add_model_roces(plot_dic): for i, model_dic in enumerate(model_dics): mean_roce = get_mean_roce( - true_path=true_path, - pred_paths=model_dic["pred_paths"], - target=target, - split=split, - fpr_vals=fpr_vals) + true_path=true_path, pred_paths=model_dic["pred_paths"], target=target, split=split, fpr_vals=fpr_vals + ) model_dics[i]["roce"] = mean_roce return plot_dic @@ -490,21 +463,17 @@ def vals_for_plot(plot_dic): roce_no_overlap = remove_overlap(roce_scores, bar_height) # get the default cycle colors - fpr_colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] + fpr_colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] # set the labels equal to the plot names, but replace every space # with a new line to avoid overlapping labels plot_names = [dic["plot_name"] for dic in model_dics] - labels = [l.replace(" ", "\n") for l in plot_names] + labels = [label.replace(" ", "\n") for label in plot_names] return roce_scores, roce_no_overlap, fpr_colors, labels, bar_height -def base_plot(roce_scores, - roce_no_overlap, - labels, - fpr_colors, - bar_height): +def base_plot(roce_scores, roce_no_overlap, labels, fpr_colors, bar_height): """ Make the basic ROCE plot without any extra features, label sizes, axis limits, etc. @@ -541,7 +510,6 @@ def base_plot(roce_scores, # go through each value of fpr for j in range(roce_no_overlap.shape[1]): - # start is where the slice starts start = -0.4 @@ -561,23 +529,11 @@ def base_plot(roce_scores, x_range = np.arange(start, end, interval / 100) y_vals = np.array([new_perform] * len(x_range)) - plt.plot(x_range, y_vals, - '-', - color=fpr_colors[j], - linewidth=3, - label='_nolegend_') + plt.plot(x_range, y_vals, "-", color=fpr_colors[j], linewidth=3, label="_nolegend_") # add black lines at +/- bar_height / 2 - plt.plot(x_range, (y_vals - bar_height / 2), - '-', - color='black', - linewidth=1, - label='_nolegend_') - plt.plot(x_range, (y_vals + bar_height / 2), - '-', - color='black', - linewidth=1, - label='_nolegend_') + plt.plot(x_range, (y_vals - bar_height / 2), "-", color="black", linewidth=1, label="_nolegend_") + plt.plot(x_range, (y_vals + bar_height / 2), "-", color="black", linewidth=1, label="_nolegend_") # add `DELTA` to `start` as you continue left to # right in the plot @@ -586,9 +542,7 @@ def base_plot(roce_scores, return axis -def set_plot_ylim(max_scale, - roce_no_overlap, - bar_height): +def set_plot_ylim(max_scale, roce_no_overlap, bar_height): """ Set the y limits for the plot. Args: @@ -605,15 +559,13 @@ def set_plot_ylim(max_scale, max_score = roce_no_overlap.max() min_scale = max_scale if (min_score < 0) else (1 / max_scale) - ylim = [min([min_score * min_scale, 0]) - - bar_height, max_score * max_scale] + ylim = [min([min_score * min_scale, 0]) - bar_height, max_score * max_scale] plt.ylim(ylim) return ylim -def set_tick_sizes(x_axis_dic, - y_axis_dic): +def set_tick_sizes(x_axis_dic, y_axis_dic): """ Set plot tick sizes. Args: @@ -625,24 +577,20 @@ def set_tick_sizes(x_axis_dic, None """ - # x-axis tick font size + # x-axis tick font size if "ticks" in x_axis_dic: tick_dic = x_axis_dic["ticks"] if "fontsize" in tick_dic: - plt.rc('xtick', labelsize=tick_dic["fontsize"]) + plt.rc("xtick", labelsize=tick_dic["fontsize"]) # y-axis tick font size if "ticks" in y_axis_dic: tick_dic = y_axis_dic["ticks"] if "fontsize" in tick_dic: - plt.rc('ytick', labelsize=tick_dic["fontsize"]) + plt.rc("ytick", labelsize=tick_dic["fontsize"]) -def label_plot(fpr_vals, - legend_dic, - x_axis_dic, - y_axis_dic, - axis): +def label_plot(fpr_vals, legend_dic, x_axis_dic, y_axis_dic, axis): """ Add various labels to the plot. Args: @@ -661,37 +609,30 @@ def label_plot(fpr_vals, # legend fpr_pct = [(use_val * 100) for use_val in fpr_vals] - fpr_str = [("%.1f" % val) if (val < 1) else ("%d" % val) - for val in fpr_pct] + fpr_str = [("%.1f" % val) if (val < 1) else ("%d" % val) for val in fpr_pct] - kwargs = {key: legend_dic[key] for key in - [*TEXT_KEYS, 'loc', 'ncol'] if key in legend_dic} + kwargs = {key: legend_dic[key] for key in [*TEXT_KEYS, "loc", "ncol"] if key in legend_dic} if legend_dic.get("use_legend", True): - plt.legend([f'{string}%' for string in fpr_str], - **kwargs) + plt.legend([f"{string}%" for string in fpr_str], **kwargs) # y-axis font size and label ylabel_kwargs = {} - if 'labels' in y_axis_dic: - label_dic = y_axis_dic['labels'] - if 'fontsize' in label_dic: + if "labels" in y_axis_dic: + label_dic = y_axis_dic["labels"] + if "fontsize" in label_dic: ylabel_kwargs["fontsize"] = label_dic["fontsize"] plt.ylabel("ROCE", **ylabel_kwargs) # x-axis label font sizes - if 'labels' in x_axis_dic: - label_dic = x_axis_dic['labels'] - if 'fontsize' in label_dic: + if "labels" in x_axis_dic: + label_dic = x_axis_dic["labels"] + if "fontsize" in label_dic: for label in axis.get_xticklabels(): - label.set_fontsize(label_dic['fontsize']) + label.set_fontsize(label_dic["fontsize"]) -def decorate_plot(labels, - ylim, - axis, - dividers=None, - texts=None): +def decorate_plot(labels, ylim, axis, dividers=None, texts=None): """ Add various "decorations" to the plot - such as dividers between different model categories, text on the plot, etc. @@ -709,13 +650,8 @@ def decorate_plot(labels, """ max_x = len(labels) - x_range = np.arange(-0.5, max_x, max_x/100) - plt.plot(x_range, [0] * len(x_range), - '-', - color='black', - linewidth=1, - label='_nolegend_', - zorder=-10) + x_range = np.arange(-0.5, max_x, max_x / 100) + plt.plot(x_range, [0] * len(x_range), "-", color="black", linewidth=1, label="_nolegend_", zorder=-10) # add any dividers if dividers is not None: @@ -723,25 +659,18 @@ def decorate_plot(labels, for divider in dividers: loc = labels.index(divider.replace(" ", "\n")) + 2.5 * DELTA locs.append(loc) - plt.vlines(locs, - ylim[0], - ylim[1], - linestyles='--', - color='black') + plt.vlines(locs, ylim[0], ylim[1], linestyles="--", color="black") # add any text if texts is not None: for item in texts: - text = item['text'] - pos = item['position'] - kwargs = {key: item[key] for key in TEXT_KEYS - if key in item} + text = item["text"] + pos = item["position"] + kwargs = {key: item[key] for key in TEXT_KEYS if key in item} - plt.text(*pos, text, - horizontalalignment='center', - verticalalignment='center', - transform=axis.transAxes, - **kwargs) + plt.text( + *pos, text, horizontalalignment="center", verticalalignment="center", transform=axis.transAxes, **kwargs + ) def save_plot(save_path): @@ -783,41 +712,39 @@ def plot(plot_dic): plot_info = plot_dic["plot_info"] # get ROCE scores and other values needed for the plot - (roce_scores, roce_no_overlap, - fpr_colors, labels, bar_height) = vals_for_plot(plot_dic=plot_dic) + (roce_scores, roce_no_overlap, fpr_colors, labels, bar_height) = vals_for_plot(plot_dic=plot_dic) # set tick sizes - this has to come before making the plot x_axis_dic = plot_info.get("x_axis", {}) y_axis_dic = plot_info.get("y_axis", {}) - set_tick_sizes(x_axis_dic=x_axis_dic, - y_axis_dic=y_axis_dic) + set_tick_sizes(x_axis_dic=x_axis_dic, y_axis_dic=y_axis_dic) # make the base plot - axis = base_plot(roce_scores=roce_scores, - roce_no_overlap=roce_no_overlap, - labels=labels, - fpr_colors=fpr_colors, - bar_height=bar_height) + axis = base_plot( + roce_scores=roce_scores, + roce_no_overlap=roce_no_overlap, + labels=labels, + fpr_colors=fpr_colors, + bar_height=bar_height, + ) # add labels - label_plot(fpr_vals=base_info["fpr_vals"], - legend_dic=plot_info.get("legend", {}), - x_axis_dic=x_axis_dic, - y_axis_dic=y_axis_dic, - axis=axis) + label_plot( + fpr_vals=base_info["fpr_vals"], + legend_dic=plot_info.get("legend", {}), + x_axis_dic=x_axis_dic, + y_axis_dic=y_axis_dic, + axis=axis, + ) # set the y limits - ylim = set_plot_ylim(max_scale=plot_info.get("max_height_scale", 1.2), - roce_no_overlap=roce_no_overlap, - bar_height=bar_height) + ylim = set_plot_ylim( + max_scale=plot_info.get("max_height_scale", 1.2), roce_no_overlap=roce_no_overlap, bar_height=bar_height + ) # add decorations - decorate_plot(labels=labels, - ylim=ylim, - axis=axis, - dividers=plot_info.get("dividers"), - texts=plot_info.get("texts")) + decorate_plot(labels=labels, ylim=ylim, axis=axis, dividers=plot_info.get("dividers"), texts=plot_info.get("texts")) # save and show save_plot(save_path=plot_info.get("save_path")) @@ -828,9 +755,7 @@ def plot(plot_dic): return roce_scores, labels, fpr_vals -def get_perform_info(fprs, - roce_scores, - labels): +def get_perform_info(fprs, roce_scores, labels): """ Summarize the information about model performances so it can be saved in a JSON. @@ -857,13 +782,12 @@ def get_perform_info(fprs, sort_scores = scores[sort_idx].tolist() sort_labels = np.array(labels)[sort_idx].tolist() - score_list = [{"rank": i + 1, - "model": sort_labels[i].replace("\n", " "), - "roce": score} - for i, score in enumerate(sort_scores)] + score_list = [ + {"rank": i + 1, "model": sort_labels[i].replace("\n", " "), "roce": score} + for i, score in enumerate(sort_scores) + ] - this_info = {"fpr": fpr, - "scores": score_list} + this_info = {"fpr": fpr, "scores": score_list} info.append(this_info) return info @@ -883,9 +807,7 @@ def plot_all(plot_dics): for plot_dic in plot_dics: roce_scores, labels, fprs = plot(plot_dic) - info = get_perform_info(fprs=fprs, - roce_scores=roce_scores, - labels=labels) + info = get_perform_info(fprs=fprs, roce_scores=roce_scores, labels=labels) roces.append(info) return roces @@ -897,13 +819,14 @@ def main(): file with plot information. """ parser = argparse.ArgumentParser() - parser.add_argument('--config_file', type=str, - help=("Path to JSON file with plot information. " - "Please see config/plot_info.json for an " - "example.")) - parser.add_argument('--save_path', type=str, - help=("Path to JSON file with saved ROCE scores."), - default='roce.json') + parser.add_argument( + "--config_file", + type=str, + help=("Path to JSON file with plot information. " "Please see config/plot_info.json for an " "example."), + ) + parser.add_argument( + "--save_path", type=str, help=("Path to JSON file with saved ROCE scores."), default="roce.json" + ) args = parser.parse_args() config_file = args.config_file @@ -912,7 +835,7 @@ def main(): roces = plot_all(plot_dics=plot_dics) save_path = args.save_path - with open(save_path, 'w') as f_open: + with open(save_path, "w") as f_open: json.dump(roces, f_open, indent=4, sort_keys=True) print(f"Saved ROCE score information to {save_path}") diff --git a/nff/data/crystals.py b/nff/data/crystals.py index f98ca8fe..9504d100 100644 --- a/nff/data/crystals.py +++ b/nff/data/crystals.py @@ -1,6 +1,4 @@ import torch -import numpy as np -# from pymatgen.core.structure import Structure def get_crystal_graph(crystal, cutoff): @@ -21,7 +19,7 @@ def get_crystal_graph(crystal, cutoff): pbc = list(range(len(sites))) for site in crystal.sites: - for site, _, idx, _ in crystal.get_neighbors(site, cutoff, include_index=True, include_image=True): + for site, _, idx, _ in crystal.get_neighbors(site, cutoff, include_index=True, include_image=True): # noqa if site not in sites: sites.append(site) pbc.append(idx) @@ -30,4 +28,3 @@ def get_crystal_graph(crystal, cutoff): pbc = torch.LongTensor(pbc) return nxyz, pbc - diff --git a/nff/data/dataset.py b/nff/data/dataset.py index e7e74a25..e16eadc4 100644 --- a/nff/data/dataset.py +++ b/nff/data/dataset.py @@ -9,7 +9,7 @@ import numbers from collections import Counter from copy import deepcopy -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Any, Dict, List, Literal import numpy as np import torch @@ -18,6 +18,7 @@ from sklearn.model_selection import train_test_split from sklearn.utils import shuffle as skshuffle from torch.utils.data import Dataset as TorchDataset +from tqdm import trange import nff.utils.constants as const from nff.data.features import ATOM_FEAT_TYPES, BOND_FEAT_TYPES @@ -82,10 +83,11 @@ class Dataset(TorchDataset): def __init__( self, - props: dict, + props: Dict[str, List[Any]], units: str = "kcal/mol", check_props: bool = True, do_copy: bool = True, + device: str = "cuda", ) -> None: """Constructor for Dataset class. @@ -98,6 +100,7 @@ def __init__( to see if they are in the right format. do_copy (bool): whether to copy the properties or use the same dictionary. + device (str): The device to execute computations on ('cpu', 'cuda' etc.) """ if check_props: if do_copy: @@ -108,6 +111,7 @@ def __init__( self.props = props self.units = units self.to_units(units) + self.device = device def __len__(self) -> int: """Length of the dataset. @@ -203,7 +207,7 @@ def generate_neighbor_list( undirected: bool = True, key: str = "nbr_list", offset_key: str = "offsets", - ) -> list: + ) -> list | tuple[list, list]: """Generates a neighbor list for each one of the atoms in the dataset. By default, does not consider periodic boundary conditions. @@ -225,11 +229,6 @@ def generate_neighbor_list( return self.props[key] - # def make_nbr_to_mol(self): - # nbr_to_mol = [] - # for nbrs in self.props['nbr_list']: - # nbrs_to_mol.append(torch.zeros(len(nbrs))) - def make_all_directed(self): """Make everything in the dataset directed.""" make_dset_directed(self) @@ -289,6 +288,7 @@ def _get_periodic_neighbor_list( pbc=True, cutoff=cutoff, directed=(not undirected), + device=self.device, ) nbrs, offs = atoms.update_nbr_list() nbrlist.append(nbrs) @@ -300,7 +300,7 @@ def _get_periodic_neighbor_list( def generate_bond_idx(self, num_procs: int = 1) -> None: """For each index in the bond list, get the - index in the neighbour list that corresponds to the + index in the neighbor list that corresponds to the same directed pair of atoms. Args: @@ -444,6 +444,7 @@ def unwrap_xyz(self, mol_dic: dict) -> None: numbers=self.props["nxyz"][i][:, 0], cell=self.props["cell"][i], pbc=True, + device=self.device, ) # recontruct coordinates based on subgraphs index @@ -577,6 +578,7 @@ def gen_bond_prior(self, cutoff: float, bond_len_dict: dict | None = None) -> No "cutoff": cutoff, "cell": cell, "nbr_torch": False, + "device": self.device, } # the coordinates have been unwrapped and try to results offsets @@ -615,7 +617,7 @@ def as_atoms_batches( atoms_batches = [] num_batches = len(self.props["nxyz"]) - for i in range(num_batches): + for i in trange(num_batches): nxyz = self.props["nxyz"][i] atoms = AtomsBatch( nxyz[:, 0].long(), diff --git a/nff/data/features/graph.py b/nff/data/features/graph.py index c409af6d..a31f1dd1 100644 --- a/nff/data/features/graph.py +++ b/nff/data/features/graph.py @@ -2,14 +2,15 @@ Tools for generating graph-based features """ -import torch -import numpy as np import copy + +import numpy as np +import torch from rdkit import Chem from rdkit.Chem import AllChem -from nff.utils.xyz2mol import xyz2mol from nff.utils import tqdm_enum +from nff.utils.xyz2mol import xyz2mol # default options for xyz2mol @@ -20,44 +21,26 @@ # default feature types and options -BOND_FEAT_TYPES = ["bond_type", - "conjugated", - "in_ring", - "stereo", - "in_ring_size"] - -ATOM_FEAT_TYPES = ["atom_type", - "num_bonds", - "formal_charge", - "chirality", - "num_bonded_h", - "hybrid", - "aromaticity", - "mass"] - -CHIRAL_OPTIONS = ["chi_unspecified", - "chi_tetrahedral_cw", - "chi_tetrahedral_ccw", - "chi_other"] - -HYBRID_OPTIONS = ["s", - "sp", - "sp2", - "sp3", - "sp3d", - "sp3d2"] - -BOND_OPTIONS = ["single", - "double", - "triple", - "aromatic"] - -STEREO_OPTIONS = ["stereonone", - "stereoany", - "stereoz", - "stereoe", - "stereocis", - "stereotrans"] +BOND_FEAT_TYPES = ["bond_type", "conjugated", "in_ring", "stereo", "in_ring_size"] + +ATOM_FEAT_TYPES = [ + "atom_type", + "num_bonds", + "formal_charge", + "chirality", + "num_bonded_h", + "hybrid", + "aromaticity", + "mass", +] + +CHIRAL_OPTIONS = ["chi_unspecified", "chi_tetrahedral_cw", "chi_tetrahedral_ccw", "chi_other"] + +HYBRID_OPTIONS = ["s", "sp", "sp2", "sp3", "sp3d", "sp3d2"] + +BOND_OPTIONS = ["single", "double", "triple", "aromatic"] + +STEREO_OPTIONS = ["stereonone", "stereoany", "stereoz", "stereoe", "stereocis", "stereotrans"] AT_NUM = list(range(1, 100)) FORMAL_CHARGES = [-2, -1, 0, 1, 2] @@ -68,46 +51,31 @@ # dictionary with feature names, their options, type, # and size when stored as a vector -FEAT_DIC = {"bond_type": {"options": BOND_OPTIONS, - "num": len(BOND_OPTIONS) + 1}, - "conjugated": {"options": [bool], - "num": 1}, - "in_ring": {"options": [bool], - "num": 1}, - "stereo": {"options": STEREO_OPTIONS, - "num": len(STEREO_OPTIONS) + 1}, - "in_ring_size": {"options": RING_SIZE, - "num": len(RING_SIZE) + 1}, - "atom_type": {"options": AT_NUM, - "num": len(AT_NUM) + 1}, - "num_bonds": {"options": BONDS, - "num": len(BONDS) + 1}, - "formal_charge": {"options": FORMAL_CHARGES, - "num": len(FORMAL_CHARGES) + 1}, - "chirality": {"options": CHIRAL_OPTIONS, - "num": len(CHIRAL_OPTIONS) + 1}, - "num_bonded_h": {"options": NUM_H, - "num": len(NUM_H) + 1}, - "hybrid": {"options": HYBRID_OPTIONS, - "num": len(HYBRID_OPTIONS) + 1}, - "aromaticity": {"options": [bool], - "num": 1}, - "mass": {"options": [float], - "num": 1}} - -META_DATA = {"bond_features": BOND_FEAT_TYPES, - "atom_features": ATOM_FEAT_TYPES, - "details": FEAT_DIC} +FEAT_DIC = { + "bond_type": {"options": BOND_OPTIONS, "num": len(BOND_OPTIONS) + 1}, + "conjugated": {"options": [bool], "num": 1}, + "in_ring": {"options": [bool], "num": 1}, + "stereo": {"options": STEREO_OPTIONS, "num": len(STEREO_OPTIONS) + 1}, + "in_ring_size": {"options": RING_SIZE, "num": len(RING_SIZE) + 1}, + "atom_type": {"options": AT_NUM, "num": len(AT_NUM) + 1}, + "num_bonds": {"options": BONDS, "num": len(BONDS) + 1}, + "formal_charge": {"options": FORMAL_CHARGES, "num": len(FORMAL_CHARGES) + 1}, + "chirality": {"options": CHIRAL_OPTIONS, "num": len(CHIRAL_OPTIONS) + 1}, + "num_bonded_h": {"options": NUM_H, "num": len(NUM_H) + 1}, + "hybrid": {"options": HYBRID_OPTIONS, "num": len(HYBRID_OPTIONS) + 1}, + "aromaticity": {"options": [bool], "num": 1}, + "mass": {"options": [float], "num": 1}, +} + +META_DATA = {"bond_features": BOND_FEAT_TYPES, "atom_features": ATOM_FEAT_TYPES, "details": FEAT_DIC} # default number of atom features -NUM_ATOM_FEATS = sum([val["num"] for key, val in FEAT_DIC.items() - if key in ATOM_FEAT_TYPES]) +NUM_ATOM_FEATS = sum([val["num"] for key, val in FEAT_DIC.items() if key in ATOM_FEAT_TYPES]) # default number of bond features -NUM_BOND_FEATS = sum([val["num"] for key, val in FEAT_DIC.items() - if key in BOND_FEAT_TYPES]) +NUM_BOND_FEATS = sum([val["num"] for key, val in FEAT_DIC.items() if key in BOND_FEAT_TYPES]) def remove_bad_idx(dataset, smiles_list, bad_idx, verbose=True): @@ -122,11 +90,10 @@ def remove_bad_idx(dataset, smiles_list, bad_idx, verbose=True): None """ - bad_idx = sorted(list(set(bad_idx))) + bad_idx = list(set(bad_idx)) new_props = {} for key, values in dataset.props.items(): - new_props[key] = [val for i, val in enumerate( - values) if i not in bad_idx] + new_props[key] = [val for i, val in enumerate(values) if i not in bad_idx] if not new_props[key]: continue if type(values) is torch.Tensor: @@ -139,9 +106,7 @@ def remove_bad_idx(dataset, smiles_list, bad_idx, verbose=True): conv_pct = good_len / total_len * 100 if verbose: - print(("Converted %d of %d " - "species (%.2f%%)" % ( - good_len, total_len, conv_pct))) + print("Converted %d of %d " "species (%.2f%%)" % (good_len, total_len, conv_pct)) def smiles_from_smiles(smiles): @@ -167,11 +132,11 @@ def smiles_from_mol(mol): """ Get the canonical smiles from an rdkit mol. Args: - mol (rdkit.Chem.rdchem.Mol): rdkit Mol + mol (rdkit.Chem.rdchem.Mol): rdkit Mol Returns: new_smiles (str): canonicial smiles new_mol (rdkit.Chem.rdchem.Mol): rdkit Mol created - from the canonical smiles. + from the canonical smiles. """ new_smiles = Chem.MolToSmiles(mol) @@ -184,11 +149,11 @@ def smiles_from_mol(mol): def get_undirected_bonds(mol): """ Get an undirected bond list from an RDKit mol. This - means that bonds between atoms 1 and 0 are stored as + means that bonds between atoms 1 and 0 are stored as [0, 1], whereas in a directed list they would be stored as both [0, 1] and [1, 0]. Args: - mol (rdkit.Chem.rdchem.Mol): rdkit Mol + mol (rdkit.Chem.rdchem.Mol): rdkit Mol Returns: bond_list (list): undirected bond list """ @@ -197,7 +162,6 @@ def get_undirected_bonds(mol): bonds = mol.GetBonds() for bond in bonds: - start = bond.GetBeginAtomIdx() end = bond.GetEndAtomIdx() lower = min((start, end)) @@ -213,7 +177,7 @@ def undirected_bond_atoms(mol): Get a list of the atomic numbers comprising a bond in each bond of an undirected bond list. Args: - mol (rdkit.Chem.rdchem.Mol): rdkit Mol + mol (rdkit.Chem.rdchem.Mol): rdkit Mol Returns: atom_num_list (list): list of the form [[num__00, num_01], [num_10, num_11], [num_20, num_21], ...], where the `num_ij` @@ -224,7 +188,6 @@ def undirected_bond_atoms(mol): bonds = mol.GetBonds() for bond in bonds: - start = bond.GetBeginAtom().GetAtomicNum() end = bond.GetEndAtom().GetAtomicNum() lower = min((start, end)) @@ -239,15 +202,15 @@ def check_connectivity(mol_0, mol_1): """ Check if the atom connectivity in two mol objects is the same. Args: - mol_0 (rdkit.Chem.rdchem.Mol): first rdkit Mol - mol_1 (rdkit.Chem.rdchem.Mol): second rdkit Mol + mol_0 (rdkit.Chem.rdchem.Mol): first rdkit Mol + mol_1 (rdkit.Chem.rdchem.Mol): second rdkit Mol Returns: same (bool): whether or not the connectivity is the same """ bonds_0 = undirected_bond_atoms(mol_0) bonds_1 = undirected_bond_atoms(mol_1) - same = (bonds_0 == bonds_1) + same = bonds_0 == bonds_1 return same @@ -257,7 +220,7 @@ def verify_smiles(rd_mol, smiles): Verify that an RDKit mol has the same smiles as the original smiles that made it. Args: - rd_mol (rdkit.Chem.rdchem.Mol): rdkit Mol + rd_mol (rdkit.Chem.rdchem.Mol): rdkit Mol smiles (str): claimed smiles Returns: None @@ -286,28 +249,27 @@ def verify_smiles(rd_mol, smiles): # try checking bond connectivity - good_con = check_connectivity(mol_0=new_rd_mol, - mol_1=db_mol) + good_con = check_connectivity(mol_0=new_rd_mol, mol_1=db_mol) if good_con: - msg = (("WARNING: xyz2mol SMILES is {} " - "and database SMILES is {}. " - "However, the connectivity is the same. " - "Check to make sure the SMILES are resonances " - "structures.".format(rd_smiles, db_smiles))) + msg = ( + f"WARNING: xyz2mol SMILES is {rd_smiles} " + f"and database SMILES is {db_smiles}. " + "However, the connectivity is the same. " + "Check to make sure the SMILES are resonances " + "structures." + ) return # otherwise raise an exception - msg = (("SMILES created by xyz2mol is {}, " - "which doesn't match the database " - "SMILES {}.".format(rd_smiles, db_smiles))) + msg = f"SMILES created by xyz2mol is {rd_smiles}, " "which doesn't match the database " f"SMILES {db_smiles}." raise Exception(msg) def log_failure(bad_idx, i): """ - Log how many smiles have conformers that you've successfully converted + Log how many smiles have conformers that you've successfully converted to RDKit mols. Args: bad_idx (list[int]): indices to get rid of in the dataset @@ -322,9 +284,7 @@ def log_failure(bad_idx, i): good_len = i - len(bad_idx) conv_pct = good_len / i * 100 - print(("Converted %d of %d " - "species (%.2f%%)" % ( - good_len, i, conv_pct))) + print("Converted %d of %d " "species (%.2f%%)" % (good_len, i, conv_pct)) def log_missing(missing_e): @@ -341,8 +301,7 @@ def log_missing(missing_e): print("No elements are missing from xyz2mol") else: missing_e = list(set(missing_e)) - print("Elements {} are missing from xyz2mol".format( - ", ".join(missing_e))) + print("Elements {} are missing from xyz2mol".format(", ".join(missing_e))) def get_enum_func(track): @@ -355,17 +314,10 @@ def get_enum_func(track): tqdm if track == True. """ - if track: - func = tqdm_enum - else: - func = enumerate - return func + return tqdm_enum if track else enumerate -def make_rd_mols(dataset, - verbose=True, - check_smiles=False, - track=True): +def make_rd_mols(dataset, verbose=True, check_smiles=False, track=True): """ Use xyz2mol to add RDKit mols to a dataset that contains molecule coordinates. @@ -382,7 +334,7 @@ def make_rd_mols(dataset, """ - num_atoms = dataset.props['num_atoms'] + num_atoms = dataset.props["num_atoms"] # number of atoms in each conformer mol_size = dataset.props.get("mol_size", num_atoms).tolist() smiles_list = dataset.props["smiles"] @@ -396,7 +348,6 @@ def make_rd_mols(dataset, enum = get_enum_func(track) for i, smiles in enum(smiles_list): - # split the nxyz of each species into the component # nxyz of each conformer @@ -409,9 +360,7 @@ def make_rd_mols(dataset, missing_e = [] # go through each conformer nxyz - - for j, nxyz in enumerate(nxyz_list): - + for nxyz in nxyz_list: # if a conformer in the species has already failed # to produce an RDKit mol, then don't bother converting # any of the other conformers for that species @@ -420,28 +369,26 @@ def make_rd_mols(dataset, continue # coordinates and atomic numbers - xyz = nxyz[:, 1:].tolist() - atoms = nxyz[:, 0].numpy().astype('int').tolist() + atoms = nxyz[:, 0].numpy().astype("int").tolist() try: - - mol = xyz2mol(atoms=atoms, - coordinates=xyz, - charge=charge, - use_graph=QUICK, - allow_charged_fragments=CHARGED_FRAGMENTS, - embed_chiral=EMBED_CHIRAL, - use_huckel=USE_HUCKEL) + mol = xyz2mol( + atoms=atoms, + coordinates=xyz, + charge=charge, + use_graph=QUICK, + allow_charged_fragments=CHARGED_FRAGMENTS, + embed_chiral=EMBED_CHIRAL, + use_huckel=USE_HUCKEL, + ) if check_smiles: # check the smiles if requested verify_smiles(rd_mol=mol, smiles=smiles) except Exception as e: - - print(("xyz2mol failed " - "with error '{}' ".format(e))) - print("Removing smiles {}".format(smiles)) + print("xyz2mol failed " f"with error '{e}' ") + print(f"Removing smiles {smiles}") bad_idx.append(i) if verbose: @@ -463,10 +410,7 @@ def make_rd_mols(dataset, # remove any species with missing RDKit mols - remove_bad_idx(dataset=dataset, - smiles_list=smiles_list, - bad_idx=bad_idx, - verbose=verbose) + remove_bad_idx(dataset=dataset, smiles_list=smiles_list, bad_idx=bad_idx, verbose=verbose) if verbose: log_missing(missing_e) @@ -502,7 +446,7 @@ def bond_feat_to_vec(feat_type, feat): feat_type (int): what type of feature it is feat (Union[floa, int]): feaure value Returns: - one_hot (torch.Tensor): one-hot encoding of + one_hot (torch.Tensor): one-hot encoding of the feature. """ @@ -512,16 +456,14 @@ def bond_feat_to_vec(feat_type, feat): result = torch.Tensor([conj]) return result - elif feat_type == "bond_type": + if feat_type == "bond_type": # select from `BOND_OPTIONS` options = BOND_OPTIONS bond_type = feat - one_hot = make_one_hot(options=options, - result=bond_type) + one_hot = make_one_hot(options=options, result=bond_type) return one_hot - elif feat_type == "in_ring_size": - + if feat_type == "in_ring_size": # This is already a one-hot encoded vector, # because RDKit tests if the bond is in a # ring of a specific size, so the feature we @@ -537,24 +479,23 @@ def bond_feat_to_vec(feat_type, feat): ring_size = option break - one_hot = make_one_hot(options=options, - result=ring_size) + one_hot = make_one_hot(options=options, result=ring_size) return one_hot - elif feat_type == "in_ring": + if feat_type == "in_ring": # just 0 or 1 in_ring = feat result = torch.Tensor([in_ring]) return result - elif feat_type == "stereo": + if feat_type == "stereo": # select from `STEREO_OPTIONS` stereo = feat options = STEREO_OPTIONS - one_hot = make_one_hot(options=options, - result=stereo) + one_hot = make_one_hot(options=options, result=stereo) return one_hot + return ValueError(f"Unrecognized feature type {feat_type}") def get_bond_features(bond, feat_type): @@ -599,66 +540,58 @@ def atom_feat_to_vec(feat_type, feat): feat_type (int): what type of feature it is feat (Union[floa, int]): feaure value Returns: - one_hot (torch.Tensor): one-hot encoding of + one_hot (torch.Tensor): one-hot encoding of the feature. """ if feat_type == "atom_type": options = AT_NUM - one_hot = make_one_hot(options=options, - result=feat) + one_hot = make_one_hot(options=options, result=feat) return one_hot - elif feat_type == "num_bonds": + if feat_type == "num_bonds": options = BONDS - one_hot = make_one_hot(options=options, - result=feat) + one_hot = make_one_hot(options=options, result=feat) return one_hot - elif feat_type == "formal_charge": - + if feat_type == "formal_charge": options = FORMAL_CHARGES - one_hot = make_one_hot(options=options, - result=feat) + one_hot = make_one_hot(options=options, result=feat) return one_hot - elif feat_type == "chirality": + if feat_type == "chirality": options = CHIRAL_OPTIONS - one_hot = make_one_hot(options=options, - result=feat) + one_hot = make_one_hot(options=options, result=feat) return one_hot - elif feat_type == "num_bonded_h": - + if feat_type == "num_bonded_h": options = NUM_H - one_hot = make_one_hot(options=options, - result=feat) + one_hot = make_one_hot(options=options, result=feat) return one_hot - elif feat_type == "hybrid": - + if feat_type == "hybrid": options = HYBRID_OPTIONS - one_hot = make_one_hot(options=options, - result=feat) + one_hot = make_one_hot(options=options, result=feat) return one_hot - elif feat_type == "aromaticity": + if feat_type == "aromaticity": one_hot = torch.Tensor([feat]) return one_hot - elif feat_type == "mass": + if feat_type == "mass": # the mass is converted to a feature vector # by dividing by 100 result = torch.Tensor([feat / 100]) return result + return ValueError(f"Unrecognized feature type {feat_type}") def get_atom_features(atom, feat_type): @@ -680,21 +613,16 @@ def get_atom_features(atom, feat_type): feat = atom.GetTotalDegree() elif feat_type == "formal_charge": - feat = atom.GetFormalCharge() elif feat_type == "chirality": feat = atom.GetChiralTag().name.lower() elif feat_type == "num_bonded_h": - - neighbors = [at.GetAtomicNum() for at - in atom.GetNeighbors()] - feat = len([i for i in neighbors if - i == 1]) + neighbors = [at.GetAtomicNum() for at in atom.GetNeighbors()] + feat = len([i for i in neighbors if i == 1]) elif feat_type == "hybrid": - feat = atom.GetHybridization().name.lower() elif feat_type == "aromaticity": @@ -705,8 +633,7 @@ def get_atom_features(atom, feat_type): # convert to a feature vector - vec = atom_feat_to_vec(feat_type=feat_type, - feat=feat) + vec = atom_feat_to_vec(feat_type=feat_type, feat=feat) return vec @@ -718,7 +645,7 @@ def get_all_bond_feats(bond, feat_types): bond (rdkit.Chem.rdchem.Bond): bond object feat_types (list[str]): list of feature types Returns: - feat_dic (dict): dictionary of the form + feat_dic (dict): dictionary of the form {feat_type: bond_feat_vector} for all feature types. """ @@ -726,8 +653,7 @@ def get_all_bond_feats(bond, feat_types): feat_dic = {} for feat_type in feat_types: - feature = get_bond_features(bond=bond, - feat_type=feat_type) + feature = get_bond_features(bond=bond, feat_type=feat_type) feat_dic[feat_type] = feature return feat_dic @@ -740,7 +666,7 @@ def get_all_atom_feats(atom, feat_types): atom (rdkit.Chem.rdchem.Atom): atom object feat_types (list[str]): list of feature types Returns: - feat_dic (dict): dictionary of the form + feat_dic (dict): dictionary of the form {feat_type: atom_feat_vector} for all feature types. """ @@ -748,16 +674,13 @@ def get_all_atom_feats(atom, feat_types): feat_dic = {} for feat_type in feat_types: - feature = get_atom_features(atom=atom, - feat_type=feat_type) + feature = get_atom_features(atom=atom, feat_type=feat_type) feat_dic[feat_type] = feature return feat_dic -def featurize_bonds(dataset, - feat_types=BOND_FEAT_TYPES, - track=True): +def featurize_bonds(dataset, feat_types=BOND_FEAT_TYPES, track=True): """ Add the bond feature vectors of each species and conformer to the dataset. @@ -779,14 +702,13 @@ def featurize_bonds(dataset, # number of bonds in a species props["num_bonds"] = [] - num_atoms = dataset.props['num_atoms'] + num_atoms = dataset.props["num_atoms"] mol_size = dataset.props.get("mol_size", num_atoms).tolist() enum = get_enum_func(track) # go through each set of RDKit mols for i, rd_mols in enum(dataset.props["rd_mols"]): - num_confs = (num_atoms[i] // mol_size[i]).item() split_sizes = [mol_size[i]] * num_confs @@ -798,12 +720,10 @@ def featurize_bonds(dataset, # go through each RDKit mol for j, rd_mol in enumerate(rd_mols): - bonds = rd_mol.GetBonds() bond_list = [] for bond in bonds: - all_props.append(torch.tensor([])) start = bond.GetBeginAtomIdx() @@ -815,12 +735,11 @@ def featurize_bonds(dataset, bond_list.append([lower, upper]) # get the bond features - feat_dic = get_all_bond_feats(bond=bond, - feat_types=feat_types) + feat_dic = get_all_bond_feats(bond=bond, feat_types=feat_types) # add to the features `all_props`, which contains # the bond features of all the conformers of this species - for key, feat in feat_dic.items(): + for feat in feat_dic.values(): all_props[-1] = torch.cat((all_props[-1], feat)) # shift the bond list for each conformer to take into account @@ -829,8 +748,7 @@ def featurize_bonds(dataset, other_atoms = sum(split_sizes[:j]) shifted_bond_list = np.array(bond_list) + other_atoms - props["bond_list"][-1].append(torch.LongTensor( - shifted_bond_list)) + props["bond_list"][-1].append(torch.LongTensor(shifted_bond_list)) props["num_bonds"][-1].append(len(bonds)) # convert everything into a tensor after looping through each conformer @@ -841,9 +759,7 @@ def featurize_bonds(dataset, return dataset -def featurize_atoms(dataset, - feat_types=ATOM_FEAT_TYPES, - track=True): +def featurize_atoms(dataset, feat_types=ATOM_FEAT_TYPES, track=True): """ Add the atom feature vectors of each species and conformer to the dataset. @@ -862,8 +778,7 @@ def featurize_atoms(dataset, enum = get_enum_func(track) # go through each set of RDKit mols for each species - for i, rd_mols in enum(dataset.props["rd_mols"]): - + for _, rd_mols in enum(dataset.props["rd_mols"]): # initialize a list of features for each atom all_props = [] @@ -876,10 +791,9 @@ def featurize_atoms(dataset, all_props.append(torch.tensor([])) # get the atomic features - feat_dic = get_all_atom_feats(atom=atom, - feat_types=feat_types) + feat_dic = get_all_atom_feats(atom=atom, feat_types=feat_types) - for key, feat in feat_dic.items(): + for feat in feat_dic.values(): all_props[-1] = torch.cat((all_props[-1], feat)) # stack the atomic features @@ -903,18 +817,13 @@ def decode_one_hot(options, vector): return bool(vector.item()) # if the options are a single float, return the value - elif options == [float]: + if options == [float]: return vector.item() # otherwise return the option at the nonzero index # (or None if it's the last index or everything is 0) index = vector.nonzero() - if len(index) == 0 or index >= len(options): - result = None - else: - result = options[index] - - return result + return None if len(index) == 0 or index >= len(options) else options[index] def decode_atomic(features, meta_data=META_DATA): @@ -923,7 +832,7 @@ def decode_atomic(features, meta_data=META_DATA): Args: features (torch.Tensor): feature vector meta_data (dict): dictionary that tells you the - atom and bond feature types + atom and bond feature types Returns: dic (dict): dictionary of feature values """ @@ -946,8 +855,7 @@ def decode_atomic(features, meta_data=META_DATA): options = options_list[i] name = feat_names[i] - result = decode_one_hot(options=options, - vector=vector) + result = decode_one_hot(options=options, vector=vector) dic[name] = result # multiply by 100 if it's the mass @@ -963,7 +871,7 @@ def decode_bond(features, meta_data=META_DATA): Args: features (torch.Tensor): feature vector meta_data (dict): dictionary that tells you the - atom and bond feature types + atom and bond feature types Returns: dic (dict): dictionary of feature values """ @@ -985,18 +893,15 @@ def decode_bond(features, meta_data=META_DATA): options = options_list[i] name = feat_names[i] - result = decode_one_hot(options=options, - vector=vector) + result = decode_one_hot(options=options, vector=vector) dic[name] = result return dic -def featurize_dataset(dataset, - bond_feats=BOND_FEAT_TYPES, - atom_feats=ATOM_FEAT_TYPES): +def featurize_dataset(dataset, bond_feats=BOND_FEAT_TYPES, atom_feats=ATOM_FEAT_TYPES): """ - Add RDKit mols, atomic features and bond features to + Add RDKit mols, atomic features and bond features to a dataset. Note that this has been superseded by the parallel version in data/parallel.py. Args: @@ -1028,8 +933,8 @@ def featurize_dataset(dataset, def add_morgan(dataset, vec_length): """ Add Morgan fingerprints to the dataset. Note that this uses - the smiles of each species to get one fingerprint per species, - as opposed to getting the graph of each conformer and its + the smiles of each species to get one fingerprint per species, + as opposed to getting the graph of each conformer and its fingerprint. Args: @@ -1041,14 +946,10 @@ def add_morgan(dataset, vec_length): """ dataset.props["morgan"] = [] - for smiles in dataset.props['smiles']: + for smiles in dataset.props["smiles"]: mol = Chem.MolFromSmiles(smiles) - if vec_length != 0: - morgan = AllChem.GetMorganFingerprintAsBitVect( - mol, radius=2, nBits=vec_length) - else: - morgan = [] + morgan = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=vec_length) if vec_length != 0 else [] - arr_morgan = np.array(list(morgan)).astype('float32') + arr_morgan = np.array(list(morgan)).astype("float32") morgan_tens = torch.tensor(arr_morgan) dataset.props["morgan"].append(morgan_tens) diff --git a/nff/data/features/xyz.py b/nff/data/features/xyz.py index 7fe91901..e515a56d 100644 --- a/nff/data/features/xyz.py +++ b/nff/data/features/xyz.py @@ -4,11 +4,10 @@ import logging -from rdkit import Chem -from rdkit.Chem import rdMolDescriptors as rdMD import torch from e3fp.pipeline import fprints_from_mol - +from rdkit import Chem +from rdkit.Chem import rdMolDescriptors as rdMD from tqdm import tqdm @@ -66,11 +65,11 @@ def get_3d_representation(xyz, smiles, method, mol=None): """ representation_fn = { - 'autocorrelation_3d': rdMD.CalcAUTOCORR3D, - 'rdf': rdMD.CalcRDF, - 'morse': rdMD.CalcMORSE, - 'whim': rdMD.CalcWHIM, - 'getaway': lambda x: rdMD.CalcWHIM(x, precision=0.001) + "autocorrelation_3d": rdMD.CalcAUTOCORR3D, + "rdf": rdMD.CalcRDF, + "morse": rdMD.CalcMORSE, + "whim": rdMD.CalcWHIM, + "getaway": lambda x: rdMD.CalcWHIM(x, precision=0.001), } # if a `mol` is not given, generate it from the xyz and smiles @@ -95,23 +94,19 @@ def featurize_rdkit(dataset, method): props = dataset.props # go through each geometry - for i in range(len(props['nxyz'])): - - smiles = props['smiles'][i] - nxyz = props['nxyz'][i] + for i in range(len(props["nxyz"])): + smiles = props["smiles"][i] + nxyz = props["nxyz"][i] reps = [] # if there are RDKit mols in the dataset, you can # get the 3D representation from the mol itself - if 'rd_mols' in props: - rd_mols = props['rd_mols'][i] + if "rd_mols" in props: + rd_mols = props["rd_mols"][i] for rd_mol in rd_mols: - rep = torch.Tensor(get_3d_representation(xyz=None, - smiles=None, - method=method, - mol=rd_mol)) + rep = torch.Tensor(get_3d_representation(xyz=None, smiles=None, method=method, mol=rd_mol)) reps.append(rep) # otherwise you can get the mols from the nxyz, but this @@ -123,23 +118,22 @@ def featurize_rdkit(dataset, method): # if `mol_size` is there then split the nxyz into conformer # geomtries - if 'mol_size' in props: - mol_size = props['mol_size'][i].item() + if "mol_size" in props: + mol_size = props["mol_size"][i].item() n_confs = nxyz.shape[0] // mol_size nxyz_list = torch.split(nxyz, [mol_size] * n_confs) for sub_nxyz in nxyz_list: - - msg = ("Warning: no RDKit mols found in dataset. " - "Using nxyz and SMILES and assuming that the " - "nxyz atom ordering is the same as in the RDKit " - "mol. Make sure to check this!") + msg = ( + "Warning: no RDKit mols found in dataset. " + "Using nxyz and SMILES and assuming that the " + "nxyz atom ordering is the same as in the RDKit " + "mol. Make sure to check this!" + ) print(msg) xyz = sub_nxyz.detach().cpu().numpy().tolist() - rep = torch.Tensor(get_3d_representation(xyz=xyz, - smiles=smiles, - method=method)) + rep = torch.Tensor(get_3d_representation(xyz=xyz, smiles=smiles, method=method)) reps.append(rep) reps = torch.stack(reps) @@ -159,16 +153,11 @@ def get_e3fp(mol, bits, smiles=None): smiles = Chem.MolToSmiles(mol) mol.SetProp("_Name", smiles) fprint_params = {"bits": bits} - fp = (fprints_from_mol(mol, fprint_params=fprint_params)[0] - .to_vector().toarray().astype(int) - ).reshape(-1) + fp = (fprints_from_mol(mol, fprint_params=fprint_params)[0].to_vector().toarray().astype(int)).reshape(-1) return fp -def add_e3fp(rd_dataset, - fp_length, - verbose=False, - track=True): +def add_e3fp(rd_dataset, fp_length, verbose=False, track=True): """ Add E3FP fingerprints to each conformer in the dataset. Args: @@ -194,12 +183,11 @@ def add_e3fp(rd_dataset, smiles = batch["smiles"] fps = [] for mol in mols: - fp_array = get_e3fp(mol, fp_length, smiles) fps.append(torch.Tensor(fp_array)) e3fp_list.append(torch.stack(fps)) - rd_dataset.props['e3fp'] = e3fp_list + rd_dataset.props["e3fp"] = e3fp_list return rd_dataset diff --git a/nff/data/graphs.py b/nff/data/graphs.py index a7002b03..91d2dec3 100644 --- a/nff/data/graphs.py +++ b/nff/data/graphs.py @@ -118,7 +118,7 @@ def get_neighbor_list(xyz, cutoff=5, undirected=True): indices of connected atoms. """ - if torch.is_tensor(xyz) == False: + if not torch.is_tensor(xyz): xyz = torch.Tensor(xyz) n = xyz.size(0) @@ -151,7 +151,7 @@ def to_tuple(tensor): def get_bond_idx(bonded_nbr_list, nbr_list): """ For each index in the bond list, get the - index in the neighbour list that corresponds to the + index in the neighbor list that corresponds to the same directed pair of atoms. Args: bonded_nbr_list (torch.LongTensor): pairs @@ -232,7 +232,7 @@ def generate_subgraphs(atomsobject, unwrap=True, get_edge=False): atoms = AtomsBatch(atomsobject) z, adj, dmat, threshold = adjdistmat(atoms, unwrap=unwrap) - box_len = torch.Tensor(np.diag(atoms.get_cell())) + torch.Tensor(np.diag(atoms.get_cell())) G = nx.from_numpy_matrix(adj) for i, item in enumerate(z): @@ -243,14 +243,13 @@ def generate_subgraphs(atomsobject, unwrap=True, get_edge=False): edge_list = [] partitions = [] - for i, sg in enumerate(sub_graphs): + for sg in sub_graphs: partitions.append(list(sg.nodes)) if get_edge: edge_list.append(list(sg.edges)) if len(edge_list) != 0: return partitions, edge_list - else: - return partitions + return partitions def get_single_molecule(atomsobject, mol_idx, single_mol_id): @@ -281,7 +280,7 @@ def reconstruct_atoms(atomsobject, mol_idx): def list2adj(bond_list, size=None): E = bond_list if size is None: - size = max(set([n for e in E for n in e])) + 1 + size = max({n for e in E for n in e}) + 1 # make an empty adjacency list adjacency = [[0] * size for _ in range(size)] # populate the list for each edge @@ -448,9 +447,9 @@ def add_ji_kj(angle_lists, nbr_lists): """ Get ji and kj idx (explained more below): Args: - angle_list (list[torch.LongTensor]): list of angle + angle_lists (list[torch.LongTensor]): list of angle lists - nbr_list (list[torch.LongTensor]): list of directed neighbor + nbr_lists (list[torch.LongTensor]): list of directed neighbor lists Returns: ji_idx_list (list[torch.LongTensor]): ji_idx for each geom @@ -563,7 +562,7 @@ def full_angle_idx(batch): for i in range(num_confs): max_idx = (i + 1) * mol_size - min_idx = (i) * mol_size + min_idx = i * mol_size # get only the indices for this conformer conf_mask = (nbr_list[:, 0] < max_idx) * (nbr_list[:, 0] >= min_idx) diff --git a/nff/data/loader.py b/nff/data/loader.py index 702b018c..1916bad1 100644 --- a/nff/data/loader.py +++ b/nff/data/loader.py @@ -1,25 +1,23 @@ -import numpy as np -import torch import copy -from torch.utils.data.sampler import Sampler, BatchSampler +import numpy as np +import torch +from torch.utils.data.sampler import BatchSampler, Sampler -REINDEX_KEYS = ['atoms_nbr_list', 'nbr_list', 'bonded_nbr_list', - 'angle_list', 'mol_nbrs'] -NBR_LIST_KEYS = ['bond_idx', 'kj_idx', 'ji_idx'] -MOL_IDX_KEYS = ['atomwise_mol_list', 'directed_nbr_mol_list', - 'undirected_nbr_mol_list'] -IGNORE_KEYS = ['rd_mols'] +REINDEX_KEYS = ["atoms_nbr_list", "nbr_list", "bonded_nbr_list", "angle_list", "mol_nbrs"] +NBR_LIST_KEYS = ["bond_idx", "kj_idx", "ji_idx"] +MOL_IDX_KEYS = ["atomwise_mol_list", "directed_nbr_mol_list", "undirected_nbr_mol_list"] +IGNORE_KEYS = ["rd_mols"] TYPE_KEYS = { - 'atoms_nbr_list': torch.long, - 'nbr_list': torch.long, - 'num_atoms': torch.long, - 'bond_idx': torch.long, - 'bonded_nbr_list': torch.long, - 'angle_list': torch.long, - 'ji_idx': torch.long, - 'kj_idx': torch.long, + "atoms_nbr_list": torch.long, + "nbr_list": torch.long, + "num_atoms": torch.long, + "bond_idx": torch.long, + "bonded_nbr_list": torch.long, + "angle_list": torch.long, + "ji_idx": torch.long, + "kj_idx": torch.long, } @@ -37,18 +35,17 @@ def collate_dicts(dicts): # new indices for the batch: the first one is zero and the # last does not matter - cumulative_atoms = np.cumsum([0] + [d['num_atoms'] for d in dicts])[:-1] + cumulative_atoms = np.cumsum([0] + [d["num_atoms"] for d in dicts])[:-1] for n, d in zip(cumulative_atoms, dicts): for key in REINDEX_KEYS: if key in d: d[key] = d[key] + int(n) - if all(['nbr_list' in d for d in dicts]): + if all("nbr_list" in d for d in dicts): # same idea, but for quantities whose maximum value is the length of # the nbr list in each batch - cumulative_nbrs = np.cumsum( - [0] + [len(d['nbr_list']) for d in dicts])[:-1] + cumulative_nbrs = np.cumsum([0] + [len(d["nbr_list"]) for d in dicts])[:-1] for n, d in zip(cumulative_nbrs, dicts): for key in NBR_LIST_KEYS: if key in d: @@ -60,24 +57,17 @@ def collate_dicts(dicts): for i, d in enumerate(dicts): d[key] += i - # batching the data batch = {} for key, val in dicts[0].items(): if key in IGNORE_KEYS: continue - if type(val) == str: + if isinstance(val, str): batch[key] = [data[key] for data in dicts] - elif hasattr(val, 'shape') and len(val.shape) > 0: - batch[key] = torch.cat([ - data[key] - for data in dicts - ], dim=0) + elif hasattr(val, "shape") and len(val.shape) > 0: + batch[key] = torch.cat([data[key] for data in dicts], dim=0) else: - batch[key] = torch.stack( - [data[key] for data in dicts], - dim=0 - ) + batch[key] = torch.stack([data[key] for data in dicts], dim=0) # adjusting the data types: for key, dtype in TYPE_KEYS.items(): @@ -102,9 +92,7 @@ class ImbalancedDatasetSampler(Sampler): """ - def __init__(self, - target_name, - props): + def __init__(self, target_name, props): """ Args: target_name (str): name of the property being classified @@ -113,10 +101,8 @@ def __init__(self, data_length = len(props[target_name]) - negative_idx = [i for i, target in enumerate( - props[target_name]) if round(target.item()) == 0] - positive_idx = [i for i in range(data_length) - if i not in negative_idx] + negative_idx = [i for i, target in enumerate(props[target_name]) if round(target.item()) == 0] + positive_idx = [i for i in range(data_length) if i not in negative_idx] num_neg = len(negative_idx) num_pos = len(positive_idx) @@ -135,22 +121,14 @@ def __init__(self, self.weights[positive_idx] = 1 / positive_weight def __iter__(self): - - return (i for i in torch.multinomial( - self.weights, self.data_length, replacement=True)) + return (i for i in torch.multinomial(self.weights, self.data_length, replacement=True)) def __len__(self): return self.data_length class BalancedFFSampler(torch.utils.data.sampler.Sampler): - - def __init__(self, - balance_type=None, - weights=None, - balance_dict=None, - **kwargs): - + def __init__(self, balance_type=None, weights=None, balance_dict=None, **kwargs): from nff.data.sampling import spec_config_zhu_balance if weights is not None: @@ -171,32 +149,21 @@ def __init__(self, self.data_length = len(self.balance_dict["weights"]) def __iter__(self): - - return (i for i in torch.multinomial( - self.balance_dict["weights"], - self.data_length, - replacement=True)) + return (i for i in torch.multinomial(self.balance_dict["weights"], self.data_length, replacement=True)) def __len__(self): return self.data_length class BalancedBatchedSpecies(BatchSampler): - def __init__(self, - base_sampler, - smiles_list, - batch_size, - min_geoms=None): + def __init__(self, base_sampler, smiles_list, batch_size, min_geoms=None): """ min_geoms (int, optional): minimum number of geoms that a species has to have to be sampled """ from nff.data.sampling import get_spec_dic - BatchSampler.__init__(self, - sampler=base_sampler, - batch_size=batch_size, - drop_last=False) + BatchSampler.__init__(self, sampler=base_sampler, batch_size=batch_size, drop_last=False) self.spec_dic = get_spec_dic({"smiles": smiles_list}) self.rev_specs = self.reverse_spec_dic() @@ -218,16 +185,13 @@ def reverse_spec_dic(self): def exclude_sparse(self, min_geoms): if min_geoms is None: return - invalid_specs = [key for key, val in self.spec_dic.items() - if len(val) < min_geoms] + invalid_specs = [key for key, val in self.spec_dic.items() if len(val) < min_geoms] if not invalid_specs: return - invalid_idx = torch.cat([self.spec_dic[spec] for spec - in invalid_specs]) + invalid_idx = torch.cat([self.spec_dic[spec] for spec in invalid_specs]) self.sampler.balance_dict["weights"][invalid_idx] = 0 def __iter__(self): - sampler_indices = list(iter(self.sampler)) num_samples = len(self.sampler) // self.batch_size batch = [] @@ -242,20 +206,15 @@ def __iter__(self): # to zero weights = self.sampler.balance_dict["weights"] these_weights = copy.deepcopy(weights) - zero_weight = torch.ones(len(weights), - dtype=torch.bool) + zero_weight = torch.ones(len(weights), dtype=torch.bool) zero_weight[other_geom_idx] = False these_weights[zero_weight] = 0 # sample indices from this species only # by using `these_weights` - add_idx = torch.multinomial( - these_weights, - self.batch_size, - replacement=True) + add_idx = torch.multinomial(these_weights, self.batch_size, replacement=True) - for idx in add_idx: - batch.append(idx) + batch.extend(add_idx.tolist().copy()) yield batch batch = [] diff --git a/nff/data/parallel.py b/nff/data/parallel.py index 30b556be..c6589d26 100644 --- a/nff/data/parallel.py +++ b/nff/data/parallel.py @@ -2,20 +2,15 @@ Tools for applying functions in parallel to the dataset """ -import numpy as np -from concurrent import futures import copy +from concurrent import futures + +import numpy as np import torch +from nff.data.features import ATOM_FEAT_TYPES, BOND_FEAT_TYPES, add_e3fp, featurize_atoms, featurize_bonds, make_rd_mols +from nff.data.graphs import add_bond_idx, kj_ji_to_dset from nff.utils import fprint -from nff.data.features import (make_rd_mols, - featurize_bonds, - featurize_atoms, - add_e3fp, - BOND_FEAT_TYPES, - ATOM_FEAT_TYPES) -from nff.data.graphs import kj_ji_to_dset, add_bond_idx - NUM_PROCS = 5 @@ -35,17 +30,13 @@ def split_dataset(dataset, num): splits = np.array_split(idx, num) for split in splits: - if len(split) == 0: continue min_split = split[0] max_split = split[-1] + 1 - new_props = {key: val[min_split: max_split] for key, val - in dataset.props.items()} + new_props = {key: val[min_split:max_split] for key, val in dataset.props.items()} - new_dataset = dataset.__class__(props=new_props, - check_props=False, - units=dataset.units) + new_dataset = dataset.__class__(props=new_props, check_props=False, units=dataset.units) datasets.append(new_dataset) return datasets @@ -70,14 +61,12 @@ def rejoin_props(datasets): if type(val) is list: new_props[key] += val else: - new_props[key] = torch.cat([ - new_props[key], val], dim=0) + new_props[key] = torch.cat([new_props[key], val], dim=0) return new_props def gen_parallel(func, kwargs_list): - if len(kwargs_list) == 1: kwargs = kwargs_list[0] kwargs["track"] = True @@ -87,10 +76,9 @@ def gen_parallel(func, kwargs_list): future_objs = [] # go through each set of kwargs for i, kwargs in enumerate(kwargs_list): - # monitor with tqdm for the first process only # so that they don't print on top of each other - kwargs["track"] = (i == 0) + kwargs["track"] = i == 0 result = executor.submit(func, **kwargs) # `future_objs` are the results of applying each function @@ -110,59 +98,52 @@ def rd_parallel(datasets, check_smiles=False): Args: datasets (list): list of smaller datasets check_smiles (bool): exclude any species whose - SMILES strings aren't the same as the + SMILES strings aren't the same as the Returns: - results_dsets (list): list of datasets with + results_dsets (list): list of datasets with RDKit mols. """ - kwargs_list = [{"dataset": dataset, "verbose": False, - "check_smiles": check_smiles} - for dataset in datasets] - result_dsets = gen_parallel(func=make_rd_mols, - kwargs_list=kwargs_list) + kwargs_list = [{"dataset": dataset, "verbose": False, "check_smiles": check_smiles} for dataset in datasets] + result_dsets = gen_parallel(func=make_rd_mols, kwargs_list=kwargs_list) return result_dsets def bonds_parallel(datasets, feat_types): """ - Generate bond lists and bond features for the dataset + Generate bond lists and bond features for the dataset in parallel. Args: datasets (list): list of smaller datasets feat_types (list[str]): types of bond features to - use + use Returns: - results_dsets (list): list of datasets with + results_dsets (list): list of datasets with bond lists and features. """ - kwargs_list = [{"dataset": dataset, "feat_types": feat_types} - for dataset in datasets] - result_dsets = gen_parallel(func=featurize_bonds, - kwargs_list=kwargs_list) + kwargs_list = [{"dataset": dataset, "feat_types": feat_types} for dataset in datasets] + result_dsets = gen_parallel(func=featurize_bonds, kwargs_list=kwargs_list) return result_dsets def atoms_parallel(datasets, feat_types): """ - Generate atom features for the dataset + Generate atom features for the dataset in parallel. Args: datasets (list): list of smaller datasets feat_types (list[str]): types of atom features to - use + use Returns: - results_dsets (list): list of datasets with + results_dsets (list): list of datasets with atom features. """ - kwargs_list = [{"dataset": dataset, "feat_types": feat_types} - for dataset in datasets] - result_dsets = gen_parallel(func=featurize_atoms, - kwargs_list=kwargs_list) + kwargs_list = [{"dataset": dataset, "feat_types": feat_types} for dataset in datasets] + result_dsets = gen_parallel(func=featurize_atoms, kwargs_list=kwargs_list) return result_dsets @@ -175,34 +156,28 @@ def e3fp_parallel(datasets, fp_length): datasets (list): list of smaller datasets fp_length (int): fingerprint length Returns: - results_dsets (list): list of datasets with + results_dsets (list): list of datasets with E3FP fingerprints. """ - kwargs_list = [{"rd_dataset": dataset, "fp_length": fp_length} for - dataset in datasets] + kwargs_list = [{"rd_dataset": dataset, "fp_length": fp_length} for dataset in datasets] - result_dsets = gen_parallel(func=add_e3fp, - kwargs_list=kwargs_list) + result_dsets = gen_parallel(func=add_e3fp, kwargs_list=kwargs_list) return result_dsets def kj_ji_parallel(dsets): - kwargs_list = [{"dataset": dataset} for - dataset in dsets] + kwargs_list = [{"dataset": dataset} for dataset in dsets] - result_dsets = gen_parallel(func=kj_ji_to_dset, - kwargs_list=kwargs_list) + result_dsets = gen_parallel(func=kj_ji_to_dset, kwargs_list=kwargs_list) return result_dsets def bond_idx_parallel(dsets): - kwargs_list = [{"dataset": dataset} for - dataset in dsets] + kwargs_list = [{"dataset": dataset} for dataset in dsets] - result_dsets = gen_parallel(func=add_bond_idx, - kwargs_list=kwargs_list) + result_dsets = gen_parallel(func=add_bond_idx, kwargs_list=kwargs_list) return result_dsets @@ -220,16 +195,12 @@ def summarize_rd(new_sets, first_set): tried = len(first_set) succ = sum([len(d) for d in new_sets]) pct = succ / tried * 100 - fprint("Converted %d of %d molecules (%.2f%%)." % - (succ, tried, pct)) + fprint("Converted %d of %d molecules (%.2f%%)." % (succ, tried, pct)) -def featurize_parallel(dataset, - num_procs, - bond_feats=BOND_FEAT_TYPES, - atom_feats=ATOM_FEAT_TYPES): +def featurize_parallel(dataset, num_procs, bond_feats=BOND_FEAT_TYPES, atom_feats=ATOM_FEAT_TYPES): """ - Add RDKit mols, atom features and bond features to a dataset in + Add RDKit mols, atom features and bond features to a dataset in parallel. Args: dataset (nff.data.dataset): NFF dataset @@ -259,7 +230,7 @@ def featurize_parallel(dataset, datasets = split_dataset(dataset=dataset, num=num_procs) # add RDKit mols if they're not already in the dataset - has_rdmols = all(['rd_mols' in dset.props for dset in datasets]) + has_rdmols = all("rd_mols" in dset.props for dset in datasets) if not has_rdmols: fprint("Converting xyz to RDKit mols...") datasets = rd_parallel(datasets) @@ -283,9 +254,7 @@ def featurize_parallel(dataset, dataset.props["offsets"] = offsets -def add_e3fp_parallel(dataset, - fp_length, - num_procs): +def add_e3fp_parallel(dataset, fp_length, num_procs): """ Add E3FP fingerprints to a dataset in parallel. Args: @@ -308,9 +277,7 @@ def add_e3fp_parallel(dataset, def add_kj_ji_parallel(dataset, num_procs): - - fprint((f"Adding kj and ji indices with {num_procs} " - "parallel processes")) + fprint(f"Adding kj and ji indices with {num_procs} " "parallel processes") datasets = split_dataset(dataset=dataset, num=num_procs) datasets = kj_ji_parallel(datasets) @@ -319,9 +286,7 @@ def add_kj_ji_parallel(dataset, num_procs): def add_bond_idx_parallel(dataset, num_procs): - - fprint((f"Adding bond indices with {num_procs} " - "parallel processes")) + fprint(f"Adding bond indices with {num_procs} " "parallel processes") datasets = split_dataset(dataset=dataset, num=num_procs) datasets = bond_idx_parallel(datasets) diff --git a/nff/data/sampling.py b/nff/data/sampling.py index 7a7cba50..fdaa386b 100644 --- a/nff/data/sampling.py +++ b/nff/data/sampling.py @@ -5,11 +5,11 @@ import torch from tqdm import tqdm +from nff.data import Dataset from nff.train.loss import batch_zhu_p from nff.utils import constants as const -from nff.utils.misc import cat_props -from nff.data import Dataset from nff.utils.geom import compute_distances +from nff.utils.misc import cat_props def get_spec_dic(props): @@ -28,8 +28,7 @@ def get_spec_dic(props): spec_dic = {} for i, spec in enumerate(props["smiles"]): - no_stereo_spec = (spec.replace("\\", "") - .replace("/", "")) + no_stereo_spec = spec.replace("\\", "").replace("/", "") if no_stereo_spec not in spec_dic: spec_dic[no_stereo_spec] = [] spec_dic[no_stereo_spec].append(i) @@ -40,8 +39,7 @@ def get_spec_dic(props): return spec_dic -def compute_zhu(props, - zhu_kwargs): +def compute_zhu(props, zhu_kwargs): """ Compute the approximate Zhu-Nakamura hopping probabilities for each geom in the dataset. @@ -58,18 +56,19 @@ def compute_zhu(props, expec_gap_kcal = zhu_kwargs["expec_gap"] * const.AU_TO_KCAL["energy"] func_type = zhu_kwargs["func_type"] - zhu_p = batch_zhu_p(batch=cat_props(props), - upper_key=upper_key, - lower_key=lower_key, - expec_gap=expec_gap_kcal, - func_type=func_type, - gap_shape=None) + zhu_p = batch_zhu_p( + batch=cat_props(props), + upper_key=upper_key, + lower_key=lower_key, + expec_gap=expec_gap_kcal, + func_type=func_type, + gap_shape=None, + ) return zhu_p -def balanced_spec_zhu(spec_dic, - zhu_p): +def balanced_spec_zhu(spec_dic, zhu_p): """ Get the Zhu weights assigned to each geom, such that the probability of getting a geom in species A @@ -92,7 +91,6 @@ def balanced_spec_zhu(spec_dic, all_weights = torch.zeros(num_geoms) for lst_idx in spec_dic.values(): - idx = torch.LongTensor(lst_idx) this_zhu = zhu_p[idx] sum_zhu = this_zhu.sum() @@ -131,13 +129,7 @@ def imbalanced_spec_zhu(zhu_p): return all_weights -def assign_clusters(ref_idx, - spec_nxyz, - ref_nxyzs, - device, - num_clusters, - extra_category, - extra_rmsd): +def assign_clusters(ref_idx, spec_nxyz, ref_nxyzs, device, num_clusters, extra_category, extra_rmsd): """ Assign each geom to a cluster. @@ -205,9 +197,7 @@ def assign_clusters(ref_idx, dset_0 = Dataset(props=props_0) dset_1 = Dataset(props=props_1) - rmsds, _ = compute_distances(dataset=dset_0, - device=device, - dataset_1=dset_1) + rmsds, _ = compute_distances(dataset=dset_0, device=device, dataset_1=dset_1) # take the minimum rmsd with respect to the set of reference # nxyz's in each cluster. Put infinity if a species is missing a @@ -224,7 +214,6 @@ def assign_clusters(ref_idx, # to that cluster min_rmsds[torch.isnan(min_rmsds)] = float("inf") - clusters = min_rmsds.argmin(-1) if extra_category: @@ -232,24 +221,15 @@ def assign_clusters(ref_idx, clusters[in_extra] = num_clusters # record clusters in `cluster_dic` - cluster_dic = {i: [] for i in - range(num_clusters + int(extra_category))} + cluster_dic = {i: [] for i in range(num_clusters + int(extra_category))} for spec_idx, cluster in enumerate(clusters): cluster_dic[cluster.item()].append(spec_idx) - - return cluster_dic, min_rmsds -def per_spec_config_weights(spec_nxyz, - ref_nxyzs, - ref_idx, - num_clusters, - extra_category, - extra_rmsd, - device='cpu'): +def per_spec_config_weights(spec_nxyz, ref_nxyzs, ref_idx, num_clusters, extra_category, extra_rmsd, device="cpu"): """ Get weights to evenly sample different regions of phase space for a given species @@ -294,13 +274,15 @@ def per_spec_config_weights(spec_nxyz, """ # a dictionary that tells you which geoms are in each cluster - cluster_dic, cluster_rmsds = assign_clusters(ref_idx=ref_idx, - spec_nxyz=spec_nxyz, - ref_nxyzs=ref_nxyzs, - device=device, - num_clusters=num_clusters, - extra_category=extra_category, - extra_rmsd=extra_rmsd) + cluster_dic, cluster_rmsds = assign_clusters( + ref_idx=ref_idx, + spec_nxyz=spec_nxyz, + ref_nxyzs=ref_nxyzs, + device=device, + num_clusters=num_clusters, + extra_category=extra_category, + extra_rmsd=extra_rmsd, + ) # assign weights to each geom equal to 1 / (num geoms in cluster), # so that the probability of sampling any one cluster is equal to @@ -320,16 +302,10 @@ def per_spec_config_weights(spec_nxyz, # return normalized weights geom_weights /= geom_weights.sum() - return geom_weights, cluster_rmsds, cluster_dic -def all_spec_config_weights(props, - ref_nxyz_dic, - spec_dic, - device, - extra_category, - extra_rmsd): +def all_spec_config_weights(props, ref_nxyz_dic, spec_dic, device, extra_category, extra_rmsd): """ Get the "configuration weights" for each geom, i.e. the weights chosen to evenly sample each cluster @@ -363,18 +339,17 @@ def all_spec_config_weights(props, """ weight_dic = {} - num_geoms = len(props['nxyz']) - num_clusters = max([len(ref_dic['nxyz']) for - ref_dic in ref_nxyz_dic.values()]) + num_geoms = len(props["nxyz"]) + num_clusters = max([len(ref_dic["nxyz"]) for ref_dic in ref_nxyz_dic.values()]) cluster_rmsds = torch.zeros(num_geoms, num_clusters) cluster_assgn = torch.zeros(num_geoms) for spec in tqdm(list(spec_dic.keys())): idx = spec_dic[spec] - ref_nxyzs = ref_nxyz_dic[spec]['nxyz'] - ref_idx = ref_nxyz_dic[spec]['idx'] - spec_nxyz = [props['nxyz'][i] for i in idx] + ref_nxyzs = ref_nxyz_dic[spec]["nxyz"] + ref_idx = ref_nxyz_dic[spec]["idx"] + spec_nxyz = [props["nxyz"][i] for i in idx] geom_weights, these_rmsds, cluster_dic = per_spec_config_weights( spec_nxyz=spec_nxyz, ref_nxyzs=ref_nxyzs, @@ -382,7 +357,8 @@ def all_spec_config_weights(props, num_clusters=num_clusters, device=device, extra_category=extra_category, - extra_rmsd=extra_rmsd) + extra_rmsd=extra_rmsd, + ) # assign weights to each species weight_dic[spec] = geom_weights @@ -398,8 +374,7 @@ def all_spec_config_weights(props, return weight_dic, cluster_rmsds, cluster_assgn -def balanced_spec_config(weight_dic, - spec_dic): +def balanced_spec_config(weight_dic, spec_dic): """ Generate weights for geoms such that there is balance with respect to species [p(A) = p(B)], and with respect to clusters in each @@ -425,20 +400,19 @@ def balanced_spec_config(weight_dic, return all_weights -def imbalanced_spec_config(weight_dic, - spec_dic): +def imbalanced_spec_config(weight_dic, spec_dic): """ - Generate weights for geoms such that there is no balance with respect - to species [p(A) != p(B)], but there is with respect to clusters in - each species [p(A, c1) = p(A, c2), where c1 and c2 are two different - clusters in species A]. - Args: - spec_dic (dict): dictionary with indices of geoms in each species. - weight_dic (dict): dictionary of the form {smiles: geom_weights}, - where geom_weights are the set of normalized weights for - each geometry in that species. - Returns: - all_weights (torch.Tensor): normalized set of weights + Generate weights for geoms such that there is no balance with respect + to species [p(A) != p(B)], but there is with respect to clusters in + each species [p(A, c1) = p(A, c2), where c1 and c2 are two different + clusters in species A]. + Args: + spec_dic (dict): dictionary with indices of geoms in each species. + weight_dic (dict): dictionary of the form {smiles: geom_weights}, + where geom_weights are the set of normalized weights for + each geometry in that species. + Returns: + all_weights (torch.Tensor): normalized set of weights """ num_geoms = sum([i.shape[0] for i in weight_dic.values()]) @@ -486,15 +460,17 @@ def get_rand_weights(spec_dic): return balanced_spec_weights, imbalanced_spec_weights -def combine_weights(balanced_config, - imbalanced_config, - balanced_zhu, - imbalanced_zhu, - balanced_rand, - imbalanced_rand, - spec_weight, - config_weight, - zhu_weight): +def combine_weights( + balanced_config, + imbalanced_config, + balanced_zhu, + imbalanced_zhu, + balanced_rand, + imbalanced_rand, + spec_weight, + config_weight, + zhu_weight, +): """ Combine config weights, Zhu-Nakamura weights, and random weights to get the final weights for each geom. @@ -530,20 +506,19 @@ def combine_weights(balanced_config, # combination of zhu weights that are balanced and imbalanced with respect # to species - weighted_zhu = (balanced_zhu * zhu_weight * spec_weight - + imbalanced_zhu * zhu_weight * (1 - spec_weight)) + weighted_zhu = balanced_zhu * zhu_weight * spec_weight + imbalanced_zhu * zhu_weight * (1 - spec_weight) # combination of config weights that are balanced and imbalanced with # respect to species - weighted_config = (balanced_config * config_weight * spec_weight - + imbalanced_config * config_weight * (1 - spec_weight)) + weighted_config = balanced_config * config_weight * spec_weight + imbalanced_config * config_weight * ( + 1 - spec_weight + ) # combination of random weights that are balanced and imbalanced with # respect to species - rand_weight = (1 - zhu_weight - config_weight) - weighted_rand = (balanced_rand * rand_weight * spec_weight - + imbalanced_rand * rand_weight * (1 - spec_weight)) + rand_weight = 1 - zhu_weight - config_weight + weighted_rand = balanced_rand * rand_weight * spec_weight + imbalanced_rand * rand_weight * (1 - spec_weight) # final weights @@ -552,15 +527,17 @@ def combine_weights(balanced_config, return final_weights -def spec_config_zhu_balance(props, - ref_nxyz_dic, - zhu_kwargs, - spec_weight, - config_weight, - zhu_weight, - extra_category=False, - extra_rmsd=None, - device='cpu'): +def spec_config_zhu_balance( + props, + ref_nxyz_dic, + zhu_kwargs, + spec_weight, + config_weight, + zhu_weight, + extra_category=False, + extra_rmsd=None, + device="cpu", +): """ Generate weights that combine balancing of species, configurations, and Zhu-Nakamura hopping rates. @@ -601,28 +578,22 @@ def spec_config_zhu_balance(props, spec_dic=spec_dic, device=device, extra_category=extra_category, - extra_rmsd=extra_rmsd) + extra_rmsd=extra_rmsd, + ) - balanced_config = balanced_spec_config( - weight_dic=config_weight_dic, - spec_dic=spec_dic) + balanced_config = balanced_spec_config(weight_dic=config_weight_dic, spec_dic=spec_dic) - imbalanced_config = imbalanced_spec_config( - weight_dic=config_weight_dic, - spec_dic=spec_dic) + imbalanced_config = imbalanced_spec_config(weight_dic=config_weight_dic, spec_dic=spec_dic) # get the species-balanced and species-imbalanced # zhu weights - zhu_p = compute_zhu(props=props, - zhu_kwargs=zhu_kwargs) - balanced_zhu = balanced_spec_zhu(spec_dic=spec_dic, - zhu_p=zhu_p) + zhu_p = compute_zhu(props=props, zhu_kwargs=zhu_kwargs) + balanced_zhu = balanced_spec_zhu(spec_dic=spec_dic, zhu_p=zhu_p) imbalanced_zhu = imbalanced_spec_zhu(zhu_p=zhu_p) # get the random weights - balanced_rand, imbalanced_rand = get_rand_weights( - spec_dic=spec_dic) + balanced_rand, imbalanced_rand = get_rand_weights(spec_dic=spec_dic) # combine them all together @@ -635,11 +606,10 @@ def spec_config_zhu_balance(props, imbalanced_rand=imbalanced_rand, spec_weight=spec_weight, config_weight=config_weight, - zhu_weight=zhu_weight) + zhu_weight=zhu_weight, + ) # put relevant info in a dictionary - results = {"weights": final_weights, - "cluster_rmsds": cluster_rmsds, - "clusters": cluster_assgn} + results = {"weights": final_weights, "cluster_rmsds": cluster_rmsds, "clusters": cluster_assgn} return results diff --git a/nff/data/sparse.py b/nff/data/sparse.py index ba1db993..b5496952 100644 --- a/nff/data/sparse.py +++ b/nff/data/sparse.py @@ -16,8 +16,7 @@ def sparsify_tensor(tensor): if len(ij) > 0: v = tensor[ij[:, 0], ij[:, 1]] return sp.FloatTensor(ij.t(), v, tensor.size()) - else: - return 0 + return 0 def sparsify_array(array): diff --git a/nff/data/stats.py b/nff/data/stats.py index 3d68ad7d..31c3685a 100644 --- a/nff/data/stats.py +++ b/nff/data/stats.py @@ -1,7 +1,7 @@ """Module to deal with statistics of the datasets, removal of outliers and other statistical functions.""" import logging -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -16,8 +16,8 @@ def remove_outliers( array: Union[List, np.ndarray, torch.Tensor], std_away: float = 3.0, - reference_mean: float = None, - reference_std: float = None, + reference_mean: Optional[float] = None, + reference_std: Optional[float] = None, max_value: float = np.inf, ) -> Tuple[np.ndarray, np.ndarray, float, float]: """Remove outliers from given array using both a number of standard @@ -53,14 +53,8 @@ def remove_outliers( stats_array = array.copy() max_values = stats_array.copy() # used for outlier removal - if reference_mean is None: - mean = np.mean(stats_array) - else: - mean = reference_mean - if reference_std is None: - std = np.std(stats_array) - else: - std = reference_std + mean = reference_mean if reference_mean else np.mean(stats_array) + std = reference_std if reference_std else np.std(stats_array) non_outlier = np.bitwise_and(np.abs(max_values - mean) < std_away * std, max_values < max_value) non_outlier = np.arange(len(array))[non_outlier] @@ -76,8 +70,8 @@ def remove_outliers( def remove_dataset_outliers( dset: Dataset, reference_key: str = "energy", - reference_mean: float = None, - reference_std: float = None, + reference_mean: Optional[float] = None, + reference_std: Optional[float] = None, std_away: float = 3.0, max_value: float = np.inf, ) -> Tuple[Dataset, float, float]: @@ -119,7 +113,7 @@ def remove_dataset_outliers( def center_dataset( - dset: Dataset, reference_key: str = "energy", reference_value: float = None + dset: Dataset, reference_key: str = "energy", reference_value: Optional[float] = None ) -> Tuple[Dataset, float]: """Center a dataset by subtracting the mean of the reference key. @@ -151,34 +145,25 @@ def center_dataset( def get_atom_count(formula: str) -> Dict[str, int]: """Count the number of each atom type in the formula. - Parameters - ---------- - formula - The formula parameter is a string representing a chemical formula. - - Returns - ------- - a dictionary containing the count of each atom in the given chemical formula. + Args: + formula (str): A chemical formula. + Returns: + Dict[str, int]: A dictionary containing the count of each atom in the given + chemical formula. """ - - # return dictionary formula = Formula(formula) return formula.count() def all_atoms(unique_formulas: List[str]) -> set: - """Return set of all atoms in the list of formulas. + """Return a set of all atoms present in the list of formulas. - Parameters - ---------- - unique_formulas - list of strings representing the chemical formulas for which you want to count the - occurrences of each atom. + Args: + unique_formulas (List[str]): A list of chemical formula strings. - Returns - ------- - a set containing all the atoms in the list of formulas. + Returns: + set: A set containing all unique atom types found in the provided formulas. """ atom_set = set() for formula in unique_formulas: @@ -189,47 +174,32 @@ def all_atoms(unique_formulas: List[str]) -> set: def reg_atom_count(formula: str, atoms: List[str]) -> np.ndarray: - """Count the number of each specified atom type in the formula. - - Parameters - ---------- - formula - A string that represents a chemical formula. It can contain elements and - their corresponding subscripts. For example, "H2O" represents water, where "H" is the element - hydrogen and "O" is the element oxygen. The subscript "2" indicates that there are two - atoms - list of strings representing the atoms for which you want to count the - occurrences in the `formula`. - - Returns - ------- - an array containing the count of each atom in the given formula. + """Count the occurrence of specified atom types in the formula. + + Args: + formula (str): A chemical formula string. + atoms (List[str]): A list of atom types to count in the formula. + + Returns: + np.ndarray: An array containing the count of each specified atom in the formula. """ dictio = get_atom_count(formula) count_array = np.array([dictio.get(atom, 0) for atom in atoms]) - return count_array def get_stoich_dict(dset: Dataset, formula_key: str = "formula", energy_key: str = "energy") -> Dict[str, float]: - """Linear regression to find the per atom energy for each element in the dataset. - - Parameters - ---------- - dset - Dataset object containing properties for each data point. It is assumed to have a property - for the chemical formula of each data point and a property for the energy value of each data point. - formula_key, optional - key for chemical formula in the dset properties dictionary. - energy_key, optional - key for energy in the dset properties dictionary. - - Returns - ------- - a dictionary containing the stoichiometric energy coefficients for each element in the dataset. + """Determine per-atom energy coefficients via linear regression. + Args: + dset (Dataset): Dataset object containing chemical formulas and energy values. + formula_key (str, optional): Key for chemical formulas in the dataset properties. + energy_key (str, optional): Key for energy values in the dataset properties. + + Returns: + Dict[str, float]: A dictionary with stoichiometric energy coefficients for + each element, including an 'offset' representing the intercept. """ - # calculates the linear regresion and return the stoich dictionary formulas = dset.props[formula_key] energies = dset.props[energy_key] logging.debug("formulas: %s", formulas) @@ -237,7 +207,7 @@ def get_stoich_dict(dset: Dataset, formula_key: str = "formula", energy_key: str unique_formulas = list(set(formulas)) logging.debug("unique formulas: %s", unique_formulas) - # find the ground state energy for each formula/stoichiometry + # Find the ground state energy for each formula/stoichiometry. ground_en = [ min([energies[i] for i in range(len(formulas)) if formulas[i] == formula]) for formula in unique_formulas ] @@ -247,7 +217,6 @@ def get_stoich_dict(dset: Dataset, formula_key: str = "formula", energy_key: str logging.debug("unique atoms: %s", unique_atoms) x_in = np.stack([reg_atom_count(formula, unique_atoms) for formula in unique_formulas]) - y_out = np.array(ground_en) logging.debug("x_in: %s", x_in) @@ -257,14 +226,13 @@ def get_stoich_dict(dset: Dataset, formula_key: str = "formula", energy_key: str clf.fit(x_in, y_out) pred = (clf.coef_ * x_in).sum(-1) + clf.intercept_ - # pred = clf.predict(x_in) logging.info("coef: %s", clf.coef_) logging.info("intercept: %s", clf.intercept_) logging.debug("pred: %s", pred) err = abs(pred - y_out).mean() # in kcal/mol logging.info("MAE between target energy and stoich energy is %.3f kcal/mol", err) logging.info("R : %s", clf.score(x_in, y_out)) - fit_dic = {atom: coef for atom, coef in zip(unique_atoms, clf.coef_.reshape(-1))} + fit_dic = dict(zip(unique_atoms, clf.coef_.reshape(-1))) stoich_dict = {**fit_dic, "offset": clf.intercept_.item()} logging.info(stoich_dict) @@ -277,34 +245,22 @@ def perform_energy_offset( formula_key: str = "formula", energy_key: str = "energy", ) -> Dataset: - """Peform energy offset calculation on the dataset. Subtract the energy of the reference state for each atom - from the energy of each data point in the dataset. - - Parameters - ---------- - dset - Dataset object containing properties for each data point. It is assumed to have a property - for the chemical formula of each data point and a property for the energy value of each data point. - stoic_dict - a dictionary containing the stoichiometric energy coefficients for each element in the dataset. - formula_key, optional - key for chemical formula in the dset properties dictionary. - energy_key, optional - key for energy in the dset properties dictionary. - - Returns - ------- - a new dataset with the energy offset performed. + """Perform energy offset calculation by subtracting the reference energy per atom. + + Args: + dset (Dataset): Dataset object containing chemical formulas and energy values. + stoic_dict (Dict[str, float]): Dictionary with stoichiometric energy coefficients for + each element. + formula_key (str, optional): Key for chemical formulas in the dataset properties. + energy_key (str, optional): Key for energy values in the dataset properties. + Returns: + Dataset: A new dataset with the energy offset applied to each energy value. """ - # perform the energy offset formulas = dset.props[formula_key] energies = dset.props[energy_key] - if isinstance(energies, torch.Tensor): - new_energies = energies.clone() - else: - new_energies = energies.copy() + new_energies = energies.clone() if isinstance(energies, torch.Tensor) else energies.copy() for i, formula in enumerate(formulas): dictio = get_atom_count(formula) diff --git a/nff/data/utils.py b/nff/data/utils.py index c69a2454..4e6038df 100644 --- a/nff/data/utils.py +++ b/nff/data/utils.py @@ -1,8 +1,7 @@ import os import shutil -import sys import tempfile -from urllib import request as request +from urllib import request import numpy as np @@ -36,7 +35,7 @@ def get_md17_dataset(molecule, cutoff=5.0): "azobenzene_dft": "C1=CC=C(N=NC2=CC=CC=C2)C=C1", } - if molecule not in smiles_dict.keys(): + if molecule not in smiles_dict: raise ValueError("Incorrect value for molecule. Must be one of: ", list(smiles_dict.keys())) # make tmpdir to save npz file diff --git a/nff/io/ase.py b/nff/io/ase.py index 3296773e..f38c84fc 100644 --- a/nff/io/ase.py +++ b/nff/io/ase.py @@ -93,7 +93,6 @@ def get_mol_nbrs(self, r_cut=95): """ # periodic systems if np.array([atoms.pbc.any() for atoms in self.get_list_atoms()]).any(): - nbrs = [] nbrs_T = [] nbrs = [] z = [] @@ -418,25 +417,6 @@ def get_batch_T(self): """ return self.get_batch_kinetic_energy() / (1.5 * units.kB * self.props["num_atoms"].detach().cpu().numpy()) - def batch_properties(): - """This function is used to batch process properties. - It takes in a list of properties and performs some operations on them. - """ - - def batch_virial(): - """Calculate the virial for a batch of systems. - - This function calculates the virial for a batch of systems using a specific algorithm. - The virial is a measure of the internal forces within a system - and is commonly used in molecular dynamics simulations. - - Parameters: - None - - Returns: - None - """ - @classmethod def from_atoms(cls, atoms, **kwargs): """Create an instance of the class from an ASE Atoms object. diff --git a/nff/io/ase_ax.py b/nff/io/ase_ax.py index bd0d9532..ab3a52ae 100644 --- a/nff/io/ase_ax.py +++ b/nff/io/ase_ax.py @@ -1,42 +1,35 @@ import numpy as np import torch - from ase import Atoms from ase.calculators.calculator import Calculator, all_changes import nff.utils.constants as const -from nff.train import load_model -from nff.data.sparse import sparsify_array from nff.data import Dataset -from nff.nn.utils import torch_nbr_list -from nff.nn.models.schnet import SchNet, SchNetDiabat +from nff.data.sparse import sparsify_array +from nff.nn.models.cp3d import OnlyBondUpdateCP3D from nff.nn.models.hybridgraph import HybridGraphConv +from nff.nn.models.schnet import SchNet, SchNetDiabat from nff.nn.models.schnet_features import SchNetFeatures -from nff.nn.models.cp3d import OnlyBondUpdateCP3D - +from nff.nn.utils import torch_nbr_list +from nff.train import load_model DEFAULT_CUTOFF = 5.0 DEFAULT_SKIN = 1.0 DEFAULT_DIRECTED = False -CONVERSION_DIC = {"ev": 1 / const.EV_TO_KCAL_MOL, - "au": const.KCAL_TO_AU["energy"]} +CONVERSION_DIC = {"ev": 1 / const.EV_TO_KCAL_MOL, "au": const.KCAL_TO_AU["energy"]} -UNDIRECTED = [SchNet, - SchNetDiabat, - HybridGraphConv, - SchNetFeatures, - OnlyBondUpdateCP3D] +UNDIRECTED = [SchNet, SchNetDiabat, HybridGraphConv, SchNetFeatures, OnlyBondUpdateCP3D] def check_directed(model, atoms): model_cls = model.__class__.__name__ msg = f"{model_cls} needs a directed neighbor list" - assert (not atoms.undirected), msg + assert not atoms.undirected, msg class AtomsBatch(Atoms): """Class to deal with the Neural Force Field and batch several - Atoms objects. + Atoms objects. """ def __init__( @@ -47,7 +40,7 @@ def __init__( needs_angles=False, undirected=(not DEFAULT_DIRECTED), cutoff_skin=DEFAULT_SKIN, - **kwargs + **kwargs, ): """ @@ -61,16 +54,16 @@ def __init__( super().__init__(*args, **kwargs) self.props = {} if (props is None) else props.copy() - self.nbr_list = self.props.get('nbr_list', None) - self.offsets = self.props.get('offsets', None) - self.num_atoms = self.props.get('num_atoms', len(self)) + self.nbr_list = self.props.get("nbr_list", None) + self.offsets = self.props.get("offsets", None) + self.num_atoms = self.props.get("num_atoms", len(self)) self.cutoff = cutoff self.cutoff_skin = cutoff_skin self.needs_angles = needs_angles - self.kj_idx = self.props.get('kj_idx') - self.ji_idx = self.props.get('ji_idx') - self.angle_list = self.props.get('angle_list') + self.kj_idx = self.props.get("kj_idx") + self.ji_idx = self.props.get("ji_idx") + self.angle_list = self.props.get("angle_list") self.device = self.props.get("device", 0) self.undirected = undirected @@ -82,14 +75,11 @@ def get_nxyz(self): nxyz (np.array): atomic numbers + cartesian coordinates of the atoms. """ - nxyz = np.concatenate([ - self.get_atomic_numbers().reshape(-1, 1), - self.get_positions().reshape(-1, 3) - ], axis=1) + nxyz = np.concatenate([self.get_atomic_numbers().reshape(-1, 1), self.get_positions().reshape(-1, 3)], axis=1) return nxyz - def get_batch(self, device='cpu'): + def get_batch(self): """Uses the properties of Atoms to create a batch to be sent to the model. @@ -100,15 +90,15 @@ def get_batch(self, device='cpu'): if self.nbr_list is None: # or self.offsets is None: self.update_nbr_list() - self.props['nbr_list'] = self.nbr_list - self.props['angle_list'] = self.angle_list - self.props['ji_idx'] = self.ji_idx - self.props['kj_idx'] = self.kj_idx + self.props["nbr_list"] = self.nbr_list + self.props["angle_list"] = self.angle_list + self.props["ji_idx"] = self.ji_idx + self.props["kj_idx"] = self.kj_idx - self.props['offsets'] = self.offsets + self.props["offsets"] = self.offsets - self.props['nxyz'] = torch.Tensor(self.get_nxyz()) - self.props['num_atoms'] = torch.LongTensor([len(self)]) + self.props["nxyz"] = torch.Tensor(self.get_nxyz()) + self.props["num_atoms"] = torch.LongTensor([len(self)]) return self.props @@ -127,35 +117,29 @@ def update_nbr_list(self): """ if self.needs_angles: - - dataset = Dataset({key: [val] for key, val in - self.props.items()}, check_props=False) + dataset = Dataset({key: [val] for key, val in self.props.items()}, check_props=False) if "nxyz" not in dataset.props: dataset.props["nxyz"] = [self.get_nxyz()] - dataset.generate_neighbor_list((self.cutoff + self.cutoff_skin), - undirected=self.undirected) + dataset.generate_neighbor_list((self.cutoff + self.cutoff_skin), undirected=self.undirected) dataset.generate_angle_list() - self.ji_idx = dataset.props['ji_idx'][0] - self.kj_idx = dataset.props['kj_idx'][0] - self.nbr_list = dataset.props['nbr_list'][0] - self.angle_list = dataset.props['angle_list'][0] + self.ji_idx = dataset.props["ji_idx"][0] + self.kj_idx = dataset.props["kj_idx"][0] + self.nbr_list = dataset.props["nbr_list"][0] + self.angle_list = dataset.props["angle_list"][0] nbr_list = self.nbr_list if any(self.pbc): - offsets = offsets[self.nbr_list[:, 0], - self.nbr_list[:, 1], :].detach().to("cpu").numpy() + offsets = self.offsets[self.nbr_list[:, 0], self.nbr_list[:, 1], :].detach().to("cpu").numpy() else: offsets = np.zeros((self.nbr_list.shape[0], 3)) else: - edge_from, edge_to, offsets = torch_nbr_list(self, - (self.cutoff + - self.cutoff_skin), - self.device, - directed=(not self.undirected)) + edge_from, edge_to, offsets = torch_nbr_list( + self, (self.cutoff + self.cutoff_skin), self.device, directed=(not self.undirected) + ) nbr_list = torch.LongTensor(np.stack([edge_from, edge_to], axis=1)) self.nbr_list = nbr_list @@ -164,29 +148,19 @@ def update_nbr_list(self): return nbr_list, offsets - def batch_properties(): + def batch_properties(self): pass - def batch_kinetic_energy(): + def batch_kinetic_energy(self): pass - def batch_virial(): + def batch_virial(self): pass @classmethod - def from_atoms(cls, - atoms, - props=None, - needs_angles=False, - device=0, - **kwargs): + def from_atoms(cls, atoms, props=None, needs_angles=False, device=0, **kwargs): instance = cls( - atoms, - positions=atoms.positions, - numbers=atoms.numbers, - props=props, - needs_angles=needs_angles, - **kwargs + atoms, positions=atoms.positions, numbers=atoms.numbers, props=props, needs_angles=needs_angles, **kwargs ) instance.device = device @@ -196,28 +170,28 @@ def from_atoms(cls, class NeuralFF(Calculator): """ASE calculator using a pretrained NeuralFF model""" - implemented_properties = ['energy', 'forces'] + implemented_properties = ["energy", "forces"] def __init__( self, model, - device='cpu', - output_keys=['energy'], - conversion='ev', + device="cpu", + output_keys=["energy"], + conversion="ev", dataset_props=None, needs_angles=False, model_kwargs=None, - **kwargs + **kwargs, ): """Creates a NeuralFF calculator.nff/io/ase.py Args: model (TYPE): Description - device (str): device on which the calculations will be performed + device (str): device on which the calculations will be performed **kwargs: Description model (one of nff.nn.models) output_keys (list): values outputted by neural network (not including gradients) - conversion (str): conversion of output energies and forces from kcal/mol + conversion (str): conversion of output energies and forces from kcal/mol dataset_props (dict): dataset.props from an initial dataset """ @@ -234,8 +208,7 @@ def __init__( # output keys if getattr(model, "grad_keys", []): - keep_keys = [key for key in model.grad_keys - if key.replace("_grad", "") in self.output_keys] + keep_keys = [key for key in model.grad_keys if key.replace("_grad", "") in self.output_keys] if hasattr(model, "_grad_keys"): model._grad_keys = keep_keys else: @@ -245,12 +218,7 @@ def to(self, device): self.device = device self.model.to(device) - def calculate( - self, - atomsbatch=None, - properties=['energy', 'forces'], - system_changes=all_changes - ): + def calculate(self, atomsbatch=None, properties=["energy", "forces"], system_changes=all_changes): """Calculates the desired properties for the given AtomsBatch. Args: @@ -261,7 +229,7 @@ def calculate( system_changes (default from ase) """ - if not any([isinstance(self.model, i) for i in UNDIRECTED]): + if not any(isinstance(self.model, i) for i in UNDIRECTED): check_directed(self.model, atomsbatch) Calculator.calculate(self, atomsbatch, properties, system_changes) @@ -276,7 +244,7 @@ def calculate( # add keys so that the readout function can calculate these properties for key in self.output_keys: batch[key] = [] - if 'forces' in properties: + if "forces" in properties: batch[key + "_grad"] = [] kwargs = {} @@ -289,16 +257,13 @@ def calculate( # results to an empty list if len(self.output_keys) != 1: self.results["energy"] = [] - if len(self.output_keys) != 1 and 'forces' in properties: + if len(self.output_keys) != 1 and "forces" in properties: self.results["forces"] = [] for key in self.output_keys: + assert self.conversion in CONVERSION_DIC, f"Unit conversion kcal/mol to {self.conversion} not supported." - assert self.conversion in CONVERSION_DIC, "Unit conversion kcal/mol to {} not supported.".format( - self.conversion) - - value = prediction[key].detach().cpu( - ).numpy() * CONVERSION_DIC[self.conversion] + value = prediction[key].detach().cpu().numpy() * CONVERSION_DIC[self.conversion] # if you're only outputting energy, then set energy to value if len(self.output_keys) == 1: @@ -307,9 +272,8 @@ def calculate( else: self.results["energy"].append(value.reshape(-1)) - if 'forces' in properties: - value_grad = prediction[key + "_grad"].detach( - ).cpu().numpy() * CONVERSION_DIC[self.conversion] + if "forces" in properties: + value_grad = prediction[key + "_grad"].detach().cpu().numpy() * CONVERSION_DIC[self.conversion] if len(self.output_keys) == 1: self.results["forces"] = -value_grad.reshape(-1, 3) @@ -324,20 +288,13 @@ def calculate( def from_file( cls, model_path, - device='cuda', - output_keys=['energy'], - conversion='ev', + device="cuda", + output_keys=["energy"], + conversion="ev", params=None, model_type=None, needs_angles=False, - **kwargs + **kwargs, ): - model = load_model(model_path, - params=params, - model_type=model_type) - return cls(model, - device, - output_keys, - conversion, - needs_angles=needs_angles, - **kwargs) + model = load_model(model_path, params=params, model_type=model_type) + return cls(model, device, output_keys, conversion, needs_angles=needs_angles, **kwargs) diff --git a/nff/io/ase_calcs.py b/nff/io/ase_calcs.py index f2c813ce..b19a0918 100644 --- a/nff/io/ase_calcs.py +++ b/nff/io/ase_calcs.py @@ -25,6 +25,7 @@ from nff.io.ase import DEFAULT_DIRECTED, AtomsBatch from nff.nn.models.cp3d import OnlyBondUpdateCP3D from nff.nn.models.hybridgraph import HybridGraphConv +from nff.nn.models.mace import NffScaleMACE from nff.nn.models.schnet import SchNet, SchNetDiabat from nff.nn.models.schnet_features import SchNetFeatures from nff.train.builders.model import load_model @@ -85,6 +86,8 @@ def __init__( self.model_units = model_units self.prediction_units = prediction_units + print("Requested properties:", self.properties) + def to(self, device): self.device = device self.model.to(device) @@ -94,11 +97,11 @@ def log_embedding(self, jobdir, log_filename, props): sampling after calling NFF on geometries.""" log_file = os.path.join(jobdir, log_filename) + # ruff: noqa: SIM108 if os.path.exists(log_file): log = np.load(log_file) else: log = None - if log is not None: log = np.append(log, props[None, :, :, :], axis=0) else: @@ -106,8 +109,6 @@ def log_embedding(self, jobdir, log_filename, props): np.save(log_filename, log) - return - def calculate( self, atoms: AtomsBatch = None, @@ -123,7 +124,7 @@ def calculate( system_changes (default from ase) """ - if not any([isinstance(self.model, i) for i in UNDIRECTED]): + if not any(isinstance(self.model, i) for i in UNDIRECTED): check_directed(self.model, atoms) # for backwards compatability @@ -131,8 +132,6 @@ def calculate( self.properties = properties Calculator.calculate(self, atoms, self.properties, system_changes) - # TODO: update atoms only when necessary - atoms.update_nbr_list(update_atoms=True) # run model # atomsbatch = AtomsBatch(atoms) @@ -148,10 +147,8 @@ def calculate( batch[grad_key] = [] kwargs = {} - requires_stress = "stress" in self.properties requires_embedding = "embedding" in self.properties - if requires_embedding: kwargs["requires_embedding"] = True if requires_stress: @@ -175,10 +172,7 @@ def calculate( else: energy = prediction_numpy[self.en_key] - if grad_key in prediction_numpy: - energy_grad = prediction_numpy[grad_key] - else: - energy_grad = None + energy_grad = prediction_numpy.get(grad_key, None) # TODO: implement unit conversion with prediction_numpy self.results = {"energy": energy.reshape(-1)} @@ -199,14 +193,21 @@ def calculate( self.results["embedding"] = embedding if requires_stress: - stress = prediction["stress_volume"].detach().cpu().numpy() * ( - 1 / const.EV_TO_KCAL_MOL - # TODO change to more general prediction - ) - self.results["stress"] = stress * (1 / atoms.get_volume()) + if isinstance( + self.model, NffScaleMACE + ): # the implementation of stress calculation in MACE is a bit different + # and hence this is required (ASE_suit: mace/mace/calculators/mace.py) + + self.results["stress"] = ( + torch.mean(prediction["stress"], dim=0).cpu().numpy() + ) # converting to eV/Angstrom^3 + else: # for other models + stress = prediction["stress_volume"].detach().cpu().numpy() + self.results["stress"] = stress * (1 / atoms.get_volume()) if "stress_disp" in prediction: self.results["stress"] = self.results["stress"] + prediction["stress_disp"] self.results["stress"] = full_3x3_to_voigt_6_stress(self.results["stress"]) + atoms.results = self.results.copy() def get_embedding(self, atoms=None): @@ -228,7 +229,14 @@ class EnsembleNFF(Calculator): """Produces an ensemble of NFF calculators to predict the discrepancy between the properties""" - implemented_properties = ["energy", "forces", "stress", "energy_std", "forces_std", "stress_std"] + implemented_properties = [ + "energy", + "forces", + "stress", + "energy_std", + "forces_std", + "stress_std", + ] def __init__( self, @@ -285,6 +293,7 @@ def offset_energy(self, atoms, energy: Union[float, np.ndarray]): for ele, num in ads_count.items(): ref_en += num * stoidict.get(ele, 0.0) ref_en += stoidict.get("offset", 0.0) + if self.offset_units == "atomic": energy += ref_en * HARTREE_TO_EV else: @@ -329,7 +338,7 @@ def calculate( """ for model in self.models: - if not any([isinstance(model, i) for i in UNDIRECTED]): + if not any(isinstance(model, i) for i in UNDIRECTED): check_directed(model, atoms) if getattr(self, "properties", None) is None: @@ -403,11 +412,7 @@ def calculate( gradients.append(prediction_numpy["energy_grad"]) if "stress_volume" in prediction: # TODO: implement unit conversion for stress with prediction_numpy - stresses.append( - prediction["stress_volume"].detach().cpu().numpy() - * (1 / const.EV_TO_KCAL_MOL) - * (1 / atoms.get_volume()) - ) + stresses.append(prediction["stress_volume"].detach().cpu().numpy() * (1 / atoms.get_volume())) energies = np.stack(energies) gradients = np.stack(gradients) @@ -425,7 +430,7 @@ def calculate( if "e_disp" in prediction: self.results["energy"] = self.results["energy"] + prediction["e_disp"] if self.jobdir is not None and system_changes: - energy_std = self.results["energy_std"][None] + self.results["energy_std"][None] self.log_ensemble(self.jobdir, "energy_nff_ensemble.npy", energies) if "forces" in properties: @@ -434,7 +439,7 @@ def calculate( if "forces_disp" in prediction: self.results["forces"] = self.results["forces"] + prediction["forces_disp"] if self.jobdir is not None: - forces_std = self.results["forces_std"][None, :, :] + forces_std = self.results["forces_std"][None, :, :] # noqa self.log_ensemble(self.jobdir, "forces_nff_ensemble.npy", -1 * gradients) if "stress" in properties: @@ -443,7 +448,7 @@ def calculate( if "stress_disp" in prediction: self.results["stress"] = self.results["stress"] + prediction["stress_disp"] if self.jobdir is not None: - stress_std = self.results["stress_std"][None, :, :] + stress_std = self.results["stress_std"][None, :, :] # noqa self.log_ensemble(self.jobdir, "stress_nff_ensemble.npy", stresses) atoms.results = self.results.copy() @@ -457,7 +462,7 @@ def set(self, **kwargs): The special keyword 'parameters' can be used to read parameters from a file.""" changed_params = Calculator.set(self, **kwargs) - if "offset_data" in self.parameters.keys(): + if "offset_data" in self.parameters: self.offset_data = self.parameters["offset_data"] print(f"offset data: {self.offset_data} is set from parameters") @@ -477,7 +482,7 @@ def __init__(self, optimizer, nbrlist_update_freq=5): def run(self, fmax=0.2, steps=1000): epochs = steps // self.update_freq - for step in range(epochs): + for _ in range(epochs): self.optimizer.run(fmax=fmax, steps=self.update_freq) self.optimizer.atoms.update_nbr_list() @@ -614,7 +619,7 @@ def calculate( system_changes=all_changes, add_steps=True, ): - if not any([isinstance(self.model, i) for i in UNDIRECTED]): + if not any(isinstance(self.model, i) for i in UNDIRECTED): check_directed(self.model, atoms) super().calculate(atoms=atoms, properties=properties, system_changes=system_changes) @@ -781,7 +786,7 @@ def __init__( model=model, device=device, en_key=en_key, - directed=DEFAULT_DIRECTED, + directed=directed, **kwargs, ) self.V_min = V_min @@ -791,7 +796,7 @@ def __init__( self.k = self.k_0 / (self.V_max - self.V_min) def calculate(self, atoms, properties=["energy", "forces"], system_changes=all_changes): - if not any([isinstance(self.model, i) for i in UNDIRECTED]): + if not any(isinstance(self.model, i) for i in UNDIRECTED): check_directed(self.model, atoms) super().calculate(atoms=atoms, properties=properties, system_changes=system_changes) @@ -1000,7 +1005,7 @@ def setup_contraint(self, cvdic, max_steps, device): max_steps (int): maximum number of steps of the MD simulation device: device """ - for cvname, val in cvdic.items(): + for val in cvdic.values(): if val["type"].lower() == "proj_vec_plane": mol_inds = [i - 1 for i in val["mol"]] # caution check type ring_inds = [i - 1 for i in val["ring"]] @@ -1042,7 +1047,7 @@ def create_time_dependec_arrays(self, restraint_list, max_steps): kappas = [] eq_vals = [] # in case the restraint does not start at 0 - templist = list(range(0, restraint_list[0]["step"])) + templist = list(range(restraint_list[0]["step"])) steps += templist kappas += [0 for _ in templist] eq_vals += [0 for _ in templist] @@ -1157,7 +1162,7 @@ def calculate( # print("calculating ...") self.step += 1 # print("step ", self.step, self.step*0.0005) - if not any([isinstance(self.model, i) for i in UNDIRECTED]): + if not any(isinstance(self.model, i) for i in UNDIRECTED): check_directed(self.model, atoms) # for backwards compatability @@ -1204,12 +1209,12 @@ def calculate( self.results["stress"] = stress * (1 / atoms.get_volume()) with open("colvar", "a") as f: - f.write("{} ".format(self.step * 0.5)) + f.write(f"{self.step * 0.5} ") # ARREGLAR, SI YA ESTA CALCULADO PARA QUE RECALCULAR LA CVS for cv in self.hr.cvs: curr_cv_val = float(cv.get_value(torch.tensor(atoms.get_positions(), device=self.device))) - f.write(" {:.6f} ".format(curr_cv_val)) - f.write("{:.6f} \n".format(float(bias_energy))) + f.write(f" {curr_cv_val:.6f} ") + f.write(f"{float(bias_energy):.6f} \n") @classmethod def from_file(cls, model_path, device="cuda", **kwargs): diff --git a/nff/io/ase_utils.py b/nff/io/ase_utils.py index 2d9446ef..e94130ba 100644 --- a/nff/io/ase_utils.py +++ b/nff/io/ase_utils.py @@ -110,7 +110,7 @@ def __init__(self, idx, atoms, force_consts, targ_angles=None): else: self.targ_angles = np.radians(atoms.get_angles(self.idx)) - if isinstance(force_consts, float) or isinstance(force_consts, int): + if isinstance(force_consts, (float, int)): self.force_consts = np.array([float(force_consts)] * len(self.idx)) else: assert len(force_consts) == len(self.idx) @@ -163,7 +163,7 @@ def __init__(self, idx, atoms, force_consts, targ_diheds=None): else: self.targ_diheds = np.radians(atoms.get_dihedrals(self.idx)) - if isinstance(force_consts, float) or isinstance(force_consts, int): + if isinstance(force_consts, (float, int)): self.force_consts = np.array([float(force_consts)] * len(self.idx)) else: assert len(force_consts) == len(self.idx) @@ -204,7 +204,7 @@ def __init__(self, idx, atoms, force_consts, targ_lengths=None): deltas = atoms.get_positions()[idx[:, 0]] - atoms.get_positions()[idx[:, 1]] self.targ_lengths = np.linalg.norm(deltas, axis=-1) - if isinstance(force_consts, float) or isinstance(force_consts, int): + if isinstance(force_consts, (float, int)): self.force_consts = np.array([float(force_consts)] * len(self.idx)) else: assert len(force_consts) == len(self.idx) @@ -238,7 +238,7 @@ def __repr__(self): def split(array, num_atoms): shape = [-1] total_atoms = num_atoms.sum() - if not all([i == total_atoms for i in np.array(array).shape]): + if not all(i == total_atoms for i in np.array(array).shape): shape = [-1, 3] split_idx = np.cumsum(num_atoms) @@ -491,11 +491,10 @@ def step(self, f=None): g = -f if self.use_line_search: raise NotImplementedError("Not yet implemented wdith line search") - else: - self.force_calls += 1 - self.function_calls += 1 - steplengths = (self.p**2).sum(1) ** 0.5 - dr = self.determine_step(dr=self.p, steplengths=steplengths, f=f) * self.damping + self.force_calls += 1 + self.function_calls += 1 + steplengths = (self.p**2).sum(1) ** 0.5 + dr = self.determine_step(dr=self.p, steplengths=steplengths, f=f) * self.damping self.atoms.set_positions(r + dr) diff --git a/nff/io/bias_calculators.py b/nff/io/bias_calculators.py index c22d73ab..daa36b9c 100644 --- a/nff/io/bias_calculators.py +++ b/nff/io/bias_calculators.py @@ -1,19 +1,17 @@ -import numpy as np -from typing import Union, Tuple +from typing import Dict, List, Optional, Tuple, Union -from ase.calculators.calculator import Calculator, all_changes +import numpy as np from ase import units +from ase.calculators.calculator import Calculator, all_changes import nff.utils.constants as const -from nff.utils.cuda import batch_to - from nff.io.ase_calcs import NeuralFF, check_directed from nff.md.colvars import ColVar as CV - -from nff.nn.models.schnet import SchNet, SchNetDiabat +from nff.nn.models.cp3d import OnlyBondUpdateCP3D from nff.nn.models.hybridgraph import HybridGraphConv +from nff.nn.models.schnet import SchNet, SchNetDiabat from nff.nn.models.schnet_features import SchNetFeatures -from nff.nn.models.cp3d import OnlyBondUpdateCP3D +from nff.utils.cuda import batch_to DEFAULT_CUTOFF = 5.0 DEFAULT_DIRECTED = False @@ -54,12 +52,10 @@ def __init__( device="cpu", en_key="energy", directed=DEFAULT_DIRECTED, - extra_constraints: list[dict] = None, + extra_constraints: Optional[List[Dict]] = None, **kwargs, ): - NeuralFF.__init__( - self, model=model, device=device, en_key=en_key, directed=directed, **kwargs - ) + NeuralFF.__init__(self, model=model, device=device, en_key=en_key, directed=directed, **kwargs) self.cv_defs = cv_defs self.num_cv = len(cv_defs) @@ -82,31 +78,26 @@ def __init__( self.conf_k = np.zeros(shape=(self.num_cv, 1)) for ii, cv in enumerate(self.cv_defs): - if "range" in cv.keys(): + if "range" in cv: self.ext_coords[ii] = cv["range"][0] self.ranges[ii] = cv["range"] else: raise KeyError("range") - if "margin" in cv.keys(): + if "margin" in cv: self.margins[ii] = cv["margin"] - if "conf_k" in cv.keys(): + if "conf_k" in cv: self.conf_k[ii] = cv["conf_k"] - if "ext_k" in cv.keys(): + if "ext_k" in cv: self.ext_k[ii] = cv["ext_k"] - elif "ext_sigma" in cv.keys(): - self.ext_k[ii] = (units.kB * self.equil_temp) / ( - cv["ext_sigma"] * cv["ext_sigma"] - ) + elif "ext_sigma" in cv: + self.ext_k[ii] = (units.kB * self.equil_temp) / (cv["ext_sigma"] * cv["ext_sigma"]) else: raise KeyError("ext_k/ext_sigma") - if "type" not in cv.keys(): - self.cv_defs[ii]["type"] = "not_angle" - else: - self.cv_defs[ii]["type"] = cv["type"] + self.cv_defs[ii]["type"] = cv.get("type", "not_angle") self.constraints = None self.num_const = 0 @@ -118,19 +109,14 @@ def __init__( self.constraints[-1]["func"] = CV(cv["definition"]) self.constraints[-1]["pos"] = cv["pos"] - if "k" in cv.keys(): + if "k" in cv: self.constraints[-1]["k"] = cv["k"] - elif "sigma" in cv.keys(): - self.constraints[-1]["k"] = (units.kB * self.equil_temp) / ( - cv["sigma"] * cv["sigma"] - ) + elif "sigma" in cv: + self.constraints[-1]["k"] = (units.kB * self.equil_temp) / (cv["sigma"] * cv["sigma"]) else: raise KeyError("k/sigma") - if "type" not in cv.keys(): - self.constraints[-1]["type"] = "not_angle" - else: - self.constraints[-1]["type"] = cv["type"] + self.constraints[-1]["type"] = cv.get("type", "not_angle") self.num_const = len(self.constraints) @@ -147,9 +133,7 @@ def _check_boundaries(self, xi: np.ndarray): in_bounds = (xi <= self.ranges[:, 1]).all() and (xi >= self.ranges[:, 0]).all() return in_bounds - def diff( - self, a: Union[np.ndarray, float], b: Union[np.ndarray, float], cv_type: str - ) -> Union[np.ndarray, float]: + def diff(self, a: Union[np.ndarray, float], b: Union[np.ndarray, float], cv_type: str) -> Union[np.ndarray, float]: """get difference of elements of numbers or arrays in range(-inf, inf) if is_angle is False or in range(-pi, pi) if is_angle is True Args: @@ -252,9 +236,7 @@ def harmonic_constraint( constr_ener = 0.0 for i in range(self.num_const): - dxi = self.diff( - xi[i], self.constraints[i]["pos"], self.constraints[i]["type"] - ) + dxi = self.diff(xi[i], self.constraints[i]["pos"], self.constraints[i]["type"]) constr_grad += self.constraints[i]["k"] * dxi * grad_xi[i] constr_ener += 0.5 * self.constraints[i]["k"] * dxi**2 @@ -287,7 +269,7 @@ def calculate( system_changes (default from ase) """ - if not any([isinstance(self.model, i) for i in UNDIRECTED]): + if not any(isinstance(self.model, i) for i in UNDIRECTED): check_directed(self.model, atoms) # for backwards compatability @@ -314,14 +296,10 @@ def calculate( prediction = self.model(batch, **kwargs) # change energy and force to numpy array and eV - model_energy = prediction[self.en_key].detach().cpu().numpy() * ( - 1 / const.EV_TO_KCAL_MOL - ) + model_energy = prediction[self.en_key].detach().cpu().numpy() * (1 / const.EV_TO_KCAL_MOL) if grad_key in prediction: - model_grad = prediction[grad_key].detach().cpu().numpy() * ( - 1 / const.EV_TO_KCAL_MOL - ) + model_grad = prediction[grad_key].detach().cpu().numpy() * (1 / const.EV_TO_KCAL_MOL) else: raise KeyError(grad_key) @@ -339,14 +317,12 @@ def calculate( cv_grad_lens = np.zeros(shape=(self.num_cv, 1)) cv_invmass = np.zeros(shape=(self.num_cv, 1)) cv_dot_PES = np.zeros(shape=(self.num_cv, 1)) - for ii, cv_def in enumerate(self.cv_defs): + for ii, _ in enumerate(self.cv_defs): xi, xi_grad = self.the_cv[ii](atoms) cvs[ii] = xi cv_grads[ii] = xi_grad cv_grad_lens[ii] = np.linalg.norm(xi_grad) - cv_invmass[ii] = np.einsum( - "i,ii,i", xi_grad.flatten(), M_inv, xi_grad.flatten() - ) + cv_invmass[ii] = np.einsum("i,ii,i", xi_grad.flatten(), M_inv, xi_grad.flatten()) cv_dot_PES[ii] = np.dot(xi_grad.flatten(), model_grad.flatten()) self.results = { @@ -391,9 +367,7 @@ def calculate( self.results["const_vals"] = consts if requires_stress: - stress = prediction["stress_volume"].detach().cpu().numpy() * ( - 1 / const.EV_TO_KCAL_MOL - ) + stress = prediction["stress_volume"].detach().cpu().numpy() * (1 / const.EV_TO_KCAL_MOL) self.results["stress"] = stress * (1 / atoms.get_volume()) @@ -407,7 +381,8 @@ class with neural force field [["cv_type", [atom_indices], np.array([minimum, maximum]), bin_width], [possible second dimension]] equil_temp: float temperature of the simulation (important for extended system dynamics) dt: time step of the extended dynamics (has to be equal to that of the real system dyn!) - friction_per_ps: friction for the Lagevin dyn of extended system (has to be equal to that of the real system dyn!) + friction_per_ps: friction for the Lagevin dyn of extended system + (has to be equal to that of the real system dyn!) nfull: numer of samples need for full application of bias force """ @@ -439,52 +414,38 @@ def __init__( self.nfull = nfull for ii, cv in enumerate(self.cv_defs): - if "bin_width" in cv.keys(): + if "bin_width" in cv: self.ext_binwidth[ii] = cv["bin_width"] - elif "ext_sigma" in cv.keys(): + elif "ext_sigma" in cv: self.ext_binwidth[ii] = cv["ext_sigma"] else: raise KeyError("bin_width") - if "ext_pos" in cv.keys(): + if "ext_pos" in cv: # set initial position self.ext_coords[ii] = cv["ext_pos"] else: raise KeyError("ext_pos") - if "ext_mass" in cv.keys(): + if "ext_mass" in cv: self.ext_masses[ii] = cv["ext_mass"] else: raise KeyError("ext_mass") # initialize extended system at target temp of MD simulation for i in range(self.num_cv): - self.ext_vel[i] = np.random.randn() * np.sqrt( - self.equil_temp * units.kB / self.ext_masses[i] - ) + self.ext_vel[i] = np.random.randn() * np.sqrt(self.equil_temp * units.kB / self.ext_masses[i]) self.friction = friction_per_ps * 1.0e-3 / units.fs - self.rand_push = np.sqrt( - self.equil_temp - * self.friction - * self.ext_dt - * units.kB - / (2.0e0 * self.ext_masses) - ) + self.rand_push = np.sqrt(self.equil_temp * self.friction * self.ext_dt * units.kB / (2.0e0 * self.ext_masses)) self.prefac1 = 2.0 / (2.0 + self.friction * self.ext_dt) - self.prefac2 = (2.0e0 - self.friction * self.ext_dt) / ( - 2.0e0 + self.friction * self.ext_dt - ) + self.prefac2 = (2.0e0 - self.friction * self.ext_dt) / (2.0e0 + self.friction * self.ext_dt) # set up all grid accumulators for ABF self.nbins_per_dim = np.array([1 for i in range(self.num_cv)]) self.grid = [] for i in range(self.num_cv): - self.nbins_per_dim[i] = int( - np.ceil( - np.abs(self.ranges[i, 1] - self.ranges[i, 0]) / self.ext_binwidth[i] - ) - ) + self.nbins_per_dim[i] = int(np.ceil(np.abs(self.ranges[i, 1] - self.ranges[i, 0]) / self.ext_binwidth[i])) self.grid.append( np.arange( self.ranges[i, 0] + self.ext_binwidth[i] / 2, @@ -513,9 +474,7 @@ def get_index(self, xi: np.ndarray) -> tuple: """ bin_x = np.zeros(shape=xi.shape, dtype=np.int64) for i in range(self.num_cv): - bin_x[i] = int( - np.floor(np.abs(xi[i] - self.ranges[i, 0]) / self.ext_binwidth[i]) - ) + bin_x[i] = int(np.floor(np.abs(xi[i] - self.ranges[i, 0]) / self.ext_binwidth[i])) return tuple(bin_x.reshape(1, -1)[0]) def _update_bias(self, xi: np.ndarray): @@ -524,11 +483,7 @@ def _update_bias(self, xi: np.ndarray): self.ext_hist[bink] += 1 # linear ramp function - ramp = ( - 1.0 - if self.ext_hist[bink] > self.nfull - else self.ext_hist[bink] / self.nfull - ) + ramp = 1.0 if self.ext_hist[bink] > self.nfull else self.ext_hist[bink] / self.nfull for i in range(self.num_cv): # apply bias force on extended system @@ -540,8 +495,7 @@ def _update_bias(self, xi: np.ndarray): self.ext_hist[bink], self.bias[i][bink], self.m2_force[i][bink], - self.ext_k[i] - * self.diff(xi[i], self.ext_coords[i], self.cv_defs[i]["type"]), + self.ext_k[i] * self.diff(xi[i], self.ext_coords[i], self.cv_defs[i]["type"]), ) self.ext_forces[i] -= ramp * self.bias[i][bink] @@ -583,16 +537,16 @@ def _up_extvel(self): class aMDeABF(eABF): """Accelerated extended-system Adaptive Biasing Force Calculator - class with neural force field + class with neural force field - Accelerated Molecular Dynamics + Accelerated Molecular Dynamics - see: - aMD: Hamelberg et. al., J. Chem. Phys. 120, 11919 (2004); https://doi.org/10.1063/1.1755656 - GaMD: Miao et. al., J. Chem. Theory Comput. (2015); https://doi.org/10.1021/acs.jctc.5b00436 - SaMD: Zhao et. al., J. Phys. Chem. Lett. 14, 4, 1103 - 1112 (2023); https://doi.org/10.1021/acs.jpclett.2c03688 + see: + aMD: Hamelberg et. al., J. Chem. Phys. 120, 11919 (2004); https://doi.org/10.1063/1.1755656 + GaMD: Miao et. al., J. Chem. Theory Comput. (2015); https://doi.org/10.1021/acs.jctc.5b00436 + SaMD: Zhao et. al., J. Phys. Chem. Lett. 14, 4, 1103 - 1112 (2023); https://doi.org/10.1021/acs.jpclett.2c03688 - Apply global boost potential to potential energy, that is independent of Collective Variables. + Apply global boost potential to potential energy, that is independent of Collective Variables. Args: model: the neural force field model @@ -608,7 +562,8 @@ class with neural force field "SaMD: apply Sigmoid accelerated MD equil_temp: float temperature of the simulation (important for extended system dynamics) dt: time step of the extended dynamics (has to be equal to that of the real system dyn!) - friction_per_ps: friction for the Lagevin dyn of extended system (has to be equal to that of the real system dyn!) + friction_per_ps: friction for the Lagevin dyn of extended system + (has to be equal to that of the real system dyn!) nfull: numer of samples need for full application of bias force """ @@ -659,9 +614,7 @@ def __init__( ], f"Unknown aMD method {self.amd_method}" if self.amd_method == "amd": - print( - " >>> Warning: Please use GaMD or SaMD to obtain accurate free energy estimates!\n" - ) + print(" >>> Warning: Please use GaMD or SaMD to obtain accurate free energy estimates!\n") self.pot_count = 0 self.pot_var = 0.0 @@ -745,8 +698,7 @@ def _apply_boost(self, epot): if self.amd_method == "amd": amd_pot = np.square(self.E - epot) / (self.parameter + (self.E - epot)) boost_grad = ( - ((epot - self.E) * (epot - 2.0 * self.parameter - self.E)) - / np.square(epot - self.parameter - self.E) + ((epot - self.E) * (epot - 2.0 * self.parameter - self.E)) / np.square(epot - self.parameter - self.E) ) * self.amd_forces elif self.amd_method == "samd": @@ -761,17 +713,7 @@ def _apply_boost(self, epot): ) ) boost_grad = ( - -( - 1.0 - / ( - np.exp( - -self.k * (epot - self.pot_min) + np.log((1 / self.c0) - 1) - ) - + 1 - ) - - 1 - ) - * self.amd_forces + -(1.0 / (np.exp(-self.k * (epot - self.pot_min) + np.log((1 / self.c0) - 1)) + 1) - 1) * self.amd_forces ) else: @@ -790,9 +732,7 @@ def _update_pot_distribution(self, epot: float): self.pot_min = np.min([epot, self.pot_min]) self.pot_max = np.max([epot, self.pot_max]) self.pot_count += 1 - self.pot_avg, self.pot_m2, self.pot_var = welford_var( - self.pot_count, self.pot_avg, self.pot_m2, epot - ) + self.pot_avg, self.pot_m2, self.pot_var = welford_var(self.pot_count, self.pot_avg, self.pot_m2, epot) self.pot_std = np.sqrt(self.pot_var) def _calc_E_k0(self): @@ -803,9 +743,7 @@ def _calc_E_k0(self): """ if self.amd_method == "gamd_lower": self.E = self.pot_max - ko = (self.amd_parameter / self.pot_std) * ( - (self.pot_max - self.pot_min) / (self.pot_max - self.pot_avg) - ) + ko = (self.amd_parameter / self.pot_std) * ((self.pot_max - self.pot_min) / (self.pot_max - self.pot_avg)) self.k0 = np.min([1.0, ko]) @@ -820,9 +758,7 @@ def _calc_E_k0(self): self.E = self.pot_min + (self.pot_max - self.pot_min) / self.k0 elif self.amd_method == "samd": - ko = (self.amd_parameter / self.pot_std) * ( - (self.pot_max - self.pot_min) / (self.pot_max - self.pot_avg) - ) + ko = (self.amd_parameter / self.pot_std) * ((self.pot_max - self.pot_min) / (self.pot_max - self.pot_avg)) self.k0 = np.min([1.0, ko]) if (self.pot_std / self.amd_parameter) <= 1.0: @@ -831,10 +767,7 @@ def _calc_E_k0(self): self.k1 = np.max( [ 0, - ( - np.log(self.c) - + np.log((self.pot_std) / (self.amd_parameter) - 1) - ) + (np.log(self.c) + np.log((self.pot_std) / (self.amd_parameter) - 1)) / (self.pot_avg - self.pot_min), ] ) @@ -857,7 +790,8 @@ class WTMeABF(eABF): [["cv_type", [atom_indices], np.array([minimum, maximum]), bin_width], [possible second dimension]] equil_temp: float temperature of the simulation (important for extended system dynamics) dt: time step of the extended dynamics (has to be equal to that of the real system dyn!) - friction_per_ps: friction for the Lagevin dyn of extended system (has to be equal to that of the real system dyn!) + friction_per_ps: friction for the Lagevin dyn of extended system + (has to be equal to that of the real system dyn!) nfull: numer of samples need for full application of bias force hill_height: unscaled height of the MetaD Gaussian hills in eV hill_drop_freq: #steps between depositing Gaussians @@ -903,7 +837,7 @@ def __init__( self.center = [] for ii, cv in enumerate(self.cv_defs): - if "hill_std" in cv.keys(): + if "hill_std" in cv: self.hill_std[ii] = cv["hill_std"] self.hill_var[ii] = cv["hill_std"] * cv["hill_std"] else: @@ -922,11 +856,7 @@ def _update_bias(self, xi: np.ndarray): self.ext_hist[bink] += 1 # linear ramp function - ramp = ( - 1.0 - if self.ext_hist[bink] > self.nfull - else self.ext_hist[bink] / self.nfull - ) + ramp = 1.0 if self.ext_hist[bink] > self.nfull else self.ext_hist[bink] / self.nfull for i in range(self.num_cv): # apply bias force on extended system @@ -938,8 +868,7 @@ def _update_bias(self, xi: np.ndarray): self.ext_hist[bink], self.bias[i][bink], self.m2_force[i][bink], - self.ext_k[i] - * self.diff(xi[i], self.ext_coords[i], self.cv_defs[i]["type"]), + self.ext_k[i] * self.diff(xi[i], self.ext_coords[i], self.cv_defs[i]["type"]), ) self.ext_forces[i] -= ramp * self.bias[i][bink] + mtd_forces[i] @@ -974,9 +903,7 @@ def _accumulate_wtm_force(self, xi: np.ndarray) -> Tuple[list, float]: bink = self.get_index(xi) if self.call_count % self.hill_drop_freq == 0: - w = self.hill_height * np.exp( - -self.metapot[bink] / (units.kB * self.well_tempered_temp) - ) + w = self.hill_height * np.exp(-self.metapot[bink] / (units.kB * self.well_tempered_temp)) dx = self.diff(self.grid[0], xi[0], self.cv_defs[0]["type"]).reshape( -1, @@ -1006,34 +933,20 @@ def _analytic_wtm_force(self, xi: np.ndarray) -> Tuple[list, float]: ind = np.ma.indices((len(self.center),))[0] ind = np.ma.masked_array(ind) - dist_to_centers = [] - for ii in range(self.num_cv): - dist_to_centers.append( - self.diff( - xi[ii], np.asarray(self.center)[:, ii], self.cv_defs[ii]["type"] - ) - ) - - dist_to_centers = np.asarray(dist_to_centers) + dist_to_centers = np.array( + [self.diff(xi[ii], np.asarray(self.center)[:, ii], self.cv_defs[ii]["type"]) for ii in range(self.num_cv)] + ) if self.num_cv > 1: - ind[ - (abs(dist_to_centers) > 3 * self.hill_std.reshape(-1, 1)).all(axis=0) - ] = np.ma.masked + ind[(abs(dist_to_centers) > 3 * self.hill_std.reshape(-1, 1)).all(axis=0)] = np.ma.masked else: - ind[ - (abs(dist_to_centers) > 3 * self.hill_std.reshape(-1, 1)).all(axis=0) - ] = np.ma.masked + ind[(abs(dist_to_centers) > 3 * self.hill_std.reshape(-1, 1)).all(axis=0)] = np.ma.masked # can get slow in long run, so only iterate over significant elements for i in np.nditer(ind.compressed(), flags=["zerosize_ok"]): - w = self.hill_height * np.exp( - -local_pot / (units.kB * self.well_tempered_temp) - ) + w = self.hill_height * np.exp(-local_pot / (units.kB * self.well_tempered_temp)) - epot = w * np.exp( - -np.power(dist_to_centers[:, i] / self.hill_std, 2).sum() / 2.0 - ) + epot = w * np.exp(-np.power(dist_to_centers[:, i] / self.hill_std, 2).sum() / 2.0) local_pot += epot bias_force -= epot * dist_to_centers[:, i] / self.hill_var @@ -1074,12 +987,10 @@ def __init__( device="cpu", en_key="energy", directed=DEFAULT_DIRECTED, - extra_constraints: list[dict] = None, + extra_constraints: Optional[List[Dict]] = None, **kwargs, ): - NeuralFF.__init__( - self, model=model, device=device, en_key=en_key, directed=directed, **kwargs - ) + NeuralFF.__init__(self, model=model, device=device, en_key=en_key, directed=directed, **kwargs) self.gamma = gamma self.cv_defs = cv_defs @@ -1094,22 +1005,19 @@ def __init__( self.conf_k = np.zeros(shape=(self.num_cv, 1)) for ii, cv in enumerate(self.cv_defs): - if "range" in cv.keys(): + if "range" in cv: self.ext_coords[ii] = cv["range"][0] self.ranges[ii] = cv["range"] else: raise KeyError("range") - if "margin" in cv.keys(): + if "margin" in cv: self.margins[ii] = cv["margin"] - if "conf_k" in cv.keys(): + if "conf_k" in cv: self.conf_k[ii] = cv["conf_k"] - if "type" not in cv.keys(): - self.cv_defs[ii]["type"] = "not_angle" - else: - self.cv_defs[ii]["type"] = cv["type"] + self.cv_defs[ii]["type"] = cv.get("type", "not_angle") self.constraints = None self.num_const = 0 @@ -1121,25 +1029,18 @@ def __init__( self.constraints[-1]["func"] = CV(cv["definition"]) self.constraints[-1]["pos"] = cv["pos"] - if "k" in cv.keys(): + if "k" in cv: self.constraints[-1]["k"] = cv["k"] - elif "sigma" in cv.keys(): - self.constraints[-1]["k"] = (units.kB * self.equil_temp) / ( - cv["sigma"] * cv["sigma"] - ) + elif "sigma" in cv: + self.constraints[-1]["k"] = (units.kB * self.equil_temp) / (cv["sigma"] * cv["sigma"]) else: raise KeyError("k/sigma") - if "type" not in cv.keys(): - self.constraints[-1]["type"] = "not_angle" - else: - self.constraints[-1]["type"] = cv["type"] + self.constraints[-1]["type"] = cv.get("type", "not_angle") self.num_const = len(self.constraints) - def diff( - self, a: Union[np.ndarray, float], b: Union[np.ndarray, float], cv_type: str - ) -> Union[np.ndarray, float]: + def diff(self, a: Union[np.ndarray, float], b: Union[np.ndarray, float], cv_type: str) -> Union[np.ndarray, float]: """get difference of elements of numbers or arrays in range(-inf, inf) if is_angle is False or in range(-pi, pi) if is_angle is True Args: @@ -1225,9 +1126,7 @@ def harmonic_constraint( constr_ener = 0.0 for i in range(self.num_const): - dxi = self.diff( - xi[i], self.constraints[i]["pos"], self.constraints[i]["type"] - ) + dxi = self.diff(xi[i], self.constraints[i]["pos"], self.constraints[i]["type"]) constr_grad += self.constraints[i]["k"] * dxi * grad_xi[i] constr_ener += 0.5 * self.constraints[i]["k"] * dxi**2 @@ -1260,7 +1159,7 @@ def calculate( system_changes (default from ase) """ - if not any([isinstance(self.model, i) for i in UNDIRECTED]): + if not any(isinstance(self.model, i) for i in UNDIRECTED): check_directed(self.model, atoms) # for backwards compatability @@ -1287,14 +1186,10 @@ def calculate( prediction = self.model(batch, **kwargs) # change energy and force to numpy array and eV - model_energy = prediction[self.en_key].detach().cpu().numpy() * ( - 1 / const.EV_TO_KCAL_MOL - ) + model_energy = prediction[self.en_key].detach().cpu().numpy() * (1 / const.EV_TO_KCAL_MOL) if grad_key in prediction: - model_grad = prediction[grad_key].detach().cpu().numpy() * ( - 1 / const.EV_TO_KCAL_MOL - ) + model_grad = prediction[grad_key].detach().cpu().numpy() * (1 / const.EV_TO_KCAL_MOL) else: raise KeyError(grad_key) @@ -1312,14 +1207,12 @@ def calculate( cv_grad_lens = np.zeros(shape=(self.num_cv, 1)) cv_invmass = np.zeros(shape=(self.num_cv, 1)) cv_dot_PES = np.zeros(shape=(self.num_cv, 1)) - for ii, cv_def in enumerate(self.cv_defs): + for ii, _ in enumerate(self.cv_defs): xi, xi_grad = self.the_cv[ii](atoms) cvs[ii] = xi cv_grads[ii] = xi_grad cv_grad_lens[ii] = np.linalg.norm(xi_grad) - cv_invmass[ii] = np.einsum( - "i,ii,i", xi_grad.flatten(), M_inv, xi_grad.flatten() - ) + cv_invmass[ii] = np.einsum("i,ii,i", xi_grad.flatten(), M_inv, xi_grad.flatten()) cv_dot_PES[ii] = np.dot(xi_grad.flatten(), model_grad.flatten()) self.results = { @@ -1364,15 +1257,11 @@ def calculate( self.results["const_vals"] = consts if requires_stress: - stress = prediction["stress_volume"].detach().cpu().numpy() * ( - 1 / const.EV_TO_KCAL_MOL - ) + stress = prediction["stress_volume"].detach().cpu().numpy() * (1 / const.EV_TO_KCAL_MOL) self.results["stress"] = stress * (1 / atoms.get_volume()) -def welford_var( - count: float, mean: float, M2: float, newValue: float -) -> Tuple[float, float, float]: +def welford_var(count: float, mean: float, M2: float, newValue: float) -> Tuple[float, float, float]: """On-the-fly estimate of sample variance by Welford's online algorithm Args: count: current number of samples (with new one) diff --git a/nff/io/chgnet.py b/nff/io/chgnet.py index a0530113..c903aa16 100644 --- a/nff/io/chgnet.py +++ b/nff/io/chgnet.py @@ -3,7 +3,7 @@ from typing import Dict, List import torch -from chgnet.data.dataset import StructureData, StructureJsonData +from chgnet.data.dataset import StructureData from pymatgen.core.structure import Structure from pymatgen.io.ase import AseAtomsAdaptor @@ -17,25 +17,16 @@ def convert_nff_to_chgnet_structure_data( cutoff: float = 5.0, shuffle: bool = True, ): - """The function `convert_nff_to_chgnet_structure_data` converts a dataset in NFF format to a dataset in - CHGNet structure data format. - - Parameters - ---------- - dataset : Dataset - The `dataset` parameter is an object of the `Dataset` class. - cutoff : float - The `cutoff` parameter is a float value that represents the distance cutoff for constructing the - neighbor list in the conversion process. It determines the maximum distance between atoms within - which they are considered neighbors. Any atoms beyond this distance will not be included in the - neighbor list. - shuffle : bool - The `shuffle` parameter is a boolean value that determines whether the dataset should be shuffled + """ + Converts a dataset in NFF format to a dataset in CHGNet structure data format. - Returns: - ------- - a `chgnet_dataset` object of type `StructureData`. + Args: + dataset (Dataset): An object of the Dataset class. + cutoff (float, optional): Distance cutoff for constructing the neighbor list. Defaults to 5.0. + shuffle (bool, optional): Whether the dataset should be shuffled. Defaults to True. + Returns: + StructureData: A CHGNet StructureData object. """ dataset = dataset.copy() dataset.to_units("eV/atom") # convert units to eV @@ -66,29 +57,29 @@ def convert_chgnet_structure_targets_to_nff( stresses: bool = False, magmoms: bool = False, ) -> Dataset: - """Converts a dataset in CHGNet structure json data format to a dataset in - NFF format. + """ + Converts a dataset in CHGNet structure JSON data format to a dataset in NFF format. Args: - structures: List of pymatgen structures - targets: List of dictionaries containing the properties of each structure in the batch. - stresses: Whether the dataset should include stresses - magmoms: Whether the dataset should include magnetic moments + structures (List[Structure]): List of pymatgen structures. + targets (List[Dict]): List of dictionaries containing the properties of each structure. + stresses (bool, optional): Whether the dataset should include stresses. Defaults to False. + magmoms (bool, optional): Whether the dataset should include magnetic moments. Defaults to False. Returns: - NFF Dataset + Dataset: An NFF Dataset. """ energies_per_atom = [] energy_grad = [] - stresses = [] - magmoms = [] + stresses_list = [] + magmoms_list = [] for target in targets: energies_per_atom.append(target["e"]) energy_grad.append(-target["f"]) if stresses: - stresses.append(target["s"]) + stresses_list.append(target["s"]) if magmoms: - magmoms.append(target["m"]) + magmoms_list.append(target["m"]) lattice = [] num_atoms = [] # TODO: check if this is correct @@ -112,32 +103,27 @@ def convert_chgnet_structure_targets_to_nff( "units": units, } if stresses: - concated_batch["stress"] = stresses + concated_batch["stress"] = stresses_list if magmoms: - concated_batch["magmoms"] = magmoms + concated_batch["magmoms"] = magmoms_list return Dataset(concated_batch, units=units[0]) def convert_chgnet_structure_data_to_nff( structure_data: StructureData, cutoff: float = 6.0, - shuffle: bool = True, + shuffle: bool = False, ) -> Dataset: - """Converts a dataset in CHGNet structure data format to a dataset in - NFF format. - - Parameters - ---------- - structure_data : StructureData - A `structure_data` object of type `StructureData`. - cutoff - Distance cutoff for constructing the neighbor list in the conversion process. - shuffle : bool - Whether the dataset should be shuffled + """ + Converts a dataset in CHGNet structure data format to a dataset in NFF format. + + Args: + structure_data (StructureData): A CHGNet StructureData object. + cutoff (float, optional): Distance cutoff for constructing the neighbor list. Defaults to 6.0. + shuffle (bool, optional): Whether the dataset should be shuffled. Defaults to False. Returns: - ------- - a `nff_dataset` object of type `Dataset`. + Dataset: An NFF Dataset. """ pymatgen_structures = structure_data.structures energies_per_atom = structure_data.energies @@ -154,7 +140,7 @@ def convert_chgnet_structure_data_to_nff( for structure in pymatgen_structures: lattice.append(structure.lattice.matrix) nxyz.append( - torch.cat([torch.tensor([atom.species.number]), torch.tensor(atom.coords)]).tolist() for atom in structure + [torch.cat([torch.tensor([atom.species.number]), torch.tensor(atom.coords)]).tolist() for atom in structure] ) concated_batch = { @@ -172,38 +158,28 @@ def convert_chgnet_structure_data_to_nff( def convert_data_batch( - data_batch: Dict, # noqa: FA100 + data_batch: Dict, cutoff: float = 5.0, shuffle: bool = True, ): - """Converts a dataset in NFF format to a dataset in - CHGNet structure data format. - - Parameters - ---------- - data_batch : Dict - A dictionary of properties for each structure in the batch. - Basically the props in NFF Dataset - Example: - props = { - 'nxyz': [np.array([[1, 0, 0, 0], [1, 1.1, 0, 0]]), - np.array([[1, 3, 0, 0], [1, 1.1, 5, 0]])], - 'lattice': [np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], - np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])], - 'num_atoms': [2, 2], - } - cutoff : float - The `cutoff` parameter is a float value that represents the distance cutoff for constructing the - neighbor list in the conversion process. It determines the maximum distance between atoms within - which they are considered neighbors. Any atoms beyond this distance will not be included in the - neighbor list. - shuffle : bool - The `shuffle` parameter is a boolean value that determines whether the dataset should be shuffled + """ + Converts a dataset in NFF format to a dataset in CHGNet structure data format. - Returns: - ------- - a `chgnet_dataset` object of type `StructureData`. + Args: + data_batch (Dict): Dictionary of properties for each structure in the batch. + Example: + props = { + 'nxyz': [np.array([[1, 0, 0, 0], [1, 1.1, 0, 0]]), + np.array([[1, 3, 0, 0], [1, 1.1, 5, 0]])], + 'lattice': [np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])], + 'num_atoms': [2, 2], + } + cutoff (float, optional): Distance cutoff for neighbor list construction. Defaults to 5.0. + shuffle (bool, optional): Whether the dataset should be shuffled. Defaults to True. + Returns: + StructureData: A CHGNet StructureData object. """ detached_batch = batch_detach(data_batch) nxyz = detached_batch["nxyz"] diff --git a/nff/io/cprop.py b/nff/io/cprop.py index 99710dd2..53705599 100644 --- a/nff/io/cprop.py +++ b/nff/io/cprop.py @@ -4,15 +4,13 @@ import json import os + import numpy as np from nff.utils import bash_command, fprint -def get_cp_cmd(script, - config_path, - data_path, - dataset_type): +def get_cp_cmd(script, config_path, data_path, dataset_type): """ Get the string for a ChemProp command. Args: @@ -25,15 +23,11 @@ def get_cp_cmd(script, cmd (str): the chemprop command """ - cmd = (f"python {script} --config_path {config_path} " - f" --data_path {data_path} " - f" --dataset_type {dataset_type}") + cmd = f"python {script} --config_path {config_path} --data_path {data_path} --dataset_type {dataset_type}" return cmd -def cp_hyperopt(cp_folder, - hyp_folder, - rerun): +def cp_hyperopt(cp_folder, hyp_folder, rerun): """ Run hyperparameter optimization with ChemProp. Args: @@ -44,7 +38,7 @@ def cp_hyperopt(cp_folder, `hyp_folder` already exists and has the completion file `best_params.json`. Returns: - best_params (dict): best parameters from hyperparameter + best_params (dict): best parameters from hyperparameter optimization """ @@ -54,7 +48,6 @@ def cp_hyperopt(cp_folder, # If it exists and you don't want to re-run, then load it if params_exist and (not rerun): - fprint(f"Loading hyperparameter results from {param_file}\n") with open(param_file, "r") as f: @@ -71,10 +64,7 @@ def cp_hyperopt(cp_folder, data_path = config["data_path"] dataset_type = config["dataset_type"] - cmd = get_cp_cmd(hyp_script, - config_path, - data_path, - dataset_type) + cmd = get_cp_cmd(hyp_script, config_path, data_path, dataset_type) cmd += f" --config_save_path {param_file}" fprint(f"Running hyperparameter optimization in folder {hyp_folder}\n") @@ -89,8 +79,7 @@ def cp_hyperopt(cp_folder, return best_params -def cp_train(cp_folder, - train_folder): +def cp_train(cp_folder, train_folder): """ Train a chemprop model. Args: @@ -108,10 +97,7 @@ def cp_train(cp_folder, data_path = config["data_path"] dataset_type = config["dataset_type"] - cmd = get_cp_cmd(train_script, - config_path, - data_path, - dataset_type) + cmd = get_cp_cmd(train_script, config_path, data_path, dataset_type) p = bash_command(f"source activate chemprop && {cmd}") p.wait() @@ -126,22 +112,20 @@ def make_feat_paths(feat_path): paths (list): list of paths """ - if feat_path is not None: - paths = [feat_path] - else: - paths = None - return paths - - -def modify_config(base_config_path, - metric, - train_feat_path, - val_feat_path, - test_feat_path, - train_folder, - features_only, - hyp_params, - no_features): + return [feat_path] if feat_path is not None else None + + +def modify_config( + base_config_path, + metric, + train_feat_path, + val_feat_path, + test_feat_path, + train_folder, + features_only, + hyp_params, + no_features, +): """ Modify a chemprop config file with new parameters. Args: @@ -164,16 +148,17 @@ def modify_config(base_config_path, with open(base_config_path, "r") as f: config = json.load(f) - dic = {"metric": metric, - "features_path": make_feat_paths(train_feat_path), - "separate_val_features_path": make_feat_paths(val_feat_path), - "separate_test_features_path": make_feat_paths(test_feat_path), - "save_dir": train_folder, - "features_only": features_only, - **hyp_params} + dic = { + "metric": metric, + "features_path": make_feat_paths(train_feat_path), + "separate_val_features_path": make_feat_paths(val_feat_path), + "separate_test_features_path": make_feat_paths(test_feat_path), + "save_dir": train_folder, + "features_only": features_only, + **hyp_params, + } - config.update({key: val for key, val in - dic.items() if val is not None}) + config.update({key: val for key, val in dic.items() if val is not None}) if no_features: for key in list(config.keys()): @@ -188,12 +173,7 @@ def modify_config(base_config_path, json.dump(config, f, indent=4, sort_keys=True) -def modify_hyp_config(hyp_config_path, - metric, - hyp_feat_path, - hyp_folder, - features_only, - no_features): +def modify_hyp_config(hyp_config_path, metric, hyp_feat_path, hyp_folder, features_only, no_features): """ Modfiy a hyperparameter optimization config file with new parameters. Args: @@ -215,13 +195,14 @@ def modify_hyp_config(hyp_config_path, with open(hyp_config_path, "r") as f: config = json.load(f) - dic = {"metric": metric, - "features_path": make_feat_paths(hyp_feat_path), - "save_dir": hyp_folder, - "features_only": features_only} + dic = { + "metric": metric, + "features_path": make_feat_paths(hyp_feat_path), + "save_dir": hyp_folder, + "features_only": features_only, + } - config.update({key: val for key, val in - dic.items() if val is not None}) + config.update({key: val for key, val in dic.items() if val is not None}) if no_features: for key in list(config.keys()): @@ -243,7 +224,7 @@ def get_smiles(smiles_folder, name): smiles_folder (str): folder with the csvs name (str): csv file name Returns: - smiles_list (list[str]): SMILES strings + smiles_list (list[str]): SMILES strings """ path = os.path.join(smiles_folder, name) @@ -275,8 +256,7 @@ def save_smiles(smiles_folder, smiles_list, name): # no bind): file_names = [f"{name}_smiles.csv", f"{name}_full.csv"] - paths = [os.path.join(smiles_folder, name) for name in - file_names] + paths = [os.path.join(smiles_folder, name) for name in file_names] for path in paths: with open(path, "r") as f: lines = f.readlines() @@ -333,11 +313,7 @@ def make_hyperopt_csvs(smiles_folder, all_smiles): save_smiles(smiles_folder, all_smiles, name="hyperopt") -def save_hyperopt(feat_folder, - metric, - smiles_folder, - cp_save_folder, - dset_size): +def save_hyperopt(feat_folder, metric, smiles_folder, cp_save_folder, dset_size): """ Aggregate and save the train and validation SMILES for hyperparameter optimization. Args: @@ -345,7 +321,7 @@ def save_hyperopt(feat_folder, metric (str): metric with which you're evaluating the model performance smiles_folder (str): folder with the csvs cp_save_folder (str): folder in which you're saving features for chemprop use - dset_size (int, optional): maximum size of the entire dataset to use in hyperparameter + dset_size (int, optional): maximum size of the entire dataset to use in hyperparameter optimization. Returns: hyp_np_path (str): path of npz features file for hyperparameter optimization @@ -357,9 +333,8 @@ def save_hyperopt(feat_folder, for name in names: smiles_list = get_smiles(smiles_folder, f"{name}_smiles.csv") - np_save_path = os.path.join(cp_save_folder, - f"{name}_{metric}.npz") - feats = np.load(np_save_path)['features'] + np_save_path = os.path.join(cp_save_folder, f"{name}_{metric}.npz") + feats = np.load(np_save_path)["features"] all_feats.append(feats) all_smiles += smiles_list @@ -370,12 +345,10 @@ def save_hyperopt(feat_folder, all_feats = all_feats[:dset_size] # save the entire train + val dataset features - hyp_np_path = os.path.join(cp_save_folder, - f"hyperopt_{metric}.npz") + hyp_np_path = os.path.join(cp_save_folder, f"hyperopt_{metric}.npz") np.savez_compressed(hyp_np_path, features=all_feats) # save csvs for the train + val dataset - make_hyperopt_csvs(smiles_folder=smiles_folder, - all_smiles=all_smiles) + make_hyperopt_csvs(smiles_folder=smiles_folder, all_smiles=all_smiles) return hyp_np_path diff --git a/nff/io/gmm.py b/nff/io/gmm.py index 2dd1e94f..0b2735d2 100644 --- a/nff/io/gmm.py +++ b/nff/io/gmm.py @@ -12,7 +12,6 @@ # Modified by Thierry Guillemot # License: BSD 3 clause - import numpy as np from scipy import linalg from sklearn.mixture._base import BaseMixture, _check_shape @@ -67,16 +66,12 @@ def _check_weights(weights, n_components): if any(np.less(weights, 0.0)) or any(np.greater(weights, 1.0)): raise ValueError( "The parameter 'weights' should be in the range " - "[0, 1], but got max value %.5f, min value %.5f" - % (np.min(weights), np.max(weights)) + "[0, 1], but got max value %.5f, min value %.5f" % (np.min(weights), np.max(weights)) ) # check normalization if not np.allclose(np.abs(1.0 - np.sum(weights)), 0.0): - raise ValueError( - "The parameter 'weights' should be normalized, but got sum(weights) = %.5f" - % np.sum(weights) - ) + raise ValueError("The parameter 'weights' should be normalized, but got sum(weights) = %.5f" % np.sum(weights)) return weights @@ -111,12 +106,8 @@ def _check_precision_positivity(precision, covariance_type): def _check_precision_matrix(precision, covariance_type): """Check a precision matrix is symmetric and positive-definite.""" - if not ( - np.allclose(precision, precision.T) and np.all(linalg.eigvalsh(precision) > 0.0) - ): - raise ValueError( - "'%s precision' should be symmetric, positive-definite" % covariance_type - ) + if not (np.allclose(precision, precision.T) and np.all(linalg.eigvalsh(precision) > 0.0)): + raise ValueError("'%s precision' should be symmetric, positive-definite" % covariance_type) def _check_precisions_full(precisions, covariance_type): @@ -161,9 +152,7 @@ def _check_precisions(precisions, covariance_type, n_components, n_features): "diag": (n_components, n_features), "spherical": (n_components,), } - _check_shape( - precisions, precisions_shape[covariance_type], "%s precision" % covariance_type - ) + _check_shape(precisions, precisions_shape[covariance_type], "%s precision" % covariance_type) _check_precisions = { "full": _check_precisions_full, @@ -341,7 +330,7 @@ def _compute_precision_cholesky(covariances, covariance_type): ------- precisions_cholesky : array-like The cholesky decomposition of sample precisions of the current - components. The shape depends of the covariance_type. + components. The shape depends on the covariance_type. """ estimate_precision_error_message = ( "Fitting the mixture model failed because some components have " @@ -356,20 +345,16 @@ def _compute_precision_cholesky(covariances, covariance_type): for k, covariance in enumerate(covariances): try: cov_chol = linalg.cholesky(covariance, lower=True) - except linalg.LinAlgError: - raise ValueError(estimate_precision_error_message) - precisions_chol[k] = linalg.solve_triangular( - cov_chol, np.eye(n_features), lower=True - ).T + except linalg.LinAlgError as e: + raise ValueError(estimate_precision_error_message) from e + precisions_chol[k] = linalg.solve_triangular(cov_chol, np.eye(n_features), lower=True).T elif covariance_type == "tied": _, n_features = covariances.shape try: cov_chol = linalg.cholesky(covariances, lower=True) - except linalg.LinAlgError: - raise ValueError(estimate_precision_error_message) - precisions_chol = linalg.solve_triangular( - cov_chol, np.eye(n_features), lower=True - ).T + except linalg.LinAlgError as e: + raise ValueError(estimate_precision_error_message) from e + precisions_chol = linalg.solve_triangular(cov_chol, np.eye(n_features), lower=True).T else: if np.any(np.less_equal(covariances, 0.0)): raise ValueError(estimate_precision_error_message) @@ -403,9 +388,7 @@ def _compute_log_det_cholesky(matrix_chol, covariance_type, n_features): """ if covariance_type == "full": n_components, _, _ = matrix_chol.shape - log_det_chol = np.sum( - np.log(matrix_chol.reshape(n_components, -1)[:, :: n_features + 1]), 1 - ) + log_det_chol = np.sum(np.log(matrix_chol.reshape(n_components, -1)[:, :: n_features + 1]), 1) elif covariance_type == "tied": log_det_chol = np.sum(np.log(np.diag(matrix_chol))) @@ -714,9 +697,7 @@ def _check_parameters(self, X): self.weights_init = _check_weights(self.weights_init, self.n_components) if self.means_init is not None: - self.means_init = _check_means( - self.means_init, self.n_components, n_features - ) + self.means_init = _check_means(self.means_init, self.n_components, n_features) if self.precisions_init is not None: self.precisions_init = _check_precisions( @@ -737,9 +718,7 @@ def _initialize(self, X, resp): """ n_samples, _ = X.shape - weights, means, covariances = _estimate_gaussian_parameters( - X, resp, self.reg_covar, self.covariance_type - ) + weights, means, covariances = _estimate_gaussian_parameters(X, resp, self.reg_covar, self.covariance_type) weights /= n_samples self.weights_ = weights if self.weights_init is None else self.weights_init @@ -747,20 +726,13 @@ def _initialize(self, X, resp): if self.precisions_init is None: self.covariances_ = covariances - self.precisions_cholesky_ = _compute_precision_cholesky( - covariances, self.covariance_type - ) + self.precisions_cholesky_ = _compute_precision_cholesky(covariances, self.covariance_type) elif self.covariance_type == "full": self.precisions_cholesky_ = np.array( - [ - linalg.cholesky(prec_init, lower=True) - for prec_init in self.precisions_init - ] + [linalg.cholesky(prec_init, lower=True) for prec_init in self.precisions_init] ) elif self.covariance_type == "tied": - self.precisions_cholesky_ = linalg.cholesky( - self.precisions_init, lower=True - ) + self.precisions_cholesky_ = linalg.cholesky(self.precisions_init, lower=True) else: self.precisions_cholesky_ = np.sqrt(self.precisions_init) @@ -779,14 +751,10 @@ def _m_step(self, X, log_resp): X, np.exp(log_resp), self.reg_covar, self.covariance_type ) self.weights_ /= self.weights_.sum() - self.precisions_cholesky_ = _compute_precision_cholesky( - self.covariances_, self.covariance_type - ) + self.precisions_cholesky_ = _compute_precision_cholesky(self.covariances_, self.covariance_type) def _estimate_log_prob(self, X): - return _estimate_log_gaussian_prob( - X, self.means_, self.precisions_cholesky_, self.covariance_type - ) + return _estimate_log_gaussian_prob(X, self.means_, self.precisions_cholesky_, self.covariance_type) def _estimate_log_weights(self): return np.log(self.weights_) @@ -819,9 +787,7 @@ def _set_parameters(self, params): self.precisions_[k] = batched_dot_product(prec_chol, prec_chol.T) elif self.covariance_type == "tied": - self.precisions_ = batched_dot_product( - self.precisions_cholesky_, self.precisions_cholesky_.T - ) + self.precisions_ = batched_dot_product(self.precisions_cholesky_, self.precisions_cholesky_.T) else: self.precisions_ = self.precisions_cholesky_**2 @@ -855,9 +821,7 @@ def bic(self, X): bic : float The lower the better. """ - return -2 * self.score(X) * X.shape[0] + self._n_parameters() * np.log( - X.shape[0] - ) + return -2 * self.score(X) * X.shape[0] + self._n_parameters() * np.log(X.shape[0]) def aic(self, X): """Akaike information criterion for the current model on the input X. diff --git a/nff/io/mace.py b/nff/io/mace.py index b376da1f..8e4c5147 100644 --- a/nff/io/mace.py +++ b/nff/io/mace.py @@ -1,7 +1,8 @@ import logging import os import urllib -from typing import Dict, Iterable, List, Tuple, Union +from collections.abc import Iterable +from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -18,13 +19,9 @@ from nff.utils.cuda import detach # get the path to NFF models dir, which is the parent directory of this file -module_dir = os.path.abspath( - os.path.join(os.path.abspath(__file__), "..", "..", "..", "models") -) +module_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "..", "..", "..", "models")) print(module_dir) -LOCAL_MODEL_PATH = os.path.join( - module_dir, "foundation_models/mace/2023-12-03-mace-mp.model" -) +LOCAL_MODEL_PATH = os.path.join(module_dir, "foundation_models/mace/2023-12-03-mace-mp.model") MACE_URLS = dict( small="http://tinyurl.com/46jrkm3v", # 2023-12-10-mace-128-L0_energy_epoch-249.model @@ -35,13 +32,12 @@ def _check_non_zero(std): if std == 0.0: - logging.warning( - "Standard deviation of the scaling is zero, Changing to no scaling" - ) + logging.warning("Standard deviation of the scaling is zero, Changing to no scaling") std = 1.0 return std -def get_mace_mp_model_path(model: str = None) -> str: + +def get_mace_mp_model_path(model: Optional[str] = None) -> str: """Get the default MACE MP model. Replicated from the MACE codebase, Copyright (c) 2022 ACEsuit/mace and licensed under the MIT license. @@ -57,20 +53,14 @@ def get_mace_mp_model_path(model: str = None) -> str: """ if model in (None, "medium") and os.path.isfile(LOCAL_MODEL_PATH): model_path = LOCAL_MODEL_PATH - print( - f"Using local medium Materials Project MACE model for MACECalculator {model}" - ) + print(f"Using local medium Materials Project MACE model for MACECalculator {model}") elif model in (None, "small", "medium", "large") or str(model).startswith("https:"): try: checkpoint_url = ( - MACE_URLS.get(model, MACE_URLS["medium"]) - if model in (None, "small", "medium", "large") - else model + MACE_URLS.get(model, MACE_URLS["medium"]) if model in (None, "small", "medium", "large") else model ) cache_dir = os.path.expanduser("~/.cache/mace") - checkpoint_url_name = "".join( - c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_" - ) + checkpoint_url_name = "".join(c for c in os.path.basename(checkpoint_url) if c.isalnum() or c in "_") model_path = f"{cache_dir}/{checkpoint_url_name}" if not os.path.isfile(model_path): os.makedirs(cache_dir, exist_ok=True) @@ -81,13 +71,9 @@ def get_mace_mp_model_path(model: str = None) -> str: msg = f"Loading Materials Project MACE with {model_path}" print(msg) except Exception as exc: - raise RuntimeError( - "Model download failed and no local model found" - ) from exc + raise RuntimeError("Model download failed and no local model found") from exc else: - raise RuntimeError( - "Model download failed and no local model found" - ) + raise RuntimeError("Model download failed and no local model found") return model_path @@ -98,7 +84,7 @@ def get_init_kwargs_from_model(model: Union[ScaleShiftMACE, MACE]) -> dict: radial_type = "bessel" elif isinstance(model.radial_embedding.bessel_fn, GaussianBasis): radial_type = "gaussian" - + init_kwargs = { "r_max": model.r_max.item(), "num_bessel": model.radial_embedding.out_dim, @@ -113,15 +99,15 @@ def get_init_kwargs_from_model(model: Union[ScaleShiftMACE, MACE]) -> dict: "atomic_energies": model.atomic_energies_fn.atomic_energies, "avg_num_neighbors": model.interactions[0].avg_num_neighbors, "atomic_numbers": model.atomic_numbers.tolist(), - "correlation": model.products[0] - .symmetric_contractions.contractions[0] - .correlation, + "correlation": model.products[0].symmetric_contractions.contractions[0].correlation, "gate": model.readouts[-1].non_linearity.acts[0].f, "radial_MLP": model.interactions[0].conv_tp_weights.hs[1:-1], - "radial_type": radial_type + "radial_type": radial_type, } if isinstance(model, ScaleShiftMACE): - init_kwargs.update({"atomic_inter_scale":model.scale_shift.scale, "atomic_inter_shift": model.scale_shift.shift}) + init_kwargs.update( + {"atomic_inter_scale": model.scale_shift.scale, "atomic_inter_shift": model.scale_shift.shift} + ) return init_kwargs @@ -140,9 +126,8 @@ def get_atomic_number_table_from_zs(zs: Iterable[int]) -> AtomicNumberTable: z_set.add(z) return AtomicNumberTable(sorted(z_set)) -def compute_average_E0s( - train_dset: Dataset, z_table: AtomicNumberTable, desired_units: str = "eV" -) -> Dict[int, float]: + +def compute_average_E0s(train_dset: Dataset, z_table: AtomicNumberTable, desired_units: str = "eV") -> Dict[int, float]: """Function to compute the average interaction energy of each chemical element returns dictionary of E0s @@ -171,16 +156,15 @@ def compute_average_E0s( for i, z in enumerate(z_table.zs): atomic_energies_dict[z] = E0s[i] except np.linalg.LinAlgError: - logging.warning( - "Failed to compute E0s using least squares regression, using the same for all atoms" - ) + logging.warning("Failed to compute E0s using least squares regression, using the same for all atoms") atomic_energies_dict = {} - for i, z in enumerate(z_table.zs): + for z in z_table.zs: atomic_energies_dict[z] = 0.0 - + train_dset.to_units(original_units) return atomic_energies_dict + def compute_mean_rms_energy_forces( data_loader: torch.utils.data.DataLoader, atomic_energies: np.ndarray, @@ -209,7 +193,7 @@ def compute_mean_rms_energy_forces( one_hot_zs[i, z_table.z_to_index(z)] = 1 # compute atomic energies node_e0 = atomic_energies_fn(one_hot_zs) - graph_sizes = batch['num_atoms'] # list of num atoms + graph_sizes = batch["num_atoms"] # list of num atoms # given graph_sizes, transform to list of indices # index starts from 0, denoting the first graph @@ -217,17 +201,13 @@ def compute_mean_rms_energy_forces( counter = 0 batch_indices = torch.zeros(sum(graph_sizes), dtype=torch.long) for i, size in enumerate(graph_sizes): - batch_indices[counter:counter + size] = i + batch_indices[counter : counter + size] = i counter += size - + # get the graph energy - graph_e0s = scatter_sum( - src=node_e0, index=batch_indices, dim=-1, dim_size=len(graph_sizes) - ) - atom_energy_list.append( - (batch['energy'] - graph_e0s) / graph_sizes - ) # {[n_graphs], } - forces_list.append(-batch['energy_grad']) # {[n_graphs*n_atoms,3], } + graph_e0s = scatter_sum(src=node_e0, index=batch_indices, dim=-1, dim_size=len(graph_sizes)) + atom_energy_list.append((batch["energy"] - graph_e0s) / graph_sizes) # {[n_graphs], } + forces_list.append(-batch["energy_grad"]) # {[n_graphs*n_atoms,3], } atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs] forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], } @@ -237,6 +217,7 @@ def compute_mean_rms_energy_forces( rms = _check_non_zero(rms) return mean, rms + def compute_avg_num_neighbors(data_loader: torch.utils.data.DataLoader) -> float: """Compute the average number of neighbors in a dataset. @@ -249,19 +230,23 @@ def compute_avg_num_neighbors(data_loader: torch.utils.data.DataLoader) -> float num_neighbors = [] for batch in data_loader: - unique_neighbors_list = torch.unique(batch['nbr_list'], dim=0) # remove repeated neighbors + unique_neighbors_list = torch.unique(batch["nbr_list"], dim=0) # remove repeated neighbors receivers = unique_neighbors_list[:, 1] _, counts = torch.unique(receivers, return_counts=True) num_neighbors.append(counts) - avg_num_neighbors = torch.mean( - torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype()) - ) + avg_num_neighbors = torch.mean(torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype())) return detach(avg_num_neighbors, to_numpy=True).item() -def update_mace_init_params(train: Dataset, val: Dataset, train_loader: torch.utils.data.DataLoader, model_params: Dict, logger: logging.Logger = None) -> Dict[str, Union[int, float, np.ndarray, List[int]]]: +def update_mace_init_params( + train: Dataset, + val: Dataset, + train_loader: torch.utils.data.DataLoader, + model_params: Dict, + logger: Optional[logging.Logger] = None, +) -> Dict[str, Union[int, float, np.ndarray, List[int]]]: """Update the MACE model initialization parameters based values obtained from training and validation datasets. Args: @@ -276,9 +261,16 @@ def update_mace_init_params(train: Dataset, val: Dataset, train_loader: torch.ut """ if not logger: logger = logging.getLogger(__name__) - + # z_table - z_table = get_atomic_number_table_from_zs([int(z) for data_split in (train, val) for data in data_split for z in detach(data["nxyz"][:, 0], to_numpy=True)]) + z_table = get_atomic_number_table_from_zs( + [ + int(z) + for data_split in (train, val) + for data in data_split + for z in detach(data["nxyz"][:, 0], to_numpy=True) + ] + ) logger.info("Z Table %s", z_table.zs) # avg_num_neighbors @@ -290,9 +282,7 @@ def update_mace_init_params(train: Dataset, val: Dataset, train_loader: torch.ut # atomic_energies # {8: -4.930998234144857, 38: -5.8572783662579795, 77: -8.316066722236071} atomic_energies_dict = compute_average_E0s(train, z_table) - atomic_energies: np.ndarray = np.array( - [atomic_energies_dict[z] for z in z_table.zs] - ) + atomic_energies: np.ndarray = np.array([atomic_energies_dict[z] for z in z_table.zs]) logger.info("Atomic energies: %s", atomic_energies.tolist()) # mean & std @@ -309,6 +299,7 @@ def update_mace_init_params(train: Dataset, val: Dataset, train_loader: torch.ut return model_params + class NffBatch(Batch): def __init__(self, batch=None, ptr=None, **kwargs): super().__init__(batch=batch, ptr=ptr, **kwargs) @@ -321,16 +312,14 @@ def get_example(self, idx: int) -> Data: if self.__slices__ is None: raise RuntimeError( - ( - "Cannot reconstruct data list from batch because the batch " - "object was not created using `Batch.from_data_list()`." - ) + "Cannot reconstruct data list from batch because the batch " + "object was not created using `Batch.from_data_list()`." ) data = {} idx = self.num_graphs + idx if idx < 0 else idx - for key in self.__slices__.keys(): + for key in self.__slices__: item = self[key] if self.__cat_dims__[key] is None: # The item was concatenated along a new batch dimension, diff --git a/nff/io/openmm_calculators.py b/nff/io/openmm_calculators.py index a7b8caac..6430a4a0 100644 --- a/nff/io/openmm_calculators.py +++ b/nff/io/openmm_calculators.py @@ -1,180 +1,163 @@ -import os -import numpy as np -import torch -from typing import Union, Tuple -import copy - -from ase import Atoms -from ase.neighborlist import neighbor_list -from ase.calculators.calculator import Calculator, all_changes -from ase import units +from typing import Dict, List, Optional, Tuple, Union +import numpy as np +import openmm as omm import openmm.app as app import openmm.unit as unit -import openmm as omm import parmed as pmd +from ase import units +from ase.calculators.calculator import Calculator, all_changes import nff.utils.constants as const -from nff.md.colvars import ColVar - - +from nff.md.colvars import ColVar nonbondedMethod = { - 'NonPeriodic': app.CutoffNonPeriodic, - } - + "NonPeriodic": app.CutoffNonPeriodic, +} +class PropertyNotPresent(Exception): + pass class BiasBase(Calculator): """Basic Calculator class with neural force field - + Args: model: the deural force field model cv_def: lsit of Collective Variable (CV) definitions [["cv_type", [atom_indices], np.array([minimum, maximum]), bin_width], [possible second dimension]] equil_temp: float temperature of the simulation (important for extended system dynamics) """ - - implemented_properties = ['energy', 'forces', - 'energy_unbiased', 'forces_unbiased', - 'cv_vals', 'ext_pos', 'cv_invmass', - 'grad_length', 'cv_grad_lengths', - 'cv_dot_PES', 'const_vals'] - - def __init__(self, - mmparms, - cv_defs: list[dict], - equil_temp: float = 300.0, - extra_constraints: list[dict] = None, - **kwargs): - + + implemented_properties = [ + "energy", + "forces", + "energy_unbiased", + "forces_unbiased", + "cv_vals", + "ext_pos", + "cv_invmass", + "grad_length", + "cv_grad_lengths", + "cv_dot_PES", + "const_vals", + ] + + def __init__( + self, + mmparms, + cv_defs: List[Dict], + equil_temp: float = 300.0, + extra_constraints: Optional[List[Dict]] = None, + **kwargs, + ): Calculator.__init__(self, **kwargs) - - ## OpenMM setup - if 'prmtop' in mmparms.keys(): - parm = pmd.load_file(mmparms['prmtop']) - elif 'parm7' in mmparms.keys(): + + # OpenMM setup + if "prmtop" in mmparms: + parm = pmd.load_file(mmparms["prmtop"]) + elif "parm7" in mmparms: # for a list with parm7 and rst7 - parm = pmd.load_file(mmparms['parm7'], mmparms['rst7']) + parm = pmd.load_file(mmparms["parm7"], mmparms["rst7"]) else: - raise NotImplemented + raise NotImplementedError # in case we need PBC, the pdb contains the box values - pdb = app.PDBFile(mmparms['pdb']) - - system = parm.createSystem(nonbondedMethod=nonbondedMethod[mmparms['nonbonded']], - nonbondedCutoff=mmparms['nonbonded_cutoff'] * unit.nanometers, - ) - platform = omm.Platform.getPlatformByName(mmparms['platform']) - # this will not be used - integrator = omm.LangevinIntegrator( - 0.0 * unit.kelvin, 1.0 / unit.picoseconds, 0.1 * unit.picoseconds + app.PDBFile(mmparms["pdb"]) + + system = parm.createSystem( + nonbondedMethod=nonbondedMethod[mmparms["nonbonded"]], + nonbondedCutoff=mmparms["nonbonded_cutoff"] * unit.nanometers, ) + platform = omm.Platform.getPlatformByName(mmparms["platform"]) + # this will not be used + integrator = omm.LangevinIntegrator(0.0 * unit.kelvin, 1.0 / unit.picoseconds, 0.1 * unit.picoseconds) self.context = omm.Context(system, integrator, platform) - - ## BiasBase setup + + # BiasBase setup self.cv_defs = cv_defs self.num_cv = len(cv_defs) self.the_cv = [] for cv_def in self.cv_defs: self.the_cv.append(ColVar(cv_def["definition"])) - + self.equil_temp = equil_temp - - self.ext_coords = np.zeros(shape=(self.num_cv,1)) - self.ext_masses = np.zeros(shape=(self.num_cv,1)) - self.ext_forces = np.zeros(shape=(self.num_cv,1)) - self.ext_vel = np.zeros(shape=(self.num_cv,1)) - self.ext_binwidth = np.zeros(shape=(self.num_cv,1)) - self.ext_k = np.zeros(shape=(self.num_cv,)) + + self.ext_coords = np.zeros(shape=(self.num_cv, 1)) + self.ext_masses = np.zeros(shape=(self.num_cv, 1)) + self.ext_forces = np.zeros(shape=(self.num_cv, 1)) + self.ext_vel = np.zeros(shape=(self.num_cv, 1)) + self.ext_binwidth = np.zeros(shape=(self.num_cv, 1)) + self.ext_k = np.zeros(shape=(self.num_cv,)) self.ext_dt = 0.0 - - self.ranges = np.zeros(shape=(self.num_cv,2)) - self.margins = np.zeros(shape=(self.num_cv,1)) - self.conf_k = np.zeros(shape=(self.num_cv,1)) - + + self.ranges = np.zeros(shape=(self.num_cv, 2)) + self.margins = np.zeros(shape=(self.num_cv, 1)) + self.conf_k = np.zeros(shape=(self.num_cv, 1)) + for ii, cv in enumerate(self.cv_defs): - if 'range' in cv.keys(): - self.ext_coords[ii] = cv['range'][0] - self.ranges[ii] = cv['range'] + if "range" in cv: + self.ext_coords[ii] = cv["range"][0] + self.ranges[ii] = cv["range"] else: - raise PropertyNotPresent('range') - - if 'margin' in cv.keys(): - self.margins[ii] = cv['margin'] - - if 'conf_k' in cv.keys(): - self.conf_k[ii] = cv['conf_k'] - - if 'ext_k' in cv.keys(): - self.ext_k[ii] = cv['ext_k'] - elif 'ext_sigma' in cv.keys(): - self.ext_k[ii] = (units.kB * self.equil_temp) / ( - cv['ext_sigma'] * cv['ext_sigma']) - else: - raise PropertyNotPresent('ext_k/ext_sigma') + raise PropertyNotPresent("range") + + if "margin" in cv: + self.margins[ii] = cv["margin"] - - if 'type' not in cv.keys(): - self.cv_defs[ii]['type'] = 'not_angle' + if "conf_k" in cv: + self.conf_k[ii] = cv["conf_k"] + + if "ext_k" in cv: + self.ext_k[ii] = cv["ext_k"] + elif "ext_sigma" in cv: + self.ext_k[ii] = (units.kB * self.equil_temp) / (cv["ext_sigma"] * cv["ext_sigma"]) else: - self.cv_defs[ii]['type'] = cv['type'] - + raise PropertyNotPresent("ext_k/ext_sigma") + + self.cv_defs[ii]["type"] = cv.get("type", "not_angle") + self.constraints = None self.num_const = 0 - if extra_constraints != None: + if extra_constraints is not None: self.constraints = [] for cv in extra_constraints: self.constraints.append({}) - - self.constraints[-1]['func'] = CV(cv["definition"]) - - self.constraints[-1]['pos'] = cv['pos'] - if 'k' in cv.keys(): - self.constraints[-1]['k'] = cv['k'] - elif 'sigma' in cv.keys(): - self.constraints[-1]['k'] = (units.kB * self.equil_temp) / ( - cv['sigma'] * cv['sigma']) - else: - raise PropertyNotPresent('k/sigma') - - if 'type' not in cv.keys(): - self.constraints[-1]['type'] = 'not_angle' + + self.constraints[-1]["func"] = CV(cv["definition"]) + + self.constraints[-1]["pos"] = cv["pos"] + if "k" in cv: + self.constraints[-1]["k"] = cv["k"] + elif "sigma" in cv: + self.constraints[-1]["k"] = (units.kB * self.equil_temp) / (cv["sigma"] * cv["sigma"]) else: - self.constraints[-1]['type'] = cv['type'] - + raise PropertyNotPresent("k/sigma") + + self.constraints[-1]["type"] = cv.get("type", "not_angle") + self.num_const = len(self.constraints) - - - self.cvs = np.zeros(shape=(self.num_cv,1)) - self.cv_grads = np.zeros(shape=(self.num_cv, - atoms.get_positions().shape[0], - atoms.get_positions().shape[1])) - self.cv_grad_lens = np.zeros(shape=(self.num_cv,1)) - self.cv_invmass = np.zeros(shape=(self.num_cv,1)) - self.cv_dot_PES = np.zeros(shape=(self.num_cv,1)) - - + + self.cvs = np.zeros(shape=(self.num_cv, 1)) + self.cv_grads = np.zeros(shape=(self.num_cv, atoms.get_positions().shape[0], atoms.get_positions().shape[1])) + self.cv_grad_lens = np.zeros(shape=(self.num_cv, 1)) + self.cv_invmass = np.zeros(shape=(self.num_cv, 1)) + self.cv_dot_PES = np.zeros(shape=(self.num_cv, 1)) + def _update_bias(self, xi: np.ndarray): pass - + def _propagate_ext(self): pass - + def _up_extvel(self): pass - + def _check_boundaries(self, xi: np.ndarray): - in_bounds = ((xi <= self.ranges[:,1]).all() and - (xi >= self.ranges[:,0]).all()) + in_bounds = (xi <= self.ranges[:, 1]).all() and (xi >= self.ranges[:, 0]).all() return in_bounds - - def diff(self, - a: Union[np.ndarray, float], - b: Union[np.ndarray, float], - cv_type: str - ) -> Union[np.ndarray, float]: + + def diff(self, a: Union[np.ndarray, float], b: Union[np.ndarray, float], cv_type: str) -> Union[np.ndarray, float]: """get difference of elements of numbers or arrays in range(-inf, inf) if is_angle is False or in range(-pi, pi) if is_angle is True Args: @@ -187,97 +170,102 @@ def diff(self, # wrap to range(-pi,pi) for angle if isinstance(diff, np.ndarray) and cv_type == "angle": - diff[diff > np.pi] -= 2 * np.pi diff[diff < -np.pi] += 2 * np.pi elif cv_type == "angle": - if diff < -np.pi: diff += 2 * np.pi elif diff > np.pi: diff -= 2 * np.pi return diff - - def step_bias(self, - xi: np.ndarray, - grad_xi: np.ndarray, - ) -> Tuple[np.ndarray, np.ndarray]: + + def step_bias( + self, + xi: np.ndarray, + grad_xi: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray]: """energy and gradient of bias - + Args: curr_cv: current value of the cv cv_index: for multidimensional FES - + Returns: bias_ener: bias energy bias_grad: gradiant of the bias in CV space, needs to be dotted with the cv_gradient """ - + self._propagate_ext() - + bias_grad = np.zeros_like(grad_xi[0]) - bias_ener = 0.0 - + bias_ener = 0.0 + for i in range(self.num_cv): # harmonic coupling of extended coordinate to reaction coordinate - dxi = self.diff(xi[i], self.ext_coords[i], self.cv_defs[i]['type']) + dxi = self.diff(xi[i], self.ext_coords[i], self.cv_defs[i]["type"]) self.ext_forces[i] = self.ext_k[i] * dxi bias_grad += self.ext_k[i] * dxi * grad_xi[i] bias_ener += 0.5 * self.ext_k[i] * dxi**2 # harmonic walls for confinement to range of interest if self.ext_coords[i] > (self.ranges[i][1] + self.margins[i]): - r = self.diff(self.ranges[i][1] + self.margins[i], self.ext_coords[i], self.cv_defs[i]['type']) + r = self.diff(self.ranges[i][1] + self.margins[i], self.ext_coords[i], self.cv_defs[i]["type"]) self.ext_forces[i] += self.conf_k[i] * r elif self.ext_coords[i] < (self.ranges[i][0] - self.margins[i]): - r = self.diff(self.ranges[i][0] - self.margins[i], self.ext_coords[i], self.cv_defs[i]['type']) + r = self.diff(self.ranges[i][0] - self.margins[i], self.ext_coords[i], self.cv_defs[i]["type"]) self.ext_forces[i] += self.conf_k[i] * r - + self._update_bias(xi) - self._up_extvel() - + self._up_extvel() + return bias_ener, bias_grad - - - def harmonic_constraint(self, - xi: np.ndarray, - grad_xi: np.ndarray, - ) -> Tuple[np.ndarray, np.ndarray]: + + def harmonic_constraint( + self, + xi: np.ndarray, + grad_xi: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray]: """energy and gradient of additional harmonic constraint - + Args: xi: current value of constraint "CV" grad_xi: Cartesian gradient of these CVs - + Returns: constr_ener: constraint energy constr_grad: gradient of the constraint energy - + """ - + constr_grad = np.zeros_like(grad_xi[0]) - constr_ener = 0.0 - + constr_ener = 0.0 + for i in range(self.num_const): - dxi = self.diff(xi[i], self.constraints[i]['pos'], self.constraints[i]['type']) - constr_grad += self.constraints[i]['k'] * dxi * grad_xi[i] - constr_ener += 0.5 * self.constraints[i]['k'] * dxi**2 - + dxi = self.diff(xi[i], self.constraints[i]["pos"], self.constraints[i]["type"]) + constr_grad += self.constraints[i]["k"] * dxi * grad_xi[i] + constr_ener += 0.5 * self.constraints[i]["k"] * dxi**2 + return constr_ener, constr_grad - - def calculate( - self, - atoms=None, - properties=['energy', 'forces', - 'energy_unbiased', 'forces_unbiased', - 'cv_vals', 'cv_invmass', - 'grad_length', 'cv_grad_lengths', 'cv_dot_PES', 'const_vals'], - system_changes=all_changes, + self, + atoms=None, + properties=[ + "energy", + "forces", + "energy_unbiased", + "forces_unbiased", + "cv_vals", + "cv_invmass", + "grad_length", + "cv_grad_lengths", + "cv_dot_PES", + "const_vals", + ], + system_changes=all_changes, ): """Calculates the desired properties for the given AtomsBatch. @@ -288,142 +276,131 @@ def calculate( properties: list of keywords that can be present in self.results system_changes (default from ase) """ - + # for backwards compatability if getattr(self, "properties", None) is None: self.properties = properties Calculator.calculate(self, atoms, self.properties, system_changes) - + # run OpenMM - numpy_pos = 0.1 * atoms.get_positions() # A to nm - new_positions = ([omm.Vec3(*xyz.tolist()) for xyz in numpy_pos])*unit.nanometers + numpy_pos = 0.1 * atoms.get_positions() # A to nm + new_positions = ([omm.Vec3(*xyz.tolist()) for xyz in numpy_pos]) * unit.nanometers self.context.setPositions(new_positions) state = self.context.getState(getForces=True, getEnergy=True) - model_energy = (state.getPotentialEnergy().value_in_unit( - unit.kilocalories/unit.moles) / const.EV_TO_KCAL_MOL) - model_forces = (state.getForces(asNumpy=True).value_in_unit( - unit.kilocalories/(unit.moles*unit.angstroms)) / const.EV_TO_KCAL_MOL) - - inv_masses = 1. / atoms.get_masses() - M_inv = np.diag(np.repeat(inv_masses, 3).flatten()) - - for ii, cv_def in enumerate(self.cv_defs): - xi, xi_grad = self.the_cv[ii](atoms) - self.cvs[ii] = xi - self.cv_grads[ii] = xi_grad + model_energy = state.getPotentialEnergy().value_in_unit(unit.kilocalories / unit.moles) / const.EV_TO_KCAL_MOL + model_forces = ( + state.getForces(asNumpy=True).value_in_unit(unit.kilocalories / (unit.moles * unit.angstroms)) + / const.EV_TO_KCAL_MOL + ) + + inv_masses = 1.0 / atoms.get_masses() + M_inv = np.diag(np.repeat(inv_masses, 3).flatten()) + + for ii, _ in enumerate(self.cv_defs): + xi, xi_grad = self.the_cv[ii](atoms) + self.cvs[ii] = xi + self.cv_grads[ii] = xi_grad self.cv_grad_lens[ii] = np.linalg.norm(xi_grad) - self.cv_invmass[ii] = np.matmul(xi_grad.flatten(), np.matmul(M_inv, xi_grad.flatten())) - self.cv_dot_PES[ii] = np.dot(xi_grad.flatten(), model_forces.flatten()) - + self.cv_invmass[ii] = np.matmul(xi_grad.flatten(), np.matmul(M_inv, xi_grad.flatten())) + self.cv_dot_PES[ii] = np.dot(xi_grad.flatten(), model_forces.flatten()) + bias_ener, bias_grad = self.step_bias(self.cvs, self.cv_grads) energy = model_energy + bias_ener forces = model_forces - bias_grad - + if self.constraints: - consts = np.zeros(shape=(self.num_const,1)) - const_grads = np.zeros(shape=(self.num_const, - atoms.get_positions().shape[0], - atoms.get_positions().shape[1])) + consts = np.zeros(shape=(self.num_const, 1)) + const_grads = np.zeros( + shape=(self.num_const, atoms.get_positions().shape[0], atoms.get_positions().shape[1]) + ) for ii, const_dict in enumerate(self.constraints): - consts[ii], const_grads[ii] = const_dict['func'](atoms) - + consts[ii], const_grads[ii] = const_dict["func"](atoms) + const_ener, const_grad = self.harmonic_constraint(consts, const_grads) energy += const_ener forces -= const_grad - + self.results = { - 'energy': energy.reshape(-1), - 'forces': forces.reshape(-1, 3), - 'energy_unbiased': model_energy, - 'forces_unbiased': model_forces.reshape(-1, 3), - 'grad_length': np.linalg.norm(model_forces), - 'cv_vals': self.cvs, - 'cv_grad_lengths': self.cv_grad_lens, - 'cv_invmass': self.cv_invmass, - 'cv_dot_PES': self.cv_dot_PES, - 'ext_pos': self.ext_coords, + "energy": energy.reshape(-1), + "forces": forces.reshape(-1, 3), + "energy_unbiased": model_energy, + "forces_unbiased": model_forces.reshape(-1, 3), + "grad_length": np.linalg.norm(model_forces), + "cv_vals": self.cvs, + "cv_grad_lengths": self.cv_grad_lens, + "cv_invmass": self.cv_invmass, + "cv_dot_PES": self.cv_dot_PES, + "ext_pos": self.ext_coords, } - + if self.constraints: - self.results['const_vals'] = consts + self.results["const_vals"] = consts + - class eABF(BiasBase): - """extended-system Adaptive Biasing Force Calculator + """extended-system Adaptive Biasing Force Calculator class with neural force field - + Args: model: the neural force field model cv_def: lsit of Collective Variable (CV) definitions [["cv_type", [atom_indices], np.array([minimum, maximum]), bin_width], [possible second dimension]] equil_temp: float temperature of the simulation (important for extended system dynamics) dt: time step of the extended dynamics (has to be equal to that of the real system dyn!) - friction_per_ps: friction for the Lagevin dyn of extended system (has to be equal to that of the real system dyn!) + friction_per_ps: friction for the Lagevin dyn of extended system + (has to be equal to that of the real system dyn!) nfull: numer of samples need for full application of bias force """ - def __init__(self, - mmparms, - cv_defs: list[dict], - dt: float, - friction_per_ps: float, - equil_temp: float = 300.0, - nfull: int = 100, - **kwargs): - - BiasBase.__init__(self, - mmparms=mmparms, - cv_defs=cv_defs, - equil_temp=equil_temp, - **kwargs) - - + def __init__( + self, + mmparms, + cv_defs: list[dict], + dt: float, + friction_per_ps: float, + equil_temp: float = 300.0, + nfull: int = 100, + **kwargs, + ): + BiasBase.__init__(self, mmparms=mmparms, cv_defs=cv_defs, equil_temp=equil_temp, **kwargs) + self.ext_dt = dt * units.fs - self.nfull = nfull - + self.nfull = nfull + for ii, cv in enumerate(self.cv_defs): - if 'bin_width' in cv.keys(): - self.ext_binwidth[ii] = cv['bin_width'] - elif 'ext_sigma' in cv.keys(): - self.ext_binwidth[ii] = cv['ext_sigma'] + if "bin_width" in cv: + self.ext_binwidth[ii] = cv["bin_width"] + elif "ext_sigma" in cv: + self.ext_binwidth[ii] = cv["ext_sigma"] else: - raise PropertyNotPresent('bin_width') - - if 'ext_pos' in cv.keys(): + raise PropertyNotPresent("bin_width") + + if "ext_pos" in cv: # set initial position - self.ext_coords[ii] = cv['ext_pos'] + self.ext_coords[ii] = cv["ext_pos"] else: - raise PropertyNotPresent('ext_pos') - - - if 'ext_mass' in cv.keys(): - self.ext_masses[ii] = cv['ext_mass'] + raise PropertyNotPresent("ext_pos") + + if "ext_mass" in cv: + self.ext_masses[ii] = cv["ext_mass"] else: - raise PropertyNotPresent('ext_mass') - + raise PropertyNotPresent("ext_mass") + # initialize extended system at target temp of MD simulation for i in range(self.num_cv): - self.ext_vel[i] = (np.random.randn() * - np.sqrt(self.equil_temp * units.kB / - self.ext_masses[i])) - - self.friction = friction_per_ps * 1.0e-3 / units.fs - self.rand_push = np.sqrt(self.equil_temp * self.friction * - self.ext_dt * units.kB / (2.0e0 * self.ext_masses)) - self.prefac1 = 2.0 / (2.0 + self.friction * self.ext_dt) - self.prefac2 = ((2.0e0 - self.friction * self.ext_dt) / - (2.0e0 + self.friction * self.ext_dt)) - + self.ext_vel[i] = np.random.randn() * np.sqrt(self.equil_temp * units.kB / self.ext_masses[i]) + + self.friction = friction_per_ps * 1.0e-3 / units.fs + self.rand_push = np.sqrt(self.equil_temp * self.friction * self.ext_dt * units.kB / (2.0e0 * self.ext_masses)) + self.prefac1 = 2.0 / (2.0 + self.friction * self.ext_dt) + self.prefac2 = (2.0e0 - self.friction * self.ext_dt) / (2.0e0 + self.friction * self.ext_dt) # set up all grid accumulators for ABF self.nbins_per_dim = np.array([1 for i in range(self.num_cv)]) self.grid = [] for i in range(self.num_cv): - self.nbins_per_dim[i] = ( - int(np.ceil(np.abs(self.ranges[i,1] - self.ranges[i,0]) / - self.ext_binwidth[i])) - ) + self.nbins_per_dim[i] = int(np.ceil(np.abs(self.ranges[i, 1] - self.ranges[i, 0]) / self.ext_binwidth[i])) self.grid.append( np.arange( self.ranges[i, 0] + self.ext_binwidth[i] / 2, @@ -434,21 +411,15 @@ def __init__(self, self.nbins = np.prod(self.nbins_per_dim) # accumulators and conditional averages - self.bias = np.zeros( - (self.num_cv, *self.nbins_per_dim), dtype=float - ) + self.bias = np.zeros((self.num_cv, *self.nbins_per_dim), dtype=float) self.var_force = np.zeros_like(self.bias) self.m2_force = np.zeros_like(self.bias) - + self.cv_crit = np.copy(self.bias) - self.histogram = np.zeros( - self.nbins_per_dim, dtype=float - ) + self.histogram = np.zeros(self.nbins_per_dim, dtype=float) self.ext_hist = np.zeros_like(self.histogram) - - def get_index(self, xi: np.ndarray) -> tuple: """get list of bin indices for current position of CVs or extended variables Args: @@ -458,27 +429,18 @@ def get_index(self, xi: np.ndarray) -> tuple: """ bin_x = np.zeros(shape=xi.shape, dtype=np.int64) for i in range(self.num_cv): - bin_x[i] = int(np.floor(np.abs(xi[i] - self.ranges[i,0]) / - self.ext_binwidth[i])) + bin_x[i] = int(np.floor(np.abs(xi[i] - self.ranges[i, 0]) / self.ext_binwidth[i])) return tuple(bin_x.reshape(1, -1)[0]) - - - def _update_bias(self, - xi: np.ndarray): - if self._check_boundaries(self.ext_coords): + def _update_bias(self, xi: np.ndarray): + if self._check_boundaries(self.ext_coords): bink = self.get_index(self.ext_coords) self.ext_hist[bink] += 1 - + # linear ramp function - ramp = ( - 1.0 - if self.ext_hist[bink] > self.nfull - else self.ext_hist[bink] / self.nfull - ) + ramp = 1.0 if self.ext_hist[bink] > self.nfull else self.ext_hist[bink] / self.nfull for i in range(self.num_cv): - # apply bias force on extended system ( self.bias[i][bink], @@ -488,130 +450,119 @@ def _update_bias(self, self.ext_hist[bink], self.bias[i][bink], self.m2_force[i][bink], - self.ext_k[i] * - self.diff(xi[i], self.ext_coords[i], self.cv_defs[i]['type']), + self.ext_k[i] * self.diff(xi[i], self.ext_coords[i], self.cv_defs[i]["type"]), ) - self.ext_forces[i] -= ramp * self.bias[i][bink] + self.ext_forces[i] -= ramp * self.bias[i][bink] - """ + """ Not sure how this can be dumped/printed to work with the rest # xi-conditioned accumulators for CZAR - if (xi <= self.ranges[:,1]).all() and + if (xi <= self.ranges[:,1]).all() and (xi >= self.ranges[:,0]).all(): bink = self.get_index(xi) self.histogram[bink] += 1 for i in range(self.num_cv): - dx = diff(self.ext_coords[i], self.grid[i][bink[i]], + dx = diff(self.ext_coords[i], self.grid[i][bink[i]], self.cv_defs[i]['type']) self.correction_czar[i][bink] += self.ext_k[i] * dx """ - - + def _propagate_ext(self): + self.ext_rand_gauss = np.random.randn(len(self.ext_vel), 1) - self.ext_rand_gauss = np.random.randn(len(self.ext_vel),1) - self.ext_vel += self.rand_push * self.ext_rand_gauss self.ext_vel += 0.5e0 * self.ext_dt * self.ext_forces / self.ext_masses - self.ext_coords += self.prefac1 * self.ext_dt * self.ext_vel - + self.ext_coords += self.prefac1 * self.ext_dt * self.ext_vel + # wrap to range(-pi,pi) for angle for ii in range(self.num_cv): - if self.cv_defs[ii]['type'] == 'angle': + if self.cv_defs[ii]["type"] == "angle": if self.ext_coords[ii] > np.pi: - self.ext_coords[ii] -= 2*np.pi + self.ext_coords[ii] -= 2 * np.pi elif self.ext_coords[ii] < -np.pi: - self.ext_coords[ii] += 2*np.pi - - + self.ext_coords[ii] += 2 * np.pi + def _up_extvel(self): - self.ext_vel *= self.prefac2 - self.ext_vel += self.rand_push * self.ext_rand_gauss + self.ext_vel += self.rand_push * self.ext_rand_gauss self.ext_vel += 0.5e0 * self.ext_dt * self.ext_forces / self.ext_masses class WTMeABF(eABF): - """Well tempered MetaD extended-system Adaptive Biasing Force Calculator + """Well tempered MetaD extended-system Adaptive Biasing Force Calculator based on eABF class - + Args: model: the neural force field model cv_def: lsit of Collective Variable (CV) definitions [["cv_type", [atom_indices], np.array([minimum, maximum]), bin_width], [possible second dimension]] equil_temp: float temperature of the simulation (important for extended system dynamics) dt: time step of the extended dynamics (has to be equal to that of the real system dyn!) - friction_per_ps: friction for the Lagevin dyn of extended system (has to be equal to that of the real system dyn!) + friction_per_ps: friction for the Langevin dyn of extended system + (has to be equal to that of the real system dyn!) nfull: numer of samples need for full application of bias force hill_height: unscaled height of the MetaD Gaussian hills in eV hill_drop_freq: #steps between depositing Gaussians well_tempered_temp: ficticious temperature for the well-tempered scaling """ - def __init__(self, - mmparms, - cv_defs: list[dict], - dt: float, - friction_per_ps: float, - equil_temp: float = 300.0, - nfull: int = 100, - hill_height: float = 0.0, - hill_drop_freq: int = 20, - well_tempered_temp: float = 4000.0, - **kwargs): - - eABF.__init__(self, - mmparms=mmparms, - cv_defs=cv_defs, - equil_temp=equil_temp, - dt=dt, - friction_per_ps=friction_per_ps, - nfull=nfull, - **kwargs) - - self.hill_height = hill_height - self.hill_drop_freq = hill_drop_freq - self.hill_std = np.zeros(shape=(self.num_cv)) - self.hill_var = np.zeros(shape=(self.num_cv)) + def __init__( + self, + mmparms, + cv_defs: list[dict], + dt: float, + friction_per_ps: float, + equil_temp: float = 300.0, + nfull: int = 100, + hill_height: float = 0.0, + hill_drop_freq: int = 20, + well_tempered_temp: float = 4000.0, + **kwargs, + ): + eABF.__init__( + self, + mmparms=mmparms, + cv_defs=cv_defs, + equil_temp=equil_temp, + dt=dt, + friction_per_ps=friction_per_ps, + nfull=nfull, + **kwargs, + ) + + self.hill_height = hill_height + self.hill_drop_freq = hill_drop_freq + self.hill_std = np.zeros(self.num_cv) + self.hill_var = np.zeros(self.num_cv) self.well_tempered_temp = well_tempered_temp - self.call_count = 0 - self.center = [] - + self.call_count = 0 + self.center = [] + for ii, cv in enumerate(self.cv_defs): - if 'hill_std' in cv.keys(): - self.hill_std[ii] = cv['hill_std'] - self.hill_var[ii] = cv['hill_std']*cv['hill_std'] + if "hill_std" in cv: + self.hill_std[ii] = cv["hill_std"] + self.hill_var[ii] = cv["hill_std"] * cv["hill_std"] else: - raise PropertyNotPresent('hill_std') - + raise PropertyNotPresent("hill_std") # set up all grid for MetaD potential self.metapot = np.zeros_like(self.histogram) - self.metaforce = np.zeros_like(self.bias) - - - def _update_bias(self, - xi: np.ndarray): - + self.metaforce = np.zeros_like(self.bias) + + def _update_bias(self, xi: np.ndarray): mtd_forces = self.get_wtm_force(self.ext_coords) self.call_count += 1 - - if self._check_boundaries(self.ext_coords): + if self._check_boundaries(self.ext_coords): bink = self.get_index(self.ext_coords) self.ext_hist[bink] += 1 - + # linear ramp function - ramp = ( - 1.0 - if self.ext_hist[bink] > self.nfull - else self.ext_hist[bink] / self.nfull - ) + ramp = 1.0 if self.ext_hist[bink] > self.nfull else self.ext_hist[bink] / self.nfull for i in range(self.num_cv): - # apply bias force on extended system ( self.bias[i][bink], @@ -621,12 +572,10 @@ def _update_bias(self, self.ext_hist[bink], self.bias[i][bink], self.m2_force[i][bink], - self.ext_k[i] * - self.diff(xi[i], self.ext_coords[i], self.cv_defs[i]['type']), + self.ext_k[i] * self.diff(xi[i], self.ext_coords[i], self.cv_defs[i]["type"]), ) self.ext_forces[i] -= ramp * self.bias[i][bink] + mtd_forces[i] - def get_wtm_force(self, xi: np.ndarray) -> np.ndarray: """compute well-tempered metadynamics bias force from superposition of gaussian hills Args: @@ -634,21 +583,20 @@ def get_wtm_force(self, xi: np.ndarray) -> np.ndarray: Returns: bias_force: bias force from metadynamics """ - + is_in_bounds = self._check_boundaries(xi) - + if (self.call_count % self.hill_drop_freq == 0) and is_in_bounds: - self.center.append(np.copy(xi.reshape(-1))) - + self.center.append(np.copy(xi.reshape(-1))) + if is_in_bounds and self.num_cv == 1: bias_force, _ = self._accumulate_wtm_force(xi) else: bias_force, _ = self._analytic_wtm_force(xi) - - return bias_force - def _accumulate_wtm_force(self, - xi: np.ndarray) -> Tuple[list, float]: + return bias_force + + def _accumulate_wtm_force(self, xi: np.ndarray) -> Tuple[list, float]: """compute numerical WTM bias force from a grid Right now this works only for 1D CVs Args: @@ -659,20 +607,17 @@ def _accumulate_wtm_force(self, bink = self.get_index(xi) if self.call_count % self.hill_drop_freq == 0: + w = self.hill_height * np.exp(-self.metapot[bink] / (units.kB * self.well_tempered_temp)) - w = self.hill_height * np.exp( - -self.metapot[bink] - / (units.kB * self.well_tempered_temp) + dx = self.diff(self.grid[0], xi[0], self.cv_defs[0]["type"]).reshape( + -1, ) - - dx = self.diff(self.grid[0], xi[0], self.cv_defs[0]['type']).reshape(-1,) epot = w * np.exp(-(dx * dx) / (2.0 * self.hill_var[0])) self.metapot += epot self.metaforce[0] -= epot * dx / self.hill_var[0] return self.metaforce[:, bink], self.metapot[bink] - def _analytic_wtm_force(self, xi: np.ndarray) -> Tuple[list, float]: """compute analytic WTM bias force from sum of gaussians hills Args: @@ -688,45 +633,31 @@ def _analytic_wtm_force(self, xi: np.ndarray) -> Tuple[list, float]: if len(self.center) == 0: print(" >>> Warning: no metadynamics hills stored") return bias_force - + ind = np.ma.indices((len(self.center),))[0] ind = np.ma.masked_array(ind) - - dist_to_centers = [] - for ii in range(self.num_cv): - dist_to_centers.append(self.diff(xi[ii], np.asarray(self.center)[:,ii], self.cv_defs[ii]['type'])) - - dist_to_centers = np.asarray(dist_to_centers) - + + dist_to_centers = np.array( + [self.diff(xi[ii], np.asarray(self.center)[:, ii], self.cv_defs[ii]["type"]) for ii in range(self.num_cv)] + ) + if self.num_cv > 1: - ind[(abs(dist_to_centers) > 3 * self.hill_std.reshape(-1,1)).all(axis=0)] = np.ma.masked + ind[(abs(dist_to_centers) > 3 * self.hill_std.reshape(-1, 1)).all(axis=0)] = np.ma.masked else: - ind[(abs(dist_to_centers) > 3 * self.hill_std.reshape(-1,1)).all(axis=0)] = np.ma.masked - + ind[(abs(dist_to_centers) > 3 * self.hill_std.reshape(-1, 1)).all(axis=0)] = np.ma.masked + # can get slow in long run, so only iterate over significant elements for i in np.nditer(ind.compressed(), flags=["zerosize_ok"]): - w = self.hill_height * np.exp( - -local_pot / (units.kB * self.well_tempered_temp) - ) - - epot = w * np.exp(-np.power(dist_to_centers[:,i]/self.hill_std, 2).sum() / 2.0) - local_pot += epot - bias_force -= epot * dist_to_centers[:,i] / self.hill_var - - return bias_force.reshape(-1,1), local_pot - - - - + w = self.hill_height * np.exp(-local_pot / (units.kB * self.well_tempered_temp)) + epot = w * np.exp(-np.power(dist_to_centers[:, i] / self.hill_std, 2).sum() / 2.0) + local_pot += epot + bias_force -= epot * dist_to_centers[:, i] / self.hill_var + return bias_force.reshape(-1, 1), local_pot -def welford_var( - count: float, - mean: float, - M2: float, - newValue: float) -> Tuple[float, float, float]: +def welford_var(count: float, mean: float, M2: float, newValue: float) -> Tuple[float, float, float]: """On-the-fly estimate of sample variance by Welford's online algorithm Args: count: current number of samples (with new one) diff --git a/nff/md/TI.py b/nff/md/TI.py index 55404b16..474f6f56 100644 --- a/nff/md/TI.py +++ b/nff/md/TI.py @@ -87,7 +87,7 @@ def run(self): self.atomsbatch.props["aggr_wgt"] = self.init_aggr - for step in range(epochs): + for _step in range(epochs): self.integrator.run(self.mdparam["nbr_list_update_freq"]) self.atomsbatch.update_nbr_list() self.atomsbatch.props["aggr_wgt"] += dlambda diff --git a/nff/md/aims/calcs/basis.py b/nff/md/aims/calcs/basis.py index 41c32366..8ed3fff6 100644 --- a/nff/md/aims/calcs/basis.py +++ b/nff/md/aims/calcs/basis.py @@ -1,8 +1,9 @@ -from torch.utils.data import DataLoader -import torch -import numpy as np import copy +import numpy as np +import torch +from torch.utils.data import DataLoader + from nff.data import Dataset, collate_dicts from nff.train.evaluate import evaluate @@ -51,7 +52,7 @@ def dgamma_dt(en, p, m): """ reshape_m = m.reshape(1, -1, 1) - deriv = -en + (p ** 2 / (2 * reshape_m)).sum((1, 2)) + deriv = -en + (p**2 / (2 * reshape_m)).sum((1, 2)) return deriv @@ -65,8 +66,7 @@ def to_dset(r, atom_nums, nbrs, gen_nbrs): """ atom_num_reshape = atom_nums.reshape(-1, 1) - nxyz = [torch.cat([atom_num_reshape, xyz], dim=-1) - for xyz in r] + nxyz = [torch.cat([atom_num_reshape, xyz], dim=-1) for xyz in r] dataset = Dataset(props={"nxyz": nxyz}) @@ -78,14 +78,7 @@ def to_dset(r, atom_nums, nbrs, gen_nbrs): return dataset -def get_engrad(r, - atom_nums, - nbrs, - gen_nbrs, - batch_size, - device, - model, - diabat_keys): +def get_engrad(r, atom_nums, nbrs, gen_nbrs, batch_size, device, model, diabat_keys): """ Args: r (torch.Tensor): a position tensor of dimension N_J x N_at x 3, @@ -93,53 +86,25 @@ def get_engrad(r, J and N_at is the number of atoms. """ - dataset = to_dset(r=r, - atom_nums=atom_nums, - nbrs=nbrs, - gen_nbrs=gen_nbrs) + dataset = to_dset(r=r, atom_nums=atom_nums, nbrs=nbrs, gen_nbrs=gen_nbrs) - loader = DataLoader(dataset, - batch_size=batch_size, - collate_fn=collate_dicts) + loader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_dicts) - results, _, _ = evaluate(model=model, - loader=loader, - loss_fn=lambda x, y: 0, - device=device, - debatch=True) + results, _, _ = evaluate(model=model, loader=loader, loss_fn=lambda x, y: 0, device=device, debatch=True) for key, val in results.items(): - if key.endswith("_grad") or key.startswith('nacv_'): + if key.endswith("_grad") or key.startswith("nacv_"): results[key] = torch.stack(val) return results, dataset -def compute_derivs(r, - m, - atom_num, - p, - nbrs, - gen_nbrs, - batch_size, - device, - model, - diabat_keys, - diabatic): - - results, dataset = get_engrad(r, - atom_num, - nbrs, - gen_nbrs, - batch_size, - device, - model, - diabat_keys) +def compute_derivs(r, m, atom_num, p, nbrs, gen_nbrs, batch_size, device, model, diabat_keys, diabatic): + results, dataset = get_engrad(r, atom_num, nbrs, gen_nbrs, batch_size, device, model, diabat_keys) num_states = len(diabat_keys) derivs = [] for i in range(num_states): - if diabatic: en_key = diabat_keys[i][i] grad_key = en_key + "_grad" @@ -153,22 +118,14 @@ def compute_derivs(r, p_deriv = dp_dt(en_grad) r_deriv = dr_dt(p, m) - dic = {"gamma": gamma_deriv, - "p": p_deriv, - "r": r_deriv} + dic = {"gamma": gamma_deriv, "p": p_deriv, "r": r_deriv} derivs.append(dic) return derivs -def overlap_formula(expand_r_i, - expand_r_j, - expand_alpha_i, - expand_alpha_j, - expand_p_i, - expand_p_j): - +def overlap_formula(expand_r_i, expand_r_j, expand_alpha_i, expand_alpha_j, expand_p_i, expand_p_j): r_i = expand_r_i.numpy() r_j = expand_r_j.numpy() alpha_i = expand_alpha_i.numpy() @@ -176,17 +133,16 @@ def overlap_formula(expand_r_i, p_i = expand_p_i.numpy() p_j = expand_p_j.numpy() - A = (-alpha_j * r_j ** 2 - alpha_i * r_i ** 2 - + 1j * p_j * (-r_j) - 1j * p_i * (-r_i)) + A = -alpha_j * r_j**2 - alpha_i * r_i**2 + 1j * p_j * (-r_j) - 1j * p_i * (-r_i) B = alpha_i + alpha_j - C = (2 * alpha_j * r_j + 1j * p_j - + 2 * alpha_i * r_i - 1j * p_i) + C = 2 * alpha_j * r_j + 1j * p_j + 2 * alpha_i * r_i - 1j * p_i # has dimension N_I x N_J x N_at x 3 - overlaps = ((2 / np.pi) ** 0.5 * (alpha_i * alpha_j) ** 0.25 - * np.exp(A) * (np.pi / B) ** 0.5 * np.exp(C ** 2 / (4 * B))) + overlaps = ( + (2 / np.pi) ** 0.5 * (alpha_i * alpha_j) ** 0.25 * np.exp(A) * (np.pi / B) ** 0.5 * np.exp(C**2 / (4 * B)) + ) # take the product over the last two dimensions N_I = expand_r_i.shape[0] @@ -201,15 +157,7 @@ def overlap_formula(expand_r_i, return overlap_prod -def tile_params(r_i, - r_j, - p_i, - p_j, - alpha_i, - alpha_j, - m_i=None, - m_j=None): - +def tile_params(r_i, r_j, p_i, p_j, alpha_i, alpha_j, m_i=None, m_j=None): N_I = r_i.shape[0] N_J = r_i.shape[0] N_at = r_i.shape[1] @@ -220,35 +168,20 @@ def tile_params(r_i, expand_p_i = p_i.expand(N_J, N_I, N_at, 3).transpose(0, 1) expand_p_j = p_j.expand(N_I, N_J, N_at, 3) - expand_alpha_i = (alpha_i.reshape(1, 1, N_at, 1) - .expand(N_I, N_J, N_at, 3)) + expand_alpha_i = alpha_i.reshape(1, 1, N_at, 1).expand(N_I, N_J, N_at, 3) - expand_alpha_j = (alpha_j.reshape(1, 1, N_at, 1) - .expand(N_J, N_I, N_at, 3) - .transpose(0, 1)) + expand_alpha_j = alpha_j.reshape(1, 1, N_at, 1).expand(N_J, N_I, N_at, 3).transpose(0, 1) if m_i is not None and m_j is not None: - expand_mi = (m_i.reshape(1, 1, N_at, 1) - .expand(N_I, N_J, N_at, 3)) + expand_mi = m_i.reshape(1, 1, N_at, 1).expand(N_I, N_J, N_at, 3) - expand_mj = (m_j.reshape(1, 1, N_at, 1) - .expand(N_J, N_I, N_at, 3) - .transpose(0, 1)) + expand_mj = m_j.reshape(1, 1, N_at, 1).expand(N_J, N_I, N_at, 3).transpose(0, 1) - return (expand_r_i, expand_r_j, expand_p_i, - expand_p_j, expand_alpha_i, expand_alpha_j, - expand_mi, expand_mj) - else: - return (expand_r_i, expand_r_j, expand_p_i, - expand_p_j, expand_alpha_i, expand_alpha_j) + return (expand_r_i, expand_r_j, expand_p_i, expand_p_j, expand_alpha_i, expand_alpha_j, expand_mi, expand_mj) + return (expand_r_i, expand_r_j, expand_p_i, expand_p_j, expand_alpha_i, expand_alpha_j) -def get_overlaps(r_i, - r_j, - alpha_i, - alpha_j, - p_i, - p_j): +def get_overlaps(r_i, r_j, alpha_i, alpha_j, p_i, p_j): """ Args: r_i: Gaussian positions in state i. Tensor of @@ -260,34 +193,27 @@ def get_overlaps(r_i, """ - (expand_r_i, expand_r_j, expand_p_i, - expand_p_j, expand_alpha_i, expand_alpha_j) = tile_params(r_i, - r_j, - p_i, - p_j, - alpha_i, - alpha_j) + (expand_r_i, expand_r_j, expand_p_i, expand_p_j, expand_alpha_i, expand_alpha_j) = tile_params( + r_i, r_j, p_i, p_j, alpha_i, alpha_j + ) - r_max = ((expand_alpha_i * expand_r_i + expand_alpha_j * expand_r_j) - / (expand_alpha_i + expand_alpha_j)) + r_max = (expand_alpha_i * expand_r_i + expand_alpha_j * expand_r_j) / (expand_alpha_i + expand_alpha_j) # G_ij - overlap = overlap_formula(expand_r_i=expand_r_i, - expand_r_j=expand_r_j, - expand_alpha_i=expand_alpha_i, - expand_alpha_j=expand_alpha_j, - expand_p_i=expand_p_i, - expand_p_j=expand_p_j) + overlap = overlap_formula( + expand_r_i=expand_r_i, + expand_r_j=expand_r_j, + expand_alpha_i=expand_alpha_i, + expand_alpha_j=expand_alpha_j, + expand_p_i=expand_p_i, + expand_p_j=expand_p_j, + ) return overlap, r_max -def get_coupling_r(r_list, - p_list, - alpha_dic, - atom_nums, - min_overlap): +def get_coupling_r(r_list, p_list, alpha_dic, atom_nums, min_overlap): """ Get all overlaps betwene nuclear wave functions on different states, and get the positions at which the overlaps are large enough that we'll want to calculate matrix elements @@ -304,32 +230,27 @@ def get_coupling_r(r_list, for i in range(num_states): for j in range(num_states): - r_i = r_list[i] # N_I x N_at x 3 r_j = r_list[j] # N_J x N_at x 3 p_i = p_list[i] p_j = p_list[j] - alpha_i = torch.Tensor([alpha_dic[atom_num] - for atom_num in atom_nums]) + alpha_i = torch.Tensor([alpha_dic[atom_num] for atom_num in atom_nums]) alpha_j = copy.deepcopy(alpha_i) - overlap, r_max = get_overlaps(r_i=r_i, - r_j=r_j, - alpha_i=alpha_i, - alpha_j=alpha_j, - p_i=p_i, - p_j=p_j) + overlap, r_max = get_overlaps(r_i=r_i, r_j=r_j, alpha_i=alpha_i, alpha_j=alpha_j, p_i=p_i, p_j=p_j) couple_mask = abs(overlap) > min_overlap couple_idx = couple_mask.nonzero() couple_r = r_max[couple_idx[:, 0], couple_idx[:, 1]] - couple_dic[f"{i}_{j}"] = {"overlap": overlap, - "couple_idx": couple_idx, - "couple_r": couple_r, - "couple_mask": couple_mask} + couple_dic[f"{i}_{j}"] = { + "overlap": overlap, + "couple_idx": couple_idx, + "couple_r": couple_r, + "couple_mask": couple_mask, + } return couple_dic @@ -350,19 +271,12 @@ def compute_A(m, nacv, hbar=1): """ m_reshape = m.reshape(1, -1, 1) - A_ij = (-hbar ** 2 / m_reshape * nacv).sum((1, 2)) + A_ij = (-(hbar**2) / m_reshape * nacv).sum((1, 2)) return A_ij -def nonad_ham_ij(couple_r, - nacv, - overlap, - mask, - alpha_j, - r_j, - p_j, - m): +def nonad_ham_ij(couple_r, nacv, overlap, mask, alpha_j, r_j, p_j, m): """ Construct the off-diagonal elements of the Hamiltonian in the adiabatic basis. Computes f_ij = f(R_ij), the value of the matrix @@ -412,8 +326,7 @@ def nonad_ham_ij(couple_r, # we take the sum along the atomic dimensions n_ij = idx.shape[0] - nabla_ij_real = ((-alpha_j_reshape * (couple_r - mask_rj)) - .reshape(n_ij, -1).sum(-1)) + nabla_ij_real = (-alpha_j_reshape * (couple_r - mask_rj)).reshape(n_ij, -1).sum(-1) nabla_ij_im = mask_pj.reshape(n_ij, -1).sum(-1) # Convert to numpy to make it complex @@ -498,15 +411,8 @@ def nonad_ham_ij(couple_r, # return h_ij -def nuc_ke(r_j, - p_j, - alpha_j, - r_i, - p_i, - alpha_i, - mask, - m, - hbar=1): + +def nuc_ke(r_j, p_j, alpha_j, r_i, p_i, alpha_i, mask, m, hbar=1): """ Get the diagonal kinetic energy part of the Hamiltonian. Args: @@ -524,38 +430,31 @@ def nuc_ke(r_j, # **** this should actually be done analytically - (expand_r_i, expand_r_j, expand_p_i, - expand_p_j, expand_alpha_i, expand_alpha_j, - expand_mi, expand_mj) = tile_params(r_i, - r_j, - p_i, - p_j, - alpha_i, - alpha_j) + (expand_r_i, expand_r_j, expand_p_i, expand_p_j, expand_alpha_i, expand_alpha_j, expand_mi, expand_mj) = ( + tile_params(r_i, r_j, p_i, p_j, alpha_i, alpha_j) + ) - A = (-2 * expand_alpha_j - (expand_p_j) ** 2 - + 4 * 1j * expand_alpha_j * expand_p_j * expand_r_j) + A = -2 * expand_alpha_j - (expand_p_j) ** 2 + 4 * 1j * expand_alpha_j * expand_p_j * expand_r_j - B = (-4 * 1j * expand_alpha_j * expand_p_j - - 8 * (expand_alpha_j) ** 2 * expand_r_j) + B = -4 * 1j * expand_alpha_j * expand_p_j - 8 * (expand_alpha_j) ** 2 * expand_r_j - C = 4 * expand_alpha_j ** 2 + C = 4 * expand_alpha_j**2 - D = (1j * expand_p_i * expand_r_i - 1j * expand_p_j * expand_r_j - - expand_alpha_i * expand_r_i ** 2 - expand_alpha_j * expand_r_j ** 2) + D = ( + 1j * expand_p_i * expand_r_i + - 1j * expand_p_j * expand_r_j + - expand_alpha_i * expand_r_i**2 + - expand_alpha_j * expand_r_j**2 + ) - E = (1j * expand_p_j - 1j * expand_p_i + 2 * expand_r_i * expand_alpha_i - + 2 * expand_r_j * expand_alpha_j) + E = 1j * expand_p_j - 1j * expand_p_i + 2 * expand_r_i * expand_alpha_i + 2 * expand_r_j * expand_alpha_j F = expand_alpha_i + expand_alpha_j # *** are we dividing by the right mass here?? - prefactor = ((2) ** 0.5 * (expand_alpha_i * expand_alpha_j) ** 0.25 - * (-hbar ** 2) / (2 * expand_mj)) + prefactor = (2) ** 0.5 * (expand_alpha_i * expand_alpha_j) ** 0.25 * (-(hbar**2)) / (2 * expand_mj) - main_term = (1 / (4 * F ** (5/2)) - * np.exp(D + E ** 2 / (4 * F)) - * (C * E ** 2 + 2 * (C + B * E) * F + 4 * A * F ** 2)) + main_term = 1 / (4 * F ** (5 / 2)) * np.exp(D + E**2 / (4 * F)) * (C * E**2 + 2 * (C + B * E) * F + 4 * A * F**2) # dimension N_I x N_J x N_at x 3 ke_vec = prefactor * main_term @@ -567,9 +466,7 @@ def nuc_ke(r_j, return ke -def elec_e(energies, - overlap, - mask): +def elec_e(energies, overlap, mask): """ Args: energies (torch.Tensor): n_ij dimensional tensor @@ -587,18 +484,9 @@ def elec_e(energies, return h_ij -def construct_ham(r_list, - p_list, - atom_nums, - m, - couple_dic, - nbrs, - gen_nbrs, - batch_size, - device, - model, - diabat_keys, - alpha_dic): +def construct_ham( + r_list, p_list, atom_nums, m, couple_dic, nbrs, gen_nbrs, batch_size, device, model, diabat_keys, alpha_dic +): """ This needs to be fixed -- need properly H term for adiabatic, and also nuclear kinetic energy @@ -617,39 +505,37 @@ def construct_ham(r_list, h_ad (np.array): Hamiltonian in adiabatic basis """ - num_states = int(couple_dic ** 0.5) + num_states = int(couple_dic**0.5) max_basis = max([r.shape[0] for r in r_list]) # padded, as different states have different number of # trj basis functions - h_d = torch.zeros(num_states, num_states, - max_basis, max_basis).numpy() + h_d = torch.zeros(num_states, num_states, max_basis, max_basis).numpy() - h_ad = torch.zeros(num_states, num_states, - max_basis, max_basis).numpy() + h_ad = torch.zeros(num_states, num_states, max_basis, max_basis).numpy() for key, sub_dic in couple_dic.items(): - i, j = key.split("_") couple_r = sub_dic["couple_r"] - results, dataset = get_engrad(r=couple_r, - atom_nums=atom_nums, - nbrs=nbrs, - gen_nbrs=gen_nbrs, - batch_size=batch_size, - device=device, - model=model, - diabat_keys=diabat_keys) + results, dataset = get_engrad( + r=couple_r, + atom_nums=atom_nums, + nbrs=nbrs, + gen_nbrs=gen_nbrs, + batch_size=batch_size, + device=device, + model=model, + diabat_keys=diabat_keys, + ) mask = sub_dic["couple_mask"] # numpy array overlap = sub_dic["overlap"] # numpy array (complex) # h_d_ij = torch.zeros_like(overlap).numpy() - alpha_j = torch.Tensor([alpha_dic[atom_num] - for atom_num in atom_nums]) + alpha_j = torch.Tensor([alpha_dic[atom_num] for atom_num in atom_nums]) alpha_i = copy.deepcopy(alpha_j) r_j = r_list[j] @@ -659,17 +545,9 @@ def construct_ham(r_list, p_i = p_list[i] if i == j: - # nuclear kinetic energy component - h_ad_ij = nuc_ke(r_j, - p_j, - alpha_j, - r_i, - p_i, - alpha_i, - mask, - m) + h_ad_ij = nuc_ke(r_j, p_j, alpha_j, r_i, p_i, alpha_i, mask, m) h_d_ij = copy.deepcopy(h_ad_ij) @@ -678,36 +556,24 @@ def construct_ham(r_list, ad_key = f"energy_{i}" diabat_key = diabat_keys[i][j] - h_ad_ij += elec_e(results[ad_key], - overlap, - mask) + h_ad_ij += elec_e(results[ad_key], overlap, mask) - h_d_ij += elec_e(results[diabat_key], - overlap, - mask) + h_d_ij += elec_e(results[diabat_key], overlap, mask) else: - # The off-diagonal Hamiltonian in the adiabatic # basis involves the non-adiabatic coupling vector nacv = results[f"nacv_{i}{j}"] - h_ad_ij = nonad_ham_ij(couple_r=couple_r, - nacv=nacv, - overlap=overlap, - mask=mask, - alpha_j=alpha_j, - r_j=r_j, - p_j=p_j, - m=m) + h_ad_ij = nonad_ham_ij( + couple_r=couple_r, nacv=nacv, overlap=overlap, mask=mask, alpha_j=alpha_j, r_j=r_j, p_j=p_j, m=m + ) # The off-diagonal Hamiltonian in the diabatic # basis is the diabatic eletronic energy diabat_key = diabat_keys[i][j] - h_d_ij = elec_e(results[diabat_key], - overlap, - mask) + h_d_ij = elec_e(results[diabat_key], overlap, mask) N_I, N_J = overlap.shape[:2] h_d[i, j, :N_I, :N_J] = h_d_ij @@ -716,26 +582,23 @@ def construct_ham(r_list, return h_d, h_ad -def diabat_spawn_criterion(states, - results, - diabat_keys, - threshold): +def diabat_spawn_criterion(states, results, diabat_keys, threshold): """ - Args: + Args: - results_list (list[dict]): list of dictionaries. Each - dictionary corresponds to an electronic state. - It contains model predictions for the - positions of the nuclear wave packets on - that state. - Returns: - thresh_dic (dict): dictionary with keys for each state, - the value of which is a subdictionary. The subdictionary - contains keys for each other state. Say we're looking at - main key i and subdictionary key j. Then thresh_dic[i][j] - is a boolean tensor of dimension N_I. For each Gaussian - basis function in state i, it tells you whether you should - replicate it on state j. + results_list (list[dict]): list of dictionaries. Each + dictionary corresponds to an electronic state. + It contains model predictions for the + positions of the nuclear wave packets on + that state. + Returns: + thresh_dic (dict): dictionary with keys for each state, + the value of which is a subdictionary. The subdictionary + contains keys for each other state. Say we're looking at + main key i and subdictionary key j. Then thresh_dic[i][j] + is a boolean tensor of dimension N_I. For each Gaussian + basis function in state i, it tells you whether you should + replicate it on state j. """ @@ -768,28 +631,25 @@ def diabat_spawn_criterion(states, return thresh_dic -def adiabat_spawn_criterion(states, - results, - v_list, - threshold): +def adiabat_spawn_criterion(states, results, v_list, threshold): """ - Args: + Args: - results_list (list[dict]): list of dictionaries. Each - dictionary corresponds to an electronic state. - It contains model predictions for the - positions of the nuclear wave packets on - that state. - v_list (list[torch.Tensor]): list of velocities for wave - packets on each state. - Returns: - thresh_dic (dict): dictionary with keys for each state, - the value of which is a subdictionary. The subdictionary - contains keys for each other state. Say we're looking at - main key i and subdictionary key j. Then thresh_dic[i][j] - is a boolean tensor of dimension N_I. For each Gaussian - basis function in state i, it tells you whether you should - replicate it on state j. + results_list (list[dict]): list of dictionaries. Each + dictionary corresponds to an electronic state. + It contains model predictions for the + positions of the nuclear wave packets on + that state. + v_list (list[torch.Tensor]): list of velocities for wave + packets on each state. + Returns: + thresh_dic (dict): dictionary with keys for each state, + the value of which is a subdictionary. The subdictionary + contains keys for each other state. Say we're looking at + main key i and subdictionary key j. Then thresh_dic[i][j] + is a boolean tensor of dimension N_I. For each Gaussian + basis function in state i, it tells you whether you should + replicate it on state j. """ @@ -813,17 +673,12 @@ def adiabat_spawn_criterion(states, # Effective coupling h_eff = (vel * nacv).sum() - thresh_dic[i][j] = {"criterion": (abs(h_eff) > threshold).any(), - "val": h_eff.norm()} + thresh_dic[i][j] = {"criterion": (abs(h_eff) > threshold).any(), "val": h_eff.norm()} return thresh_dic -def get_vals(diabatic, - surf, - diabat_keys, - results): - +def get_vals(diabatic, surf, diabat_keys, results): i = surf if diabatic: en_key = diabat_keys[i][i] @@ -837,29 +692,27 @@ def get_vals(diabatic, return en, en_grad -def nuc_classical(r, - gamma, - m, - atom_num, - p, - nbrs, - gen_nbrs, - batch_size, - device, - model, - states, - diabat_keys, - diabatic, - dt, - surf, - old_results): - +def nuc_classical( + r, + gamma, + m, + atom_num, + p, + nbrs, + gen_nbrs, + batch_size, + device, + model, + states, + diabat_keys, + diabatic, + dt, + surf, + old_results, +): # classical propagation of nuclei - old_en, old_grad = get_vals(diabatic=diabatic, - surf=surf, - diabat_keys=diabat_keys, - results=old_results) + old_en, old_grad = get_vals(diabatic=diabatic, surf=surf, diabat_keys=diabat_keys, results=old_results) # note: we need a p + 1/2 dt and a p + 3/2 dt # The p that we keep track of will always be 1/2 dt @@ -874,19 +727,9 @@ def nuc_classical(r, # m has dim N_at r_new = r + 1 / m.reshape(1, -1, 1) * dt * p - results, dataset = get_engrad(r_new, - atom_num, - nbrs, - gen_nbrs, - batch_size, - device, - model, - diabat_keys) + results, dataset = get_engrad(r_new, atom_num, nbrs, gen_nbrs, batch_size, device, model, diabat_keys) - new_en, new_grad = get_vals(diabatic=diabatic, - surf=surf, - diabat_keys=diabat_keys, - results=results) + new_en, new_grad = get_vals(diabatic=diabatic, surf=surf, diabat_keys=diabat_keys, results=results) p_new = p - new_grad * dt @@ -896,24 +739,25 @@ def nuc_classical(r, return r_new, p_new, gamma_new, results -def find_spawn(r, - gamma, - m, - atom_num, - p, - nbrs, - gen_nbrs, - batch_size, - device, - model, - diabat_keys, - diabatic, - dt, - new_surf, - old_results, - old_surf, - threshold): - +def find_spawn( + r, + gamma, + m, + atom_num, + p, + nbrs, + gen_nbrs, + batch_size, + device, + model, + diabat_keys, + diabatic, + dt, + new_surf, + old_results, + old_surf, + threshold, +): # classical propagation too_big = False @@ -929,38 +773,34 @@ def find_spawn(r, while too_big: # this is right: keep propagating along the old surface - r_new, p_new, gamma_new, new_results = nuc_classical(r_new, - gamma_new, - m, - atom_num, - p_new, - nbrs, - gen_nbrs, - batch_size, - device, - model, - diabat_keys, - diabatic, - dt, - old_surf, - old_results) + r_new, p_new, gamma_new, new_results = nuc_classical( + r_new, + gamma_new, + m, + atom_num, + p_new, + nbrs, + gen_nbrs, + batch_size, + device, + model, + diabat_keys, + diabatic, + dt, + old_surf, + old_results, + ) states = [old_surf, new_surf] if diabatic: - spawn_dic = diabat_spawn_criterion(states, - new_results, - diabat_keys, - threshold) + spawn_dic = diabat_spawn_criterion(states, new_results, diabat_keys, threshold) else: - spawn_dic = adiabat_spawn_criterion(states, - new_results, - v_list, - threshold) + spawn_dic = adiabat_spawn_criterion(states, new_results, v_list, threshold) - coupling = spawn_dic[old_surf][new_surf]['val'] + coupling = spawn_dic[old_surf][new_surf]["val"] couplings.append(coupling) - too_big = spawn_dic[old_surf][new_surf]['criterion'] + too_big = spawn_dic[old_surf][new_surf]["criterion"] r_list.append(r_new) p_list.append(p_new) @@ -980,81 +820,74 @@ def find_spawn(r, return spawn_r, spawn_p, spawn_idx, old_results_list -def rescale(p_new, - m, - diabatic, - results, - old_surf, - new_surf): - +def rescale(p_new, m, diabatic, results, old_surf, new_surf): if diabatic: raise NotImplementedError - else: - # p has dimension N_J x N_at x 3 - # nacv has dimension N_J x N_at x 3 + # p has dimension N_J x N_at x 3 + # nacv has dimension N_J x N_at x 3 - nacv = results[f'nacv_{old_surf}{new_surf}'] - norm = (nacv ** 2).sum(-1) ** 0.5 - nacv_unit = nacv / norm + nacv = results[f"nacv_{old_surf}{new_surf}"] + norm = (nacv**2).sum(-1) ** 0.5 + nacv_unit = nacv / norm - # dot product - projection = (nacv_unit * p_new).sum(-1) + # dot product + projection = (nacv_unit * p_new).sum(-1) - # p_parallel - N_J, N_at = projection.shape - p_par = (projection.reshape(N_J, N_at, 1) - * nacv_unit) + # p_parallel + N_J, N_at = projection.shape + p_par = projection.reshape(N_J, N_at, 1) * nacv_unit - # p perpendicular - p_perp = p_new - p_par + # p perpendicular + p_perp = p_new - p_par - # get energies before and after hop - # m has shape N_at - # is this right? + # get energies before and after hop + # m has shape N_at + # is this right? - t_old = (p_new ** 2 / (2 * m.reshape(1, -1, 1))).sum() - t_old_perp = (p_perp ** 2 / (2 * m.reshape(1, -1, 1))).sum() - t_old_par = (p_par ** 2 / (2 * m.reshape(1, -1, 1))).sum() - v_old = results[f'energy_{old_surf}'] - v_new = results[f'energy_{new_surf}'] + t_old = (p_new**2 / (2 * m.reshape(1, -1, 1))).sum() + t_old_perp = (p_perp**2 / (2 * m.reshape(1, -1, 1))).sum() + t_old_par = (p_par**2 / (2 * m.reshape(1, -1, 1))).sum() + v_old = results[f"energy_{old_surf}"] + v_new = results[f"energy_{new_surf}"] - # re-scale p_parallel - # not 100% sure if this is right + # re-scale p_parallel + # not 100% sure if this is right - scale_sq = (t_old + v_old - (t_old_perp + v_new)) / t_old_par + scale_sq = (t_old + v_old - (t_old_perp + v_new)) / t_old_par - if scale_sq < 0: - # kinetic energy can't compensate the change in - # potential energy - return None + if scale_sq < 0: + # kinetic energy can't compensate the change in + # potential energy + return None - scale = scale_sq ** 0.5 + scale = scale_sq**0.5 - new_p = p_par * scale + p_perp + new_p = p_par * scale + p_perp return new_p -def backward_prop(spawn_r, - spawn_p, - spawn_gamma, - spawn_idx, - m, - atom_num, - nbrs, - gen_nbrs, - batch_size, - device, - model, - diabat_keys, - diabatic, - dt, - new_surf, - old_results_list, - old_surf, - threshold, - dr): - +def backward_prop( + spawn_r, + spawn_p, + spawn_gamma, + spawn_idx, + m, + atom_num, + nbrs, + gen_nbrs, + batch_size, + device, + model, + diabat_keys, + diabatic, + dt, + new_surf, + old_results_list, + old_surf, + threshold, + dr, +): num_steps = spawn_idx old_results = old_results_list[spawn_idx] @@ -1062,48 +895,46 @@ def backward_prop(spawn_r, p_new = copy.deepcopy(spawn_p) # re-scale the momentum along the nacv # to ensure energy conservation - p_new = rescale(p_new=p_new, - m=m, - diabatic=diabatic, - results=old_results, - old_surf=old_surf, - new_surf=new_surf) + p_new = rescale(p_new=p_new, m=m, diabatic=diabatic, results=old_results, old_surf=old_surf, new_surf=new_surf) if p_new is None: - - grad = old_results[f'energy_{new_surf}_grad'] + grad = old_results[f"energy_{new_surf}_grad"] r_new = spawn_r - grad * dr - new_results = get_engrad(r_new=r_new, - atom_num=atom_num, - nbrs=nbrs, - gen_nbrs=gen_nbrs, - batch_size=batch_size, - device=device, - model=model, - diabat_keys=diabat_keys) + new_results = get_engrad( + r_new=r_new, + atom_num=atom_num, + nbrs=nbrs, + gen_nbrs=gen_nbrs, + batch_size=batch_size, + device=device, + model=model, + diabat_keys=diabat_keys, + ) old_results_list[spawn_idx] = new_results - return backward_prop(r_new, - spawn_p, - spawn_gamma, - spawn_idx, - m, - atom_num, - nbrs, - gen_nbrs, - batch_size, - device, - model, - diabat_keys, - diabatic, - dt, - new_surf, - old_results_list, - old_surf, - threshold, - dr) + return backward_prop( + r_new, + spawn_p, + spawn_gamma, + spawn_idx, + m, + atom_num, + nbrs, + gen_nbrs, + batch_size, + device, + model, + diabat_keys, + diabatic, + dt, + new_surf, + old_results_list, + old_surf, + threshold, + dr, + ) gamma_new = copy.deepcopy(spawn_gamma) @@ -1114,21 +945,23 @@ def backward_prop(spawn_r, for _ in range(num_steps): # I don't think just replacing it with -dt is right - don't # we have to replace the forces with their negative values? - r_new, p_new, gamma_new, new_results = nuc_classical(r_new, - gamma_new, - m, - atom_num, - p_new, - nbrs, - gen_nbrs, - batch_size, - device, - model, - diabat_keys, - diabatic, - (-dt), - new_surf, - old_results) + r_new, p_new, gamma_new, new_results = nuc_classical( + r_new, + gamma_new, + m, + atom_num, + p_new, + nbrs, + gen_nbrs, + batch_size, + device, + model, + diabat_keys, + diabatic, + (-dt), + new_surf, + old_results, + ) old_results = new_results return r_new, p_new diff --git a/nff/md/ci/opt.py b/nff/md/ci/opt.py index cc7ee352..87bfe7e3 100644 --- a/nff/md/ci/opt.py +++ b/nff/md/ci/opt.py @@ -1,25 +1,12 @@ -import sys - -sys.path.append("/home/saxelrod/htvs-ax/htvs") - -import os - -import django - -os.environ["DJANGO_SETTINGS_MODULE"] = "djangochem.settings.orgel" -django.setup() - -# Shell Plus Model Imports - - import copy import json +import os import pdb -import random +import sys +import django import numpy as np -from ase import Atoms, optimize, units -from ase.calculators.calculator import Calculator +from ase import optimize, units from ase.io.trajectory import Trajectory as AseTrajectory from ase.md.verlet import VelocityVerlet from django.contrib.auth.models import Group @@ -28,40 +15,30 @@ from jobs.models import Job, JobConfig from neuralnet.utils import vib from pgmols.models import ( - AtomBasis, Geom, - Hessian, - Jacobian, - MDFrame, - Mechanism, Method, - Mol, - MolGroupObjectPermission, - MolSet, - MolUserObjectPermission, - PathImage, - ProductLink, - ReactantLink, - Reaction, - ReactionPath, - ReactionType, - SinglePoint, - Species, - Stoichiometry, - Trajectory, ) from rdkit import Chem from torch.utils.data import DataLoader from nff.data import Dataset, collate_dicts - -# from nff.nn.models import PostProcessModel from nff.io.ase_ax import AtomsBatch, NeuralFF from nff.nn.tensorgrad import get_schnet_hessians from nff.train import load_model from nff.utils import constants as const from nff.utils.cuda import batch_to +sys.path.append("/home/saxelrod/htvs-ax/htvs") + + +os.environ["DJANGO_SETTINGS_MODULE"] = "djangochem.settings.orgel" +django.setup() + +# Shell Plus Model Imports + + +# from nff.nn.models import PostProcessModel + KT = 0.000944853 FS_TO_AU = 41.341374575751 AU_TO_ANGS = 0.529177 @@ -230,8 +207,8 @@ def opt_ci(model, nxyz, penalty=0.5, lower_idx=0, upper_idx=1, method="BFGS", st atoms.set_calculator(init_calc) ref_energy = atoms.get_potential_energy().item() * const.EV_TO_KCAL_MOL - lower_key = "energy_{}".format(lower_idx) - upper_key = "energy_{}".format(upper_idx) + lower_key = f"energy_{lower_idx}" + upper_key = f"energy_{upper_idx}" set_ci_calc( atoms=atoms, model=model, lower_key=lower_key, upper_key=upper_key, ref_energy=ref_energy, penalty=penalty @@ -335,7 +312,7 @@ def opt_and_sample_ci( model=model, nxyz=nxyz, penalty=penalty, lower_idx=lower_idx, upper_idx=upper_idx, method=method, steps=steps ) - energy_keys = ["energy_{}".format(lower_idx), "energy_{}".format(upper_idx)] + energy_keys = [f"energy_{lower_idx}", f"energy_{upper_idx}"] lower_atoms, upper_atoms = sample_ci( ci_atoms=ci_atoms, model=model, cutoff=cutoff, energy_keys=energy_keys, device=device, kt=KT @@ -348,7 +325,6 @@ def test(): # weightpath = "/home/saxelrod/engaging/models/971" weightpath = "/home/saxelrod/engaging/models/953" nxyz = BASE_NXYZ - penalty = 0.5 # atoms = opt_ci(weightpath=weightpath, nxyz=nxyz, # penalty=penalty) @@ -368,8 +344,8 @@ def test(): kt=KT, ) - lower_calc = NeuralFF(model=model, output_keys=["energy_{}".format(lower_idx)]) - upper_calc = NeuralFF(model=model, output_keys=["energy_{}".format(upper_idx)]) + lower_calc = NeuralFF(model=model, output_keys=[f"energy_{lower_idx}"]) + upper_calc = NeuralFF(model=model, output_keys=[f"energy_{upper_idx}"]) lower_atoms.set_calculator(lower_calc) upper_atoms.set_calculator(upper_calc) @@ -382,21 +358,21 @@ def test(): def run_ci_md(model, lower_atoms, upper_atoms, lower_idx, upper_idx, base_name="test", dt=0.5, tmax=500): - lower_calc = NeuralFF(model=model, output_keys=["energy_{}".format(lower_idx)]) - upper_calc = NeuralFF(model=model, output_keys=["energy_{}".format(upper_idx)]) + lower_calc = NeuralFF(model=model, output_keys=[f"energy_{lower_idx}"]) + upper_calc = NeuralFF(model=model, output_keys=[f"energy_{upper_idx}"]) lower_atoms.set_calculator(lower_calc) upper_atoms.set_calculator(upper_calc) - lower_log = "{}_lower.log".format(base_name) - lower_trj_name = "{}_lower.traj".format(base_name) + lower_log = f"{base_name}_lower.log" + lower_trj_name = f"{base_name}_lower.traj" num_steps = int(tmax / dt) lower_integrator = VelocityVerlet(lower_atoms, dt=dt * units.fs, logfile=lower_log, trajectory=lower_trj_name) lower_integrator.run(num_steps) - upper_log = "{}_upper.log".format(base_name) - upper_trj_name = "{}_upper.traj".format(base_name) + upper_log = f"{base_name}_upper.log" + upper_trj_name = f"{base_name}_upper.traj" upper_integrator = VelocityVerlet(upper_atoms, dt=dt * units.fs, logfile=upper_log, trajectory=upper_trj_name) upper_integrator.run(num_steps) @@ -549,8 +525,8 @@ def to_db( md_job.details = md_details md_job.save() - lower_key = "energy_{}".format(lower_idx) - upper_key = "energy_{}".format(upper_idx) + lower_key = f"energy_{lower_idx}" + upper_key = f"energy_{upper_idx}" best_atoms = [] # pdb.set_trace() @@ -583,7 +559,7 @@ def to_db( for atoms in best_atoms[:num_samples]: nxyz = AtomsBatch(atoms).get_nxyz() coords = make_coords(nxyz) - new_geom = make_geom(method=md_method, job=md_job, coords=coords, parentgeom=ci_geom) + make_geom(method=md_method, job=md_job, coords=coords, parentgeom=ci_geom) @pdb_wrap @@ -599,7 +575,7 @@ def make_plots(): smiles = "c1ccc(/N=N\\c2ccccc2)cc1" group = Group.objects.get(name="switches") parentgeom = Geom.objects.filter(species__smiles=smiles, species__group=group, converged=True).first() - trj_name = "{}_upper.traj".format(parentgeom.id) + trj_name = f"{parentgeom.id}_upper.traj" print(trj_name) return diff --git a/nff/md/colvars.py b/nff/md/colvars.py index c48fb8e3..a0ea768e 100644 --- a/nff/md/colvars.py +++ b/nff/md/colvars.py @@ -1,128 +1,128 @@ -from typing import Union, Tuple -import itertools as itertools -import os +"""This file contains a class that helps define collective variables +for running biased MD simulations with NFF. +""" + +from __future__ import annotations + +import itertools +from itertools import repeat +from typing import TYPE_CHECKING + import numpy as np import torch +from rdkit import Chem +from torch import nn +from torch.nn import ModuleDict -from ase import Atoms -from nff.io.ase import AtomsBatch -from nff.utils.scatter import compute_grad -from nff.train import load_model, evaluate +from nff.train import load_model from nff.utils.cuda import batch_to +from nff.utils.scatter import compute_grad + +if TYPE_CHECKING: + from ase import Atoms + + from nff.io.ase import AtomsBatch class ColVar(torch.nn.Module): """collective variable class - - computes cv and its Cartesian gradient + + computes cv and its Cartesian gradient """ - - implemented_cvs = ['distance', 'angle', 'dihedral', - 'coordination_number', 'coordination', - 'minimal_distance', - 'projecting_centroidvec', - 'projecting_veconplane', - 'projecting_veconplanenormal', - 'projection_channelnormal', - 'Sp', 'Sd', - 'adjecencey_matrix', - 'energy_gap' - ] - + + implemented_cvs = [ + "distance", + "angle", + "dihedral", + "coordination_number", + "coordination", + "minimal_distance", + "projecting_centroidvec", + "projecting_veconplane", + "projecting_veconplanenormal", + "projection_channelnormal", + "Sp", + "Sd", + "adjecencey_matrix", + "energy_gap", + ] + def __init__(self, info_dict: dict): - """initialization of many class variables to avoid recurrent assignment + """Initialization of many class variables to avoid recurrent assignment with every forward call + Args: info_dict (dict): dictionary that contains all the definitions of the CV, the common key is name, which defines the CV function all other keys are specific to each CV """ - super(ColVar, self).__init__() - self.info_dict=info_dict - - if 'name' not in info_dict.keys(): - raise TypeError("CV definition is missing the key \"name\"!") - - if self.info_dict['name'] not in self.implemented_cvs: + super().__init__() + self.info_dict = info_dict + + if "name" not in info_dict: + raise TypeError('CV definition is missing the key "name"!') + + if self.info_dict["name"] not in self.implemented_cvs: raise NotImplementedError(f"The CV {self.info_dict['name']} is not implemented!") - - if self.info_dict['name'] == 'Sp': - self.Oacid = torch.tensor(self.info_dict['x']) - self.Owater = torch.tensor(self.info_dict['y']) - self.H = torch.tensor(self.info_dict['z']) - self.Box = torch.tensor(self.info_dict.get('box',None)) - self.O = torch.cat((Oacid,Owater)) - self.do = self.info_dict['dcv1'] - self.d = self.info_dict['dcv2'] - self.ro = self.info_dict['acidhyd'] - self.r1 = self.info_dict['waterhyd'] - - elif self.info_dict['name'] == 'Sd': - self.Oacid = torch.tensor(self.info_dict['x']) - self.Owater = torch.tensor(self.info_dict['y']) - self.H = torch.tensor(self.info_dict['z']) - self.Box = torch.tensor(self.info_dict.get('box',None)) - self.O = torch.cat((Oacid,Owater)) - self.do = self.info_dict['dcv1'] - self.d = self.info_dict['dcv2'] - self.ro = self.info_dict['acidhyd'] - self.r1 = self.info_dict['waterhyd'] - - elif self.info_dict['name'] == 'adjecencey_matrix': - self.model = self.info_dict['model'] - self.device = self.info_dict['device'] - self.bond_length = self.info_dict['bond_length'] - self.cell = self.info_dict.get('box',None) - self.atom_numbers = torch.tensor(self.info_dict['atom_numbers']) - self.target = self.info_dict['target'] - self.model = self.model.to(self.device) + + if self.info_dict["name"] in ["Sp", "Sd"]: + self.Oacid = torch.tensor(self.info_dict["x"]) + self.Owater = torch.tensor(self.info_dict["y"]) + self.H = torch.tensor(self.info_dict["z"]) + self.Box = torch.tensor(self.info_dict.get("box", None)) + self.O = torch.cat((self.Oacid, self.Owater)) + self.do = self.info_dict["dcv1"] + self.d = self.info_dict["dcv2"] + self.ro = self.info_dict["acidhyd"] + self.r1 = self.info_dict["waterhyd"] + + elif self.info_dict["name"] == "adjecencey_matrix": + self.model = self.info_dict["model"] + self.device = self.info_dict["device"] + self.bond_length = self.info_dict["bond_length"] + self.cell = self.info_dict.get("box", None) + self.atom_numbers = torch.tensor(self.info_dict["atom_numbers"]) + self.target = self.info_dict["target"] + self.model = self.model.to(self.device) self.model.eval() - - elif self.info_dict['name'] == 'projecting_centroidvec': - self.vector_inds = self.info_dict['vector'] - self.mol_inds = torch.LongTensor(self.info_dict['indices']) - self.reference_inds = self.info_dict['reference'] - - elif self.info_dict['name'] == 'projecting_veconplane': - self.mol_inds = torch.LongTensor(self.info_dict['mol_inds']) - self.ring_inds = torch.LongTensor(self.info_dict['ring_inds']) - - elif self.info_dict['name'] == 'projecting_veconplanenormal': - self.mol_inds = torch.LongTensor(self.info_dict['mol_inds']) - self.ring_inds = torch.LongTensor(self.info_dict['ring_inds']) - - elif self.info_dict['name'] == 'projection_channelnormal': - self.mol_inds = torch.LongTensor(self.info_dict['mol_inds']) - self.g1_inds = torch.LongTensor(self.info_dict['g1_inds']) - self.g2_inds = torch.LongTensor(self.info_dict['g2_inds']) - - elif self.info_dict['name'] == 'energy_gap': - self.device = self.info_dict['device'] - path = self.info_dict['path'] - model_type = self.info_dict['model_type'] - self.model = load_model(path, - model_type=model_type, - device=self.device) - self.model = self.model.to(self.device) + + elif self.info_dict["name"] == "projecting_centroidvec": + self.vector_inds = self.info_dict["vector"] + self.mol_inds = torch.LongTensor(self.info_dict["indices"]) + self.reference_inds = self.info_dict["reference"] + + elif self.info_dict["name"] in ["projecting_veconplane", "projecting_veconplanenormal"]: + self.mol_inds = torch.LongTensor(self.info_dict["mol_inds"]) + self.ring_inds = torch.LongTensor(self.info_dict["ring_inds"]) + + elif self.info_dict["name"] == "projection_channelnormal": + self.mol_inds = torch.LongTensor(self.info_dict["mol_inds"]) + self.g1_inds = torch.LongTensor(self.info_dict["g1_inds"]) + self.g2_inds = torch.LongTensor(self.info_dict["g2_inds"]) + + elif self.info_dict["name"] == "energy_gap": + self.device = self.info_dict["device"] + path = self.info_dict["path"] + model_type = self.info_dict["model_type"] + self.model = load_model(path, model_type=model_type, device=self.device) + self.model = self.model.to(self.device) self.model.eval() - - - - def _get_com(self, indices: Union[int, list]) -> torch.tensor: - """get center of mass (com) of group of atoms + def _get_com(self, indices: int | list[int]) -> torch.Tensor: + """Get center of mass (com) of group of atoms + Args: indices (Union[int, list]): atom index or list of atom indices Returns: com (torch.tensor): Center of Mass """ masses = torch.from_numpy(self.atoms.get_masses()) - + if hasattr(indices, "__len__"): # compute center of mass for group of atoms center = torch.matmul(self.xyz[indices].T, masses[indices]) - m_tot = masses[indices].sum() - com = center / m_tot + m_tot = masses[indices].sum() + com = center / m_tot else: # only one atom @@ -130,20 +130,20 @@ def _get_com(self, indices: Union[int, list]) -> torch.tensor: com = self.xyz[atom] return com - - def distance(self, - index_list: list[Union[int, list]]) -> torch.tensor: - """distance between two mass centers in range(0, inf) + + def distance(self, index_list: list[int | list]) -> torch.Tensor: + """Distance between two mass centers in range(0, inf) + Args: - distance beteen atoms: [ind0, ind1] - distance between mass centers: [[ind00, ind01, ...], [ind10, ind11, ...]] + index_list (list): can be the distance beteen atoms ([ind0, ind1]) or + distance between mass centers (a list of lists) such + as [[ind00, ind01, ...], [ind10, ind11, ...]] + Returns: cv (torch.tensor): computed distance """ if len(index_list) != 2: - raise ValueError( - "CV ERROR: Invalid number of centers in definition of distance!" - ) + raise ValueError("CV ERROR: Invalid number of centers in definition of distance!") p1 = self._get_com(index_list[0]) p2 = self._get_com(index_list[1]) @@ -153,21 +153,20 @@ def distance(self, cv = torch.linalg.norm(r12) return cv - - def angle(self, - index_list: list[Union[int, list]]) -> torch.tensor: - """get angle between three mass centers in range(-pi,pi) + + def angle(self, index_list: list[int | list]) -> torch.Tensor: + """Get angle between three mass centers in range(-pi,pi) + Args: - index_list + index_list (list): can be the angle between two atoms: [ind0, ind1, ind3] angle between centers of mass: [[ind00, ind01, ...], [ind10, ind11, ...], [ind20, ind21, ...]] + Returns: cv (torch.tensor): computed angle """ if len(index_list) != 3: - raise ValueError( - "CV ERROR: Invalid number of centers in definition of angle!" - ) + raise ValueError("CV ERROR: Invalid number of centers in definition of angle!") p1 = self._get_com(index_list[0]) p2 = self._get_com(index_list[1]) @@ -186,24 +185,23 @@ def angle(self, cv = torch.arccos(torch.dot(-q12_u, q23_u)) return cv - - def dihedral(self, - index_list: list[Union[int, list]]) -> torch.tensor: - """torsion angle between four mass centers in range(-pi,pi) - Params: - self.info_dict['index_list'] + + def dihedral(self, index_list: list[int | list]) -> torch.Tensor: + """Torsion angle between four mass centers in range(-pi,pi) + + Args: + index_list (list): can be the dihedral between atoms: [ind0, ind1, ind2, ind3] dihedral between center of mass: [[ind00, ind01, ...], [ind10, ind11, ...], [ind20, ind21, ...], [ind30, ind 31, ...]] + Returns: cv (float): computed torsional angle """ if len(index_list) != 4: - raise ValueError( - "CV ERROR: Invalid number of centers in definition of dihedral!" - ) + raise ValueError("CV ERROR: Invalid number of centers in definition of dihedral!") p1 = self._get_com(index_list[0]) p2 = self._get_com(index_list[1]) @@ -221,131 +219,119 @@ def dihedral(self, n2 = q34 - torch.dot(q34, q23_u) * q23_u cv = torch.atan2(torch.dot(torch.cross(q23_u, n1), n2), torch.dot(n1, n2)) - + return cv - - def coordination_number(self, - index_list: list[int], - switch_distance: float) -> torch.tensor: - """coordination number between two atoms in range(0, 1) + + def coordination_number(self, index_list: list[int], switch_distance: float) -> torch.Tensor: + """Coordination number between two atoms in range(0, 1) + Args: - distance between atoms: [ind00, ind01] - switch_distance: value at which the switching function is 0.5 + index_list (list): the indices of the atoms defining the coordination number as a + list of ints: [ind00, ind01] + switch_distance: value at which the switching function is 0.5 + Returns: cv (torch.tensor): computed distance """ if len(index_list) != 2: - raise ValueError( - "CV ERROR: Invalid number of atom in definition of coordination_number!" - ) + raise ValueError("CV ERROR: Invalid number of atom in definition of coordination_number!") scaled_distance = self.distance(index_list) / switch_distance - cv = (1. - scaled_distance.pow(6)) / ((1. - scaled_distance.pow(12))) + cv = (1.0 - scaled_distance.pow(6)) / (1.0 - scaled_distance.pow(12)) return cv - - def coordination(self, - index_list: list[list[int]], - switch_distance: float) -> torch.tensor: - """sum of coordination numbers between two sets of atoms in range(0, 1) + + def coordination(self, index_list: list[list[int]], switch_distance: float) -> torch.Tensor: + """Sum of coordination numbers between two sets of atoms in range(0, 1) + Args: - distance between atoms: [[ind00, ind01, ...], [ind10, ind11, ...]] - switch_distance: value at which the switching function is 0.5 + index_list (list): a list of lists of atom indices such as + [[ind00, ind01, ...], [ind10, ind11, ...]] + switch_distance: value at which the switching function is 0.5 + Returns: cv (torch.tensor): computed distance """ if len(index_list) != 2: - raise ValueError( - "CV ERROR: Invalid number of atom lists in definition of coordination_number!" - ) + raise ValueError("CV ERROR: Invalid number of atom lists in definition of coordination_number!") cv = torch.tensor(0.0) - + for idx1, idx2 in itertools.product(index_list[0], index_list[1]): cv = cv + self.coordination_number([idx1, idx2], switch_distance) - + return cv - - def minimal_distance(self, - index_list: list[list[int]]) -> torch.tensor: - """minimal distance between two sets of atoms + + def minimal_distance(self, index_list: list[list[int]]) -> torch.Tensor: + """Minimal distance between two sets of atoms + Args: - distance between atoms: [[ind00, ind01, ...], [ind10, ind11, ...]] + index_list (list): used to determine the distance between atoms, in the format + [[ind00, ind01, ...], [ind10, ind11, ...]] + Returns: cv (torch.tensor): computed distance """ if len(index_list) != 2: - raise ValueError( - "CV ERROR: Invalid number of atom lists in definition of minimal_distance!" - ) + raise ValueError("CV ERROR: Invalid number of atom lists in definition of minimal_distance!") distances = torch.zeros(len(index_list[0]) * len(index_list[1])) - + for ii, (idx1, idx2) in enumerate(itertools.product(index_list[0], index_list[1])): distances[ii] = self.distance([idx1, idx2]) - + return distances.min() - - def projecting_centroidvec(self): - """ - Projection of a position vector onto a reference vector - Atomic indices are used to determine the coordiantes of the vectors. - Params - ------ - vector: list of int - List of the indices of atoms that define the vector on which the position vector is projected - indices: list if int - List of indices of the mol/fragment - reference: list of int - List of atomic indices that are used as reference for the position vector + + def projecting_centroidvec(self) -> torch.Tensor: + """Projection of a position vector onto a reference vector + Atomic indices are used to determine the coordiantes of the vectors. """ - vector_pos = self.xyz[self.vector_inds] - vector = vector_pos[1] - vector_pos[0] - vector = vector / torch.linalg.norm(vector) - mol_pos = self.xyz[self.mol_inds] + vector_pos = self.xyz[self.vector_inds] + vector = vector_pos[1] - vector_pos[0] + vector = vector / torch.linalg.norm(vector) + mol_pos = self.xyz[self.mol_inds] reference_pos = self.xyz[self.reference_inds] - mol_centroid = mol_pos.mean(axis=0) # mol center - - reference_centroid = reference_pos.mean(axis=0) # centroid of the whole structure - + mol_centroid = mol_pos.mean(axis=0) # mol center + + reference_centroid = reference_pos.mean(axis=0) # centroid of the whole structure + # position vector with respect to the structure centroid - rel_mol_pos = mol_centroid - reference_centroid - + rel_mol_pos = mol_centroid - reference_centroid + # projection - cv = torch.dot(rel_mol_pos, vector) - return cv - - def projecting_veconplane(self): - """ - Projection of a position vector onto a the average plane + return torch.dot(rel_mol_pos, vector) + + def projecting_veconplane(self) -> torch.Tensor: + """Projection of a position vector onto a the average plane of an arbitrary ring defined in the structure Atomic indices are used to determine the coordiantes of the vectors. - Params - ------ - mol_inds: list of int - List of indices of the mol/fragment tracked by the CV - ring_inds: list of int - List of atomic indices of the ring for which the average plane is calculated. + + Args: + mol_inds (list[int]): List of indices of the mol/fragment tracked by the CV + ring_inds (list[int]): List of atomic indices of the ring for which the average plane is calculated. + + Returns: + cv (torch.tensor): tensor for the computed CV """ - mol_coors = self.xyz[self.mol_inds] + mol_coors = self.xyz[self.mol_inds] ring_coors = self.xyz[self.ring_inds] - - mol_cm = mol_coors.mean(axis=0) # mol center - ring_cm = ring_coors.mean(axis=0) # ring center + + mol_cm = mol_coors.mean(axis=0) # mol center + ring_cm = ring_coors.mean(axis=0) # ring center # ring atoms to center ring_coors = ring_coors - ring_cm r1 = torch.zeros(3, device=ring_coors.device) - N = len(ring_coors) # number of atoms in the ring + N = len(ring_coors) # number of atoms in the ring for i, rl0 in enumerate(ring_coors): r1 = r1 + rl0 * np.sin(2 * np.pi * i / N) - r1 = r1/N + r1 = r1 / N r2 = torch.zeros(3, device=ring_coors.device) for i, rl0 in enumerate(ring_coors): r2 = r2 + rl0 * np.cos(2 * np.pi * i / N) - r2 = r2/N + r2 = r2 / N plane_vec = torch.cross(r1, r2) plane_vec = plane_vec / torch.linalg.norm(plane_vec) @@ -353,39 +339,38 @@ def projecting_veconplane(self): cv = torch.dot(pos_vec, plane_vec) return cv - - def projecting_veconplanenormal(self): - """ - Projection of a position vector onto the average plane + + def projecting_veconplanenormal(self) -> torch.Tensor: + """Projection of a position vector onto the average plane of an arbitrary ring defined in the structure Atomic indices are used to determine the coordiantes of the vectors. - Params - ------ - mol_inds: list of int - List of indices of the mol/fragment tracked by the CV - ring_inds: list of int - List of atomic indices of the ring for which the average plane is calculated. + + Args: + mol_inds (list[int]): List of indices of the mol/fragment tracked by the CV + ring_inds (list[int]): List of atomic indices of the ring for which the average plane is calculated. + + Returns: + cv (torch.tensor): tensor for the computed CV """ - - mol_coors = self.xyz[self.mol_inds] + mol_coors = self.xyz[self.mol_inds] ring_coors = self.xyz[self.ring_inds] - - mol_cm = mol_coors.mean(axis=0) # mol center -# mol_cm = self._get_com(self.mol_inds) - ring_cm = ring_coors.mean(axis=0) # ring center + + mol_cm = mol_coors.mean(axis=0) # mol center + # mol_cm = self._get_com(self.mol_inds) + ring_cm = ring_coors.mean(axis=0) # ring center # ring atoms to center, center of geometry! ring_coors = ring_coors - ring_cm r1 = torch.zeros(3, device=ring_coors.device) - N = len(ring_coors) # number of atoms in the ring + N = len(ring_coors) # number of atoms in the ring for i, rl0 in enumerate(ring_coors): r1 = r1 + rl0 * np.sin(2 * np.pi * i / N) - r1 = r1/N + r1 = r1 / N r2 = torch.zeros(3, device=ring_coors.device) for i, rl0 in enumerate(ring_coors): r2 = r2 + rl0 * np.cos(2 * np.pi * i / N) - r2 = r2/N + r2 = r2 / N # normalize r1 and r2 r1 = r1 / torch.linalg.norm(r1) @@ -396,313 +381,293 @@ def projecting_veconplanenormal(self): proj2 = torch.dot(pos_vec, r2) cv = proj1 + proj2 return torch.abs(cv) - - def projection_channelnormal(self): - """ - Projection of a position vector onto the vector + + def projection_channelnormal(self) -> torch.Tensor: + """Projection of a position vector onto the vector along a channel Atomic indices are used to determine the coordiantes of the vectors. - Params - ------ - mol_inds: list of int - List of indices of the mol/fragment tracked by the CV - g1_inds: list of int - List of atomic indices denoting "start" of channel - g2_inds: list of int - List of atomic indices denoting "end" of channel + + Args: + mol_inds (list[int]): List of indices of the mol/fragment tracked by the CV + g1_inds (list[int]): List of atomic indices denoting "start" of channel + g2_inds (list[int]): List of atomic indices denoting "end" of channel + + Returns: + cv (torch.tensor): tensor for the computed CV """ - - mol_coors = self.xyz[self.mol_inds] - g1_coors = self.xyz[self.g1_inds] - g2_coors = self.xyz[self.g2_inds] - - mol_cm = self._get_com(self.mol_inds) - center_g1 = g1_coors.mean(axis=0) - center_g2 = g2_coors.mean(axis=0) - center = (center_g1 + center_g2)/2 - - normal_vec = (center_g2 - center_g1)/torch.linalg.norm(center_g2 - center_g1) - rel_pos = mol_cm - center + g1_coors = self.xyz[self.g1_inds] + g2_coors = self.xyz[self.g2_inds] + + mol_cm = self._get_com(self.mol_inds) + center_g1 = g1_coors.mean(axis=0) + center_g2 = g2_coors.mean(axis=0) + center = (center_g1 + center_g2) / 2 + + normal_vec = (center_g2 - center_g1) / torch.linalg.norm(center_g2 - center_g1) + rel_pos = mol_cm - center cv = torch.dot(rel_pos, normal_vec) return cv - - def adjacency_matrix_cv(self): - """Docstring + + def adjacency_matrix_cv(self) -> torch.Tensor: + """Create the adjacency matrix for for a given structure as the CV + + Returns: + cv (torch.tensor): computed adjacency matrix """ - edges, atomslist, Natoms, adjacency_matrix = get_adjacency_matrix(self.xyz, - self.atom_numbers, - self.bond_length, - cell=self.cell, - device=self.device) - - pred = self.model(atomslist, edges, Natoms, adjacency_matrix)[0] - rmsd = (pred-self.target).norm() - cv = rmsd.to('cpu').view(-1,1) - + edges, atomslist, Natoms, adjacency_matrix = get_adjacency_matrix( + self.xyz, self.atom_numbers, self.bond_length, cell=self.cell, device=self.device + ) + + pred = self.model(atomslist, edges, Natoms, adjacency_matrix)[0] + rmsd = (pred - self.target).norm() + cv = rmsd.to("cpu").view(-1, 1) + return cv - - def deproton1(self): - """ Emanuele Grifoni, GiovanniMaria Piccini, and Michele Parrinello, PNAS (2019), 116 (10) 4054-40 - https://www.pnas.org/doi/10.1073/pnas.1819771116 - - Sp describes the proton exchange between acid-base pairs + + def deproton1(self) -> torch.Tensor: + """Emanuele Grifoni, GiovanniMaria Piccini, and Michele Parrinello, PNAS (2019), 116 (10) 4054-40 + https://www.pnas.org/doi/10.1073/pnas.1819771116 + + Sp describes the proton exchange between acid-base pairs + + Returns: + cv (torch.tensor): computed Sp """ - dis_mat = self.xyz[None, :, :] - self.xyz[:, None, :] - - if Box is not None: - cell_dim = Box.to(dis_mat.device) - shift = torch.round(torch.divide(dis_mat,cell_dim)) - offsets = -shift - dis_mat = dis_mat+offsets*cell_dim - + + if self.Box is not None: + cell_dim = self.Box.to(dis_mat.device) + shift = torch.round(torch.divide(dis_mat, cell_dim)) + offsets = -shift + dis_mat = dis_mat + offsets * cell_dim + dis_sq = torch.linalg.norm(dis_mat, dim=-1) - dis = dis_sq[self.O,:][:,self.H] - - dis1 = dis_sq[self.Oacid,:][:,self.Owater] - cvmatrix = torch.exp(-self.do * dis) - cvmatrix = cvmatrix / cvmatrix.sum(0) - cvmatrixw = cvmatrix[self.Oacid.shape[0]:].sum(-1) - self.r1 - cvmatrix = cvmatrix[:self.Oacid.shape[0]].sum(-1) - self.ro - cv1 = 2 * cvmatrix.sum() + cvmatrixw.sum() - + dis = dis_sq[self.O, :][:, self.H] + + cvmatrix = torch.exp(-self.do * dis) + cvmatrix = cvmatrix / cvmatrix.sum(0) + cvmatrixw = cvmatrix[self.Oacid.shape[0] :].sum(-1) - self.r1 + cvmatrix = cvmatrix[: self.Oacid.shape[0]].sum(-1) - self.ro + cv1 = 2 * cvmatrix.sum() + cvmatrixw.sum() + return cv1 - - def deproton2(self): - """ Emanuele Grifoni, GiovanniMaria Piccini, and Michele Parrinello, PNAS (2019), 116 (10) 4054-40 - https://www.pnas.org/doi/10.1073/pnas.1819771116 - - Sd describes tge distance between acid-base pairs + + def deproton2(self) -> torch.Tensor: + """Emanuele Grifoni, GiovanniMaria Piccini, and Michele Parrinello, PNAS (2019), 116 (10) 4054-40 + https://www.pnas.org/doi/10.1073/pnas.1819771116 + + Sd describes the distance between acid-base pairs + + Returns: + cv2 (torch.tensor): computed distance """ - dis_mat = self.xyz[None, :, :] - self.xyz[:, None, :] - - if Box is not None: - cell_dim = Box.to(dis_mat.device) - shift = torch.round(torch.divide(dis_mat,cell_dim)) - offsets = -shift - dis_mat = dis_mat + offsets * cell_dim - - dis_sq = torch.linalg.norm(dis_mat,dim=-1) - dis = dis_sq[self.O,:][:,self.H] - dis1 = dis_sq[self.Oacid,:][:,self.Owater] - cvmatrix = torch.exp(-self.do * dis) - cvmatrix = cvmatrix / cvmatrix.sum(0) + + if self.Box is not None: + cell_dim = self.Box.to(dis_mat.device) + shift = torch.round(torch.divide(dis_mat, cell_dim)) + offsets = -shift + dis_mat = dis_mat + offsets * cell_dim + + dis_sq = torch.linalg.norm(dis_mat, dim=-1) + dis = dis_sq[self.O, :][:, self.H] + dis1 = dis_sq[self.Oacid, :][:, self.Owater] + cvmatrix = torch.exp(-self.do * dis) + cvmatrix = cvmatrix / cvmatrix.sum(0) cvmatrixx = torch.exp(-self.d * dis) cvmatrixx = cvmatrixx / cvmatrixx.sum(0) - cvmatrixw = cvmatrixx[self.Oacid.shape[0]:].sum(-1) - self.r1 - cvmatrix = cvmatrixx[:self.Oacid.shape[0]].sum(-1) - self.ro - cvmatrix1 = torch.cat((cvmatrix,cvmatrixw)) - cvmatrix2 = torch.matmul(cvmatrix.view(1,-1).t(),cvmatrixw.view(1,-1)) + cvmatrixw = cvmatrixx[self.Oacid.shape[0] :].sum(-1) - self.r1 + cvmatrix = cvmatrixx[: self.Oacid.shape[0]].sum(-1) - self.ro + cvmatrix2 = torch.matmul(cvmatrix.view(1, -1).t(), cvmatrixw.view(1, -1)) cvmatrix2 = -cvmatrix2 * dis1 - cv2 = cvmatrix2.sum() - + cv2 = cvmatrix2.sum() + return cv2 - - - def energy_gap(self, enkey1, enkey2) : - """get energy gap betweentwo adiabatic PES + + def energy_gap(self, enkey1: str, enkey2: str): + """Get energy gap between two adiabatic PES + Args: enkey1 (str): key of one adiabatic PES enkey2 (str): key of the other PES - + Returns: - cv (torch.tensor): computed energy gap + cv (torch.tensor): computed energy gap """ - - batch = batch_to(self.atoms.get_batch(), self.device) - pred = self.model(batch, device=self.device) + batch = batch_to(self.atoms.get_batch(), self.device) + pred = self.model(batch, device=self.device) energy_1 = pred[enkey1] energy_2 = pred[enkey2] - e_diff = energy_2 - energy_1 - - cv = torch.abs(e_diff) - cv_grad = pred[enkey2+'_grad'] - pred[enkey1+'_grad'] + e_diff = energy_2 - energy_1 + + cv = torch.abs(e_diff) + cv_grad = pred[enkey2 + "_grad"] - pred[enkey1 + "_grad"] if e_diff < 0: cv_grad *= -1.0 - + return cv, cv_grad - - def forward(self, atoms): - """switch function to call the right CV-func + def forward(self, atoms: Atoms) -> tuple[np.ndarray, np.ndarray]: + """Switch function to call the right CV-func + + Args: + atoms (Atoms): ASE Atoms object + + Returns: + cv (np.ndarray): computed CV """ - self.xyz = torch.from_numpy(atoms.get_positions()) - self.xyz.requires_grad=True - + self.xyz.requires_grad = True + self.atoms = atoms - - if self.info_dict['name'] == 'distance': - cv = self.distance(self.info_dict['index_list']) + + if self.info_dict["name"] == "distance": + cv = self.distance(self.info_dict["index_list"]) cv_grad = compute_grad(inputs=self.xyz, output=cv) - - elif self.info_dict['name'] == 'angle': - cv = self.angle(self.info_dict['index_list']) + + elif self.info_dict["name"] == "angle": + cv = self.angle(self.info_dict["index_list"]) cv_grad = compute_grad(inputs=self.xyz, output=cv) - - elif self.info_dict['name'] == 'dihedral': - cv = self.dihedral(self.info_dict['index_list']) + + elif self.info_dict["name"] == "dihedral": + cv = self.dihedral(self.info_dict["index_list"]) cv_grad = compute_grad(inputs=self.xyz, output=cv) - - elif self.info_dict['name'] == 'coordination_number': - cv = self.coordination_number(self.info_dict['index_list'], self.info_dict['switching_dist']) + + elif self.info_dict["name"] == "coordination_number": + cv = self.coordination_number(self.info_dict["index_list"], self.info_dict["switching_dist"]) cv_grad = compute_grad(inputs=self.xyz, output=cv) - - elif self.info_dict['name'] == 'coordination': - cv = self.coordination(self.info_dict['index_list'], self.info_dict['switching_dist']) + + elif self.info_dict["name"] == "coordination": + cv = self.coordination(self.info_dict["index_list"], self.info_dict["switching_dist"]) cv_grad = compute_grad(inputs=self.xyz, output=cv) - - elif self.info_dict['name'] == 'minimal_distance': - cv = self.minimal_distance(self.info_dict['index_list']) + + elif self.info_dict["name"] == "minimal_distance": + cv = self.minimal_distance(self.info_dict["index_list"]) cv_grad = compute_grad(inputs=self.xyz, output=cv) - - elif self.info_dict['name'] == 'projecting_centroidvec': - cv = self.projecting_centroidvec() + + elif self.info_dict["name"] == "projecting_centroidvec": + cv = self.projecting_centroidvec() cv_grad = compute_grad(inputs=self.xyz, output=cv) - - elif self.info_dict['name'] == 'projecting_veconplane': - cv = self.projecting_veconplane() + + elif self.info_dict["name"] == "projecting_veconplane": + cv = self.projecting_veconplane() cv_grad = compute_grad(inputs=self.xyz, output=cv) - - elif self.info_dict['name'] == 'projecting_veconplanenormal': - cv = self.projecting_veconplanenormal() + + elif self.info_dict["name"] == "projecting_veconplanenormal": + cv = self.projecting_veconplanenormal() cv_grad = compute_grad(inputs=self.xyz, output=cv) - - elif self.info_dict['name'] == 'projection_channelnormal': - cv = self.projection_channelnormal() + + elif self.info_dict["name"] == "projection_channelnormal": + cv = self.projection_channelnormal() cv_grad = compute_grad(inputs=self.xyz, output=cv) - - elif self.info_dict['name'] == 'Sp': + + elif self.info_dict["name"] == "Sp": cv = self.deproton1() cv_grad = compute_grad(inputs=self.xyz, output=cv) - - elif self.info_dict['name'] == 'Sd': + + elif self.info_dict["name"] == "Sd": cv = self.deproton2() cv_grad = compute_grad(inputs=self.xyz, output=cv) - - elif self.info_dict['name'] == 'energy_gap': - cv, cv_grad = self.energy_gap(self.info_dict['enkey_1'], self.info_dict['enkey_2']) - + + elif self.info_dict["name"] == "energy_gap": + cv, cv_grad = self.energy_gap(self.info_dict["enkey_1"], self.info_dict["enkey_2"]) + return cv.detach().cpu().numpy(), cv_grad.detach().cpu().numpy() - - - - - - - - - - - - - - - - - - - - - -# implement SMILES to graph function + + +# implement SMILES to graph function def smiles2graph(smiles): - ''' - Transfrom smiles into a list nodes (atomic number) - - Args: + """Transfrom smiles into a list nodes (atomic number) + + Args: smiles (str): SMILES strings - - return: - z(np.array), A (np.array): list of atomic numbers, adjancency matrix - ''' - - mol = Chem.MolFromSmiles( smiles ) # no hydrogen - z = np.array( [atom.GetAtomicNum() for atom in mol.GetAtoms()] ) + + Return: + z(np.array), A (np.array): list of atomic numbers, adjancency matrix + """ + mol = Chem.MolFromSmiles(smiles) # no hydrogen + z = np.array([atom.GetAtomicNum() for atom in mol.GetAtoms()]) A = np.stack(Chem.GetAdjacencyMatrix(mol)) - #np.fill_diagonal(A,1) + # np.fill_diagonal(A,1) return z, A + + class GraphDataset(torch.utils.data.Dataset): - def __init__(self, - AtomicNum_list, - Edge_list, - Natom_list, - Adjacency_matrix_list - ): - - ''' - GraphDataset object - - Args: - z_list (list of torch.LongTensor) - a_list (list of torch.LongTensor) - N_list (list of int) - - - ''' - self.AtomicNum_list = AtomicNum_list # atomic number - self.Edge_list = Edge_list # edge list - self.Natom_list = Natom_list # Number of atoms - self.Adjacency_matrix_list=Adjacency_matrix_list + """Class for datasets of graphs""" + + def __init__( + self, + AtomicNum_list: list[torch.LongTensor], + Edge_list: list[torch.LongTensor], + Natom_list: list[int], + Adjacency_matrix_list: list, + ) -> None: + """GraphDataset object + + Args: + AtomicNum_list (list of torch.LongTensor): list of atomic numbers + Edge_list (list of torch.LongTensor): list of edges in the graph + Natom_list (list of int): list of number of atoms in each graph + Adjacency_matrix_list (list of torch.LongTensor): list of adjacency matrices + """ + self.AtomicNum_list = AtomicNum_list # atomic number + self.Edge_list = Edge_list # edge list + self.Natom_list = Natom_list # Number of atoms + self.Adjacency_matrix_list = Adjacency_matrix_list + def __len__(self): return len(self.Natom_list) def __getitem__(self, idx): - AtomicNum = torch.LongTensor(self.AtomicNum_list[idx]) Edge = torch.LongTensor(self.Edge_list[idx]) Natom = self.Natom_list[idx] Adjacency_matrix = self.Adjacency_matrix_list[idx] - - return AtomicNum, Edge, Natom,Adjacency_matrix -def collate_graphs(batch): - '''Batch multiple graphs into one batched graph - + + return AtomicNum, Edge, Natom, Adjacency_matrix + + +def collate_graphs(batch: tuple) -> tuple: + """Batch multiple graphs into one batched graph + Args: - - batch (tuple): tuples of AtomicNum, Edge, Natom obtained from GraphDataset.__getitem__() - - Return - (tuple): Batched AtomicNum, Edge, Natom - - ''' - + batch (tuple): tuples of AtomicNum, Edge, Natom obtained from GraphDataset.__getitem__() + + Return: + tuple: Batched AtomicNum, Edge, Natom + """ AtomicNum_batch = [] Edge_batch = [] Natom_batch = [] - Adjacency_matrix_batch=[] + Adjacency_matrix_batch = [] cumulative_atoms = np.cumsum([0] + [b[2] for b in batch])[:-1] - + for i in range(len(batch)): - z, a, N,A = batch[i] + z, a, N, A = batch[i] index_shift = cumulative_atoms[i] a = a + index_shift - AtomicNum_batch.append(z) + AtomicNum_batch.append(z) Edge_batch.append(a) Natom_batch.append(N) Adjacency_matrix_batch.append(A) - + AtomicNum_batch = torch.cat(AtomicNum_batch) Edge_batch = torch.cat(Edge_batch, dim=1) - Natom_batch = Natom_batch - #Adjacency_matrix_batch=torch.block_diag(*Adjacency_matrix_batch) - Adjacency_matrix_batch=torch.cat(Adjacency_matrix_batch,dim=0).view(-1,1) - - return AtomicNum_batch, Edge_batch, Natom_batch,Adjacency_matrix_batch -from itertools import repeat -def scatter_add(src, index, dim_size, dim=-1, fill_value=0): - - ''' - Sums all values from the src tensor into out at the indices specified in the index - tensor along a given axis dim. - ''' - + # Adjacency_matrix_batch=torch.block_diag(*Adjacency_matrix_batch) + Adjacency_matrix_batch = torch.cat(Adjacency_matrix_batch, dim=0).view(-1, 1) + + return AtomicNum_batch, Edge_batch, Natom_batch, Adjacency_matrix_batch + + +def scatter_add(src, index: torch.Tensor, dim_size: int, dim: int = -1, fill_value: int = 0) -> torch.Tensor: + """Sums all values from the src tensor into out at the indices specified in the index + tensor along a given axis dim. + """ index_size = list(repeat(1, src.dim())) index_size[dim] = src.size(dim) index = index.view(index_size).expand_as(src) - + dim = range(src.dim())[dim] out_size = list(src.size()) out_size[dim] = dim_size @@ -710,172 +675,240 @@ def scatter_add(src, index, dim_size, dim=-1, fill_value=0): out = src.new_full(out_size, fill_value) return out.scatter_add_(dim, index, src) -from torch import nn -from torch.nn import ModuleDict + class GNN(torch.nn.Module): - ''' - A GNN model - ''' + """A GNN model""" + def __init__(self, n_convs=3, n_embed=64): - super(GNN, self).__init__() + super().__init__() self.atom_embed = nn.Embedding(100, n_embed) # Declare MLPs in a ModuleList self.convolutions = nn.ModuleList( - [ - ModuleDict({ - 'update_mlp': nn.Sequential(nn.Linear(n_embed, n_embed), - nn.ReLU(), - nn.Linear(n_embed, n_embed)), - 'message_mlp': nn.Sequential(nn.Linear(n_embed, n_embed), - nn.ReLU(), - nn.Linear(n_embed, n_embed)) - }) + [ + ModuleDict( + { + "update_mlp": nn.Sequential( + nn.Linear(n_embed, n_embed), nn.ReLU(), nn.Linear(n_embed, n_embed) + ), + "message_mlp": nn.Sequential( + nn.Linear(n_embed, n_embed), nn.ReLU(), nn.Linear(n_embed, n_embed) + ), + } + ) for _ in range(n_convs) ] - ) + ) # Declare readout layers - #self.readout = nn.Sequential(nn.Linear(n_embed, n_embed), nn.ReLU(), nn.Linear(n_embed, 1)) - - def forward(self, AtomicNum, Edge, Natom,adjacency_matrix): + # self.readout = nn.Sequential(nn.Linear(n_embed, n_embed), nn.ReLU(), nn.Linear(n_embed, 1)) + + def forward(self, AtomicNum, Edge, Natom, adjacency_matrix): ################ Code ################# - - # Parametrize embedding - h = self.atom_embed(AtomicNum) #eqn. 1 + + # Parametrize embedding + h = self.atom_embed(AtomicNum) # eqn. 1 for conv in self.convolutions: - messagei2j=conv.message_mlp(h[Edge[0]]*h[Edge[1]]) - messagei2j=messagei2j*adjacency_matrix - node_message=scatter_add(src=messagei2j, index=Edge[1], dim=0, dim_size=len(AtomicNum)) #+ scatter_add(src=messagei2j, index=Edge[0], dim=0, dim_size=len(AtomicNum)) - h=h+conv.update_mlp(node_message) - output=[split.sum(0) for split in torch.split(h, Natom)] - - - + messagei2j = conv.message_mlp(h[Edge[0]] * h[Edge[1]]) + messagei2j = messagei2j * adjacency_matrix + node_message = scatter_add( + src=messagei2j, index=Edge[1], dim=0, dim_size=len(AtomicNum) + ) # + scatter_add(src=messagei2j, index=Edge[0], dim=0, dim_size=len(AtomicNum)) + h = h + conv.update_mlp(node_message) + output = [split.sum(0) for split in torch.split(h, Natom)] + ################ Code ################# return output -def adjfunc(x,m,s): - return 4/((torch.exp(s * (x-m)) + 1) * (torch.exp((-s) * (x-m))+1)) -def gauss(x,m,s,a,b): - #return torch.exp(-abs((x-m)/2*s)**p) - G=(1+(2**(a/b)-1)*abs((x-m)/s)**a)**(-b/a) - G[torch.where(x tuple[torch.LongTensor, torch.LongTensor, list[int], torch.Tensor]: + """Get the adjacency matrix of a molecule + + Args: + xyz (list | torch.Tensor): the xyz coordinates of the atoms + atom_numbers (torch.Tensor): the atomic numbers of the atoms + bond_length (dict): the bond lengths of the atoms in the format "atom1-atom2" + for each key, the value is the bond length + oxygeninvolved (list[int]): list of indices of the oxygen atoms + cell (torch.Tensor | None, optional): The dimensions of the lattice if the system + is periodic. Defaults to None. + device (str, optional): the device to use. Defaults to "cpu". + + Returns: + tuple[torch.LongTensor, torch.LongTensor, list[int], torch.Tensor]: _description_ + """ dis_mat = xyz[None, :, :] - xyz[:, None, :] if cell is not None: cell_dim = torch.tensor(np.diag(cell)) - shift = torch.round(torch.divide(dis_mat,cell_dim)) + shift = torch.round(torch.divide(dis_mat, cell_dim)) offsets = -shift - - dis_mat=dis_mat+offsets*cell_dim - dis_sq=dis_mat.norm(dim=-1) - bondlen=torch.ones(dis_sq.shape) - bondlen[torch.where(atom_numbers==8)[0],torch.where(atom_numbers==14)[0].view(-1,1)]=bond_length['8-14'] - bondlen[torch.where(atom_numbers==14)[0],torch.where(atom_numbers==8)[0].view(-1,1)]=bond_length['14-8'] - bondlen[torch.where(atom_numbers==8)[0],torch.where(atom_numbers==1)[0].view(-1,1)]=bond_length['8-1'] - bondlen[torch.where(atom_numbers==1)[0],torch.where(atom_numbers==8)[0].view(-1,1)]=bond_length['1-8'] - #adjacency=dis_sq-bondlen - #adjacency=(dis_sq-0.0001)/bondlen - #adjacency_matrix=torch.exp(-((torch.abs(adjacency)/d)**p)) - #adjacency_matrix=(1-adjacency**m)/(1-adjacency**n) - #adjacency_matrix=adjacency_matrix-torch.eye(adjacency_matrix.shape[0]) - adjacency_matrix= gauss(dis_sq,bondlen,0.5,2,2) - adjacency_matrix=adjacency_matrix[torch.where(atom_numbers==14)[0].view(-1,1),torch.where(atom_numbers==8)[0][oxygeninvolved]] - adjacency_matrix=torch.matmul(adjacency_matrix,adjacency_matrix.t()) - adjacency_matrix=adjacency_matrix.fill_diagonal_(0) - edges=torch.stack([i for i in torch.where(adjacency_matrix>=0)]) - adjacency_matrix=adjacency_matrix[torch.where(adjacency_matrix>=0)[0],torch.where(adjacency_matrix>=0)[1]].view(-1,1) - atomslist=torch.tensor([14 for i in torch.where(atom_numbers==14)[0]]).view(-1) - #molecules,edge_list,atom_list=get_molecules(xyz=xyz.detach(),atom_numbers=atom_numbers,bond_length=bond_length,periodic=False) - #adjacency_matrix_list=[] - #print(compute_grad(xyz,dis_sq[0,1])) - #for i,m in enumerate(molecules): + + dis_mat = dis_mat + offsets * cell_dim + dis_sq = dis_mat.norm(dim=-1) + bondlen = torch.ones(dis_sq.shape) + bondlen[torch.where(atom_numbers == 8)[0], torch.where(atom_numbers == 14)[0].view(-1, 1)] = bond_length["8-14"] + bondlen[torch.where(atom_numbers == 14)[0], torch.where(atom_numbers == 8)[0].view(-1, 1)] = bond_length["14-8"] + bondlen[torch.where(atom_numbers == 8)[0], torch.where(atom_numbers == 1)[0].view(-1, 1)] = bond_length["8-1"] + bondlen[torch.where(atom_numbers == 1)[0], torch.where(atom_numbers == 8)[0].view(-1, 1)] = bond_length["1-8"] + # adjacency=dis_sq-bondlen + # adjacency=(dis_sq-0.0001)/bondlen + # adjacency_matrix=torch.exp(-((torch.abs(adjacency)/d)**p)) + # adjacency_matrix=(1-adjacency**m)/(1-adjacency**n) + # adjacency_matrix=adjacency_matrix-torch.eye(adjacency_matrix.shape[0]) + adjacency_matrix = gauss(dis_sq, bondlen, 0.5, 2, 2) + adjacency_matrix = adjacency_matrix[ + torch.where(atom_numbers == 14)[0].view(-1, 1), torch.where(atom_numbers == 8)[0][oxygeninvolved] + ] + adjacency_matrix = torch.matmul(adjacency_matrix, adjacency_matrix.t()) + adjacency_matrix = adjacency_matrix.fill_diagonal_(0) + edges = torch.stack(list(torch.where(adjacency_matrix >= 0))) + adjacency_matrix = adjacency_matrix[ + torch.where(adjacency_matrix >= 0)[0], torch.where(adjacency_matrix >= 0)[1] + ].view(-1, 1) + atomslist = torch.tensor([14 for i in torch.where(atom_numbers == 14)[0]]).view(-1) + # molecules,edge_list,atom_list=get_molecules( + # xyz=xyz.detach(),atom_numbers=atom_numbers,bond_length=bond_length,periodic=False + # ) + # adjacency_matrix_list=[] + # print(compute_grad(xyz,dis_sq[0,1])) + # for i,m in enumerate(molecules): # n=torch.tensor(m) # adjacency_matrix_list.append(adjacency_matrix[edge_list[i][0],edge_list[i][1]].view(-1,1)) - return torch.LongTensor(edges).to(device),torch.LongTensor(atomslist).to(device),[(len(atomslist))],adjacency_matrix.float().to(device) + return ( + torch.LongTensor(edges).to(device), + torch.LongTensor(atomslist).to(device), + [(len(atomslist))], + adjacency_matrix.float().to(device), + ) + + +def get_molecules(atom: AtomsBatch, bond_length: dict, mode: str = "bond", periodic: bool = True) -> list: + """Get the molecules in the system from the Atoms object -def get_molecules(atom,bond_length,mode='bond',periodic=True): - types=list(set(atom.numbers)) - xyz=atom.positions - #A=np.lexsort((xyz[:,2],xyz[:,1],xyz[:,0])) + Args: + atom (AtomsBatch): Atoms from which to extract the molecules + bond_length (dict): Dictionary of bond lengths + mode (str, optional): Mode of identifying distinct molecules. Defaults to "bond". + periodic (bool, optional): Whehter or not the Atoms are periodic. Defaults to True. + + Returns: + list: the molecules in the system + """ + types = list(set(atom.numbers)) + xyz = atom.positions + # A=np.lexsort((xyz[:,2],xyz[:,1],xyz[:,0])) dis_mat = xyz[None, :, :] - xyz[:, None, :] - if periodic==True: + if periodic is True: cell_dim = np.diag(np.array(atom.get_cell())) - shift = np.round(np.divide(dis_mat,cell_dim)) + shift = np.round(np.divide(dis_mat, cell_dim)) offsets = -shift - dis_mat=dis_mat+offsets*cell_dim + dis_mat = dis_mat + offsets * cell_dim dis_sq = torch.tensor(dis_mat).pow(2).sum(-1).numpy() - dis_sq=dis_sq**0.5 - clusters=np.array([0 for i in range(xyz.shape[0])]) + dis_sq = dis_sq**0.5 + clusters = np.array([0 for i in range(xyz.shape[0])]) for i in range(xyz.shape[0]): - mm=max(clusters) - ty=atom.numbers[i] - oxy_neighbors=[] - if mode=='bond': + mm = max(clusters) + ty = atom.numbers[i] + oxy_neighbors = [] + if mode == "bond": for t in types: - if bond_length.get('%s-%s'%(ty,t))!=None: - oxy_neighbors.extend(list(np.where(atom.numbers==t)[0][np.where(dis_sq[i,np.where(atom.numbers==t)[0]]<=bond_length['%s-%s'%(ty,t)])[0]])) - elif mode=='cutoff': - oxy_neighbors.extend(list(np.where(dis_sq[i]<=6)[0])) - oxy_neighbors=np.array(oxy_neighbors) - if len(oxy_neighbors)==0: - clusters[i]=mm+1 + if bond_length.get(f"{ty}-{t}") is not None: + oxy_neighbors.extend( + list( + np.where(atom.numbers == t)[0][ + np.where(dis_sq[i, np.where(atom.numbers == t)[0]] <= bond_length[f"{ty}-{t}"])[0] + ] + ) + ) + elif mode == "cutoff": + oxy_neighbors.extend(list(np.where(dis_sq[i] <= 6)[0])) + oxy_neighbors = np.array(oxy_neighbors) + if len(oxy_neighbors) == 0: + clusters[i] = mm + 1 continue - if (clusters[oxy_neighbors]==0).all() and clusters[i]!=0: - clusters[oxy_neighbors]=clusters[i] - elif (clusters[oxy_neighbors]==0).all() and clusters[i]==0: - clusters[oxy_neighbors]=mm+1 - clusters[i]=mm+1 - elif (clusters[oxy_neighbors]==0).all() == False and clusters[i]==0: - clusters[i]=min(clusters[oxy_neighbors][clusters[oxy_neighbors]!=0]) - clusters[oxy_neighbors]=min(clusters[oxy_neighbors][clusters[oxy_neighbors]!=0]) - elif (clusters[oxy_neighbors]==0).all() == False and clusters[i]!=0: - tmp=clusters[oxy_neighbors][clusters[oxy_neighbors]!=0][clusters[oxy_neighbors][clusters[oxy_neighbors]!=0]!=min(clusters[oxy_neighbors][clusters[oxy_neighbors]!=0])] - clusters[i]=min(clusters[oxy_neighbors][clusters[oxy_neighbors]!=0]) - clusters[oxy_neighbors]=min(clusters[oxy_neighbors][clusters[oxy_neighbors]!=0]) + if (clusters[oxy_neighbors] == 0).all() and clusters[i] != 0: + clusters[oxy_neighbors] = clusters[i] + elif (clusters[oxy_neighbors] == 0).all() and clusters[i] == 0: + clusters[oxy_neighbors] = mm + 1 + clusters[i] = mm + 1 + elif (clusters[oxy_neighbors] == 0).all() is False and clusters[i] == 0: + clusters[i] = min(clusters[oxy_neighbors][clusters[oxy_neighbors] != 0]) + clusters[oxy_neighbors] = min(clusters[oxy_neighbors][clusters[oxy_neighbors] != 0]) + elif (clusters[oxy_neighbors] == 0).all() is False and clusters[i] != 0: + tmp = clusters[oxy_neighbors][clusters[oxy_neighbors] != 0][ + clusters[oxy_neighbors][clusters[oxy_neighbors] != 0] + != min(clusters[oxy_neighbors][clusters[oxy_neighbors] != 0]) + ] + clusters[i] = min(clusters[oxy_neighbors][clusters[oxy_neighbors] != 0]) + clusters[oxy_neighbors] = min(clusters[oxy_neighbors][clusters[oxy_neighbors] != 0]) for tr in tmp: - clusters[np.where(clusters==tr)[0]]=min(clusters[oxy_neighbors][clusters[oxy_neighbors]!=0]) - - molecules=[] - for i in range(1,max(clusters)+1): - if np.size(np.where(clusters==i)[0])==0: + clusters[np.where(clusters == tr)[0]] = min(clusters[oxy_neighbors][clusters[oxy_neighbors] != 0]) + + molecules = [] + for i in range(1, max(clusters) + 1): + if np.size(np.where(clusters == i)[0]) == 0: continue - molecules.append(np.where(clusters==i)[0]) + molecules.append(np.where(clusters == i)[0]) + + return molecules - return molecules -def reconstruct_atoms(atomsobject, mol_idx,centre=None): + +def reconstruct_atoms(atomsobject: Atoms | AtomsBatch, mol_idx: list, centre: torch.Tensor | None = None) -> np.ndarray: + """Reconstruct the atoms in the system + + Args: + atomsobject (Atoms | AtomsBatch): The Atoms object to reconstruct + mol_idx (list): The indices of the molecules to reconstruct + centre (torch.Tensor | None, optional): A manual center of the atoms. If unspecified, + calculated within the function. Defaults to None. + + Returns: + np.ndarray: the new positions of the atoms + """ sys_xyz = torch.Tensor(atomsobject.get_positions(wrap=True)) box_len = torch.Tensor(atomsobject.get_cell_lengths_and_angles()[:3]) print(box_len) for idx in mol_idx: mol_xyz = sys_xyz[idx] - center = mol_xyz.shape[0]//2 - if centre!=None: - center=centre - intra_dmat = (mol_xyz[None, :,...] - mol_xyz[:, None, ...])[center] - if np.count_nonzero(atomsobject.cell.T-np.diag(np.diagonal(atomsobject.cell.T)))!=0: - M,N=intra_dmat.shape[0],intra_dmat.shape[1] - f=torch.linalg.solve(torch.Tensor(atomsobject.cell.T),(intra_dmat.view(-1,3).T)).T - g=f-torch.floor(f+0.5) - intra_dmat=torch.matmul(g,torch.Tensor(atomsobject.cell)) - intra_dmat=intra_dmat.view(M,3) - offsets=-torch.floor(f+0.5).view(M,3) - traj_unwrap = mol_xyz+torch.matmul(offsets,torch.Tensor(atomsobject.cell)) + center = mol_xyz.shape[0] // 2 + if centre is not None: + center = centre + intra_dmat = (mol_xyz[None, :, ...] - mol_xyz[:, None, ...])[center] + if np.count_nonzero(atomsobject.cell.T - np.diag(np.diagonal(atomsobject.cell.T))) != 0: + M, _N = intra_dmat.shape[0], intra_dmat.shape[1] + f = torch.linalg.solve(torch.Tensor(atomsobject.cell.T), (intra_dmat.view(-1, 3).T)).T + g = f - torch.floor(f + 0.5) + intra_dmat = torch.matmul(g, torch.Tensor(atomsobject.cell)) + intra_dmat = intra_dmat.view(M, 3) + offsets = -torch.floor(f + 0.5).view(M, 3) + traj_unwrap = mol_xyz + torch.matmul(offsets, torch.Tensor(atomsobject.cell)) else: - sub = (intra_dmat > 0.5 * box_len).to(torch.float) * box_len - add = (intra_dmat <= -0.5 * box_len).to(torch.float) * box_len - shift=torch.round(torch.divide(intra_dmat,box_len)) - offsets=-shift - traj_unwrap = mol_xyz+offsets*box_len - #traj_unwrap=mol_xyz+add-sub + # sub = (intra_dmat > 0.5 * box_len).to(torch.float) * box_len + # add = (intra_dmat <= -0.5 * box_len).to(torch.float) * box_len + shift = torch.round(torch.divide(intra_dmat, box_len)) + offsets = -shift + traj_unwrap = mol_xyz + offsets * box_len + # traj_unwrap=mol_xyz+add-sub sys_xyz[idx] = traj_unwrap new_pos = sys_xyz.numpy() - return new_pos \ No newline at end of file + return new_pos diff --git a/nff/md/nms.py b/nff/md/nms.py index dd31feaa..5977712d 100644 --- a/nff/md/nms.py +++ b/nff/md/nms.py @@ -1,38 +1,39 @@ -from tqdm import tqdm -import numpy as np -import os +import contextlib import copy +import os import pickle -from rdkit import Chem import shutil -from torch.utils.data import DataLoader -from torch.nn.modules.container import ModuleDict +import numpy as np from ase import optimize, units -from ase.md.verlet import VelocityVerlet from ase.io.trajectory import Trajectory as AseTrajectory -from ase.vibrations import Vibrations -from ase.units import kg, kB, mol, J, m +from ase.md.verlet import VelocityVerlet from ase.thermochemistry import IdealGasThermo +from ase.units import J, kB, kg, m, mol +from ase.vibrations import Vibrations +from rdkit import Chem +from torch.nn.modules.container import ModuleDict +from torch.utils.data import DataLoader +from tqdm import tqdm -from nff.io.ase_ax import NeuralFF, AtomsBatch -from nff.train import load_model -from nff.data import collate_dicts, Dataset +from nff.data import Dataset, collate_dicts +from nff.io.ase_ax import AtomsBatch, NeuralFF from nff.md import nve -from nff.utils.constants import FS_TO_AU, ASE_TO_FS, EV_TO_AU, BOHR_RADIUS -from nff.utils import constants as const -from nff.nn.tensorgrad import get_schnet_hessians -from nff.utils.cuda import batch_to from nff.nn.models.schnet import SchNet +from nff.nn.tensorgrad import get_schnet_hessians from nff.nn.tensorgrad import hess_from_atoms as analytical_hess +from nff.train import load_model +from nff.utils import constants as const +from nff.utils.constants import ASE_TO_FS, BOHR_RADIUS, EV_TO_AU, FS_TO_AU +from nff.utils.cuda import batch_to PT = Chem.GetPeriodicTable() PERIODICTABLE = PT -HA2J = 4.359744E-18 +HA2J = 4.359744e-18 BOHRS2ANG = 0.529177 -SPEEDOFLIGHT = 2.99792458E8 -AMU2KG = 1.660538782E-27 +SPEEDOFLIGHT = 2.99792458e8 +AMU2KG = 1.660538782e-27 TEMP = 298.15 PRESSURE = 101325 @@ -40,7 +41,7 @@ ROTOR_CUTOFF = 50 # cm^-1 CM_TO_EV = 1.2398e-4 GAS_CONST = 8.3144621 * J / mol -B_AV = 1e-44 * kg * m ** 2 +B_AV = 1e-44 * kg * m**2 RESTART_FILE = "restart.pickle" @@ -68,15 +69,8 @@ def get_key(iroot, num_states): Returns: key (str): energy key """ - - # energy if only one state - if iroot == 0 and num_states == 1: - key = "energy" - - # otherwise energy with state suffix - else: - key = "energy_{}".format(iroot) - return key + # energy if only one state, other energy with state suffix + return "energy" if iroot == 0 and num_states == 1 else f"energy_{iroot}" def init_calculator(atoms, params): @@ -88,31 +82,29 @@ def init_calculator(atoms, params): params (dict): dictionary of parameters Returns: model (nn.Module): nnpotential model - en_key (str): energy key + en_key (str): energy key """ opt_state = params.get("iroot", 0) num_states = params.get("num_states", 1) en_key = get_key(iroot=opt_state, num_states=num_states) - nn_id = params['nnid'] + nn_id = params["nnid"] # get the right weightpath (either regular or cluster-mounted) # depending on which exists - weightpath = os.path.join(params['weightpath'], str(nn_id)) + weightpath = os.path.join(params["weightpath"], str(nn_id)) if not os.path.isdir(weightpath): - weightpath = os.path.join(params['mounted_weightpath'], str(nn_id)) + weightpath = os.path.join(params["mounted_weightpath"], str(nn_id)) # get the model nn_params = params.get("networkhyperparams", {}) model_type = params.get("model_type") - model = load_model(weightpath, - model_type=model_type, - params=nn_params) + model = load_model(weightpath, model_type=model_type, params=nn_params) # get and set the calculator nff_ase = NeuralFF.from_file( weightpath, - device=params.get('device', 'cuda'), + device=params.get("device", "cuda"), output_keys=[en_key], params=nn_params, model_type=model_type, @@ -153,38 +145,28 @@ def correct_hessian(restart_file, hessian): def get_output_keys(model): - atomwisereadout = model.atomwisereadout # get the names of all the attributes of the readout dict readout_attr_names = dir(atomwisereadout) # restrict to the attributes that are ModuleDicts - readout_dict_names = [name for name in readout_attr_names if - type(getattr(atomwisereadout, name)) is ModuleDict] + readout_dict_names = [name for name in readout_attr_names if type(getattr(atomwisereadout, name)) is ModuleDict] # get the ModuleDicts - readout_dicts = [getattr(atomwisereadout, name) - for name in readout_dict_names] + readout_dicts = [getattr(atomwisereadout, name) for name in readout_dict_names] # get their keys - output_keys = [key for dic in readout_dicts for key in dic.keys()] + output_keys = [key for dic in readout_dicts for key in dic] return output_keys -def get_loader(model, - nxyz_list, - num_states, - cutoff, - needs_angles=False, - base_keys=['energy']): - +def get_loader(model, nxyz_list, num_states, cutoff, needs_angles=False, base_keys=["energy"]): # base_keys = get_output_keys(model) grad_keys = [key + "_grad" for key in base_keys] ref_quant = [0] * len(nxyz_list) - ref_quant_grad = [ - np.zeros(((len(nxyz_list[0])), 3)).tolist()] * len(nxyz_list) + ref_quant_grad = [np.zeros(((len(nxyz_list[0])), 3)).tolist()] * len(nxyz_list) props = {"nxyz": nxyz_list} props.update({key: ref_quant for key in base_keys}) @@ -201,42 +183,32 @@ def get_loader(model, def check_convg(model, loader, energy_key, device, restart_file): - - mode_dic = get_modes(model=model, - loader=loader, - energy_key=energy_key, - device=device) + mode_dic = get_modes(model=model, loader=loader, energy_key=energy_key, device=device) freqs = mode_dic["freqs"] neg_freqs = list(filter(lambda x: x < 0, freqs)) num_neg = len(neg_freqs) if num_neg != 0: - print(("Found {} negative frequencies; " - "restarting optimization.").format(num_neg)) + print(f"Found {num_neg} negative frequencies; " "restarting optimization.") correct_hessian(restart_file=restart_file, hessian=mode_dic["hess"]) return False, mode_dic - else: - print(("Found no negative frequencies; " - "optimization complete.")) + print("Found no negative frequencies; " "optimization complete.") - return True, mode_dic + return True, mode_dic def get_opt_kwargs(params): - # params with the right name for max_step new_params = copy.deepcopy(params) new_params["steps"] = new_params["opt_max_step"] new_params.pop("opt_max_step") - opt_kwargs = {key: val for key, - val in new_params.items() if key in OPT_KEYS} + opt_kwargs = {key: val for key, val in new_params.items() if key in OPT_KEYS} return opt_kwargs def opt_conformer(atoms, params): - converged = False device = params.get("device", "cuda") restart_file = params.get("restart_file", RESTART_FILE) @@ -247,8 +219,7 @@ def opt_conformer(atoms, params): nn_params = params.get("networkhyperparams", {}) output_keys = nn_params.get("output_keys", ["energy"]) - for iteration in tqdm(range(max_rounds)): - + for _iteration in tqdm(range(max_rounds)): model, energy_key = init_calculator(atoms=atoms, params=params) opt_module = getattr(optimize, params.get("opt_type", "BFGS")) @@ -258,19 +229,18 @@ def opt_conformer(atoms, params): nxyz_list = [atoms.get_nxyz()] - model, loader = get_loader(model=model, - nxyz_list=nxyz_list, - num_states=num_states, - cutoff=cutoff, - needs_angles=params.get( - "needs_angles", False), - base_keys=output_keys) - - hess_converged, mode_dic = check_convg(model=model, - loader=loader, - energy_key=energy_key, - device=device, - restart_file=restart_file) + model, loader = get_loader( + model=model, + nxyz_list=nxyz_list, + num_states=num_states, + cutoff=cutoff, + needs_angles=params.get("needs_angles", False), + base_keys=output_keys, + ) + + hess_converged, mode_dic = check_convg( + model=model, loader=loader, energy_key=energy_key, device=device, restart_file=restart_file + ) if dyn_converged and hess_converged: converged = True break @@ -279,15 +249,12 @@ def opt_conformer(atoms, params): def get_confs(traj_filename, thermo_filename, num_starting_poses): - with open(thermo_filename, "r") as f: lines = f.readlines() energies = [] for line in lines: - try: + with contextlib.suppress(ValueError): energies.append(float(line.split()[2])) - except ValueError: - pass sort_idx = np.argsort(energies) sorted_steps = np.array(range(len(lines)))[sort_idx[:num_starting_poses]] @@ -300,7 +267,7 @@ def get_confs(traj_filename, thermo_filename, num_starting_poses): def get_nve_params(params): nve_params = copy.deepcopy(nve.DEFAULTNVEPARAMS) - common_keys = [key for key in nve_params.keys() if key in params] + common_keys = [key for key in nve_params if key in params] for key in common_keys: nve_params[key] = params[key] @@ -313,35 +280,29 @@ def get_nve_params(params): def md_to_conf(params): - - thermo_filename = params.get( - "thermo_filename", nve.DEFAULTNVEPARAMS["thermo_filename"]) + thermo_filename = params.get("thermo_filename", nve.DEFAULTNVEPARAMS["thermo_filename"]) if os.path.isfile(thermo_filename): os.remove(thermo_filename) nve_params = get_nve_params(params) - nxyz = np.array(params['nxyz']) + nxyz = np.array(params["nxyz"]) atoms = AtomsBatch(nxyz[:, 0], nxyz[:, 1:]) _, _ = init_calculator(atoms=atoms, params=params) - nve_instance = nve.Dynamics(atomsbatch=atoms, - mdparam=nve_params) + nve_instance = nve.Dynamics(atomsbatch=atoms, mdparam=nve_params) nve_instance.run() - thermo_filename = params.get( - "thermo_filename", nve.DEFAULTNVEPARAMS["thermo_filename"]) - traj_filename = params.get( - "traj_filename", nve.DEFAULTNVEPARAMS["traj_filename"]) + thermo_filename = params.get("thermo_filename", nve.DEFAULTNVEPARAMS["thermo_filename"]) + traj_filename = params.get("traj_filename", nve.DEFAULTNVEPARAMS["traj_filename"]) num_starting_poses = params.get("num_starting_poses", NUM_CONFS) - best_confs = get_confs(traj_filename=traj_filename, - thermo_filename=thermo_filename, - num_starting_poses=num_starting_poses) + best_confs = get_confs( + traj_filename=traj_filename, thermo_filename=thermo_filename, num_starting_poses=num_starting_poses + ) return best_confs def confs_to_opt(params, best_confs): - convg_atoms = [] energy_list = [] mode_list = [] @@ -367,10 +328,8 @@ def confs_to_opt(params, best_confs): def get_opt_and_modes(params): - best_confs = md_to_conf(params) - all_geoms, all_modes = confs_to_opt(params=params, - best_confs=best_confs) + all_geoms, all_modes = confs_to_opt(params=params, best_confs=best_confs) opt_geom = all_geoms[0] mode_dic = all_modes[0] @@ -378,7 +337,7 @@ def get_opt_and_modes(params): def get_orca_form(cc_mat, cc_freqs, n_atoms): - """ Converts cclib version of Orca's (almost orthogonalizing) matrix + """Converts cclib version of Orca's (almost orthogonalizing) matrix and mode frequencies back into the original Orca forms. Also converts frequencies from cm^{-1} into atomic units (Hartree).""" @@ -386,46 +345,44 @@ def get_orca_form(cc_mat, cc_freqs, n_atoms): pure_matrix = np.asarray(cc_mat) pure_freqs = np.asarray(cc_freqs) n_modes = len(pure_matrix[:, 0]) - n_inactive = n_atoms*3 - len(pure_matrix[:, 0]) + n_inactive = n_atoms * 3 - len(pure_matrix[:, 0]) n_tot = n_modes + n_inactive for i in range(len(pure_matrix)): - - new_col = pure_matrix[i].reshape(3*len(pure_matrix[i])) + new_col = pure_matrix[i].reshape(3 * len(pure_matrix[i])) if i == 1: new_mat = np.column_stack((old_col, new_col)) elif i > 1: new_mat = np.column_stack((new_mat, new_col)) - old_col = new_col[:] + new_col[:] matrix = np.asarray(new_mat[:]).reshape(n_tot, n_modes) - zero_col = np.asarray([[0]]*len(matrix)) - for i in range(0, n_inactive): + zero_col = np.asarray([[0]] * len(matrix)) + for _ in range(n_inactive): matrix = np.insert(matrix, [0], zero_col, axis=1) freqs = np.asarray(pure_freqs[:]) - for i in range(0, n_inactive): + for _ in range(n_inactive): freqs = np.insert(freqs, 0, 0) return matrix, freqs * CM_2_AU def get_orth(mass_vec, matrix): - """Makes orthogonalizing matrix given the outputted - (non-orthogonal) matrix from Orca. The mass_vec variable - is a list of the masses of the atoms in the molecule (must be) - in the order given to Orca when it calculated normal modes). - Note that this acts directly on the matrix outputted from Orca, - not on the cclib version that divides columns into sets of - three entries for each atom.""" + """Makes orthogonalizing matrix given the outputted + (non-orthogonal) matrix from Orca. The mass_vec variable + is a list of the masses of the atoms in the molecule (must be) + in the order given to Orca when it calculated normal modes). + Note that this acts directly on the matrix outputted from Orca, + not on the cclib version that divides columns into sets of + three entries for each atom.""" m = np.array([[mass] for mass in mass_vec]) # repeat sqrt(m) three times, one for each direction - sqrt_m_vec = np.kron(m ** 0.5, np.ones((3, 1))) + sqrt_m_vec = np.kron(m**0.5, np.ones((3, 1))) # a matrix with sqrt_m repeated K times, where # K = 3N - 5 or 3N-6 is the number of modes - sqrt_m_mat = np.kron(sqrt_m_vec, np.ones( - (1, len(sqrt_m_vec)))) + sqrt_m_mat = np.kron(sqrt_m_vec, np.ones((1, len(sqrt_m_vec)))) # orthogonalize the matrix by element-wise multiplication with 1/sqrt(m) orth = sqrt_m_mat * matrix @@ -439,7 +396,7 @@ def get_orth(mass_vec, matrix): def get_n_in(matrix): - """ Get number of inactive modes """ + """Get number of inactive modes""" n_in = 0 for entry in matrix[0]: @@ -449,7 +406,7 @@ def get_n_in(matrix): def get_disp(mass_vec, matrix, freqs, q, p, hb=1): - """Makes position and momentum displacements from + """Makes position and momentum displacements from unitless harmonic oscillator displacements and unitless momenta. Uses atomic units (hbar = 1). For different units change the value of hbar.""" @@ -472,18 +429,18 @@ def get_disp(mass_vec, matrix, freqs, q, p, hb=1): def wigner_sample(w, kt=25.7 / 1000 / 27.2, hb=1): - """ Sample unitless x and unitless p from a Wigner distribution. + """Sample unitless x and unitless p from a Wigner distribution. Takes frequency and temperature in au as inputs. Default temperature is 300 K.""" - sigma = (1/np.tanh((hb*w)/(2*kt)))**0.5/2**0.5 + sigma = (1 / np.tanh((hb * w) / (2 * kt))) ** 0.5 / 2**0.5 cov = [[sigma**2, 0], [0, sigma**2]] mean = (0, 0) x, p = np.random.multivariate_normal(mean, cov) return x, p -def classical_sample(w, kt=25.7 / 1000 / 27.2, hb=1): +def classical_sample(w, kt=25.7 / 1000 / 27.2, hb=1): sigma = (kt / (hb * w)) ** 0.5 cov = [[sigma**2, 0], [0, sigma**2]] mean = (0, 0) @@ -491,12 +448,7 @@ def classical_sample(w, kt=25.7 / 1000 / 27.2, hb=1): return x, p -def make_dx_dp(mass_vec, - cc_matrix, - cc_freqs, - kt=25.7 / 1000 / 27.2, - hb=1, - classical=False): +def make_dx_dp(mass_vec, cc_matrix, cc_freqs, kt=25.7 / 1000 / 27.2, hb=1, classical=False): """Make Wigner-sampled p and dx, where dx is the displacement about the equilibrium geometry. Takes mass vector, CClib matrix, and CClib vib freqs as inputs. @@ -518,12 +470,7 @@ def make_dx_dp(mass_vec, unitless_x = np.append(np.zeros(n_in), unitless_x) unitless_p = np.append(np.zeros(n_in), unitless_p) - dx, dp = get_disp(mass_vec=mass_vec, - matrix=matrix, - freqs=freqs, - q=unitless_x, - p=unitless_p, - hb=hb) + dx, dp = get_disp(mass_vec=mass_vec, matrix=matrix, freqs=freqs, q=unitless_x, p=unitless_p, hb=hb) # re-shape to have form of [[dx1, dy1, dz1], [dx2, dy2, dz2], ...] @@ -534,16 +481,15 @@ def make_dx_dp(mass_vec, def split_convert_xyz(xyz): - """ Splits xyz into Z, coordinates in au, and masses in au """ - coords = [(np.array(element[1:])*ANGS_2_AU).tolist() for element in xyz] - mass_vec = [PERIODICTABLE.GetAtomicWeight( - int(element[0]))*AMU_2_AU for element in xyz] + """Splits xyz into Z, coordinates in au, and masses in au""" + coords = [(np.array(element[1:]) * ANGS_2_AU).tolist() for element in xyz] + mass_vec = [PERIODICTABLE.GetAtomicWeight(int(element[0])) * AMU_2_AU for element in xyz] Z = [element[0] for element in xyz] return Z, coords, mass_vec def join_xyz(Z, coords): - """ Joins Z's and coordinates back into xyz """ + """Joins Z's and coordinates back into xyz""" out = [] for i in range(len(coords)): this_quad = [Z[i]] @@ -551,34 +497,24 @@ def join_xyz(Z, coords): out.append(this_quad) -def make_wigner_init(init_atoms, - vibdisps, - vibfreqs, - num_samples, - kt=25.7 / 1000 / 27.2, - hb=1, - classical=False): - """Generates Wigner-sampled coordinates and velocities. +def make_wigner_init(init_atoms, vibdisps, vibfreqs, num_samples, kt=25.7 / 1000 / 27.2, hb=1, classical=False): + """Generates Wigner-sampled coordinates and velocities. xyz is the xyz array at the optimized - geometry. xyz is in Angstrom, so xyz is first converted to + geometry. xyz is in Angstrom, so xyz is first converted to au, added to Wigner dx, and then - converted back to Angstrom. Velocity is in au. + converted back to Angstrom. Velocity is in au. vibdisps and vibfreqs are the CClib quantities found in the database.""" - xyz = np.concatenate([init_atoms.get_atomic_numbers().reshape(-1, 1), - init_atoms.get_positions()], axis=1) + xyz = np.concatenate([init_atoms.get_atomic_numbers().reshape(-1, 1), init_atoms.get_positions()], axis=1) atoms_list = [] for _ in range(num_samples): - assert min( - vibfreqs) >= 0, ("Negative frequencies found. " - "Geometry must not be converged.") + assert min(vibfreqs) >= 0, "Negative frequencies found. " "Geometry must not be converged." Z, opt_coords, mass_vec = split_convert_xyz(xyz) - dx, dp = make_dx_dp(mass_vec, vibdisps, vibfreqs, - kt, hb, classical=classical) - wigner_coords = ((np.asarray(opt_coords) + dx)/ANGS_2_AU).tolist() + dx, dp = make_dx_dp(mass_vec, vibdisps, vibfreqs, kt, hb, classical=classical) + wigner_coords = ((np.asarray(opt_coords) + dx) / ANGS_2_AU).tolist() nxyz = np.array(join_xyz(Z, wigner_coords)) velocity = (dp / np.array([[m] for m in mass_vec])).tolist() @@ -594,101 +530,96 @@ def make_wigner_init(init_atoms, return atoms_list -def nms_sample(params, - classical, - num_samples, - kt=25.7 / 1000 / 27.2, - hb=1): - +def nms_sample(params, classical, num_samples, kt=25.7 / 1000 / 27.2, hb=1): atoms, mode_dic = get_opt_and_modes(params) vibdisps = np.array(mode_dic["modes"]) vibdisps = vibdisps.reshape(vibdisps.shape[0], -1, 3).tolist() vibfreqs = mode_dic["freqs"] - atoms_list = make_wigner_init(init_atoms=atoms, - vibdisps=vibdisps, - vibfreqs=vibfreqs, - num_samples=num_samples, - kt=kt, - hb=hb, - classical=classical) + atoms_list = make_wigner_init( + init_atoms=atoms, + vibdisps=vibdisps, + vibfreqs=vibfreqs, + num_samples=num_samples, + kt=kt, + hb=hb, + classical=classical, + ) return atoms_list def get_modes(model, loader, energy_key, device): - batch = next(iter(loader)) batch = batch_to(batch, device) model = model.to(device) if isinstance(model, SchNet): - hessian = get_schnet_hessians(batch=batch, - model=model, - device=device, - energy_key=energy_key)[ - 0].cpu().detach().numpy() + hessian = ( + get_schnet_hessians(batch=batch, model=model, device=device, energy_key=energy_key)[0] + .cpu() + .detach() + .numpy() + ) else: raise NotImplementedError # convert to Ha / bohr^2 hessian *= (const.BOHR_RADIUS) ** 2 - hessian *= const.KCAL_TO_AU['energy'] + hessian *= const.KCAL_TO_AU["energy"] force_consts, vib_freqs, eigvec = vib_analy( - r=batch["nxyz"][:, 0].cpu().detach().numpy(), - xyz=batch["nxyz"][:, 1:].cpu().detach().numpy(), - hessian=hessian) + r=batch["nxyz"][:, 0].cpu().detach().numpy(), xyz=batch["nxyz"][:, 1:].cpu().detach().numpy(), hessian=hessian + ) # from https://gaussian.com/vib/#SECTION00036000000000000000 nxyz = batch["nxyz"].cpu().detach().numpy() - masses = np.array([PT.GetMostCommonIsotopeMass(int(z)) - for z in nxyz[:, 0]]) + masses = np.array([PT.GetMostCommonIsotopeMass(int(z)) for z in nxyz[:, 0]]) triple_mass = np.concatenate([np.array([item] * 3) for item in masses]) - red_mass = 1 / np.matmul(eigvec ** 2, 1 / triple_mass) + red_mass = 1 / np.matmul(eigvec**2, 1 / triple_mass) # un-mass weight the modes modes = [] for vec in eigvec: - col = vec / triple_mass ** 0.5 + col = vec / triple_mass**0.5 col /= np.linalg.norm(col) modes.append(col) modes = np.array(modes) - out_dic = {"nxyz": nxyz.tolist(), - "hess": hessian.tolist(), - "modes": modes.tolist(), - "red_mass": red_mass.tolist(), - "freqs": vib_freqs.tolist()} + out_dic = { + "nxyz": nxyz.tolist(), + "hess": hessian.tolist(), + "modes": modes.tolist(), + "red_mass": red_mass.tolist(), + "freqs": vib_freqs.tolist(), + } return out_dic def moi_tensor(massvec, expmassvec, xyz): # Center of Mass - com = np.sum(expmassvec.reshape(-1, 3) * - xyz.reshape(-1, 3), axis=0 - ) / np.sum(massvec) + com = np.sum(expmassvec.reshape(-1, 3) * xyz.reshape(-1, 3), axis=0) / np.sum(massvec) # xyz shifted to COM xyz_com = xyz.reshape(-1, 3) - com # Compute elements need to calculate MOI tensor - mass_xyz_com_sq_sum = np.sum( - expmassvec.reshape(-1, 3) * xyz_com ** 2, axis=0) + mass_xyz_com_sq_sum = np.sum(expmassvec.reshape(-1, 3) * xyz_com**2, axis=0) mass_xy = np.sum(massvec * xyz_com[:, 0] * xyz_com[:, 1], axis=0) mass_yz = np.sum(massvec * xyz_com[:, 1] * xyz_com[:, 2], axis=0) mass_xz = np.sum(massvec * xyz_com[:, 0] * xyz_com[:, 2], axis=0) # MOI tensor - moi = np.array([[mass_xyz_com_sq_sum[1] + mass_xyz_com_sq_sum[2], -1 * - mass_xy, -1 * mass_xz], - [-1 * mass_xy, mass_xyz_com_sq_sum[0] + - mass_xyz_com_sq_sum[2], -1 * mass_yz], - [-1 * mass_xz, -1 * mass_yz, mass_xyz_com_sq_sum[0] + - mass_xyz_com_sq_sum[1]]]) + moi = np.array( + [ + [mass_xyz_com_sq_sum[1] + mass_xyz_com_sq_sum[2], -1 * mass_xy, -1 * mass_xz], + [-1 * mass_xy, mass_xyz_com_sq_sum[0] + mass_xyz_com_sq_sum[2], -1 * mass_yz], + [-1 * mass_xz, -1 * mass_yz, mass_xyz_com_sq_sum[0] + mass_xyz_com_sq_sum[1]], + ] + ) # MOI eigenvectors and eigenvalues moi_eigval, moi_eigvec = np.linalg.eig(moi) @@ -697,7 +628,6 @@ def moi_tensor(massvec, expmassvec, xyz): def trans_rot_vec(massvec, xyz_com, moi_eigvec): - # Mass-weighted translational vectors zero_vec = np.zeros([len(massvec)]) sqrtmassvec = np.sqrt(massvec) @@ -710,23 +640,20 @@ def trans_rot_vec(massvec, xyz_com, moi_eigvec): # Mass-weighted rotational vectors big_p = np.matmul(xyz_com, moi_eigvec) - d4 = (np.repeat(big_p[:, 1], 3).reshape(-1) * - np.tile(moi_eigvec[:, 2], len(massvec)).reshape(-1) - - np.repeat(big_p[:, 2], 3).reshape(-1) * - np.tile(moi_eigvec[:, 1], len(massvec)).reshape(-1) - ) * expsqrtmassvec + d4 = ( + np.repeat(big_p[:, 1], 3).reshape(-1) * np.tile(moi_eigvec[:, 2], len(massvec)).reshape(-1) + - np.repeat(big_p[:, 2], 3).reshape(-1) * np.tile(moi_eigvec[:, 1], len(massvec)).reshape(-1) + ) * expsqrtmassvec - d5 = (np.repeat(big_p[:, 2], 3).reshape(-1) * - np.tile(moi_eigvec[:, 0], len(massvec)).reshape(-1) - - np.repeat(big_p[:, 0], 3).reshape(-1) * - np.tile(moi_eigvec[:, 2], len(massvec)).reshape(-1) - ) * expsqrtmassvec + d5 = ( + np.repeat(big_p[:, 2], 3).reshape(-1) * np.tile(moi_eigvec[:, 0], len(massvec)).reshape(-1) + - np.repeat(big_p[:, 0], 3).reshape(-1) * np.tile(moi_eigvec[:, 2], len(massvec)).reshape(-1) + ) * expsqrtmassvec - d6 = (np.repeat(big_p[:, 0], 3).reshape(-1) * - np.tile(moi_eigvec[:, 1], len(massvec)).reshape(-1) - - np.repeat(big_p[:, 1], 3).reshape(-1) * - np.tile(moi_eigvec[:, 0], len(massvec)).reshape(-1) - ) * expsqrtmassvec + d6 = ( + np.repeat(big_p[:, 0], 3).reshape(-1) * np.tile(moi_eigvec[:, 1], len(massvec)).reshape(-1) + - np.repeat(big_p[:, 1], 3).reshape(-1) * np.tile(moi_eigvec[:, 0], len(massvec)).reshape(-1) + ) * expsqrtmassvec d1_norm = d1 / np.linalg.norm(d1) d2_norm = d2 / np.linalg.norm(d2) @@ -735,28 +662,20 @@ def trans_rot_vec(massvec, xyz_com, moi_eigvec): d5_norm = d5 / np.linalg.norm(d5) d6_norm = d6 / np.linalg.norm(d6) - dx_norms = np.stack((d1_norm, - d2_norm, - d3_norm, - d4_norm, - d5_norm, - d6_norm)) + dx_norms = np.stack((d1_norm, d2_norm, d3_norm, d4_norm, d5_norm, d6_norm)) return dx_norms def vib_analy(r, xyz, hessian): - # r is the proton number of atoms # xyz is the cartesian coordinates in Angstrom # Hessian elements in atomic units (Ha/bohr^2) - massvec = np.array([PT.GetAtomicWeight(i.item()) * AMU2KG - for i in list(np.array(r.reshape(-1)).astype(int))]) + massvec = np.array([PT.GetAtomicWeight(i.item()) * AMU2KG for i in list(np.array(r.reshape(-1)).astype(int))]) expmassvec = np.repeat(massvec, 3) sqrtinvmassvec = np.divide(1.0, np.sqrt(expmassvec)) - hessian_mwc = np.einsum('i,ij,j->ij', sqrtinvmassvec, - hessian, sqrtinvmassvec) + hessian_mwc = np.einsum("i,ij,j->ij", sqrtinvmassvec, hessian, sqrtinvmassvec) hessian_eigval, hessian_eigvec = np.linalg.eig(hessian_mwc) xyz_com, moi_eigvec = moi_tensor(massvec, expmassvec, xyz) @@ -768,8 +687,7 @@ def vib_analy(r, xyz, hessian): # Projecting the T and R modes out of the hessian mwhess_proj = np.dot(P.T, hessian_mwc).dot(P) - hess_proj = np.einsum('i,ij,j->ij', 1 / sqrtinvmassvec, - mwhess_proj, 1 / sqrtinvmassvec) + hess_proj = np.einsum("i,ij,j->ij", 1 / sqrtinvmassvec, mwhess_proj, 1 / sqrtinvmassvec) hessian_eigval, hessian_eigvec = np.linalg.eigh(mwhess_proj) @@ -780,9 +698,7 @@ def vib_analy(r, xyz, hessian): hessian_eigval_abs = np.abs(hessian_eigval) - pre_vib_freq_cm_1 = np.sqrt( - hessian_eigval_abs * HA2J * 10e19) / (SPEEDOFLIGHT * 2 * np.pi * - BOHRS2ANG * 100) + pre_vib_freq_cm_1 = np.sqrt(hessian_eigval_abs * HA2J * 10e19) / (SPEEDOFLIGHT * 2 * np.pi * BOHRS2ANG * 100) vib_freq_cm_1 = pre_vib_freq_cm_1.copy() @@ -797,20 +713,17 @@ def vib_analy(r, xyz, hessian): if np.abs(freq) < 1.0: trans_rot_elms.append(i) - force_constants_J_m_2 = np.delete( - hessian_eigval * HA2J * 1e20 / (BOHRS2ANG ** 2) * AMU2KG, - trans_rot_elms) + force_constants_J_m_2 = np.delete(hessian_eigval * HA2J * 1e20 / (BOHRS2ANG**2) * AMU2KG, trans_rot_elms) proj_vib_freq_cm_1 = np.delete(vib_freq_cm_1, trans_rot_elms) proj_hessian_eigvec = np.delete(hessian_eigvec.T, trans_rot_elms, 0) - return (force_constants_J_m_2, proj_vib_freq_cm_1, proj_hessian_eigvec, - mwhess_proj, hess_proj) + return (force_constants_J_m_2, proj_vib_freq_cm_1, proj_hessian_eigvec, mwhess_proj, hess_proj) def free_rotor_moi(freqs): freq_ev = freqs * CM_TO_EV - mu = 1 / (8 * np.pi ** 2 * freq_ev) + mu = 1 / (8 * np.pi**2 * freq_ev) return mu @@ -819,45 +732,30 @@ def eff_moi(mu, b_av): return mu_prime -def low_freq_entropy(freqs, - temperature, - b_av=B_AV): +def low_freq_entropy(freqs, temperature, b_av=B_AV): mu = free_rotor_moi(freqs) mu_prime = eff_moi(mu, b_av) - arg = (8 * np.pi ** 3 * mu_prime * kB * temperature) - entropy = GAS_CONST * (1 / 2 + np.log(arg ** 0.5)) + arg = 8 * np.pi**3 * mu_prime * kB * temperature + entropy = GAS_CONST * (1 / 2 + np.log(arg**0.5)) return entropy -def high_freq_entropy(freqs, - temperature): - +def high_freq_entropy(freqs, temperature): freq_ev = freqs * CM_TO_EV exp_pos = np.exp(freq_ev / (kB * temperature)) - 1 exp_neg = 1 - np.exp(-freq_ev / (kB * temperature)) - entropy = GAS_CONST * ( - freq_ev / (kB * temperature * exp_pos) - - np.log(exp_neg) - ) + entropy = GAS_CONST * (freq_ev / (kB * temperature * exp_pos) - np.log(exp_neg)) return entropy -def mrrho_entropy(freqs, - temperature, - rotor_cutoff, - b_av, - alpha): - +def mrrho_entropy(freqs, temperature, rotor_cutoff, b_av, alpha): func = 1 / (1 + (rotor_cutoff / freqs) ** alpha) - s_r = low_freq_entropy(freqs=freqs, - b_av=b_av, - temperature=temperature) - s_v = high_freq_entropy(freqs=freqs, - temperature=temperature) + s_r = low_freq_entropy(freqs=freqs, b_av=b_av, temperature=temperature) + s_v = high_freq_entropy(freqs=freqs, temperature=temperature) new_vib_s = (func * s_v + (1 - func) * s_r).sum() old_vib_s = s_v.sum() @@ -865,63 +763,55 @@ def mrrho_entropy(freqs, return old_vib_s, new_vib_s -def mrrho_quants(ase_atoms, - freqs, - imag_cutoff=IMAG_CUTOFF, - temperature=TEMP, - pressure=PRESSURE, - rotor_cutoff=ROTOR_CUTOFF, - b_av=B_AV, - alpha=4, - flip_all_but_ts=False): - +def mrrho_quants( + ase_atoms, + freqs, + imag_cutoff=IMAG_CUTOFF, + temperature=TEMP, + pressure=PRESSURE, + rotor_cutoff=ROTOR_CUTOFF, + b_av=B_AV, + alpha=4, + flip_all_but_ts=False, +): potentialenergy = ase_atoms.get_potential_energy() if flip_all_but_ts: - print(("Flipping all imaginary frequencies except " - "the lowest one")) + print("Flipping all imaginary frequencies except " "the lowest one") abs_freqs = abs(freqs[1:]) else: abs_freqs = abs(freqs[freqs > imag_cutoff]) ens = abs_freqs * CM_TO_EV - ideal_gas = IdealGasThermo(vib_energies=ens, - potentialenergy=potentialenergy, - atoms=ase_atoms, - geometry='nonlinear', - symmetrynumber=1, - spin=0) + ideal_gas = IdealGasThermo( + vib_energies=ens, + potentialenergy=potentialenergy, + atoms=ase_atoms, + geometry="nonlinear", + symmetrynumber=1, + spin=0, + ) # full entropy including rotation, translation etc - old_entropy = (ideal_gas.get_entropy(temperature=temperature, - pressure=pressure).item()) - enthalpy = (ideal_gas.get_enthalpy(temperature=temperature) - .item()) + old_entropy = ideal_gas.get_entropy(temperature=temperature, pressure=pressure).item() + enthalpy = ideal_gas.get_enthalpy(temperature=temperature).item() # correction to vibrational entropy - out = mrrho_entropy(freqs=abs_freqs, - temperature=temperature, - rotor_cutoff=rotor_cutoff, - b_av=b_av, - alpha=alpha) + out = mrrho_entropy(freqs=abs_freqs, temperature=temperature, rotor_cutoff=rotor_cutoff, b_av=b_av, alpha=alpha) old_vib_s, new_vib_s = out final_entropy = old_entropy - old_vib_s + new_vib_s - free_energy = (enthalpy - temperature * final_entropy) + free_energy = enthalpy - temperature * final_entropy return final_entropy, enthalpy, free_energy -def convert_modes(atoms, - modes): - - masses = (atoms.get_masses().reshape(-1, 1) - .repeat(3, 1) - .reshape(1, -1)) +def convert_modes(atoms, modes): + masses = atoms.get_masses().reshape(-1, 1).repeat(3, 1).reshape(1, -1) # Multiply by 1 / sqrt(M) to be consistent with the DB - vibdisps = modes / (masses ** 0.5) + vibdisps = modes / (masses**0.5) norm = np.linalg.norm(vibdisps, axis=1).reshape(-1, 1) # Normalize @@ -935,14 +825,15 @@ def convert_modes(atoms, return vibdisps -def hessian_and_modes(ase_atoms, - imag_cutoff=IMAG_CUTOFF, - rotor_cutoff=ROTOR_CUTOFF, - temperature=TEMP, - pressure=PRESSURE, - flip_all_but_ts=False, - analytical=False): - +def hessian_and_modes( + ase_atoms, + imag_cutoff=IMAG_CUTOFF, + rotor_cutoff=ROTOR_CUTOFF, + temperature=TEMP, + pressure=PRESSURE, + flip_all_but_ts=False, + analytical=False, +): # comparison to the analytical Hessian # shows that delta=0.005 is indistinguishable # from the real result, whereas delta=0.05 @@ -952,8 +843,8 @@ def hessian_and_modes(ase_atoms, # because it might mess up the Hessian # calculation - if os.path.isdir('vib'): - shutil.rmtree('vib') + if os.path.isdir("vib"): + shutil.rmtree("vib") if analytical: hessian = analytical_hess(atoms=ase_atoms) @@ -964,42 +855,40 @@ def hessian_and_modes(ase_atoms, vib_results = vib.get_vibrations() dim = len(ase_atoms) - hessian = (vib_results.get_hessian() - .reshape(dim * 3, dim * 3) * - EV_TO_AU * - BOHR_RADIUS ** 2) + hessian = vib_results.get_hessian().reshape(dim * 3, dim * 3) * EV_TO_AU * BOHR_RADIUS**2 print(vib.get_frequencies()[:20]) - vib_results = vib_analy(r=ase_atoms.get_atomic_numbers(), - xyz=ase_atoms.get_positions(), - hessian=hessian) + vib_results = vib_analy(r=ase_atoms.get_atomic_numbers(), xyz=ase_atoms.get_positions(), hessian=hessian) _, freqs, modes, mwhess_proj, hess_proj = vib_results mwhess_proj *= AMU2KG - vibdisps = convert_modes(atoms=ase_atoms, - modes=modes) + vibdisps = convert_modes(atoms=ase_atoms, modes=modes) - mrrho_results = mrrho_quants(ase_atoms=ase_atoms, - freqs=freqs, - imag_cutoff=imag_cutoff, - temperature=temperature, - pressure=pressure, - rotor_cutoff=rotor_cutoff, - flip_all_but_ts=flip_all_but_ts) + mrrho_results = mrrho_quants( + ase_atoms=ase_atoms, + freqs=freqs, + imag_cutoff=imag_cutoff, + temperature=temperature, + pressure=pressure, + rotor_cutoff=rotor_cutoff, + flip_all_but_ts=flip_all_but_ts, + ) entropy, enthalpy, free_energy = mrrho_results imgfreq = len(freqs[freqs < 0]) - results = {"vibdisps": vibdisps.tolist(), - "vibfreqs": freqs.tolist(), - "modes": modes, - "hessianmatrix": hessian.tolist(), - "mwhess_proj": mwhess_proj.tolist(), - "hess_proj": hess_proj.tolist(), - "imgfreq": imgfreq, - "freeenergy": free_energy * EV_TO_AU, - "enthalpy": enthalpy * EV_TO_AU, - "entropy": entropy * temperature * EV_TO_AU} + results = { + "vibdisps": vibdisps.tolist(), + "vibfreqs": freqs.tolist(), + "modes": modes, + "hessianmatrix": hessian.tolist(), + "mwhess_proj": mwhess_proj.tolist(), + "hess_proj": hess_proj.tolist(), + "imgfreq": imgfreq, + "freeenergy": free_energy * EV_TO_AU, + "enthalpy": enthalpy * EV_TO_AU, + "entropy": entropy * temperature * EV_TO_AU, + } return results diff --git a/nff/md/npt.py b/nff/md/npt.py index 25a1acef..054d1ba1 100644 --- a/nff/md/npt.py +++ b/nff/md/npt.py @@ -1,6 +1,7 @@ import copy import math import os +from typing import Optional import numpy as np from ase import units @@ -57,12 +58,8 @@ def __init__( self.T = temperature * units.kB - # initial Maxwell-Boltmann temperature for atoms - if T_init is not None: - # convert units - T_init = T_init * units.kB - else: - T_init = 2 * self.T + # initial Maxwell-Boltzmann temperature for atoms + T_init = T_init * units.kB if T_init is not None else 2 * self.T MaxwellBoltzmannDistribution(self.atoms, T_init) Stationary(self.atoms) @@ -133,12 +130,8 @@ def __init__( self.T = temperature * units.kB - # initial Maxwell-Boltmann temperature for atoms - if T_init is not None: - # convert units - T_init = T_init * units.kB - else: - T_init = 2 * self.T + # initial Maxwell-Boltzmann temperature for atoms + T_init = T_init * units.kB if T_init is not None else 2 * self.T MaxwellBoltzmannDistribution(self.atoms, T_init) Stationary(self.atoms) @@ -222,12 +215,8 @@ def __init__( self.nbr_update_period = nbr_update_period - # initial Maxwell-Boltmann temperature for atoms - if maxwell_temp is not None: - # convert units - maxwell_temp = maxwell_temp * units.kB - else: - maxwell_temp = 2 * self.T + # initial Maxwell-Boltzmann temperature for atoms + maxwell_temp = maxwell_temp * units.kB if maxwell_temp is not None else 2 * self.T MaxwellBoltzmannDistribution(self.atoms, maxwell_temp) Stationary(self.atoms) @@ -306,11 +295,11 @@ def __init__( freq_thermostat_per_fs: float = 0.01, freq_barostat_per_fs: float = 0.0005, num_chains: int = 10, - maxwell_temp: float = None, - trajectory: str = None, - logfile: str = None, + maxwell_temp: Optional[float] = None, + trajectory: Optional[str] = None, + logfile: Optional[str] = None, loginterval: int = 1, - max_steps: int = None, + max_steps: Optional[int] = None, nbr_update_period: int = 10, append_trajectory: bool = True, **kwargs, @@ -412,7 +401,7 @@ def step(self): scale_coords = np.exp(delta_eps) scale_volume = np.exp(self.d * delta_eps) - V_t = V * scale_volume + V * scale_volume h_t = h * scale_coords # half time for all velocities @@ -532,11 +521,11 @@ def __init__( freq_thermostat_per_fs: float = 0.01, freq_barostat_per_fs: float = 0.0005, num_chains: int = 10, - maxwell_temp: float = None, - trajectory: str = None, - logfile: str = None, + maxwell_temp: Optional[float] = None, + trajectory: Optional[str] = None, + logfile: Optional[str] = None, loginterval: int = 1, - max_steps: int = None, + max_steps: Optional[int] = None, nbr_update_period: int = 10, append_trajectory: bool = True, **kwargs, @@ -666,7 +655,7 @@ def step(self): h0, ) # not sure if they matrix multiplication or not # eqs (E1, E2) - while np.isclose(1.0, np.linalg.det(h0_t), atol=1e-6) == False: + while np.isclose(1.0, np.linalg.det(h0_t), atol=1e-6) is False: # print(np.linalg.det(h0_t), h0_t) if np.linalg.det(h0_t) is np.nan: print("Failed to enforce det=1 of unit lattice vectors!") diff --git a/nff/md/nve.py b/nff/md/nve.py index 283626ab..46efac23 100644 --- a/nff/md/nve.py +++ b/nff/md/nve.py @@ -86,7 +86,10 @@ def __init__( interval=self.mdparam["save_frequency"], ) - def check_restart(self): + def check_restart(self) -> int: + """Check if the MD path is being restarted from an existing traj file and adjust the number of + steps accordingly. + """ if os.path.exists(self.mdparam["traj_filename"]): new_atoms = Trajectory(self.mdparam["traj_filename"])[-1] @@ -129,8 +132,7 @@ def check_restart(self): return self.steps - else: - return self.steps + return self.steps def setup_restart(self, restart_param): """If you want to restart a simulations with predfined mdparams but @@ -147,7 +149,6 @@ def setup_restart(self, restart_param): Args: restart_param (dict): dictionary to contains restart paramsters and file paths """ - if restart_param["thermo_filename"] == self.mdparam["thermo_filename"]: raise ValueError( "{} is also used, \ @@ -189,19 +190,23 @@ def setup_restart(self, restart_param): self.mdparam["steps"] = restart_param["steps"] - def run(self): + def run(self) -> None: + """Run the MD simulation for the specified number of steps. If the stability_check + parameter is set to True, the simulation will run until the temperature is within + reasonable bounds. The neighbor list is updated every nbr_list_update_freq steps. + """ epochs = int(self.steps // self.mdparam["nbr_list_update_freq"]) # In case it had neighbors that didn't include the cutoff skin, # for example, it's good to update the neighbor list here self.atomsbatch.update_nbr_list() if self.mdparam.get("stability_check", False): - for step in range(epochs): + for _step in range(epochs): T = self.atomsbatch.get_batch_kinetic_energy() / (1.5 * units.kB * self.atomsbatch.num_atoms) if ( - (T > (10 * self.mdparam["thermostat_params"]["temperature"] / units.kB)).any() - or (T < 1e-1).any() - and self.mdparam.get("stability_check", False) + ((10 * self.mdparam["thermostat_params"]["temperature"] / units.kB) < T).any() + or ((T < 1e-1).any() + and self.mdparam.get("stability_check", False)) ): break @@ -215,7 +220,7 @@ def run(self): self.atomsbatch.update_nbr_list() else: - for step in range(epochs): + for _step in range(epochs): self.integrator.run(self.mdparam["nbr_list_update_freq"]) # # unwrap coordinates if mol_idx is defined diff --git a/nff/md/nvt.py b/nff/md/nvt.py index 0a7e37e3..a2d6b65e 100644 --- a/nff/md/nvt.py +++ b/nff/md/nvt.py @@ -2,17 +2,56 @@ import math import os import pickle +import warnings +from typing import Optional +import ase import numpy as np from ase import units from ase.md.logger import MDLogger from ase.md.md import MolecularDynamics from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary, ZeroRotation from ase.optimize.optimize import Dynamics +from packaging.version import Version, parse from tqdm import tqdm from nff.io.ase import AtomsBatch +ASE_VERSION = parse(ase.__version__) +ASE_CUTOFF_VERSION = parse("3.23.0") + + +def run_with_ase_check( + integrator: MolecularDynamics, + steps_per_epoch: int, + ase_ver: Version = ASE_VERSION, + ase_cut: Version = ASE_CUTOFF_VERSION, +) -> None: + """Run the ASE dynamics with a check for the ASE version. ASE v3.23 has updated + the `run` method in the `Dynamics` class, so we need to check for the version + and run the appropriate method. This function will be deprecated in the future, + as ASE v3.23 will be the minimum version required for nff, and contains a warning + to that effect. + Args: + integrator (MolecularDynamics): ASE integrator object or thermostat like NoseHoover + steps_per_epoch (int): number of steps per epoch + ase_ver (Version): ASE version + ase_cut (Version): ASE cutoff version where Dynamics approach was changed + Raises: + DeprecationWarning: if the ASE version is less than 3.23 + """ + if ase_ver < ase_cut: + warnings.warn( + f"ASE version {ase_ver} uses outdated `run` method in" + " its `Dynamics` class. Please update to a newer version of ASE as this" + " method will be deprecated in nff in the future.", + DeprecationWarning, + stacklevel=2, + ) + Dynamics.run(integrator) + else: + Dynamics.run(integrator, steps=steps_per_epoch) + class NoseHoover(MolecularDynamics): def __init__( @@ -64,7 +103,7 @@ def __init__( self.nbr_update_period = nbr_update_period - # initial Maxwell-Boltmann temperature for atoms + # initial Maxwell-Boltzmann temperature for atoms if maxwell_temp is None: maxwell_temp = temperature @@ -90,10 +129,8 @@ def remove_constrained_vel(self, atoms): has_keys = True if not has_keys: print( - ( - "WARNING: velocity not set to zero for any atoms in constraint " - "%s; do not know how to find its fixed indices." % constraint - ) + "WARNING: velocity not set to zero for any atoms in constraint " + "%s; do not know how to find its fixed indices." % constraint ) if not fixed_idx: @@ -155,7 +192,7 @@ def run(self, steps=None): for _ in tqdm(range(epochs)): self.max_steps += steps_per_epoch - Dynamics.run(self) + run_with_ase_check(self, steps_per_epoch) self.atoms.update_nbr_list() @@ -247,82 +284,6 @@ def step(self): self.p_zeta += 0.5 * dpzeta_dt * self.dt -# Does anyone use this? -# class NoseHooverChainsBiased(NoseHooverChain): -# def __init__(self, -# atoms, -# timestep, -# temperature, -# ttime, -# num_chains, -# maxwell_temp=None, -# trajectory=None, -# logfile=None, -# loginterval=1, -# max_steps=None, -# nbr_update_period=20, -# append_trajectory=True, -# **kwargs): - -# NoseHooverChain.__init__(self, -# atoms=atoms, -# timestep=timestep, -# temperature=temperature, -# ttime=ttime, -# num_chains=num_chains, -# maxwell_temp=maxwell_temp, -# trajectory=trajectory, -# logfile=logfile, -# loginterval=loginterval, -# max_steps=max_steps, -# nbr_update_period=nbr_update_period, -# append_trajectory=append_trajectory, -# **kwargs) - - -# def update_bias(self): -# # update the bias function if necessary, e.g., add aconfiguration to MetaD -# self.atoms.calc.update(self) - -# def irun(self): -# # run the algorithm max_steps reached -# while self.nsteps < self.max_steps: - -# # compute the next step -# self.step() -# self.nsteps += 1 -# self.update_bias() - -# # log the step -# self.log() -# self.call_observers() - - -# def run(self, steps=None): -# if steps is None: -# steps = self.num_steps - -# epochs = math.ceil(steps / self.nbr_update_period) -# # number of steps in between nbr updates -# steps_per_epoch = int(steps / epochs) -# # maximum number of steps starts at `steps_per_epoch` -# # and increments after every nbr list update - -# self.atoms.update_nbr_list() - -# # compute initial structure and log the first step -# if self.nsteps == 0: -# self.update_bias() -# self.atoms.get_forces() -# self.log() -# self.call_observers() - -# for _ in tqdm(range(epochs)): -# self.max_steps += steps_per_epoch -# self.irun() -# self.atoms.update_nbr_list() - - class Langevin(MolecularDynamics): def __init__( self, @@ -330,7 +291,7 @@ def __init__( timestep: float, temperature: float, friction_per_ps: float = 1.0, - maxwell_temp: float = None, + maxwell_temp: Optional[float] = None, random_seed=None, trajectory=None, logfile=None, @@ -341,16 +302,16 @@ def __init__( **kwargs, ): # Random Number Generator - if random_seed == None: + if random_seed is None: random_seed = np.random.randint(2147483647) if type(random_seed) is int: np.random.seed(random_seed) - print("THE RANDOM NUMBER SEED WAS: %i" % (random_seed)) + print(f"THE RANDOM NUMBER SEED WAS: {random_seed}") else: try: np.random.set_state(random_seed) - except: - raise ValueError("\tThe provided seed was neither an int nor a state of numpy random") + except BaseException as e: + raise ValueError("\tThe provided seed was neither an int nor a state of numpy random") from e if os.path.isfile(str(trajectory)): os.remove(trajectory) @@ -409,10 +370,8 @@ def remove_constrained_vel(self, atoms): has_keys = True if not has_keys: print( - ( - "WARNING: velocity not set to zero for any atoms in constraint " - "%s; do not know how to find its fixed indices." % constraint - ) + "WARNING: velocity not set to zero for any atoms in constraint " + "%s; do not know how to find its fixed indices." % constraint ) if not fixed_idx: @@ -461,7 +420,7 @@ def run(self, steps=None): for _ in tqdm(range(epochs)): self.max_steps += steps_per_epoch - Dynamics.run(self) + run_with_ase_check(self, steps_per_epoch) x = self.atoms.get_positions(wrap=True) self.atoms.set_positions(x) @@ -485,7 +444,7 @@ def __init__( timestep: float, temperature: float, friction_per_ps: float = 1.0, - maxwell_temp: float = None, + maxwell_temp: Optional[float] = None, random_seed=None, trajectory=None, logfile=None, @@ -499,16 +458,16 @@ def __init__( os.remove(trajectory) # Random Number Generator - if random_seed == None: + if random_seed is None: random_seed = np.random.randint(2147483647) if type(random_seed) is int: - np.random.seed(radnom_seed) - print("THE RANDOM NUMBER SEED WAS: %i" % (random_seed)) + np.random.seed(random_seed) + print(f"THE RANDOM NUMBER SEED WAS: {random_seed}") else: try: np.random.set_state(random_seed) - except: - raise ValueError("\tThe provided seed was neither an int nor a state of numpy random") + except BaseException as e: + raise ValueError("\tThe provided seed was neither an int nor a state of numpy random") from e MolecularDynamics.__init__( self, @@ -549,9 +508,7 @@ def __init__( self.nbr_update_period = nbr_update_period # initial Maxwell-Boltmann temperature for atoms - if maxwell_temp is not None: - maxwell_temp = maxwell_temp - else: + if maxwell_temp is None: maxwell_temp = self.T # intialize system momentum @@ -586,10 +543,8 @@ def remove_constrained_vel(self, atoms): has_keys = True if not has_keys: print( - ( - "WARNING: velocity not set to zero for any atoms in constraint " - "%s; do not know how to find its fixed indices." % constraint - ) + "WARNING: velocity not set to zero for any atoms in constraint " + "%s; do not know how to find its fixed indices." % constraint ) if not fixed_idx: @@ -650,7 +605,7 @@ def run(self, steps=None): for _ in tqdm(range(epochs)): self.max_steps += steps_per_epoch - Dynamics.run(self) + run_with_ase_check(self, steps_per_epoch) self.atoms.update_nbr_list() momenta = [] @@ -676,7 +631,7 @@ def __init__( timestep: float, temperature: float, relaxation_const: float = 100.0, - maxwell_temp: float = None, + maxwell_temp: Optional[float] = None, random_seed=None, trajectory=None, logfile=None, @@ -687,16 +642,16 @@ def __init__( **kwargs, ): # Random Number Generator - if random_seed == None: + if random_seed is None: random_seed = np.random.randint(2147483647) if type(random_seed) is int: np.random.seed(random_seed) - print("THE RANDOM NUMBER SEED WAS: %i" % (random_seed)) + print(f"THE RANDOM NUMBER SEED WAS: {random_seed}") else: try: np.random.set_state(random_seed) - except: - raise ValueError("\tThe provided seed was neither an int nor a state of numpy random") + except BaseException as e: + raise ValueError("\tThe provided seed was neither an int nor a state of numpy random") from e if os.path.isfile(str(trajectory)): os.remove(trajectory) @@ -753,10 +708,8 @@ def remove_constrained_vel(self, atoms): has_keys = True if not has_keys: print( - ( - "WARNING: velocity not set to zero for any atoms in constraint " - "%s; do not know how to find its fixed indices." % constraint - ) + "WARNING: velocity not set to zero for any atoms in constraint " + "%s; do not know how to find its fixed indices." % constraint ) if not fixed_idx: @@ -818,7 +771,7 @@ def run(self, steps=None): for _ in tqdm(range(epochs)): self.max_steps += steps_per_epoch - Dynamics.run(self) + run_with_ase_check(self, steps_per_epoch) self.atoms.update_nbr_list() Stationary(self.atoms) ZeroRotation(self.atoms) @@ -906,7 +859,7 @@ def run(self, steps=None): # set hydrogen mass to 2 AMU (deuterium, following Grimme's mTD approach) self.increase_h_mass() - Dynamics.run(self) + run_with_ase_check(self, steps_per_epoch) # reset the masses self.decrease_h_mass() @@ -1050,7 +1003,7 @@ def run(self, steps=None): for _ in range(epochs): self.max_steps += steps_per_epoch - Dynamics.run(self) + run_with_ase_check(self, steps_per_epoch) self.atoms.update_nbr_list() @@ -1094,7 +1047,7 @@ def __call__(self): epot = self.atoms.get_potential_energy() temp = self.atoms.get_batch_T() - for i, this_ek in enumerate(ekin): + for i, _this_ek in enumerate(ekin): this_epot = epot[i] this_temp = float(temp[i]) dat += (this_epot, this_temp) diff --git a/nff/md/nvt_ax.py b/nff/md/nvt_ax.py index 8a4c323f..96c9466d 100644 --- a/nff/md/nvt_ax.py +++ b/nff/md/nvt_ax.py @@ -1,40 +1,41 @@ -import copy -import os -import numpy as np import math +import os +import numpy as np +from ase import units from ase.md.md import MolecularDynamics +from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary, ZeroRotation from ase.optimize.optimize import Dynamics -from ase import units -from ase.md.velocitydistribution import (MaxwellBoltzmannDistribution, - Stationary, ZeroRotation) class NoseHoover(MolecularDynamics): - def __init__(self, - atoms, - timestep, - temperature, - ttime, - maxwell_temp=None, - trajectory=None, - logfile=None, - loginterval=1, - max_steps=None, - nbr_update_period=20, - append_trajectory=True, - **kwargs): - + def __init__( + self, + atoms, + timestep, + temperature, + ttime, + maxwell_temp=None, + trajectory=None, + logfile=None, + loginterval=1, + max_steps=None, + nbr_update_period=20, + append_trajectory=True, + **kwargs, + ): if os.path.isfile(trajectory): os.remove(trajectory) - MolecularDynamics.__init__(self, - atoms=atoms, - timestep=timestep * units.fs, - trajectory=trajectory, - logfile=logfile, - loginterval=loginterval, - append_trajectory=append_trajectory) + MolecularDynamics.__init__( + self, + atoms=atoms, + timestep=timestep * units.fs, + trajectory=trajectory, + logfile=logfile, + loginterval=loginterval, + append_trajectory=append_trajectory, + ) # Initialize simulation parameters # convert units @@ -48,33 +49,26 @@ def __init__(self, # no rotation or translation, so target kinetic energy is 1/2 (3N - 6) kT self.targeEkin = 0.5 * (3.0 * self.Natom - 6) * self.T - self.Q = (3.0 * self.Natom - 6) * self.T * (self.ttime * self.dt)**2 + self.Q = (3.0 * self.Natom - 6) * self.T * (self.ttime * self.dt) ** 2 self.zeta = 0.0 self.num_steps = max_steps self.n_steps = 0 self.nbr_update_period = nbr_update_period self.max_steps = 0 - # initial Maxwell-Boltmann temperature for atoms - if maxwell_temp is not None: - # convert units - maxwell_temp = maxwell_temp * units.kB - else: - maxwell_temp = 2 * self.T - + # initial Maxwell-Boltzmann temperature for atoms + maxwell_temp = maxwell_temp * units.kB if maxwell_temp is not None else 2 * self.T MaxwellBoltzmannDistribution(self.atoms, maxwell_temp) Stationary(self.atoms) ZeroRotation(self.atoms) def step(self): - # get current acceleration and velocity: accel = self.atoms.get_forces() / self.atoms.get_masses().reshape(-1, 1) vel = self.atoms.get_velocities() # make full step in position - x = self.atoms.get_positions() + vel * self.dt + \ - (accel - self.zeta * vel) * (0.5 * self.dt ** 2) + x = self.atoms.get_positions() + vel * self.dt + (accel - self.zeta * vel) * (0.5 * self.dt**2) self.atoms.set_positions(x) # record current velocities @@ -89,22 +83,18 @@ def step(self): accel = f / self.atoms.get_masses().reshape(-1, 1) # make a half step in self.zeta - self.zeta = self.zeta + 0.5 * self.dt * \ - (1/self.Q) * (KE_0 - self.targeEkin) + self.zeta = self.zeta + 0.5 * self.dt * (1 / self.Q) * (KE_0 - self.targeEkin) # make another halfstep in self.zeta - self.zeta = self.zeta + 0.5 * self.dt * \ - (1/self.Q) * (self.atoms.get_kinetic_energy() - self.targeEkin) + self.zeta = self.zeta + 0.5 * self.dt * (1 / self.Q) * (self.atoms.get_kinetic_energy() - self.targeEkin) # make another half step in velocity - vel = (self.atoms.get_velocities() + 0.5 * self.dt * accel) / \ - (1 + 0.5 * self.dt * self.zeta) + vel = (self.atoms.get_velocities() + 0.5 * self.dt * accel) / (1 + 0.5 * self.dt * self.zeta) self.atoms.set_velocities(vel) return f def run(self, steps=None): - if steps is None: steps = self.num_steps @@ -120,26 +110,23 @@ def run(self, steps=None): Dynamics.run(self) self.atoms.update_nbr_list() + class NoseHooverChain(MolecularDynamics): - def __init__(self, - atoms, - timestep, - temperature, - ttime, - num_chains, - maxwell_temp, - trajectory=None, - logfile=None, - loginterval=1, - max_steps=None, - **kwargs): - - MolecularDynamics.__init__(self, - atoms, - timestep * units.fs, - trajectory, - logfile, - loginterval) + def __init__( + self, + atoms, + timestep, + temperature, + ttime, + num_chains, + maxwell_temp, + trajectory=None, + logfile=None, + loginterval=1, + max_steps=None, + **kwargs, + ): + MolecularDynamics.__init__(self, atoms, timestep * units.fs, trajectory, logfile, loginterval) # Initialize simulation parameters @@ -151,61 +138,57 @@ def __init__(self, # in units of fs: self.ttime = ttime - self.Q = 2 * np.array([self.N_dof * self.T * (self.ttime * self.dt)**2, - *[self.T * (self.ttime * self.dt)**2]*(num_chains-1)]) + self.Q = 2 * np.array( + [ + self.N_dof * self.T * (self.ttime * self.dt) ** 2, + *[self.T * (self.ttime * self.dt) ** 2] * (num_chains - 1), + ] + ) # no rotation or translation, so target kinetic energy is 3/2 N kT - 6 - self.targeEkin = 1/2 * self.N_dof * self.T + self.targeEkin = 1 / 2 * self.N_dof * self.T # self.zeta = np.array([0.0]*num_chains) - self.p_zeta = np.array([0.0]*num_chains) + self.p_zeta = np.array([0.0] * num_chains) self.num_steps = max_steps self.n_steps = 0 self.max_steps = 0 - # initial Maxwell-Boltmann temperature for atoms - if maxwell_temp is not None: - # convert units - maxwell_temp = maxwell_temp * units.kB - else: - maxwell_temp = 2 * self.T + # initial Maxwell-Boltzmann temperature for atoms + maxwell_temp = maxwell_temp * units.kB if maxwell_temp is not None else 2 * self.T MaxwellBoltzmannDistribution(self.atoms, maxwell_temp) def get_zeta_accel(self): - - p0_dot = 2 * (self.atoms.get_kinetic_energy() - self.targeEkin) - \ - self.p_zeta[0]*self.p_zeta[1] / self.Q[1] - p_middle_dot = self.p_zeta[:-2]**2 / self.Q[:-2] - \ - self.T - self.p_zeta[1:-1] * self.p_zeta[2:]/self.Q[2:] - p_last_dot = self.p_zeta[-2]**2 / self.Q[-2] - self.T + p0_dot = 2 * (self.atoms.get_kinetic_energy() - self.targeEkin) - self.p_zeta[0] * self.p_zeta[1] / self.Q[1] + p_middle_dot = self.p_zeta[:-2] ** 2 / self.Q[:-2] - self.T - self.p_zeta[1:-1] * self.p_zeta[2:] / self.Q[2:] + p_last_dot = self.p_zeta[-2] ** 2 / self.Q[-2] - self.T p_dot = np.array([p0_dot, *p_middle_dot, p_last_dot]) return p_dot / self.Q def half_step_v_zeta(self): - v = self.p_zeta / self.Q accel = self.get_zeta_accel() - v_half = v + 1/2 * accel * self.dt + v_half = v + 1 / 2 * accel * self.dt return v_half def half_step_v_system(self): - v = self.atoms.get_velocities() accel = self.atoms.get_forces() / self.atoms.get_masses().reshape(-1, 1) accel -= v * self.p_zeta[0] / self.Q[0] - v_half = v + 1/2 * accel * self.dt + v_half = v + 1 / 2 * accel * self.dt return v_half def full_step_positions(self): - accel = self.atoms.get_forces() / self.atoms.get_masses().reshape(-1, 1) - new_positions = self.atoms.get_positions() + self.atoms.get_velocities() * self.dt + \ - (accel - self.p_zeta[0] / self.Q[0])*(self.dt)**2 + new_positions = ( + self.atoms.get_positions() + + self.atoms.get_velocities() * self.dt + + (accel - self.p_zeta[0] / self.Q[0]) * (self.dt) ** 2 + ) return new_positions def step(self): - new_positions = self.full_step_positions() self.atoms.set_positions(new_positions) @@ -217,14 +200,12 @@ def step(self): v_full_zeta = self.half_step_v_zeta() accel = self.atoms.get_forces() / self.atoms.get_masses().reshape(-1, 1) - v_full_system = (v_half_system + 1/2 * accel * self.dt) / \ - (1 + 0.5 * self.dt * v_full_zeta[0]) + v_full_system = (v_half_system + 1 / 2 * accel * self.dt) / (1 + 0.5 * self.dt * v_full_zeta[0]) self.atoms.set_velocities(v_full_system) self.p_zeta = v_full_zeta * self.Q def run(self, steps=None): - if steps is None: steps = self.num_steps diff --git a/nff/md/special_thermostats.py b/nff/md/special_thermostats.py index 662fa71a..d28e8815 100644 --- a/nff/md/special_thermostats.py +++ b/nff/md/special_thermostats.py @@ -1,51 +1,47 @@ import os -import numpy as np -import copy -import math -import pickle -from tqdm import tqdm -from ase.optimize.optimize import Dynamics -from ase.md.md import MolecularDynamics -from ase.md.logger import MDLogger +import numpy as np from ase import units -from ase.md.velocitydistribution import (MaxwellBoltzmannDistribution, - Stationary, ZeroRotation) - -from nff.io.ase import AtomsBatch +from ase.md.md import MolecularDynamics +from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary, ZeroRotation +from ase.optimize.optimize import Dynamics +from tqdm import tqdm class TempRamp(MolecularDynamics): - def __init__(self, - atoms, - timestep, - target_temp, - num_steps, - maxwell_temp=None, - trajectory=None, - logfile=None, - loginterval=1, - nbr_update_period=20, - append_trajectory=True, - **kwargs): - + def __init__( + self, + atoms, + timestep, + target_temp, + num_steps, + maxwell_temp=None, + trajectory=None, + logfile=None, + loginterval=1, + nbr_update_period=20, + append_trajectory=True, + **kwargs, + ): if os.path.isfile(str(trajectory)): os.remove(trajectory) - MolecularDynamics.__init__(self, - atoms=atoms, - timestep=timestep * units.fs, - trajectory=trajectory, - logfile=logfile, - loginterval=loginterval, - append_trajectory=append_trajectory) + MolecularDynamics.__init__( + self, + atoms=atoms, + timestep=timestep * units.fs, + trajectory=trajectory, + logfile=logfile, + loginterval=loginterval, + append_trajectory=append_trajectory, + ) # Initialize simulation parameters # convert units self.dt = timestep * units.fs self.Natom = len(atoms) - + if self.atoms.pbc: self.activeDoF = (3 * self.Natom) - len(self.atoms.constraints) else: @@ -57,8 +53,10 @@ def __init__(self, if self.num_steps < self.nbr_update_period: print("WARNING: Ramp will be performed in a single rescaling step!") if self.num_steps % self.nbr_update_period != 0: - print("WARNING: Number of steps is adjusted to " - f"{self.num_steps + self.nbr_update_period - (self.num_steps % self.nbr_update_period)}!") + print( + "WARNING: Number of steps is adjusted to " + f"{self.num_steps + self.nbr_update_period - (self.num_steps % self.nbr_update_period)}!" + ) # initial Maxwell-Boltmann temperature for atoms if maxwell_temp is not None: @@ -66,14 +64,15 @@ def __init__(self, self.start_temp = maxwell_temp else: self.start_temp = (2 * self.atoms.get_kinetic_energy()) / (units.kB * self.activeDoF) - + self.num_epochs = int(np.ceil(self.num_steps / self.nbr_update_period)) - self.ramp_targets = np.linspace(self.start_temp, target_temp, - num = self.num_epochs + 1, endpoint=True)[1:] + self.ramp_targets = np.linspace(self.start_temp, target_temp, num=self.num_epochs + 1, endpoint=True)[1:] self.max_steps = 0 - print(f"Info: Temperature is adjusted {self.num_epochs} times" - "in {self.ramp_targets[1] - self.ramp_targets[0]}K increments.") - + print( + f"Info: Temperature is adjusted {self.num_epochs} times" + "in {self.ramp_targets[1] - self.ramp_targets[0]}K increments." + ) + self.remove_constrained_vel(atoms) Stationary(self.atoms) ZeroRotation(self.atoms) @@ -87,16 +86,17 @@ def remove_constrained_vel(self, atoms): fixed_idx = [] for constraint in constraints: has_keys = False - keys = ['idx', 'indices', 'index'] + keys = ["idx", "indices", "index"] for key in keys: if hasattr(constraint, key): - val = np.array(getattr(constraint, key) - ).reshape(-1).tolist() + val = np.array(getattr(constraint, key)).reshape(-1).tolist() fixed_idx += val has_keys = True if not has_keys: - print(("WARNING: velocity not set to zero for any atoms in constraint " - "%s; do not know how to find its fixed indices." % constraint)) + print( + "WARNING: velocity not set to zero for any atoms in constraint " + "%s; do not know how to find its fixed indices." % constraint + ) if not fixed_idx: return @@ -107,17 +107,15 @@ def remove_constrained_vel(self, atoms): self.atoms.set_velocities(vel) def step(self): - # get current acceleration and velocity: - accel = (self.atoms.get_forces() / - self.atoms.get_masses().reshape(-1, 1)) + accel = self.atoms.get_forces() / self.atoms.get_masses().reshape(-1, 1) vel = self.atoms.get_velocities() - + # make half a step in velocity vel_half = vel + 0.5 * self.dt * accel # make full step in position - x = self.atoms.get_positions() + vel_half * self.dt + x = self.atoms.get_positions() + vel_half * self.dt self.atoms.set_positions(x) # new accelerations @@ -125,7 +123,7 @@ def step(self): accel = f / self.atoms.get_masses().reshape(-1, 1) # make another half step in velocity - vel = vel_half + 0.5 * self.dt * accel + vel = vel_half + 0.5 * self.dt * accel self.atoms.set_velocities(vel) self.remove_constrained_vel(self.atoms) @@ -133,18 +131,16 @@ def step(self): return f def run(self): - self.atoms.update_nbr_list() for ii in tqdm(range(self.num_epochs)): self.max_steps += self.nbr_update_period Dynamics.run(self) - - curr_temp = (2.*self.atoms.get_kinetic_energy() / - (units.kB * self.activeDoF)) + + curr_temp = 2.0 * self.atoms.get_kinetic_energy() / (units.kB * self.activeDoF) curr_target = self.ramp_targets[ii] - rescale_fac = np.sqrt(curr_target/curr_temp) + rescale_fac = np.sqrt(curr_target / curr_temp) new_vel = rescale_fac * self.atom.get_velocities() self.atoms.set_velocities(new_vel) - - self.atoms.update_nbr_list() \ No newline at end of file + + self.atoms.update_nbr_list() diff --git a/nff/md/tully/ab_dynamics.py b/nff/md/tully/ab_dynamics.py index 4ece5df4..130dc512 100644 --- a/nff/md/tully/ab_dynamics.py +++ b/nff/md/tully/ab_dynamics.py @@ -3,37 +3,33 @@ """ import argparse -import shutil -import os +import copy import math +import os +import shutil + import numpy as np from ase import Atoms -import copy -from nff.md.tully.dynamics import (NeuralTully, - TULLY_LOG_FILE, - TULLY_SAVE_FILE) -from nff.md.tully.io import load_json, coords_to_xyz from nff.md.tully.ab_io import get_results as ab_results +from nff.md.tully.dynamics import TULLY_LOG_FILE, TULLY_SAVE_FILE, NeuralTully +from nff.md.tully.io import coords_to_xyz, load_json from nff.utils import constants as const def load_params(file): all_params = load_json(file) - all_params['nacv_details'] = {**all_params, - **all_params['nacv_details']} - all_params['grad_details'] = {**all_params, - **all_params['grad_details']} + all_params["nacv_details"] = {**all_params, **all_params["nacv_details"]} + all_params["grad_details"] = {**all_params, **all_params["grad_details"]} return all_params def make_atoms(all_params): - vel = np.array(all_params['velocities']) + vel = np.array(all_params["velocities"]) nxyz = coords_to_xyz(all_params["coords"]) - atoms = Atoms(nxyz[:, 0], - positions=nxyz[:, 1:]) + atoms = Atoms(nxyz[:, 0], positions=nxyz[:, 1:]) atoms.set_velocities(vel) atoms_list = [atoms] @@ -42,22 +38,23 @@ def make_atoms(all_params): class AbTully(NeuralTully): - def __init__(self, - charge, - grad_config, - nacv_config, - grad_details, - nacv_details, - atoms_list, - num_states, - initial_surf, - dt, - max_time, - elec_substeps, - decoherence, - hop_eqn, - **kwargs): - + def __init__( + self, + charge, + grad_config, + nacv_config, + grad_details, + nacv_details, + atoms_list, + num_states, + initial_surf, + dt, + max_time, + elec_substeps, + decoherence, + hop_eqn, + **kwargs, + ): self.atoms_list = atoms_list self.vel = self.get_vel() self.T = None @@ -67,14 +64,13 @@ def __init__(self, self.num_atoms = len(self.atoms_list[0]) self.num_samples = len(atoms_list) self.num_states = num_states - self.surfs = np.ones(self.num_samples, - dtype=np.int) * initial_surf + self.surfs = np.ones(self.num_samples, dtype=np.int) * initial_surf self.dt = dt * const.FS_TO_AU self.elec_substeps = elec_substeps self.max_time = max_time * const.FS_TO_AU - self.max_gap_hop = float('inf') + self.max_gap_hop = float("inf") self.log_file = TULLY_LOG_FILE self.save_file = TULLY_SAVE_FILE @@ -86,7 +82,7 @@ def __init__(self, self.c = self.init_c() self.decoherence = self.init_decoherence(params=decoherence) - self.decoherence_type = decoherence['name'] + self.decoherence_type = decoherence["name"] self.hop_eqn = hop_eqn self.diabat_propagate = False self.simple_vel_scale = False @@ -108,11 +104,8 @@ def __init__(self, @property def forces(self): - inf = np.ones((self.num_atoms, - 3)) * float('inf') - _forces = np.stack([-self.props.get(f'energy_{i}_grad', - inf).reshape(-1, 3) - for i in range(self.num_states)]) + inf = np.ones((self.num_atoms, 3)) * float("inf") + _forces = np.stack([-self.props.get(f"energy_{i}_grad", inf).reshape(-1, 3) for i in range(self.num_states)]) _forces = _forces.reshape(1, *_forces.shape) return _forces @@ -120,21 +113,17 @@ def forces(self): @forces.setter def forces(self, _forces): for i in range(self.num_states): - self.props[f'energy_{i}_grad'] = -_forces[:, i] - - def correct_phase(self, - old_force_nacv): + self.props[f"energy_{i}_grad"] = -_forces[:, i] + def correct_phase(self, old_force_nacv): if old_force_nacv is None: return new_force_nacv = self.force_nacv new_nacv = self.nacv - delta = np.max(np.linalg.norm(old_force_nacv - new_force_nacv, - axis=((-1, -2))), axis=-1) - sigma = np.max(np.linalg.norm(old_force_nacv + new_force_nacv, - axis=((-1, -2))), axis=-1) + delta = np.max(np.linalg.norm(old_force_nacv - new_force_nacv, axis=((-1, -2))), axis=-1) + sigma = np.max(np.linalg.norm(old_force_nacv + new_force_nacv, axis=((-1, -2))), axis=-1) delta = delta.reshape(*delta.shape, 1, 1, 1) sigma = sigma.reshape(*sigma.shape, 1, 1, 1) @@ -151,13 +140,10 @@ def correct_phase(self, num_states = new_nacv.shape[1] for i in range(num_states): for j in range(num_states): - self.props[f'force_nacv_{i}{j}'] = new_force_nacv[:, i, j] - self.props[f'nacv_{i}{j}'] = new_nacv[:, i, j] - - def update_props(self, - *args, - **kwargs): + self.props[f"force_nacv_{i}{j}"] = new_force_nacv[:, i, j] + self.props[f"nacv_{i}{j}"] = new_nacv[:, i, j] + def update_props(self, *args, **kwargs): old_force_nacv = copy.deepcopy(self.force_nacv) job_dir = os.path.join(os.getcwd(), str(self.step_num)) @@ -166,15 +152,17 @@ def update_props(self, else: os.makedirs(job_dir) - self.props = ab_results(nxyz=self.nxyz, - charge=self.charge, - num_states=self.num_states, - surf=self.surfs[0], - job_dir=job_dir, - grad_config=self.grad_config, - nacv_config=self.nacv_config, - grad_details=self.grad_details, - nacv_details=self.nacv_details) + self.props = ab_results( + nxyz=self.nxyz, + charge=self.charge, + num_states=self.num_states, + surf=self.surfs[0], + job_dir=job_dir, + grad_config=self.grad_config, + nacv_config=self.nacv_config, + grad_details=self.grad_details, + nacv_details=self.nacv_details, + ) self.correct_phase(old_force_nacv=old_force_nacv) self.step_num += 1 @@ -182,8 +170,7 @@ def get_vel(self): """ Velocities are in a.u. here, not ASE units """ - vel = np.stack([atoms.get_velocities() - for atoms in self.atoms_list]) + vel = np.stack([atoms.get_velocities() for atoms in self.atoms_list]) return vel @@ -197,35 +184,32 @@ def new_force_calc(self): """ surf = self.surfs[0] - needs_calc = np.bitwise_not( - np.isfinite( - self.forces[0, surf] - ) - ).any() + needs_calc = np.bitwise_not(np.isfinite(self.forces[0, surf])).any() if not needs_calc: return - new_job_dir = os.path.join(os.getcwd(), - f"{self.step_num - 1}_extra") + new_job_dir = os.path.join(os.getcwd(), f"{self.step_num - 1}_extra") if os.path.isdir(new_job_dir): shutil.rmtree(new_job_dir) else: os.makedirs(new_job_dir) - props = ab_results(nxyz=self.nxyz, - charge=self.charge, - num_states=self.num_states, - surf=surf, - job_dir=new_job_dir, - grad_config=self.grad_config, - nacv_config=self.nacv_config, - grad_details=self.grad_details, - nacv_details=self.nacv_details, - calc_nacv=False) - - key = f'energy_{surf}_grad' + props = ab_results( + nxyz=self.nxyz, + charge=self.charge, + num_states=self.num_states, + surf=surf, + job_dir=new_job_dir, + grad_config=self.grad_config, + nacv_config=self.nacv_config, + grad_details=self.grad_details, + nacv_details=self.nacv_details, + calc_nacv=False, + ) + + key = f"energy_{surf}_grad" self.props[key] = props[key] def run(self): @@ -240,28 +224,22 @@ def run(self): self.save() self.step(needs_nbrs=False) - with open(self.log_file, 'a') as f: - f.write('\nTully surface hopping terminated normally.') + with open(self.log_file, "a") as f: + f.write("\nTully surface hopping terminated normally.") @classmethod - def from_file(cls, - file): - + def from_file(cls, file): all_params = load_params(file) atoms_list = make_atoms(all_params) - instance = cls(atoms_list=atoms_list, - **all_params) + instance = cls(atoms_list=atoms_list, **all_params) return instance def main(): parser = argparse.ArgumentParser() - parser.add_argument('--params_file', - type=str, - help='Info file with parameters', - default='job_info.json') + parser.add_argument("--params_file", type=str, help="Info file with parameters", default="job_info.json") args = parser.parse_args() path = args.params_file @@ -272,8 +250,9 @@ def main(): except Exception as e: print(e) import pdb + pdb.post_mortem() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/nff/md/tully/ab_io.py b/nff/md/tully/ab_io.py index 30d0b7e0..48f01037 100644 --- a/nff/md/tully/ab_io.py +++ b/nff/md/tully/ab_io.py @@ -1,70 +1,55 @@ -import os +import copy import json -from jinja2 import Template -from rdkit import Chem +import os import time -import numpy as np -import copy -from chemconfigs.parsers.qchem import (get_cis_grads, - get_nacv, - get_sf_energies) +import numpy as np +from chemconfigs.parsers.qchem import get_cis_grads, get_nacv, get_sf_energies +from jinja2 import Template +from rdkit import Chem from nff.utils.misc import bash_command - CONFIG_DIRS = { - "bhhlyp_6-31gs_sf_tddft_engrad_qchem": - "qchem/bhhlyp_6-31gs_sf_tddft_engrad", - - "bhhlyp_6-31gs_sf_tddft_nacv_qchem": - "qchem/bhhlyp_6-31gs_sf_tddft_nacv" + "bhhlyp_6-31gs_sf_tddft_engrad_qchem": "qchem/bhhlyp_6-31gs_sf_tddft_engrad", + "bhhlyp_6-31gs_sf_tddft_nacv_qchem": "qchem/bhhlyp_6-31gs_sf_tddft_nacv", } -SPIN_FLIP_CONFIGS = ["bhhlyp_6-31gs_sf_tddft_engrad_qchem", - "bhhlyp_6-31gs_sf_tddft_engrad_qchem_pcm", - "bhhlyp_6-31gs_sf_tddft_nacv_qchem", - "bhhlyp_6-31gs_sf_tddft_nacv_qchem_pcm"] +SPIN_FLIP_CONFIGS = [ + "bhhlyp_6-31gs_sf_tddft_engrad_qchem", + "bhhlyp_6-31gs_sf_tddft_engrad_qchem_pcm", + "bhhlyp_6-31gs_sf_tddft_nacv_qchem", + "bhhlyp_6-31gs_sf_tddft_nacv_qchem_pcm", +] PERIODICTABLE = Chem.GetPeriodicTable() -def render(temp_text, - jobspec, - write_path): - +def render(temp_text, jobspec, write_path): template = Template(temp_text) inp = template.render(jobspec=jobspec) - with open(write_path, 'w') as f_open: + with open(write_path, "w") as f_open: f_open.write(inp) def get_files(config, jobspec): - platform = jobspec['details'].get("platform") + platform = jobspec["details"].get("platform") dic = config if platform is not None: dic = config[platform] - files = [dic['job_template_filename'], - *dic['extra_template_filenames']] + files = [dic["job_template_filename"], *dic["extra_template_filenames"]] - if config['name'] == "bhhlyp_6-31gs_sf_tddft_engrad_qchem": - rm_file = 'qchem_bhhlyp_6-31gs_sf_tddft_engrad.inp' + if config["name"] == "bhhlyp_6-31gs_sf_tddft_engrad_qchem": + rm_file = "qchem_bhhlyp_6-31gs_sf_tddft_engrad.inp" if rm_file in files: files.remove(rm_file) return files -def render_config(config_name, - config_dir, - config, - jobspec, - job_dir, - num_parallel, - run_parallel=True): - +def render_config(config_name, config_dir, config, jobspec, job_dir, num_parallel, run_parallel=True): files = get_files(config, jobspec) # use 1 / num_parallel * total number of cores @@ -72,123 +57,91 @@ def render_config(config_name, this_jobspec = copy.deepcopy(jobspec) if run_parallel: - nprocs = this_jobspec['details']['nprocs'] - this_jobspec['details']['nprocs'] = int(nprocs / num_parallel) + nprocs = this_jobspec["details"]["nprocs"] + this_jobspec["details"]["nprocs"] = int(nprocs / num_parallel) for file in files: temp_path = os.path.join(config_dir, file) write_path = os.path.join(job_dir, file) - with open(temp_path, 'r') as f: + with open(temp_path, "r") as f: temp_text = f.read() - render(temp_text=temp_text, - jobspec=this_jobspec, - write_path=write_path) + render(temp_text=temp_text, jobspec=this_jobspec, write_path=write_path) - info_path = os.path.join(job_dir, 'job_info.json') - with open(info_path, 'w') as f: + info_path = os.path.join(job_dir, "job_info.json") + with open(info_path, "w") as f: json.dump(this_jobspec, f, indent=4) def translate_dir(direc): - if '$HOME' in direc: - direc = direc.replace("$HOME", - os.environ["HOME"]) + if "$HOME" in direc: + direc = direc.replace("$HOME", os.environ["HOME"]) return direc -def load_config(config_name, - htvs_dir): - +def load_config(config_name, htvs_dir): config_dir_name = CONFIG_DIRS[config_name] - config_dir = os.path.join(translate_dir(htvs_dir), - 'chemconfigs', - config_dir_name) + config_dir = os.path.join(translate_dir(htvs_dir), "chemconfigs", config_dir_name) - config_path = os.path.join(config_dir, - 'config.json') + config_path = os.path.join(config_dir, "config.json") - with open(config_path, 'r') as f: + with open(config_path, "r") as f: config = json.load(f) return config, config_dir -def render_all(config_name, - jobspec, - job_dir, - num_parallel): +def render_all(config_name, jobspec, job_dir, num_parallel): + htvs_dir = jobspec["details"]["htvs"] + config, config_dir = load_config(config_name=config_name, htvs_dir=htvs_dir) - htvs_dir = jobspec['details']['htvs'] - config, config_dir = load_config(config_name=config_name, - htvs_dir=htvs_dir) - - render_config(config_name=config_name, - config_dir=config_dir, - config=config, - jobspec=jobspec, - job_dir=job_dir, - num_parallel=num_parallel) + render_config( + config_name=config_name, + config_dir=config_dir, + config=config, + jobspec=jobspec, + job_dir=job_dir, + num_parallel=num_parallel, + ) def get_coords(nxyz): coords = [] - for l in nxyz: - this_coord = {"element": PERIODICTABLE.GetElementSymbol(int(l[0])), - "x": l[1], - "y": l[2], - "z": l[3]} + for line in nxyz: + this_coord = {"element": PERIODICTABLE.GetElementSymbol(int(line[0])), "x": line[1], "y": line[2], "z": line[3]} coords.append(this_coord) return coords -def init_jobspec(nxyz, - details, - charge): - +def init_jobspec(nxyz, details, charge): coords = get_coords(nxyz) - jobspec = {'details': details, - 'coords': coords, - 'charge': charge} + jobspec = {"details": details, "coords": coords, "charge": charge} return jobspec -def sf_grad_jobspec(jobspec, - surf): - - jobspec['details'].update({'grad_roots': [int(surf)], - 'num_grad_roots': 1}) +def sf_grad_jobspec(jobspec, surf): + jobspec["details"].update({"grad_roots": [int(surf)], "num_grad_roots": 1}) return jobspec -def sf_nacv_jobspec(jobspec, - singlet_path, - num_states): - - with open(singlet_path, 'r') as f: +def sf_nacv_jobspec(jobspec, singlet_path, num_states): + with open(singlet_path, "r") as f: singlets = json.load(f) coupled_states = singlets[:num_states] - details = jobspec['details'] + details = jobspec["details"] details.update({"coupled_states": coupled_states}) return jobspec -def run_job(config_name, - jobspec, - job_dir, - num_parallel): - - render_all(config_name=config_name, - jobspec=jobspec, - job_dir=job_dir, - num_parallel=num_parallel) +def run_job(config_name, jobspec, job_dir, num_parallel): + render_all(config_name=config_name, jobspec=jobspec, job_dir=job_dir, num_parallel=num_parallel) cmd = f"cd {job_dir} && bash job.sh && rm *fchk" p = bash_command(cmd) @@ -196,95 +149,59 @@ def run_job(config_name, return p -def bhhlyp_6_31gs_sf_tddft_engrad_qchem(nxyz, - details, - charge, - surf, - job_dir, - num_parallel): - - jobspec = init_jobspec(nxyz=nxyz[0], - details=details, - charge=charge) - jobspec = sf_grad_jobspec(jobspec=jobspec, - surf=surf) +def bhhlyp_6_31gs_sf_tddft_engrad_qchem(nxyz, details, charge, surf, job_dir, num_parallel): + jobspec = init_jobspec(nxyz=nxyz[0], details=details, charge=charge) + jobspec = sf_grad_jobspec(jobspec=jobspec, surf=surf) - config_name = 'bhhlyp_6-31gs_sf_tddft_engrad_qchem' + config_name = "bhhlyp_6-31gs_sf_tddft_engrad_qchem" - grad_dir = os.path.join(job_dir, 'grad') + grad_dir = os.path.join(job_dir, "grad") if not os.path.isdir(grad_dir): os.makedirs(grad_dir) # copy job_info.json - p = run_job(config_name=config_name, - jobspec=jobspec, - job_dir=grad_dir, - num_parallel=num_parallel) + p = run_job(config_name=config_name, jobspec=jobspec, job_dir=grad_dir, num_parallel=num_parallel) return p def get_singlet_path(job_dir): - singlet_path = os.path.join(job_dir, 'grad', 'singlets.json') + singlet_path = os.path.join(job_dir, "grad", "singlets.json") return singlet_path -def bhhlyp_6_31gs_sf_tddft_nacv_qchem(nxyz, - details, - charge, - num_states, - job_dir, - num_parallel): - +def bhhlyp_6_31gs_sf_tddft_nacv_qchem(nxyz, details, charge, num_states, job_dir, num_parallel): singlet_path = get_singlet_path(job_dir) exists = False while not exists: exists = os.path.isfile(singlet_path) time.sleep(5) - jobspec = init_jobspec(nxyz=nxyz[0], - details=details, - charge=charge) - jobspec = sf_nacv_jobspec(jobspec=jobspec, - singlet_path=singlet_path, - num_states=num_states) + jobspec = init_jobspec(nxyz=nxyz[0], details=details, charge=charge) + jobspec = sf_nacv_jobspec(jobspec=jobspec, singlet_path=singlet_path, num_states=num_states) config_name = "bhhlyp_6-31gs_sf_tddft_nacv_qchem" - nacv_dir = os.path.join(job_dir, 'nacv') + nacv_dir = os.path.join(job_dir, "nacv") if not os.path.isdir(nacv_dir): os.makedirs(nacv_dir) - p = run_job(config_name=config_name, - jobspec=jobspec, - job_dir=nacv_dir, - num_parallel=num_parallel) + p = run_job(config_name=config_name, jobspec=jobspec, job_dir=nacv_dir, num_parallel=num_parallel) return p -def run_sf(job_dir, - nxyz, - charge, - num_states, - surf, - grad_details, - nacv_details, - grad_config, - nacv_config, - calc_nacv=True): - +def run_sf( + job_dir, nxyz, charge, num_states, surf, grad_details, nacv_details, grad_config, nacv_config, calc_nacv=True +): procs = [] proc_names = [] num_parallel = 2 if calc_nacv else 1 - if grad_config == 'bhhlyp_6-31gs_sf_tddft_engrad_qchem': - p = bhhlyp_6_31gs_sf_tddft_engrad_qchem(nxyz=nxyz, - details=grad_details, - charge=charge, - surf=surf, - job_dir=job_dir, - num_parallel=num_parallel) + if grad_config == "bhhlyp_6-31gs_sf_tddft_engrad_qchem": + p = bhhlyp_6_31gs_sf_tddft_engrad_qchem( + nxyz=nxyz, details=grad_details, charge=charge, surf=surf, job_dir=job_dir, num_parallel=num_parallel + ) procs.append(p) proc_names.append("Q-Chem engrad") @@ -294,12 +211,14 @@ def run_sf(job_dir, if calc_nacv: if nacv_config == "bhhlyp_6-31gs_sf_tddft_nacv_qchem": - p = bhhlyp_6_31gs_sf_tddft_nacv_qchem(nxyz=nxyz, - details=nacv_details, - charge=charge, - num_states=num_states, - job_dir=job_dir, - num_parallel=num_parallel) + p = bhhlyp_6_31gs_sf_tddft_nacv_qchem( + nxyz=nxyz, + details=nacv_details, + charge=charge, + num_states=num_states, + job_dir=job_dir, + num_parallel=num_parallel, + ) procs.append(p) proc_names.append("Q-Chem NACV") @@ -315,8 +234,8 @@ def run_sf(job_dir, def parse_sf_grads(job_dir): - path = os.path.join(job_dir, 'singlet_grad.out') - with open(path, 'r') as f: + path = os.path.join(job_dir, "singlet_grad.out") + with open(path, "r") as f: lines = f.readlines() output_dics = get_cis_grads(lines) @@ -324,84 +243,68 @@ def parse_sf_grads(job_dir): def parse_sf_ens(job_dir): - path = os.path.join(job_dir, 'singlet_energy.out') - with open(path, 'r') as f: + path = os.path.join(job_dir, "singlet_energy.out") + with open(path, "r") as f: lines = f.readlines() output_dics = get_sf_energies(lines) return output_dics -def parse_sf_nacv(job_dir, - conifg_name): - +def parse_sf_nacv(job_dir, conifg_name): if conifg_name == "bhhlyp_6-31gs_sf_tddft_nacv_qchem": - out_name = 'qchem_bhhlyp_6-31gs_sf_tddft_nacv' + out_name = "qchem_bhhlyp_6-31gs_sf_tddft_nacv" else: raise NotImplementedError - path = os.path.join(job_dir, f'{out_name}.out') - with open(path, 'r') as f: + path = os.path.join(job_dir, f"{out_name}.out") + with open(path, "r") as f: lines = f.readlines() output_dics = get_nacv(lines) return output_dics -def check_sf(grad_config, - nacv_config): +def check_sf(grad_config, nacv_config): configs = [grad_config, nacv_config] - is_sf = any([config in SPIN_FLIP_CONFIGS for config in configs]) + is_sf = any(config in SPIN_FLIP_CONFIGS for config in configs) return is_sf -def parse_sf(job_dir, - nacv_config, - calc_nacv=True): - - nacv_dir = os.path.join(job_dir, 'nacv') - grad_dir = os.path.join(job_dir, 'grad') +def parse_sf(job_dir, nacv_config, calc_nacv=True): + nacv_dir = os.path.join(job_dir, "nacv") + grad_dir = os.path.join(job_dir, "grad") en_dics = parse_sf_ens(job_dir=grad_dir) grad_dics = parse_sf_grads(job_dir=grad_dir) - if calc_nacv: - nacv_dic = parse_sf_nacv(job_dir=nacv_dir, - conifg_name=nacv_config) - else: - nacv_dic = {} + nacv_dic = parse_sf_nacv(job_dir=nacv_dir, conifg_name=nacv_config) if calc_nacv else {} return en_dics, grad_dics, nacv_dic -def en_to_arr(results, - en_dics, - singlets): - +def en_to_arr(results, en_dics, singlets): for dic in en_dics: - state = dic['state'] + state = dic["state"] if state not in singlets: continue idx = singlets.index(state) key = f"energy_{idx}" - en = dic['energy'] + en = dic["energy"] results[key] = np.array([en]) return results -def grad_to_arr(results, - grad_dics, - singlets): - +def grad_to_arr(results, grad_dics, singlets): combined_grad = {} for dic in grad_dics: combined_grad.update(dic) for abs_state, grad in combined_grad.items(): idx = singlets.index(abs_state) - key = f'energy_{idx}_grad' + key = f"energy_{idx}_grad" grad = np.array(grad) shape = grad.shape @@ -410,12 +313,9 @@ def grad_to_arr(results, return results -def nacv_to_arr(results, - nacv_dic, - singlets): - - translation = {"deriv_nacv_etf": 'nacv'} - keys = ['deriv_nacv_etf', 'force_nacv'] +def nacv_to_arr(results, nacv_dic, singlets): + translation = {"deriv_nacv_etf": "nacv"} + keys = ["deriv_nacv_etf", "force_nacv"] for key in keys: if key not in nacv_dic: @@ -427,8 +327,7 @@ def nacv_to_arr(results, singlet_end = singlets.index(end_state) translate_base = translation.get(key, key) - results_key = (f"{translate_base}_{singlet_start}" - f"{singlet_end}") + results_key = f"{translate_base}_{singlet_start}" f"{singlet_end}" nacv = np.array(nacv) shape = nacv.shape @@ -437,46 +336,26 @@ def nacv_to_arr(results, return results -def combine_results(singlets, - en_dics, - grad_dics, - nacv_dic): - +def combine_results(singlets, en_dics, grad_dics, nacv_dic): results = {} - results = en_to_arr(results=results, - en_dics=en_dics, - singlets=singlets) - results = grad_to_arr(results=results, - grad_dics=grad_dics, - singlets=singlets) - results = nacv_to_arr(results=results, - nacv_dic=nacv_dic, - singlets=singlets) + results = en_to_arr(results=results, en_dics=en_dics, singlets=singlets) + results = grad_to_arr(results=results, grad_dics=grad_dics, singlets=singlets) + results = nacv_to_arr(results=results, nacv_dic=nacv_dic, singlets=singlets) return results -def parse(job_dir, - grad_config, - nacv_config, - calc_nacv=True): - - is_sf = check_sf(grad_config=grad_config, - nacv_config=nacv_config) +def parse(job_dir, grad_config, nacv_config, calc_nacv=True): + is_sf = check_sf(grad_config=grad_config, nacv_config=nacv_config) if is_sf: - en_dics, grad_dics, nacv_dic = parse_sf(job_dir=job_dir, - nacv_config=nacv_config, - calc_nacv=calc_nacv) + en_dics, grad_dics, nacv_dic = parse_sf(job_dir=job_dir, nacv_config=nacv_config, calc_nacv=calc_nacv) singlet_path = get_singlet_path(job_dir) - with open(singlet_path, 'r') as f: + with open(singlet_path, "r") as f: singlets = json.load(f) - results = combine_results(singlets=singlets, - en_dics=en_dics, - grad_dics=grad_dics, - nacv_dic=nacv_dic) + results = combine_results(singlets=singlets, en_dics=en_dics, grad_dics=grad_dics, nacv_dic=nacv_dic) else: raise NotImplementedError @@ -484,38 +363,28 @@ def parse(job_dir, return results -def get_results(nxyz, - charge, - num_states, - surf, - job_dir, - grad_config, - nacv_config, - grad_details, - nacv_details, - calc_nacv=True): - - is_sf = check_sf(grad_config=grad_config, - nacv_config=nacv_config) +def get_results( + nxyz, charge, num_states, surf, job_dir, grad_config, nacv_config, grad_details, nacv_details, calc_nacv=True +): + is_sf = check_sf(grad_config=grad_config, nacv_config=nacv_config) if is_sf: - run_sf(job_dir=job_dir, - nxyz=nxyz, - charge=charge, - surf=surf, - num_states=num_states, - grad_details=grad_details, - nacv_details=nacv_details, - grad_config=grad_config, - nacv_config=nacv_config, - calc_nacv=calc_nacv) + run_sf( + job_dir=job_dir, + nxyz=nxyz, + charge=charge, + surf=surf, + num_states=num_states, + grad_details=grad_details, + nacv_details=nacv_details, + grad_config=grad_config, + nacv_config=nacv_config, + calc_nacv=calc_nacv, + ) else: raise NotImplementedError - results = parse(job_dir=job_dir, - grad_config=grad_config, - nacv_config=nacv_config, - calc_nacv=calc_nacv) + results = parse(job_dir=job_dir, grad_config=grad_config, nacv_config=nacv_config, calc_nacv=calc_nacv) return results diff --git a/nff/md/tully/dynamics.py b/nff/md/tully/dynamics.py index cd4319e3..1e69186e 100644 --- a/nff/md/tully/dynamics.py +++ b/nff/md/tully/dynamics.py @@ -137,7 +137,7 @@ def setup_save(self): def init_decoherence(self, params): if not params: - return + return None name = params["name"] kwargs = params.get("kwargs", {}) @@ -163,10 +163,7 @@ def get_vel(self): return vel def init_c(self): - if self.explicit_diabat: - num_states = self.num_diabat - else: - num_states = self.num_states + num_states = self.num_diabat if self.explicit_diabat else self.num_states c = np.zeros((self.num_samples, num_states), dtype="complex128") c[:, self.surfs[0]] = 1 @@ -191,9 +188,9 @@ def init_sigma(self): @property def U(self): if not self.props: - return + return None if "U" not in self.props: - return + return None return self.props["U"] @U.setter @@ -238,7 +235,7 @@ def nacv(self): continue key = f"nacv_{i}{j}" if key not in self.props: - return + return None _nacv[:, i, j, :] = self.props[key] return _nacv @@ -273,7 +270,7 @@ def force_nacv(self): nacv = self.nacv if nacv is None: - return + return None gap = self.gap.reshape(self.num_samples, self.num_states, self.num_states, 1, 1) @@ -338,7 +335,7 @@ def full_pot_V(self): @property def H_plus_nacv(self): if self.nacv is None: - return + return None pot_V = self.pot_V nac_term = -1j * (self.nacv * self.vel.reshape(self.num_samples, 1, 1, self.num_atoms, 3)).sum((-1, -2)) @@ -348,8 +345,8 @@ def H_plus_nacv(self): def H_d(self): diabat_keys = getattr(self, "diabat_keys", [None]) reshaped = np.array(diabat_keys).reshape(-1).tolist() - if not all([i in self.props for i in reshaped]): - return + if not all(i in self.props for i in reshaped): + return None _H_d = np.zeros((self.num_samples, self.num_diabat, self.num_diabat)) @@ -448,7 +445,7 @@ def setup_logging(self, remove_old=True): f.write(hdr) template = "%-10.1f " - for i, state in enumerate(states): + for _ in states: template += "%15.4f%%" template += "%15.4f" template += "%15.4f" @@ -508,7 +505,7 @@ def from_pickle(cls, file, max_time=None): nxyz = state_dict["nxyz"] if single: nxyz = [nxyz] - for i, nxyz in enumerate(nxyz): + for i, nxyz in enumerate(nxyz): # noqa if nxyz is None: trjs[i].append(None) continue diff --git a/nff/md/tully/io.py b/nff/md/tully/io.py index e481c849..d4f5d631 100644 --- a/nff/md/tully/io.py +++ b/nff/md/tully/io.py @@ -2,33 +2,28 @@ Link between Tully surface hopping and both NFF models and JSON parameter files. """ + import json import os -import torch -from torch.utils.data import DataLoader import numpy as np - -from rdkit import Chem +import torch from ase import Atoms +from rdkit import Chem +from torch.utils.data import DataLoader -from nff.train import batch_to, batch_detach -from nff.nn.utils import single_spec_nbrs from nff.data import Dataset, collate_dicts +from nff.io.ase_ax import AtomsBatch, NeuralFF +from nff.nn.utils import single_spec_nbrs +from nff.train import batch_detach, batch_to from nff.utils import constants as const from nff.utils.scatter import compute_grad -from nff.io.ase_ax import NeuralFF, AtomsBatch PERIODICTABLE = Chem.GetPeriodicTable() ANGLE_MODELS = ["DimeNet", "DimeNetDiabat", "DimeNetDiabatDelta"] -def check_hop(model, - results, - max_gap_hop, - surf, - num_states): - +def check_hop(model, results, max_gap_hop, surf, num_states): # **** this won't work - assumes surf is an integer """ `max_gap_hop` in a.u. @@ -39,35 +34,29 @@ def check_hop(model, continue upper = max([i, surf]) lower = min([i, surf]) - key = f'energy_{upper}_energy_{lower}_delta' + key = f"energy_{upper}_energy_{lower}_delta" gap_keys.append(key) # convert max_gap_hop to kcal - max_conv = max_gap_hop * const.AU_TO_KCAL['energy'] - gaps = torch.cat([results[key].reshape(-1, 1) - for key in gap_keys], dim=-1) + max_conv = max_gap_hop * const.AU_TO_KCAL["energy"] + gaps = torch.cat([results[key].reshape(-1, 1) for key in gap_keys], dim=-1) can_hop = (gaps <= max_conv).sum(-1).to(torch.bool) return can_hop -def split_by_hop(dic, - can_hop, - num_atoms): - +def split_by_hop(dic, can_hop, num_atoms): hop_dic = {} no_hop_dic = {} for key, val in dic.items(): - if any(['nacv' in key, 'grad' in key, 'nxyz' in key]): + if any(["nacv" in key, "grad" in key, "nxyz" in key]): val = torch.split(val, num_atoms) - hop_tensor = torch.cat([item for i, item in enumerate(val) - if can_hop[i]]) + hop_tensor = torch.cat([item for i, item in enumerate(val) if can_hop[i]]) - no_hop_tensor = torch.cat([item for i, item in enumerate(val) - if not can_hop[i]]) + no_hop_tensor = torch.cat([item for i, item in enumerate(val) if not can_hop[i]]) hop_dic[key] = hop_tensor no_hop_dic[key] = no_hop_tensor @@ -75,70 +64,46 @@ def split_by_hop(dic, return hop_dic, no_hop_dic -def split_all(model, - xyz, - max_gap_hop, - surf, - num_states, - batch, - results): - - can_hop = check_hop(model=model, - results=results, - max_gap_hop=max_gap_hop, - surf=surf, - num_states=num_states) +def split_all(model, xyz, max_gap_hop, surf, num_states, batch, results): + can_hop = check_hop(model=model, results=results, max_gap_hop=max_gap_hop, surf=surf, num_states=num_states) - num_atoms = batch['num_atoms'].tolist() - batch['xyz'] = xyz + num_atoms = batch["num_atoms"].tolist() + batch["xyz"] = xyz - hop_batch, no_hop_batch = split_by_hop(dic=batch, - can_hop=can_hop, - num_atoms=num_atoms) + hop_batch, no_hop_batch = split_by_hop(dic=batch, can_hop=can_hop, num_atoms=num_atoms) - hop_results, no_hop_results = split_by_hop(dic=results, - can_hop=can_hop, - num_atoms=num_atoms) + hop_results, no_hop_results = split_by_hop(dic=results, can_hop=can_hop, num_atoms=num_atoms) splits = (hop_batch, no_hop_batch, hop_results, no_hop_results) return splits, can_hop -def init_results(num_atoms, - num_states): - - en_keys = [f'energy_{i}' for i in range(num_states)] +def init_results(num_atoms, num_states): + en_keys = [f"energy_{i}" for i in range(num_states)] grad_keys = [key + "_grad" for key in en_keys] - nacv_keys = [f"nacv_{i}{j}" for i in range(num_states) - for j in range(num_states) if i != j] + nacv_keys = [f"nacv_{i}{j}" for i in range(num_states) for j in range(num_states) if i != j] force_nacv_keys = ["force_" + key for key in nacv_keys] num_samples = len(num_atoms) - shapes = {"energy": [num_samples], - "grad": [num_samples, num_atoms[0], 3]} + shapes = {"energy": [num_samples], "grad": [num_samples, num_atoms[0], 3]} - key_maps = {"energy": en_keys, - "grad": [*grad_keys, *nacv_keys, *force_nacv_keys]} + key_maps = {"energy": en_keys, "grad": [*grad_keys, *nacv_keys, *force_nacv_keys]} results = {} for key_type, keys in key_maps.items(): shape = shapes[key_type] for key in keys: - init = torch.ones(*shape) * float('nan') + init = torch.ones(*shape) * float("nan") results[key] = init -def fill_results(batch, - these_results, - results, - idx): - - num_atoms = batch['num_atoms'].tolist() - grad_flags = ['_grad', 'nacv'] +def fill_results(batch, these_results, results, idx): + num_atoms = batch["num_atoms"].tolist() + grad_flags = ["_grad", "nacv"] - for key, val in these_results.keys(): - if any([flag in key for flag in grad_flags]): + for key, val in these_results: + if any(flag in key for flag in grad_flags): val = torch.stack(torch.split(val, num_atoms)) results[key][idx] = val @@ -146,96 +111,68 @@ def fill_results(batch, return results -def combine_all(no_hop_results, - hop_results, - no_hop_batch, - hop_batch, - can_hop, - num_states, - batch): - - num_atoms = batch['num_atoms'].tolist() - results = init_results(num_atoms=num_atoms, - num_states=num_states) +def combine_all(no_hop_results, hop_results, no_hop_batch, hop_batch, can_hop, num_states, batch): + num_atoms = batch["num_atoms"].tolist() + results = init_results(num_atoms=num_atoms, num_states=num_states) hop_idx = can_hop.nonzero() no_hop_idx = torch.bitwise_not(can_hop).nonzero() - tuples = [(no_hop_batch, no_hop_results, no_hop_idx), - (hop_batch, hop_results, hop_idx)] + tuples = [(no_hop_batch, no_hop_results, no_hop_idx), (hop_batch, hop_results, hop_idx)] for tup in tuples: batch, these_results, idx = tup - results = fill_results(batch=batch, - these_results=these_results, - results=results, - idx=idx) + results = fill_results(batch=batch, these_results=these_results, results=results, idx=idx) return results -def grad_by_split(model, - hop_batch, - hop_results, - no_hop_batch, - no_hop_results, - surf): - +def grad_by_split(model, hop_batch, hop_results, no_hop_batch, no_hop_results, surf): # add all the gradients for the hop batch and results - model.diabatic_readout.add_all_grads(xyz=hop_batch['xyz'], - results=hop_results, - num_atoms=hop_batch['num_atoms'], - u=hop_results['U'], - add_u=False) + model.diabatic_readout.add_all_grads( + xyz=hop_batch["xyz"], results=hop_results, num_atoms=hop_batch["num_atoms"], u=hop_results["U"], add_u=False + ) # just add the state gradient for the non-hop batch / results - key = f'energy_{surf}' - surf_grad = compute_grad(inputs=no_hop_batch['xyz'], - output=no_hop_results[key]) - no_hop_results[key + '_grad'] = surf_grad + key = f"energy_{surf}" + surf_grad = compute_grad(inputs=no_hop_batch["xyz"], output=no_hop_results[key]) + no_hop_results[key + "_grad"] = surf_grad return hop_results, no_hop_results -def add_grad(model, - batch, - xyz, - results, - max_gap_hop, - surf, - num_states): - +def add_grad(model, batch, xyz, results, max_gap_hop, surf, num_states): # split batches and results into those that require NACVs # and gradients on all states, and those that only require # the gradient on the current state - splits, can_hop = split_all(model=model, - xyz=xyz, - max_gap_hop=max_gap_hop, - surf=surf, - num_states=num_states, - batch=batch, - results=results) + splits, can_hop = split_all( + model=model, xyz=xyz, max_gap_hop=max_gap_hop, surf=surf, num_states=num_states, batch=batch, results=results + ) (hop_batch, no_hop_batch, hop_results, no_hop_results) = splits # add the relevant gradients - hop_results, no_hop_results = grad_by_split(model=model, - hop_batch=hop_batch, - hop_results=hop_results, - no_hop_batch=no_hop_batch, - no_hop_results=no_hop_results, - surf=surf) + hop_results, no_hop_results = grad_by_split( + model=model, + hop_batch=hop_batch, + hop_results=hop_results, + no_hop_batch=no_hop_batch, + no_hop_results=no_hop_results, + surf=surf, + ) # combine everything together - results = combine_all(no_hop_results=no_hop_results, - hop_results=hop_results, - no_hop_batch=no_hop_batch, - hop_batch=hop_batch, - can_hop=can_hop, - num_states=num_states, - batch=batch) + results = combine_all( + no_hop_results=no_hop_results, + hop_results=hop_results, + no_hop_batch=no_hop_batch, + hop_batch=hop_batch, + can_hop=can_hop, + num_states=num_states, + batch=batch, + ) return results @@ -255,14 +192,7 @@ def add_grad(model, # for i in range(num_states)} -def run_model(model, - batch, - device, - surf, - max_gap_hop, - num_states, - all_engrads, - nacv): +def run_model(model, batch, device, surf, max_gap_hop, num_states, all_engrads, nacv): """ `max_gap_hop` in a.u. """ @@ -284,14 +214,16 @@ def run_model(model, xyz = None model.add_nacv = nacv - results = model(batch, - xyz=xyz, - add_nacv=nacv, - # add_grad=all_engrads, - add_grad=True, - add_gap=True, - add_u=True, - inference=True) + results = model( + batch, + xyz=xyz, + add_nacv=nacv, + # add_grad=all_engrads, + add_grad=True, + add_gap=True, + add_u=True, + inference=True, + ) # If we use NACV then we can come back to what's commented # out below, where you only ask for gradients NACVs among states @@ -318,8 +250,7 @@ def run_model(model, def get_phases(U, old_U): # Compute overlap - S = np.einsum('...ki, ...kj -> ...ij', - old_U, U) + S = np.einsum("...ki, ...kj -> ...ij", old_U, U) # Take the element in each column with the # largest absolute value, not just the diagonal. @@ -333,48 +264,33 @@ def get_phases(U, old_U): max_idx = abs(S).argmax(axis=1) num_samples = S.shape[0] - S_max = np.take_along_axis( - S.transpose(0, 2, 1), - max_idx.reshape(num_samples, num_states, 1), - axis=2 - ).transpose(0, 2, 1) + S_max = np.take_along_axis(S.transpose(0, 2, 1), max_idx.reshape(num_samples, num_states, 1), axis=2).transpose( + 0, 2, 1 + ) new_phases = np.sign(S_max) return new_phases -def update_phase(new_phases, - i, - j, - results, - key, - num_atoms): - - phase = ((new_phases[:, :, i] * new_phases[:, :, j]) - .reshape(-1, 1, 1)) +def update_phase(new_phases, i, j, results, key, num_atoms): + phase = (new_phases[:, :, i] * new_phases[:, :, j]).reshape(-1, 1, 1) - updated = np.concatenate( - np.split(results[key], num_atoms) - ).reshape(-1, num_atoms[0], 3) * phase + updated = np.concatenate(np.split(results[key], num_atoms)).reshape(-1, num_atoms[0], 3) * phase results[key] = updated return results -def correct_nacv(results, - old_U, - num_atoms, - num_states): +def correct_nacv(results, old_U, num_atoms, num_states): """ Stack the non-adiabatic couplings and correct their phases. Also correct the phases of U. """ # get phase correction - new_phases = get_phases(U=results["U"], - old_U=old_U) + new_phases = get_phases(U=results["U"], old_U=old_U) new_U = results["U"] * new_phases results["U"] = new_U @@ -397,45 +313,32 @@ def correct_nacv(results, if key not in results: continue - results = update_phase( - new_phases=new_phases, - i=i, - j=j, - results=results, - key=key, - num_atoms=num_atoms) + results = update_phase(new_phases=new_phases, i=i, j=j, results=results, key=key, num_atoms=num_atoms) return results -def batched_calc(model, - batch, - device, - num_states, - surf, - max_gap_hop, - all_engrads, - nacv): +def batched_calc(model, batch, device, num_states, surf, max_gap_hop, all_engrads, nacv): """ Get model results from a batch, including nacv phase correction """ - results = run_model(model=model, - batch=batch, - device=device, - surf=surf, - max_gap_hop=max_gap_hop, - num_states=num_states, - all_engrads=all_engrads, - nacv=nacv) + results = run_model( + model=model, + batch=batch, + device=device, + surf=surf, + max_gap_hop=max_gap_hop, + num_states=num_states, + all_engrads=all_engrads, + nacv=nacv, + ) return results -def concat_and_conv(results_list, - num_atoms, - diabat_keys): +def concat_and_conv(results_list, num_atoms, diabat_keys): """ Concatenate results from separate batches and convert to atomic units @@ -450,13 +353,13 @@ def concat_and_conv(results_list, for key in keys: val = torch.cat([i[key] for i in results_list]) - if ('energy' in key and '_grad' in key) or 'force_nacv' in key: - val *= conv['energy'] * conv['_grad'] + if ("energy" in key and "_grad" in key) or "force_nacv" in key: + val *= conv["energy"] * conv["_grad"] val = val.reshape(*grad_shape) - elif 'energy' in key or key in diabat_keys: - val *= conv['energy'] - elif 'nacv' in key: - val *= conv['_grad'] + elif "energy" in key or key in diabat_keys: + val *= conv["energy"] + elif "nacv" in key: + val *= conv["_grad"] val = val.reshape(*grad_shape) # else: # msg = f"{key} has no known conversion" @@ -467,35 +370,18 @@ def concat_and_conv(results_list, return all_results -def make_loader(nxyz, - nbr_list, - num_atoms, - needs_nbrs, - cutoff, - cutoff_skin, - device, - batch_size): - - props = {"nxyz": [torch.Tensor(i) - for i in nxyz]} +def make_loader(nxyz, nbr_list, num_atoms, needs_nbrs, cutoff, cutoff_skin, device, batch_size): + props = {"nxyz": [torch.Tensor(i) for i in nxyz]} - dataset = Dataset(props=props, - units='kcal/mol', - check_props=True) + dataset = Dataset(props=props, units="kcal/mol", check_props=True) if needs_nbrs or nbr_list is None: - nbrs = single_spec_nbrs(dset=dataset, - cutoff=(cutoff + - cutoff_skin), - device=device, - directed=True) - dataset.props['nbr_list'] = nbrs + nbrs = single_spec_nbrs(dset=dataset, cutoff=(cutoff + cutoff_skin), device=device, directed=True) + dataset.props["nbr_list"] = nbrs else: - dataset.props['nbr_list'] = nbr_list + dataset.props["nbr_list"] = nbr_list - loader = DataLoader(dataset, - batch_size=batch_size, - collate_fn=collate_dicts) + loader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_dicts) return loader @@ -516,55 +402,58 @@ def my_func(*args, **kwargs): # @timing -def get_results(model, - nxyz, - nbr_list, - num_atoms, - needs_nbrs, - cutoff, - cutoff_skin, - device, - batch_size, - old_U, - num_states, - surf, - max_gap_hop, - all_engrads, - nacv, - diabat_keys): +def get_results( + model, + nxyz, + nbr_list, + num_atoms, + needs_nbrs, + cutoff, + cutoff_skin, + device, + batch_size, + old_U, + num_states, + surf, + max_gap_hop, + all_engrads, + nacv, + diabat_keys, +): """ `nxyz_list` assumed to be in Angstroms """ - loader = make_loader(nxyz=nxyz, - nbr_list=nbr_list, - num_atoms=num_atoms, - needs_nbrs=needs_nbrs, - cutoff=cutoff, - cutoff_skin=cutoff_skin, - device=device, - batch_size=batch_size) + loader = make_loader( + nxyz=nxyz, + nbr_list=nbr_list, + num_atoms=num_atoms, + needs_nbrs=needs_nbrs, + cutoff=cutoff, + cutoff_skin=cutoff_skin, + device=device, + batch_size=batch_size, + ) results_list = [] for batch in loader: - results = batched_calc(model=model, - batch=batch, - device=device, - num_states=num_states, - surf=surf, - max_gap_hop=max_gap_hop, - all_engrads=all_engrads, - nacv=nacv) + results = batched_calc( + model=model, + batch=batch, + device=device, + num_states=num_states, + surf=surf, + max_gap_hop=max_gap_hop, + all_engrads=all_engrads, + nacv=nacv, + ) results_list.append(results) - all_results = concat_and_conv(results_list=results_list, - num_atoms=num_atoms, - diabat_keys=diabat_keys) + all_results = concat_and_conv(results_list=results_list, num_atoms=num_atoms, diabat_keys=diabat_keys) if old_U is not None: - all_results = correct_nacv(results=all_results, - old_U=old_U, - num_atoms=[num_atoms] * old_U.shape[0], - num_states=num_states) + all_results = correct_nacv( + results=all_results, old_U=old_U, num_atoms=[num_atoms] * old_U.shape[0], num_states=num_states + ) return all_results @@ -572,7 +461,7 @@ def get_results(model, def coords_to_xyz(coords): nxyz = [] for dic in coords: - directions = ['x', 'y', 'z'] + directions = ["x", "y", "z"] n = float(PERIODICTABLE.GetAtomicNumber(dic["element"])) xyz = [dic[i] for i in directions] nxyz.append([n, *xyz]) @@ -580,36 +469,27 @@ def coords_to_xyz(coords): def load_json(file): - - with open(file, 'r') as f: + with open(file, "r") as f: info = json.load(f) - if 'details' in info: - details = info['details'] - else: - details = {} - all_params = {key: val for key, val in info.items() - if key != "details"} + details = info.get("details", {}) + all_params = {key: val for key, val in info.items() if key != "details"} all_params.update(details) return all_params -def make_dataset(nxyz, - ground_params): - props = { - 'nxyz': [torch.Tensor(nxyz)] - } +def make_dataset(nxyz, ground_params): + props = {"nxyz": [torch.Tensor(nxyz)]} cutoff = ground_params["cutoff"] cutoff_skin = ground_params["cutoff_skin"] - dataset = Dataset(props.copy(), units='kcal/mol') - dataset.generate_neighbor_list(cutoff=(cutoff + cutoff_skin), - undirected=False) + dataset = Dataset(props.copy(), units="kcal/mol") + dataset.generate_neighbor_list(cutoff=(cutoff + cutoff_skin), undirected=False) model_type = ground_params["model_type"] - needs_angles = (model_type in ANGLE_MODELS) + needs_angles = model_type in ANGLE_MODELS if needs_angles: dataset.generate_angle_list() @@ -626,13 +506,8 @@ def get_batched_props(dataset): return batched_props -def add_calculator(atomsbatch, - model_path, - model_type, - device, - batched_props): - - needs_angles = (model_type in ANGLE_MODELS) +def add_calculator(atomsbatch, model_path, model_type, device, batched_props): + needs_angles = model_type in ANGLE_MODELS nff_ase = NeuralFF.from_file( model_path=model_path, @@ -642,41 +517,40 @@ def add_calculator(atomsbatch, params=None, model_type=model_type, needs_angles=needs_angles, - dataset_props=batched_props + dataset_props=batched_props, ) atomsbatch.set_calculator(nff_ase) -def get_atoms(ground_params, - all_params): - +def get_atoms(ground_params, all_params): coords = all_params["coords"] nxyz = coords_to_xyz(coords) - atoms = Atoms(nxyz[:, 0], - positions=nxyz[:, 1:]) + atoms = Atoms(nxyz[:, 0], positions=nxyz[:, 1:]) - dataset, needs_angles = make_dataset(nxyz=nxyz, - ground_params=ground_params) + dataset, needs_angles = make_dataset(nxyz=nxyz, ground_params=ground_params) batched_props = get_batched_props(dataset) - device = ground_params.get('device', 'cuda') + device = ground_params.get("device", "cuda") - atomsbatch = AtomsBatch.from_atoms(atoms=atoms, - props=batched_props, - needs_angles=needs_angles, - device=device, - undirected=False, - cutoff_skin=ground_params['cutoff_skin']) + atomsbatch = AtomsBatch.from_atoms( + atoms=atoms, + props=batched_props, + needs_angles=needs_angles, + device=device, + undirected=False, + cutoff_skin=ground_params["cutoff_skin"], + ) - if 'model_path' in all_params: - model_path = all_params['model_path'] + if "model_path" in all_params: + model_path = all_params["model_path"] else: - model_path = os.path.join(all_params['weightpath'], - str(all_params["nnid"])) - add_calculator(atomsbatch=atomsbatch, - model_path=model_path, - model_type=ground_params["model_type"], - device=device, - batched_props=batched_props) + model_path = os.path.join(all_params["weightpath"], str(all_params["nnid"])) + add_calculator( + atomsbatch=atomsbatch, + model_path=model_path, + model_type=ground_params["model_type"], + device=device, + batched_props=batched_props, + ) return atomsbatch diff --git a/nff/md/tully/step.py b/nff/md/tully/step.py index a631e150..603dfdbd 100644 --- a/nff/md/tully/step.py +++ b/nff/md/tully/step.py @@ -9,41 +9,30 @@ import torch -def compute_T(nacv, - vel, - c): - +def compute_T(nacv, vel, c): # vel has shape num_samples x num_atoms x 3 # nacv has shape num_samples x num_states x num_states # x num_atoms x 3 # T has shape num_samples x (num_states x num_states) - T = (vel.reshape(vel.shape[0], 1, 1, -1, 3) - * nacv).sum((-1, -2)) + T = (vel.reshape(vel.shape[0], 1, 1, -1, 3) * nacv).sum((-1, -2)) # anything that's nan has too big a gap # for hopping and should therefore have T=0 T[np.isnan(T)] = 0 num_states = nacv.shape[1] - coupling = np.einsum('nij, nj-> ni', T, c[:, :num_states]) + coupling = np.einsum("nij, nj-> ni", T, c[:, :num_states]) return T, coupling -def get_dc_dt(c, - vel, - nacv, - energy, - hbar=1): - +def get_dc_dt(c, vel, nacv, energy, hbar=1): # energies have shape num_samples x num_states w = energy / hbar # T has dimension num_samples x (num_states x num_states) - T, coupling = compute_T(nacv=nacv, - vel=vel, - c=c) + T, coupling = compute_T(nacv=nacv, vel=vel, c=c) dc_dt = -(1j * w * c + coupling) @@ -55,33 +44,21 @@ def get_a(c): num_samples = c.shape[0] num_states = c.shape[1] - a = np.zeros((num_samples, num_states, num_states), - dtype='complex128') + a = np.zeros((num_samples, num_states, num_states), dtype="complex128") for i in range(num_states): for j in range(num_states): - a[..., i, j] = (np.conj(c[..., i]) - * c[..., j]) + a[..., i, j] = np.conj(c[..., i]) * c[..., j] return a -def remove_self_hop(p, - surfs): - +def remove_self_hop(p, surfs): same_surfs = surfs.reshape(-1, 1) - np.put_along_axis(p, - same_surfs, - np.zeros_like(same_surfs), - axis=-1) + np.put_along_axis(p, same_surfs, np.zeros_like(same_surfs), axis=-1) return p -def get_tully_p(c, - T, - dt, - surfs, - num_adiabat, - **kwargs): +def get_tully_p(c, T, dt, surfs, num_adiabat, **kwargs): """ Tully surface hopping probability """ @@ -95,12 +72,10 @@ def get_tully_p(c, b = -2 * np.real(np.conj(a) * T) # a_surf has dimension num_samples x 1 - a_surf = np.stack([sample_a[surf, surf] for - sample_a, surf in zip(a, surfs)]).reshape(-1, 1) + a_surf = np.stack([sample_a[surf, surf] for sample_a, surf in zip(a, surfs)]).reshape(-1, 1) # b_surf has dimension num_samples x num_states - b_surf = np.stack([sample_b[:, surf] for - sample_b, surf in zip(b, surfs)]) + b_surf = np.stack([sample_b[:, surf] for sample_b, surf in zip(b, surfs)]) # p has dimension num_samples x num_states, for the # hopping probability of each sample to all other @@ -112,8 +87,7 @@ def get_tully_p(c, p = np.real(dt * b_surf / a_surf) # no hopping from current state to self - p = remove_self_hop(p=p, - surfs=surfs) + p = remove_self_hop(p=p, surfs=surfs) # only hop among adiabatic states of interest p = p[:, :num_adiabat] @@ -121,12 +95,7 @@ def get_tully_p(c, return p -def get_sharc_p(old_c, - new_c, - P, - surfs, - num_adiabat, - **kwargs): +def get_sharc_p(old_c, new_c, P, surfs, num_adiabat, **kwargs): """ P is the propagator. """ @@ -134,52 +103,32 @@ def get_sharc_p(old_c, num_samples = old_c.shape[0] num_states = old_c.shape[1] - other_surfs = get_other_surfs(surfs=surfs, - num_states=num_states, - num_samples=num_samples) + other_surfs = get_other_surfs(surfs=surfs, num_states=num_states, num_samples=num_samples) - c_beta_t = np.take_along_axis(old_c, - surfs.reshape(-1, 1), - axis=-1) - c_beta_dt = np.take_along_axis(new_c, - surfs.reshape(-1, 1), - axis=-1) + c_beta_t = np.take_along_axis(old_c, surfs.reshape(-1, 1), axis=-1) + c_beta_dt = np.take_along_axis(new_c, surfs.reshape(-1, 1), axis=-1) - c_alpha_dt = np.take_along_axis(new_c, - other_surfs, - axis=-1) + c_alpha_dt = np.take_along_axis(new_c, other_surfs, axis=-1) # `P` has dimension num_samples x num_states x num_states - P_alpha_beta = np.take_along_axis(np.take_along_axis( - P, - surfs.reshape(-1, 1, 1), - axis=-1).squeeze(-1), - other_surfs, - axis=-1 + P_alpha_beta = np.take_along_axis( + np.take_along_axis(P, surfs.reshape(-1, 1, 1), axis=-1).squeeze(-1), other_surfs, axis=-1 ) - P_beta_beta = np.take_along_axis(np.take_along_axis( - P, - surfs.reshape(-1, 1, 1), - axis=-1).squeeze(-1), - surfs.reshape(-1, 1), - axis=-1 + P_beta_beta = np.take_along_axis( + np.take_along_axis(P, surfs.reshape(-1, 1, 1), axis=-1).squeeze(-1), surfs.reshape(-1, 1), axis=-1 ) # h_alpha is the transition probability from the current state # to alpha num = np.real(c_alpha_dt * np.conj(P_alpha_beta) * np.conj(c_beta_t)) - denom = abs(c_beta_t) ** 2 - np.real(c_beta_dt * np.conj(P_beta_beta) - * np.conj(c_beta_t)) + denom = abs(c_beta_t) ** 2 - np.real(c_beta_dt * np.conj(P_beta_beta) * np.conj(c_beta_t)) pref = 1 - abs(c_beta_dt) ** 2 / abs(c_beta_t) ** 2 h = np.zeros((num_samples, num_states)) - np.put_along_axis(h, - other_surfs, - pref * num / denom, - axis=-1) + np.put_along_axis(h, other_surfs, pref * num / denom, axis=-1) h[h < 0] = 0 # only hop among adiabatic states of interest @@ -188,12 +137,10 @@ def get_sharc_p(old_c, return h -def get_p_hop(hop_eqn='sharc', - **kwargs): - - if hop_eqn == 'sharc': +def get_p_hop(hop_eqn="sharc", **kwargs): + if hop_eqn == "sharc": p = get_sharc_p(**kwargs) - elif hop_eqn == 'tully': + elif hop_eqn == "tully": p = get_tully_p(**kwargs) else: raise NotImplementedError @@ -201,15 +148,9 @@ def get_p_hop(hop_eqn='sharc', return p -def get_new_surf(p_hop, - surfs, - max_gap_hop, - energy): - +def get_new_surf(p_hop, surfs, max_gap_hop, energy): num_samples = p_hop.shape[0] - lhs = np.concatenate([np.zeros(num_samples).reshape(-1, 1), - p_hop.cumsum(axis=-1)], - axis=-1)[:, :-1] + lhs = np.concatenate([np.zeros(num_samples).reshape(-1, 1), p_hop.cumsum(axis=-1)], axis=-1)[:, :-1] rhs = lhs + p_hop r = np.random.rand(num_samples).reshape(-1, 1) hop = (lhs < r) * (r <= rhs) @@ -221,12 +162,8 @@ def get_new_surf(p_hop, if max_gap_hop is None: return new_surfs - old_en = np.take_along_axis(energy, - surfs.reshape(-1, 1), - axis=-1).squeeze(-1) - new_en = np.take_along_axis(energy, - new_surfs.reshape(-1, 1), - axis=-1).squeeze(-1) + old_en = np.take_along_axis(energy, surfs.reshape(-1, 1), axis=-1).squeeze(-1) + new_en = np.take_along_axis(energy, new_surfs.reshape(-1, 1), axis=-1).squeeze(-1) gaps = abs(old_en - new_en) bad_idx = gaps >= max_gap_hop new_surfs[bad_idx] = surfs[bad_idx] @@ -234,29 +171,20 @@ def get_new_surf(p_hop, return new_surfs -def solve_quadratic(vel, - nac_dir, - old_en, - new_en, - mass): - a = (1 / (2 * mass.reshape(1, -1, 1)) - * nac_dir ** 2).sum((-1, -2)).astype('complex128') - b = (vel * nac_dir).sum((-1, -2)).astype('complex128') - c = (new_en - old_en).astype('complex128') +def solve_quadratic(vel, nac_dir, old_en, new_en, mass): + a = (1 / (2 * mass.reshape(1, -1, 1)) * nac_dir**2).sum((-1, -2)).astype("complex128") + b = (vel * nac_dir).sum((-1, -2)).astype("complex128") + c = (new_en - old_en).astype("complex128") - sqrt = np.sqrt(b ** 2 - 4 * a * c) + sqrt = np.sqrt(b**2 - 4 * a * c) scale_pos = (-b + sqrt) / (2 * a) scale_neg = (-b - sqrt) / (2 * a) # take solution with smallest absolute value of # scaling factor - scales = np.concatenate([scale_pos.reshape(-1, 1), - scale_neg.reshape(-1, 1)], - axis=1) + scales = np.concatenate([scale_pos.reshape(-1, 1), scale_neg.reshape(-1, 1)], axis=1) scale_argmin = np.argmin(abs(scales), axis=1) - scale = np.take_along_axis(scales, - scale_argmin.reshape(-1, 1), - axis=1) + scale = np.take_along_axis(scales, scale_argmin.reshape(-1, 1), axis=1) scale[np.imag(scale) != 0] = np.nan scale = np.real(scale) @@ -264,17 +192,12 @@ def solve_quadratic(vel, return scale -def get_simple_scale(mass, - new_en, - old_en, - vel): - +def get_simple_scale(mass, new_en, old_en, vel): m = mass.reshape(1, -1, 1) gap = old_en - new_en - arg = ((2 * gap + (m * vel ** 2).sum((-1, -2))) - .astype('complex128')) + arg = (2 * gap + (m * vel**2).sum((-1, -2))).astype("complex128") num = np.sqrt(arg) - denom = np.sqrt((m * vel ** 2).sum((-1, -2))) + denom = np.sqrt((m * vel**2).sum((-1, -2))) v_scale = num / denom @@ -285,18 +208,12 @@ def get_simple_scale(mass, return v_scale -def rescale(energy, - vel, - nacv, - mass, - surfs, - new_surfs, - simple_scale): +def rescale(energy, vel, nacv, mass, surfs, new_surfs, simple_scale): """ Velocity re-scaling, from: - Landry, B.R. and Subotnik, J.E., 2012. How to recover Marcus theory with - fewest switches surface hopping: Add just a touch of decoherence. The + Landry, B.R. and Subotnik, J.E., 2012. How to recover Marcus theory with + fewest switches surface hopping: Add just a touch of decoherence. The Journal of chemical physics, 137(22), p.22A513. If no NACV is available, the KE is simply rescaled to conserve energy. @@ -304,28 +221,18 @@ def rescale(energy, """ # old and new energies - old_en = np.take_along_axis(energy, surfs.reshape(-1, 1), - -1).reshape(-1) - new_en = np.take_along_axis(energy, new_surfs.reshape(-1, 1), - -1).reshape(-1) + old_en = np.take_along_axis(energy, surfs.reshape(-1, 1), -1).reshape(-1) + new_en = np.take_along_axis(energy, new_surfs.reshape(-1, 1), -1).reshape(-1) if simple_scale or nacv is None: - v_scale = get_simple_scale(mass=mass, - new_en=new_en, - old_en=old_en, - vel=vel) + v_scale = get_simple_scale(mass=mass, new_en=new_en, old_en=old_en, vel=vel) new_vel = v_scale.reshape(-1, 1, 1) * vel return new_vel # nacvs connecting old to new surfaces ones = [1] * 4 - start_nacv = np.take_along_axis(nacv, surfs - .reshape(-1, *ones), - axis=1) - pair_nacv = np.take_along_axis(start_nacv, new_surfs - .reshape(-1, *ones), - axis=2 - ).squeeze(1).squeeze(1) + start_nacv = np.take_along_axis(nacv, surfs.reshape(-1, *ones), axis=1) + pair_nacv = np.take_along_axis(start_nacv, new_surfs.reshape(-1, *ones), axis=2).squeeze(1).squeeze(1) # nacv unit vector norm = np.linalg.norm(pair_nacv, axis=-1) @@ -334,45 +241,24 @@ def rescale(energy, nac_dir = pair_nacv / norm.reshape(*pair_nacv.shape[:-1], 1) # solve quadratic equation for momentum rescaling - scale = solve_quadratic(vel=vel, - nac_dir=nac_dir, - old_en=old_en, - new_en=new_en, - mass=mass) + scale = solve_quadratic(vel=vel, nac_dir=nac_dir, old_en=old_en, new_en=new_en, mass=mass) # scale the velocity - new_vel = (scale.reshape(-1, 1, 1) * nac_dir - / mass.reshape(1, -1, 1) - + vel) + new_vel = scale.reshape(-1, 1, 1) * nac_dir / mass.reshape(1, -1, 1) + vel return new_vel -def try_hop(c, - p_hop, - surfs, - vel, - nacv, - mass, - energy, - max_gap_hop, - simple_scale): +def try_hop(c, p_hop, surfs, vel, nacv, mass, energy, max_gap_hop, simple_scale): """ `energy` has dimension num_samples x num_states """ - new_surfs = get_new_surf(p_hop=p_hop, - surfs=surfs, - max_gap_hop=max_gap_hop, - energy=energy) + new_surfs = get_new_surf(p_hop=p_hop, surfs=surfs, max_gap_hop=max_gap_hop, energy=energy) - new_vel = rescale(energy=energy, - vel=vel, - nacv=nacv, - mass=mass, - surfs=surfs, - new_surfs=new_surfs, - simple_scale=simple_scale) + new_vel = rescale( + energy=energy, vel=vel, nacv=nacv, mass=mass, surfs=surfs, new_surfs=new_surfs, simple_scale=simple_scale + ) # reset any frustrated hops or things that didn't hop frustrated = np.isnan(new_vel).any((-1, -2)).nonzero()[0] @@ -382,21 +268,12 @@ def try_hop(c, return new_surfs, new_vel -def runge_c(c, - vel, - nacv, - energy, - elec_dt, - hbar=1): +def runge_c(c, vel, nacv, energy, elec_dt, hbar=1): """ Runge-Kutta step for c """ - deriv = partial(get_dc_dt, - vel=vel, - nacv=nacv, - energy=energy, - hbar=hbar) + deriv = partial(get_dc_dt, vel=vel, nacv=nacv, energy=energy, hbar=hbar) k1, T1 = deriv(c) k2, T2 = deriv(c + elec_dt * k1 / 2) @@ -412,37 +289,24 @@ def remove_T_nan(T, S): num_states = S.shape[1] nan_idx = np.bitwise_not(np.isfinite(T)) - num_nan = int(nan_idx.nonzero()[0].reshape(-1).shape[0] - / num_states ** 2) - eye = (np.eye(num_states).reshape(-1, num_states, num_states) - .repeat(num_nan, axis=0)).reshape(-1) + num_nan = int(nan_idx.nonzero()[0].reshape(-1).shape[0] / num_states**2) + eye = (np.eye(num_states).reshape(-1, num_states, num_states).repeat(num_nan, axis=0)).reshape(-1) T[nan_idx] = eye return T -def get_implicit_diabat(c, - elec_substeps, - old_H_ad, - new_H_ad, - new_U, - old_U, - dt, - hbar=1): - +def get_implicit_diabat(c, elec_substeps, old_H_ad, new_H_ad, new_U, old_U, dt, hbar=1): num_ad = c.shape[1] - S = np.einsum('...ki, ...kj -> ...ij', - old_U, new_U)[:, :num_ad, :num_ad] + S = np.einsum("...ki, ...kj -> ...ij", old_U, new_U)[:, :num_ad, :num_ad] - s_t_s = np.einsum('...ji, ...jk -> ...ik', S, S) - lam, O = np.linalg.eigh(s_t_s) + s_t_s = np.einsum("...ji, ...jk -> ...ik", S, S) + lam, o = np.linalg.eigh(s_t_s) # in case any eigenvalues are 0 or slightly negative - with np.errstate(divide='ignore', invalid='ignore'): - lam_half = np.stack([np.diag(i ** (-1 / 2)) - for i in lam]) - T = np.einsum('...ij, ...jk, ...kl, ...ml -> ...im', - S, O, lam_half, O) + with np.errstate(divide="ignore", invalid="ignore"): + lam_half = np.stack([np.diag(i ** (-1 / 2)) for i in lam]) + T = np.einsum("...ij, ...jk, ...kl, ...ml -> ...im", S, o, lam_half, o) # set T to the identity for any cases in which one of # the eigenvalues is 0 @@ -451,59 +315,48 @@ def get_implicit_diabat(c, T_inv = T.transpose(0, 2, 1) old_H_d = old_H_ad - new_H_d = np.einsum('...ij, ...jk, ...lk -> ...il', - T, new_H_ad, T) + new_H_d = np.einsum("...ij, ...jk, ...lk -> ...il", T, new_H_ad, T) return old_H_d, new_H_d, T_inv -def adiabatic_c(c, - elec_substeps, - old_H_plus_nacv, - new_H_plus_nacv, - dt, - hbar=1, - **kwargs): - +def adiabatic_c(c, elec_substeps, old_H_plus_nacv, new_H_plus_nacv, dt, hbar=1, **kwargs): num_samples = old_H_plus_nacv.shape[0] num_states = old_H_plus_nacv.shape[1] n = elec_substeps - exp = (np.eye(num_states, num_states) - .reshape(1, num_states, num_states) - .repeat(num_samples, axis=0)) + exp = np.eye(num_states, num_states).reshape(1, num_states, num_states).repeat(num_samples, axis=0) delta_tau = dt / n for i in range(1, n + 1): - new_exp = torch.tensor( - -1j / hbar * ( - old_H_plus_nacv + i / n * - (new_H_plus_nacv - old_H_plus_nacv) - ) - * delta_tau - ).matrix_exp().numpy() - exp = np.einsum('ijk, ikl -> ijl', exp, new_exp) + new_exp = ( + torch.tensor(-1j / hbar * (old_H_plus_nacv + i / n * (new_H_plus_nacv - old_H_plus_nacv)) * delta_tau) + .matrix_exp() + .numpy() + ) + exp = np.einsum("ijk, ikl -> ijl", exp, new_exp) P = exp - c_new = np.einsum('ijk, ik -> ij', P, c) + c_new = np.einsum("ijk, ik -> ij", P, c) return c_new, P -def diabatic_c(c, - elec_substeps, - new_U, - old_U, - dt, - explicit_diabat, - hbar=1, - old_H_d=None, - new_H_d=None, - old_H_ad=None, - new_H_ad=None, - **kwargs): - +def diabatic_c( + c, + elec_substeps, + new_U, + old_U, + dt, + explicit_diabat, + hbar=1, + old_H_d=None, + new_H_d=None, + old_H_ad=None, + new_H_ad=None, + **kwargs, +): if not explicit_diabat: old_H_d, new_H_d, T_inv = get_implicit_diabat( c=c, @@ -513,61 +366,43 @@ def diabatic_c(c, new_U=new_U, old_U=old_U, dt=dt, - hbar=hbar) + hbar=hbar, + ) num_samples = old_H_d.shape[0] num_states = old_H_d.shape[1] n = elec_substeps - exp = (np.eye(num_states, num_states) - .reshape(1, num_states, num_states) - .repeat(num_samples, axis=0)) + exp = np.eye(num_states, num_states).reshape(1, num_states, num_states).repeat(num_samples, axis=0) delta_tau = dt / n for i in range(1, n + 1): - - new_exp = torch.tensor( - -1j / hbar * ( - old_H_d + i / n * (new_H_d - old_H_d) - ) - * delta_tau - ).matrix_exp().numpy() - exp = np.einsum('ijk, ikl -> ijl', exp, new_exp) + new_exp = torch.tensor(-1j / hbar * (old_H_d + i / n * (new_H_d - old_H_d)) * delta_tau).matrix_exp().numpy() + exp = np.einsum("ijk, ikl -> ijl", exp, new_exp) if explicit_diabat: # new_U has dimension num_samples x num_states x num_states T = old_U T_inv = new_U.transpose(0, 2, 1) - P = np.einsum('ijk, ikl, ilm -> ijm', - T_inv, exp, T) + P = np.einsum("ijk, ikl, ilm -> ijm", T_inv, exp, T) else: # if implicit, T(t) = identity - P = np.einsum('ijk, ikl -> ijl', - T_inv, exp) + P = np.einsum("ijk, ikl -> ijl", T_inv, exp) - c_new = np.einsum('ijk, ik -> ij', P, c) + c_new = np.einsum("ijk, ik -> ij", P, c) # print(abs(c_new[30]) ** 2) return c_new, P -def verlet_step_1(forces, - surfs, - vel, - xyz, - mass, - dt): - +def verlet_step_1(forces, surfs, vel, xyz, mass, dt): # `forces` has dimension (num_samples x num_states # x num_atoms x 3) # `surfs` has dimension `num_samples` - surf_forces = np.take_along_axis( - forces, surfs.reshape(-1, 1, 1, 1), - axis=1 - ).squeeze(1) + surf_forces = np.take_along_axis(forces, surfs.reshape(-1, 1, 1, 1), axis=1).squeeze(1) # `surf_forces` has dimension (num_samples x # num_atoms x 3) @@ -577,22 +412,14 @@ def verlet_step_1(forces, # `vel` and `xyz` each have dimension # (num_samples x num_atoms x 3) - new_xyz = xyz + vel * dt + 0.5 * accel * dt ** 2 + new_xyz = xyz + vel * dt + 0.5 * accel * dt**2 new_vel = vel + 0.5 * dt * accel return new_xyz, new_vel -def verlet_step_2(forces, - surfs, - vel, - mass, - dt): - - surf_forces = np.take_along_axis( - forces, surfs.reshape(-1, 1, 1, 1), - axis=1 - ).squeeze(1) +def verlet_step_2(forces, surfs, vel, mass, dt): + surf_forces = np.take_along_axis(forces, surfs.reshape(-1, 1, 1, 1), axis=1).squeeze(1) accel = surf_forces / mass.reshape(1, -1, 1) new_vel = vel + 0.5 * dt * accel @@ -601,7 +428,6 @@ def verlet_step_2(forces, def delta_F_for_tau(forces): - num_samples = forces.shape[0] num_states = forces.shape[1] num_atoms = forces.shape[-2] @@ -614,21 +440,14 @@ def delta_F_for_tau(forces): def get_diag_delta_R(delta_R): - num_states = delta_R.shape[1] - diag_delta_R = np.take_along_axis(delta_R, np.arange(num_states) - .reshape(1, -1, 1, 1, 1), axis=2 - ).repeat(num_states, axis=2) + diag_delta_R = np.take_along_axis(delta_R, np.arange(num_states).reshape(1, -1, 1, 1, 1), axis=2).repeat( + num_states, axis=2 + ) return diag_delta_R -def get_tau_d(forces, - energy, - force_nacv, - delta_R, - hbar=1, - zeta=1): - +def get_tau_d(forces, energy, force_nacv, delta_R, hbar=1, zeta=1): # tau_d^{ni} has shape num_samples x num_states x num_states # delta_R, delta_P, and force_nacv have shape # num_samples x num_states x num_states x num_atoms x 3 @@ -637,24 +456,14 @@ def get_tau_d(forces, diag_delta_R = get_diag_delta_R(delta_R) term_1 = (delta_F * diag_delta_R).sum((-1, -2)) / (2 * hbar) - term_2 = - 2 * abs(zeta / hbar * ( - force_nacv.transpose((0, 2, 1, 3, 4)) - * diag_delta_R) - .sum((-1, -2)) - ) + term_2 = -2 * abs(zeta / hbar * (force_nacv.transpose((0, 2, 1, 3, 4)) * diag_delta_R).sum((-1, -2))) tau = term_1 + term_2 return tau -def get_tau_reset(forces, - energy, - force_nacv, - delta_R, - hbar=1, - zeta=1): - +def get_tau_reset(forces, energy, force_nacv, delta_R, hbar=1, zeta=1): delta_F = delta_F_for_tau(forces) diag_delta_R = get_diag_delta_R(delta_R) @@ -668,7 +477,7 @@ def matmul(a, b): Matrix multiplication in electronic subspace """ - out = np.einsum('ijk..., ikl...-> ijl...', a, b) + out = np.einsum("ijk..., ikl...-> ijl...", a, b) return out @@ -682,52 +491,33 @@ def commute(a, b): return comm -def get_term_3(nacv, - delta, - vel): +def get_term_3(nacv, delta, vel): num_samples = nacv.shape[0] num_states = nacv.shape[1] num_atoms = nacv.shape[3] - d_beta = np.zeros((num_samples, num_states, num_states, - 3 * num_atoms, 3 * num_atoms)) + d_beta = np.zeros((num_samples, num_states, num_states, 3 * num_atoms, 3 * num_atoms)) - d_beta += nacv.reshape(num_samples, num_states, - num_states, 1, 3 * num_atoms,) + d_beta += nacv.reshape( + num_samples, + num_states, + num_states, + 1, + 3 * num_atoms, + ) - delta_R_alpha = np.zeros((num_samples, - num_states, - num_states, - 3 * num_atoms, - 3 * num_atoms), - dtype='complex128') + delta_R_alpha = np.zeros((num_samples, num_states, num_states, 3 * num_atoms, 3 * num_atoms), dtype="complex128") - delta_R_alpha += delta.reshape(num_samples, - num_states, - num_states, - 3 * num_atoms, 1) + delta_R_alpha += delta.reshape(num_samples, num_states, num_states, 3 * num_atoms, 1) vel_reshape = vel.reshape(num_samples, 1, 1, 1, 3 * num_atoms) - term_3 = -(commute(d_beta, delta_R_alpha) - * vel_reshape - ).sum(-1) - term_3 = term_3.reshape(num_samples, - num_states, - num_states, - num_atoms, - 3) + term_3 = -(commute(d_beta, delta_R_alpha) * vel_reshape).sum(-1) + term_3 = term_3.reshape(num_samples, num_states, num_states, num_atoms, 3) return term_3 -def decoherence_T_R(pot_V, - delta_R, - delta_P, - nacv, - mass, - vel, - hbar=1): - +def decoherence_T_R(pot_V, delta_R, delta_P, nacv, mass, vel, hbar=1): # pot_V has dimension num_samples x num_states x num_states term_1 = -1j / hbar * commute(pot_V, delta_R) @@ -738,70 +528,38 @@ def decoherence_T_R(pot_V, term_2 = delta_P / mass.reshape(1, 1, 1, -1, 1) - term_3 = get_term_3(nacv=nacv, - delta=delta_R, - vel=vel) + term_3 = get_term_3(nacv=nacv, delta=delta_R, vel=vel) T_R = term_1 + term_2 + term_3 return T_R -def decoherence_T_ii(T, - surfs): - T_ii = np.take_along_axis(arr=T, - indices=surfs.reshape(-1, 1, 1, 1, 1), - axis=1 - ).squeeze(1) +def decoherence_T_ii(T, surfs): + T_ii = np.take_along_axis(arr=T, indices=surfs.reshape(-1, 1, 1, 1, 1), axis=1).squeeze(1) - T_ii = np.take_along_axis(arr=T_ii, - indices=surfs.reshape(-1, 1, 1, 1), - axis=1 - ).squeeze(1) + T_ii = np.take_along_axis(arr=T_ii, indices=surfs.reshape(-1, 1, 1, 1), axis=1).squeeze(1) num_states = T.shape[1] num_samples = T.shape[0] - delta = np.eye(num_states).reshape( - 1, - num_states, - num_states, - 1, - 1) + delta = np.eye(num_states).reshape(1, num_states, num_states, 1, 1) T_ii_delta = T_ii.reshape(num_samples, 1, 1, -1, 3) * delta return T_ii_delta -def deriv_delta_R(pot_V, - delta_R, - delta_P, - nacv, - mass, - vel, - surfs, - hbar=1, - **kwargs): - - T_R = decoherence_T_R(pot_V=pot_V, - delta_R=delta_R, - delta_P=delta_P, - nacv=nacv, - mass=mass, - vel=vel, - hbar=hbar) +def deriv_delta_R(pot_V, delta_R, delta_P, nacv, mass, vel, surfs, hbar=1, **kwargs): + T_R = decoherence_T_R(pot_V=pot_V, delta_R=delta_R, delta_P=delta_P, nacv=nacv, mass=mass, vel=vel, hbar=hbar) - T_ii_delta = decoherence_T_ii(T=T_R, - surfs=surfs) + T_ii_delta = decoherence_T_ii(T=T_R, surfs=surfs) deriv = T_R - T_ii_delta return deriv -def get_F_alpha(force_nacv, - forces): - +def get_F_alpha(force_nacv, forces): num_samples = force_nacv.shape[0] num_states = force_nacv.shape[1] num_atoms = force_nacv.shape[3] @@ -810,21 +568,11 @@ def get_F_alpha(force_nacv, row_idx = np.arange(num_states) * num_states idx = diag_idx + row_idx - F_alpha = np.zeros((num_samples, - num_states * num_states, - num_atoms, - 3)) + F_alpha = np.zeros((num_samples, num_states * num_states, num_atoms, 3)) # forces on diagonal - np.put_along_axis(arr=F_alpha, - indices=idx.reshape(1, -1, 1, 1), - values=forces, - axis=1) + np.put_along_axis(arr=F_alpha, indices=idx.reshape(1, -1, 1, 1), values=forces, axis=1) - F_alpha = F_alpha.reshape(num_samples, - num_states, - num_states, - num_atoms, - 3) + F_alpha = F_alpha.reshape(num_samples, num_states, num_states, num_atoms, 3) # - force nacv on off-diagonal (force nacv is the # positive gradient so it needs a negative in front) @@ -835,119 +583,67 @@ def get_F_alpha(force_nacv, return F_alpha -def get_F_alpha_sh(forces, - surfs): - +def get_F_alpha_sh(forces, surfs): num_samples = forces.shape[0] num_states = forces.shape[1] num_atoms = forces.shape[-2] - F_sh = np.take_along_axis(arr=forces, - indices=surfs.reshape(-1, 1, 1, 1), - axis=1) - F_sh = F_sh.reshape(num_samples, - 1, - 1, - num_atoms, - 3) - id_elec = np.eye(num_states, num_states).reshape(1, - num_states, - num_states, - 1, - 1) + F_sh = np.take_along_axis(arr=forces, indices=surfs.reshape(-1, 1, 1, 1), axis=1) + F_sh = F_sh.reshape(num_samples, 1, 1, num_atoms, 3) + id_elec = np.eye(num_states, num_states).reshape(1, num_states, num_states, 1, 1) F_sh_id = F_sh * id_elec return F_sh_id -def get_delta_F(force_nacv, - forces, - surfs): +def get_delta_F(force_nacv, forces, surfs): + F_alpha = get_F_alpha(force_nacv=force_nacv, forces=forces) - F_alpha = get_F_alpha(force_nacv=force_nacv, - forces=forces) - - F_alpha_sh = get_F_alpha_sh(forces=forces, - surfs=surfs) + F_alpha_sh = get_F_alpha_sh(forces=forces, surfs=surfs) delta_F = F_alpha - F_alpha_sh return delta_F -def decoherence_T_P(pot_V, - delta_P, - nacv, - force_nacv, - forces, - surfs, - vel, - sigma, - hbar=1): - +def decoherence_T_P(pot_V, delta_P, nacv, force_nacv, forces, surfs, vel, sigma, hbar=1): term_1 = -1j / hbar * commute(pot_V, delta_P) - delta_F = get_delta_F(force_nacv=force_nacv, - forces=forces, - surfs=surfs) - term_2 = 1 / 2 * (matmul(delta_F, sigma) - + matmul(sigma, delta_F)) + delta_F = get_delta_F(force_nacv=force_nacv, forces=forces, surfs=surfs) + term_2 = 1 / 2 * (matmul(delta_F, sigma) + matmul(sigma, delta_F)) - term_3 = get_term_3(nacv=nacv, - delta=delta_P, - vel=vel) + term_3 = get_term_3(nacv=nacv, delta=delta_P, vel=vel) T_P = term_1 + term_2 + term_3 return T_P # , term_1, term_2, term_3 -def deriv_delta_P(pot_V, - delta_P, - nacv, - force_nacv, - forces, - surfs, - vel, - sigma, - hbar=1, - **kwargs): - - T_P = decoherence_T_P(pot_V=pot_V, - delta_P=delta_P, - nacv=nacv, - force_nacv=force_nacv, - forces=forces, - surfs=surfs, - vel=vel, - sigma=sigma, - hbar=hbar) - - T_ii_delta = decoherence_T_ii(T=T_P, - surfs=surfs) +def deriv_delta_P(pot_V, delta_P, nacv, force_nacv, forces, surfs, vel, sigma, hbar=1, **kwargs): + T_P = decoherence_T_P( + pot_V=pot_V, + delta_P=delta_P, + nacv=nacv, + force_nacv=force_nacv, + forces=forces, + surfs=surfs, + vel=vel, + sigma=sigma, + hbar=hbar, + ) + + T_ii_delta = decoherence_T_ii(T=T_P, surfs=surfs) deriv = T_P - T_ii_delta return deriv -def deriv_sigma(pot_V, - delta_R, - nacv, - force_nacv, - forces, - surfs, - vel, - sigma, - hbar=1, - **kwargs): - - F_alpha = get_F_alpha(force_nacv=force_nacv, - forces=forces) +def deriv_sigma(pot_V, delta_R, nacv, force_nacv, forces, surfs, vel, sigma, hbar=1, **kwargs): + F_alpha = get_F_alpha(force_nacv=force_nacv, forces=forces) term_1 = -1j / hbar * commute(pot_V, sigma) - term_2 = 1j / hbar * commute(F_alpha, - delta_R).sum((-1, -2)) + term_2 = 1j / hbar * commute(F_alpha, delta_R).sum((-1, -2)) # `vel` has shape num_samples x num_atoms x 3 # `nacv` has shape num_samples x num_states x # num_states x num_atoms x 3 @@ -955,88 +651,46 @@ def deriv_sigma(pot_V, num_samples = nacv.shape[0] num_atoms = nacv.shape[-2] - vel_reshape = vel.reshape(num_samples, - 1, - 1, - num_atoms, - 3) - term_3 = (-commute(vel_reshape * nacv, sigma) - .sum((-1, -2))) + vel_reshape = vel.reshape(num_samples, 1, 1, num_atoms, 3) + term_3 = -commute(vel_reshape * nacv, sigma).sum((-1, -2)) deriv = term_1 + term_2 + term_3 return deriv -def get_delta_partials(pot_V, - delta_P, - delta_R, - nacv, - force_nacv, - forces, - surfs, - vel, - sigma, - mass, - hbar=1): - - partial_P = partial(deriv_delta_P, - pot_V=pot_V, - nacv=nacv, - force_nacv=force_nacv, - forces=forces, - surfs=surfs, - vel=vel, - hbar=hbar) +def get_delta_partials(pot_V, delta_P, delta_R, nacv, force_nacv, forces, surfs, vel, sigma, mass, hbar=1): + partial_P = partial( + deriv_delta_P, pot_V=pot_V, nacv=nacv, force_nacv=force_nacv, forces=forces, surfs=surfs, vel=vel, hbar=hbar + ) # missing: delta_R, delta_P - partial_R = partial(deriv_delta_R, - pot_V=pot_V, - nacv=nacv, - mass=mass, - vel=vel, - surfs=surfs, - hbar=hbar) + partial_R = partial(deriv_delta_R, pot_V=pot_V, nacv=nacv, mass=mass, vel=vel, surfs=surfs, hbar=hbar) # missing: delta_R, sigma - partial_sigma = partial(deriv_sigma, - pot_V=pot_V, - nacv=nacv, - force_nacv=force_nacv, - forces=forces, - surfs=surfs, - vel=vel, - hbar=hbar) + partial_sigma = partial( + deriv_sigma, pot_V=pot_V, nacv=nacv, force_nacv=force_nacv, forces=forces, surfs=surfs, vel=vel, hbar=hbar + ) return partial_P, partial_R, partial_sigma -def runge_delta(pot_V, - delta_P, - delta_R, - nacv, - force_nacv, - forces, - surfs, - vel, - sigma, - mass, - elec_dt, - hbar=1): - - derivs = get_delta_partials(pot_V=pot_V, - delta_P=delta_P, - delta_R=delta_R, - nacv=nacv, - force_nacv=force_nacv, - forces=forces, - surfs=surfs, - vel=vel, - sigma=sigma, - mass=mass, - hbar=hbar) +def runge_delta(pot_V, delta_P, delta_R, nacv, force_nacv, forces, surfs, vel, sigma, mass, elec_dt, hbar=1): + derivs = get_delta_partials( + pot_V=pot_V, + delta_P=delta_P, + delta_R=delta_R, + nacv=nacv, + force_nacv=force_nacv, + forces=forces, + surfs=surfs, + vel=vel, + sigma=sigma, + mass=mass, + hbar=hbar, + ) init_vals = [delta_P, delta_R, sigma] intermed_vals = copy.deepcopy(init_vals) @@ -1048,63 +702,39 @@ def runge_delta(pot_V, names = ["delta_P", "delta_R", "sigma"] for i in range(4): - kwargs = {name: val for name, val in - zip(names, intermed_vals)} - k_i = [deriv(**kwargs) - for deriv in derivs] + kwargs = dict(zip(names, intermed_vals)) + k_i = [deriv(**kwargs) for deriv in derivs] intermed_vals = [] for n in range(num_vals): if isinstance(final_vals[n], np.ndarray): - final_vals[n] = final_vals[n].astype('complex128') + final_vals[n] = final_vals[n].astype("complex128") final_vals[n] += k_i[n] * elec_dt * final_weight[i] if i == 3: continue - intermed_vals.append(init_vals[n] + ( - k_i[n] * elec_dt * step_size[i])) + intermed_vals.append(init_vals[n] + (k_i[n] * elec_dt * step_size[i])) return final_vals -def add_decoherence(c, - surfs, - new_surfs, - delta_P, - delta_R, - nacv, - energy, - forces, - mass): +def add_decoherence(c, surfs, new_surfs, delta_P, delta_R, nacv, energy, forces, mass): """ - Landry, B.R. and Subotnik, J.E., 2012. How to recover Marcus theory with - fewest switches surface hopping: Add just a touch of decoherence. The + Landry, B.R. and Subotnik, J.E., 2012. How to recover Marcus theory with + fewest switches surface hopping: Add just a touch of decoherence. The Journal of chemical physics, 137(22), p.22A513. """ - pass - -def get_other_surfs(surfs, - num_states, - num_samples): - all_surfs = (np.arange(num_states).reshape(-1, 1) - .repeat(num_samples, 1).transpose()) +def get_other_surfs(surfs, num_states, num_samples): + all_surfs = np.arange(num_states).reshape(-1, 1).repeat(num_samples, 1).transpose() other_idx = all_surfs != surfs.reshape(-1, 1) other_surfs = all_surfs[other_idx].reshape(num_samples, -1) return other_surfs -def truhlar_decoherence(c, - surfs, - energy, - vel, - dt, - mass, - hbar=1, - C=0.1, - **kwargs): +def truhlar_decoherence(c, surfs, energy, vel, dt, mass, hbar=1, C=0.1, **kwargs): """ Originally attributed to Truhlar, cited from G. Granucci and M. Persico. "Critical appraisal of the @@ -1115,28 +745,18 @@ def truhlar_decoherence(c, num_samples = c.shape[0] num_states = c.shape[1] - other_surfs = get_other_surfs(surfs=surfs, - num_states=num_states, - num_samples=num_samples) + other_surfs = get_other_surfs(surfs=surfs, num_states=num_states, num_samples=num_samples) - c_m = np.take_along_axis(c, - surfs.reshape(-1, 1), - axis=-1) + c_m = np.take_along_axis(c, surfs.reshape(-1, 1), axis=-1) - E_m = np.take_along_axis(energy, - surfs.reshape(-1, 1), - axis=-1) + E_m = np.take_along_axis(energy, surfs.reshape(-1, 1), axis=-1) - c_k = np.take_along_axis(c, - other_surfs, - axis=-1) + c_k = np.take_along_axis(c, other_surfs, axis=-1) - E_k = np.take_along_axis(energy, - other_surfs, - axis=-1) + E_k = np.take_along_axis(energy, other_surfs, axis=-1) # vel has shape num_samples x num_atoms x 3 - E_kin = (1 / 2 * mass.reshape(1, -1, 1) * vel ** 2).sum((-1, -2)) + E_kin = (1 / 2 * mass.reshape(1, -1, 1) * vel**2).sum((-1, -2)) tau_km = hbar / abs(E_k - E_m) * (1 + C / E_kin.reshape(-1, 1)) c_k_prime = c_k * np.exp(-dt / tau_km) @@ -1146,20 +766,11 @@ def truhlar_decoherence(c, num[num < 0] = 0 - c_m_prime = c_m * ( - num.reshape(-1, 1) - / abs(c_m) ** 2 - ) ** 0.5 + c_m_prime = c_m * (num.reshape(-1, 1) / abs(c_m) ** 2) ** 0.5 new_c = np.zeros_like(c) - np.put_along_axis(new_c, - surfs.reshape(-1, 1), - c_m_prime, - axis=-1) - - np.put_along_axis(new_c, - other_surfs, - c_k_prime, - axis=-1) + np.put_along_axis(new_c, surfs.reshape(-1, 1), c_m_prime, axis=-1) + + np.put_along_axis(new_c, other_surfs, c_k_prime, axis=-1) return new_c diff --git a/nff/md/tully_multiplicity/dynamics.py b/nff/md/tully_multiplicity/dynamics.py index 496dd420..9f3046db 100644 --- a/nff/md/tully_multiplicity/dynamics.py +++ b/nff/md/tully_multiplicity/dynamics.py @@ -8,87 +8,80 @@ time, so we need to use PyTorch to do it efficiently. """ -import numpy as np -import pickle -import os import copy import json -import random import math -import argparse -from functools import partial +import os +import pickle +import random import shutil +from functools import partial +from typing import List -from tqdm import tqdm - -from typing import * - -from ase.io.trajectory import Trajectory +import numpy as np from ase import Atoms +from ase.io.trajectory import Trajectory +from tqdm import tqdm +from nff.md.nvt_ax import NoseHoover, NoseHooverChain +from nff.md.tully_multiplicity.io import get_atoms, get_results, load_json +from nff.md.tully_multiplicity.step import ( + adiabatic_c, + get_p_hop, + truhlar_decoherence, + try_hop, + verlet_step_1, + verlet_step_2, +) +from nff.md.utils_ax import atoms_to_nxyz from nff.train import load_model from nff.utils import constants as const -from nff.md.utils_ax import atoms_to_nxyz -from nff.md.tully_multiplicity.io import get_results, load_json, get_atoms -from nff.md.tully_multiplicity.step import (try_hop, - verlet_step_1, verlet_step_2, - truhlar_decoherence, - adiabatic_c, - compute_T, - get_p_hop) - -from nff.md.nvt_ax import NoseHoover, NoseHooverChain -METHOD_DIC = { - "nosehoover": NoseHoover, - "nosehooverchain": NoseHooverChain -} +METHOD_DIC = {"nosehoover": NoseHoover, "nosehooverchain": NoseHooverChain} DECOHERENCE_DIC = {"truhlar": truhlar_decoherence} -TULLY_LOG_FILE = 'tully.log' -TULLY_SAVE_FILE = 'tully.pickle' - -MODEL_KWARGS = {"add_nacv": False, - "add_grad": True, - "inference": True, - "en_keys_for_grad": ["energy_0"]} +TULLY_LOG_FILE = "tully.log" +TULLY_SAVE_FILE = "tully.pickle" +MODEL_KWARGS = {"add_nacv": False, "add_grad": True, "inference": True, "en_keys_for_grad": ["energy_0"]} class NeuralTully: - def __init__(self, - atoms_list, - device, - batch_size, - adiabatic_keys: List[str], - initial_surf: str, - dt, - elec_substeps: int, - max_time, - cutoff, - model_paths: List, - simple_vel_scale, - hop_eqn, - cutoff_skin, - max_gap_hop, - nbr_update_period, - save_period, - decoherence, - **kwargs): + def __init__( + self, + atoms_list, + device, + batch_size, + adiabatic_keys: List[str], + initial_surf: str, + dt, + elec_substeps: int, + max_time, + cutoff, + model_paths: List, + simple_vel_scale, + hop_eqn, + cutoff_skin, + max_gap_hop, + nbr_update_period, + save_period, + decoherence, + **kwargs, + ): """ `max_gap_hop` in a.u. """ self.atoms_list = atoms_list self.vel = self.get_vel() - + self.device = device self.models = [self.load_model(model_path).to(device) for model_path in model_paths] - + self.T = None self.U_old = None - + self.t = 0 self.props = {} self.num_atoms = len(self.atoms_list[0]) @@ -99,61 +92,61 @@ def __init__(self, self.spinadiabatic_to_statenum = {} self.spinadiabatic_keys = [] for key in adiabatic_keys: - if 'S' in key: - new_key = key+"_ms+0" + if "S" in key: + new_key = key + "_ms+0" self.repeated_keys.append(key) self.spinadiabatic_to_statenum[new_key] = len(self.spinadiabatic_keys) self.spinadiabatic_keys.append(new_key) self.spinadiabatic_to_adiabatic[new_key] = key - elif 'D' in key: - new_key = key+"_ms-1/2" + elif "D" in key: + new_key = key + "_ms-1/2" self.repeated_keys.append(key) self.spinadiabatic_to_statenum[new_key] = len(self.spinadiabatic_keys) self.spinadiabatic_keys.append(new_key) self.spinadiabatic_to_adiabatic[new_key] = key - new_key = key+"_ms+1/2" + new_key = key + "_ms+1/2" self.repeated_keys.append(key) self.spinadiabatic_to_statenum[new_key] = len(self.spinadiabatic_keys) self.spinadiabatic_keys.append(new_key) self.spinadiabatic_to_adiabatic[new_key] = key - elif 'T' in key: - new_key = key+"_ms-1" + elif "T" in key: + new_key = key + "_ms-1" self.repeated_keys.append(key) self.spinadiabatic_to_statenum[new_key] = len(self.spinadiabatic_keys) self.spinadiabatic_keys.append(new_key) self.spinadiabatic_to_adiabatic[new_key] = key - new_key = key+"_ms+0" + new_key = key + "_ms+0" self.repeated_keys.append(key) self.spinadiabatic_to_statenum[new_key] = len(self.spinadiabatic_keys) self.spinadiabatic_keys.append(new_key) self.spinadiabatic_to_adiabatic[new_key] = key - new_key = key+"_ms+1" + new_key = key + "_ms+1" self.repeated_keys.append(key) self.spinadiabatic_to_statenum[new_key] = len(self.spinadiabatic_keys) self.spinadiabatic_keys.append(new_key) self.spinadiabatic_to_adiabatic[new_key] = key - elif 'Q' in key: - new_key = key+"_ms-3/2" + elif "Q" in key: + new_key = key + "_ms-3/2" self.repeated_keys.append(key) self.spinadiabatic_to_statenum[new_key] = len(self.spinadiabatic_keys) self.spinadiabatic_keys.append(new_key) self.spinadiabatic_to_adiabatic[new_key] = key - new_key = key+"_ms-1/2" + new_key = key + "_ms-1/2" self.repeated_keys.append(key) self.spinadiabatic_to_statenum[new_key] = len(self.spinadiabatic_keys) self.spinadiabatic_keys.append(new_key) self.spinadiabatic_to_adiabatic[new_key] = key - new_key = key+"_ms+1/2" + new_key = key + "_ms+1/2" self.repeated_keys.append(key) self.spinadiabatic_to_statenum[new_key] = len(self.spinadiabatic_keys) self.spinadiabatic_keys.append(new_key) self.spinadiabatic_to_adiabatic[new_key] = key - new_key = key+"_ms+3/2" + new_key = key + "_ms+3/2" self.repeated_keys.append(key) self.spinadiabatic_to_statenum[new_key] = len(self.spinadiabatic_keys) self.spinadiabatic_keys.append(new_key) self.spinadiabatic_to_adiabatic[new_key] = key - + self.num_spinadibat = len(self.spinadiabatic_keys) self.num_states = len(self.spinadiabatic_keys) self.initial_surf_num = self.spinadiabatic_to_statenum[initial_surf] @@ -183,14 +176,13 @@ def __init__(self, self.log_template = self.setup_logging() self.p_hop = np.zeros((self.num_samples, self.num_states)) self.just_hopped = None - self.surfs = np.ones(self.num_samples, - dtype=np.int) * self.initial_surf_num + self.surfs = np.ones(self.num_samples, dtype=np.int) * self.initial_surf_num self.c_hmc = self.init_c() - + self.update_props(needs_nbrs=True) self.c_diag = self.get_c_diag() - - #sanity check + + # sanity check if not (self.surfs == np.argmax(np.abs(self.c_diag), axis=1)).all(): print("The states in the diagonal basis got reordered! Adjusting surfs!") self.surfs = np.argmax(np.abs(self.c_diag), axis=1) @@ -198,13 +190,12 @@ def __init__(self, self.setup_save() self.decoherence = self.init_decoherence(params=decoherence) - self.decoherence_type = decoherence['name'] + self.decoherence_type = decoherence["name"] self.hop_eqn = hop_eqn self.simple_vel_scale = simple_vel_scale if os.path.isfile(TULLY_SAVE_FILE): self.restart() - def setup_save(self): if os.path.isfile(self.save_file): @@ -212,37 +203,34 @@ def setup_save(self): def init_decoherence(self, params): if not params: - return + return None - name = params['name'] - kwargs = params.get('kwargs', {}) + name = params["name"] + kwargs = params.get("kwargs", {}) method = DECOHERENCE_DIC[name] - func = partial(method, - **kwargs) + func = partial(method, **kwargs) return func def load_model(self, model_path): - param_path = os.path.join(model_path, 'params.json') - with open(param_path, 'r') as f: + param_path = os.path.join(model_path, "params.json") + with open(param_path, "r") as f: params = json.load(f) - model = load_model(model_path, params, params['model_type']) + model = load_model(model_path, params, params["model_type"]) return model - + @property def mass(self): - _mass = (self.atoms_list[0].get_masses() - * const.AMU_TO_AU) + _mass = self.atoms_list[0].get_masses() * const.AMU_TO_AU return _mass @property def nxyz(self): - _nxyz = np.stack([atoms_to_nxyz(atoms) for atoms in - self.atoms_list]) + _nxyz = np.stack([atoms_to_nxyz(atoms) for atoms in self.atoms_list]) return _nxyz @@ -250,7 +238,7 @@ def nxyz(self): def nxyz(self, _nxyz): for atoms, this_nxyz in zip(self.atoms_list, _nxyz): atoms.set_positions(this_nxyz[:, 1:]) - + @property def xyz(self): _xyz = self.nxyz[..., 1:] @@ -263,80 +251,64 @@ def xyz(self, val): atoms.set_positions(xyz) def get_vel(self): - vel = np.stack([atoms.get_velocities() - for atoms in self.atoms_list]) + vel = np.stack([atoms.get_velocities() for atoms in self.atoms_list]) vel /= const.BOHR_RADIUS * const.ASE_TO_FS * const.FS_TO_AU return vel def init_c(self): - c = np.zeros((self.num_samples, - self.num_states), - dtype='complex128') + c = np.zeros((self.num_samples, self.num_states), dtype="complex128") c[:, self.surfs[0]] = 1 return c - + def get_c_hmc(self): """ state coefficients in the HMC basis """ - c_hmc = np.einsum('ijk,ik->ij', - self.U, - self.c_diag) - + c_hmc = np.einsum("ijk,ik->ij", self.U, self.c_diag) + return c_hmc - + def get_c_diag(self): """ state coefficients in the diagonal basis """ - c_diag = np.einsum('ijk,ik->ij', - self.U.conj().transpose(0,2,1), - self.c_hmc) - + c_diag = np.einsum("ijk,ik->ij", self.U.conj().transpose(0, 2, 1), self.c_hmc) + return c_diag def get_forces(self): - _forces = np.stack([-self.props[f'energy_{key}_grad'] - for key in self.repeated_keys], - axis=1) - _forces = (_forces.reshape(self.num_samples, - -1, - self.num_states, - 3) - .transpose(0, 2, 1, 3)) + _forces = np.stack([-self.props[f"energy_{key}_grad"] for key in self.repeated_keys], axis=1) + _forces = _forces.reshape(self.num_samples, -1, self.num_states, 3).transpose(0, 2, 1, 3) return _forces def get_energy(self): - _energy = np.stack([self.props[f'energy_{key}'].reshape(-1) - for key in self.repeated_keys], - axis=1) + _energy = np.stack([self.props[f"energy_{key}"].reshape(-1) for key in self.repeated_keys], axis=1) return _energy def get_nacv(self): - _nacv = np.zeros((self.num_samples, self.num_states, - self.num_states, self.num_atoms, 3)) + _nacv = np.zeros((self.num_samples, self.num_states, self.num_states, self.num_atoms, 3)) for state_n1 in range(self.num_states): state1 = self.statenum_to_spinadiabatic[state_n1] splits = state1.split("_") adiabat1 = splits[0] - spin_ms1 = splits[1] - + splits[1] + for state_n2 in range(self.num_states): state2 = self.statenum_to_spinadiabatic[state_n2] splits = state2.split("_") adiabat2 = splits[0] - spin_ms2 = splits[1] - + splits[1] + if adiabat1 == adiabat2: continue - elif adiabat1[0] != adiabat2[0]: + if adiabat1[0] != adiabat2[0]: # checks for the same degeneracy continue - - key = f'NACV_{adiabat1}_to_{adiabat2}_grad' + + key = f"NACV_{adiabat1}_to_{adiabat2}_grad" if key not in self.props: continue _nacv[:, state_n1, state_n2, :] = self.props[key] @@ -355,20 +327,15 @@ def get_gap(self): return _gap def get_force_nacv(self): - # self.gap has shape num_samples x num_states x num_states # `nacv` has shape num_samples x num_states x num_states # x num_atoms x 3 nacv = self.nacv if nacv is None: - return + return None - gap = self.gap.reshape(self.num_samples, - self.num_states, - self.num_states, - 1, - 1) + gap = self.gap.reshape(self.num_samples, self.num_states, self.num_states, 1, 1) _force_nacv = -nacv * gap @@ -379,191 +346,156 @@ def get_pot_V(self): Potential energy matrix in n_adiabat x n_adiabat space """ - V = np.zeros((self.num_samples, - self.num_states, - self.num_states)) + V = np.zeros((self.num_samples, self.num_states, self.num_states)) idx = np.arange(self.num_states) - np.put_along_axis( - V, - idx.reshape(1, -1, 1), - self.energy.reshape(self.num_samples, - self.num_states, - 1), - axis=2 - ) + np.put_along_axis(V, idx.reshape(1, -1, 1), self.energy.reshape(self.num_samples, self.num_states, 1), axis=2) return V - + def get_SOC_mat(self): """ Matrix with SOCs in HMC basis """ - - H_soc = np.zeros((self.num_samples, - self.num_states, - self.num_states), - dtype=np.complex128) - + + H_soc = np.zeros((self.num_samples, self.num_states, self.num_states), dtype=np.complex128) + for state_n1 in range(self.num_states): state1 = self.statenum_to_spinadiabatic[state_n1] splits = state1.split("_") adiabat1 = splits[0] spin_ms1 = splits[1] - + for state_n2 in range(self.num_states): state2 = self.statenum_to_spinadiabatic[state_n2] splits = state2.split("_") adiabat2 = splits[0] spin_ms2 = splits[1] - + try: a = self.props[f"SOC_{adiabat1}_to_{adiabat2}_a"] b = self.props[f"SOC_{adiabat1}_to_{adiabat2}_b"] c = self.props[f"SOC_{adiabat1}_to_{adiabat2}_c"] - except: + except BaseException: continue - + ST_soc = False TT_soc = False - if 'S' in adiabat1: + if "S" in adiabat1: ST_soc = True - elif 'T' in adiabat1 and 'T' in adiabat2: + elif "T" in adiabat1 and "T" in adiabat2: TT_soc = True else: raise NotImplementedError - + if ST_soc: - if spin_ms2 == 'ms-1': + if spin_ms2 == "ms-1": soc_val = a + 1j * b - elif spin_ms2 == 'ms+0': - soc_val = 0. + 1j * c - elif spin_ms2 == 'ms+1': + elif spin_ms2 == "ms+0": + soc_val = 0.0 + 1j * c + elif spin_ms2 == "ms+1": soc_val = a - 1j * b - + elif TT_soc: - if spin_ms1 == 'ms-1' and spin_ms2 == 'ms-1': - soc_val = 0. + 1j * c - elif spin_ms1 == 'ms-1' and spin_ms2 == 'ms+0': + if spin_ms1 == "ms-1" and spin_ms2 == "ms-1": + soc_val = 0.0 + 1j * c + elif spin_ms1 == "ms-1" and spin_ms2 == "ms+0": soc_val = -a + 1j * b - elif spin_ms1 == 'ms+0' and spin_ms2 == 'ms-1': + elif spin_ms1 == "ms+0" and spin_ms2 == "ms-1": soc_val = a + 1j * b - elif spin_ms1 == 'ms+0' and spin_ms2 == 'ms+1': + elif spin_ms1 == "ms+0" and spin_ms2 == "ms+1": soc_val = -a + 1j * b - elif spin_ms1 == 'ms+1' and spin_ms2 == 'ms+0': + elif spin_ms1 == "ms+1" and spin_ms2 == "ms+0": soc_val = a + 1j * b - elif spin_ms1 == 'ms+1' and spin_ms2 == 'ms+1': - soc_val = 0. - 1j * c - + elif spin_ms1 == "ms+1" and spin_ms2 == "ms+1": + soc_val = 0.0 - 1j * c + H_soc[:, state_n1, state_n2] = soc_val H_soc[:, state_n2, state_n1] = soc_val.conj() - + return H_soc - + def get_H_hmc(self): """ Sum of potential energy matrix and SOCs """ - + V = self.pot_V.astype(np.complex128) H_hmc = V + self.SOC_mat - - return H_hmc - + + return H_hmc + def get_U(self): """ Diagonalizes H^total """ - + eVals, U = np.linalg.eigh(self.H_hmc) - + return U, eVals def get_H_plus_nacv(self): if self.nacv is None: - return - #pot_V = self.pot_V + return None + # pot_V = self.pot_V H_hmc = self.H_hmc - nac_term = -1j * (self.nacv * - self.vel.reshape(self.num_samples, - 1, - 1, - self.num_atoms, - 3) - ).sum((-1, -2)) - - #return pot_V + nac_term + nac_term = -1j * (self.nacv * self.vel.reshape(self.num_samples, 1, 1, self.num_atoms, 3)).sum((-1, -2)) + + # return pot_V + nac_term return H_hmc + nac_term def get_neg_G_hmc(self): - - neg_G = np.zeros((self.num_samples, - self.num_states, - self.num_states, - self.num_atoms, - 3)) + neg_G = np.zeros((self.num_samples, self.num_states, self.num_states, self.num_atoms, 3)) idx = np.arange(self.num_states) np.put_along_axis( neg_G, idx.reshape(1, -1, 1, 1, 1), - self.forces.reshape(self.num_samples, - self.num_states, - 1, - self.num_atoms, - 3), - axis=2 + self.forces.reshape(self.num_samples, self.num_states, 1, self.num_atoms, 3), + axis=2, ) neg_G += self.force_nacv - + return neg_G - + def get_neg_G_diag(self): - - neg_G_diag = np.einsum('ijk,ikl...,ilm->ijm...', - self.U.conj().transpose((0,2,1)), - self.neg_G_hmc, - self.U) - + neg_G_diag = np.einsum("ijk,ikl...,ilm->ijm...", self.U.conj().transpose((0, 2, 1)), self.neg_G_hmc, self.U) + return neg_G_diag - + def get_diag_energy(self): - H_diag = np.einsum('ijk,ikl,ilm->ijm', - self.U.conj().transpose((0,2,1)), - self.H_hmc, - self.U) + H_diag = np.einsum("ijk,ikl,ilm->ijm", self.U.conj().transpose((0, 2, 1)), self.H_hmc, self.U) idxs = np.arange(self.num_states) - _energy = np.take_along_axis(np.real(H_diag), - idxs.reshape(1, -1, 1), axis=2) - + _energy = np.take_along_axis(np.real(H_diag), idxs.reshape(1, -1, 1), axis=2) + return _energy.reshape(self.num_samples, self.num_states) - + def get_diag_forces(self): - - diag_forces = np.diagonal(self.neg_G_diag, - axis1=1, axis2=2).transpose((0,3,1,2)) - + diag_forces = np.diagonal(self.neg_G_diag, axis1=1, axis2=2).transpose((0, 3, 1, 2)) + return np.real(diag_forces) @property def state_dict(self): - _state_dict = {"nxyz": self.nxyz, - "nacv": self.nacv, - #"force_nacv": self.force_nacv, - "energy": self.energy, - "forces": self.forces, - #"H_d": self.H_d, - "U": self.U, - "t": self.t / const.FS_TO_AU, - "vel": self.vel, - "c_hmc": self.c_hmc, - "c_diag": self.c_diag, - #"T": self.T, - "surfs": self.surfs} + _state_dict = { + "nxyz": self.nxyz, + "nacv": self.nacv, + # "force_nacv": self.force_nacv, + "energy": self.energy, + "forces": self.forces, + # "H_d": self.H_d, + "U": self.U, + "t": self.t / const.FS_TO_AU, + "vel": self.vel, + "c_hmc": self.c_hmc, + "c_diag": self.c_diag, + # "T": self.T, + "surfs": self.surfs, + } return _state_dict - + @state_dict.setter def state_dict(self, dic): for key, val in dic.items(): - if key in ['force_nacv']: + if key in ["force_nacv"]: continue setattr(self, key, val) self.t *= const.FS_TO_AU @@ -582,7 +514,7 @@ def save(self, idx=None): idx = set(idx) for key, val in state_dict.items(): - if key == 't': + if key == "t": continue if val is None: continue @@ -592,7 +524,7 @@ def save(self, idx=None): this_val = v if (i in idx) else None use_val.append(this_val) use_dict[key] = use_val - use_dict['t'] = state_dict['t'] + use_dict["t"] = state_dict["t"] with open(self.save_file, "ab") as f: pickle.dump(use_dict, f) @@ -602,7 +534,6 @@ def restart(self): self.state_dict = state_dicts[-1] def setup_logging(self, remove_old=True): - states = [f"State {i}" for i in range(self.num_states)] hdr = "%-9s " % "Time [fs]" for state in states: @@ -611,11 +542,11 @@ def setup_logging(self, remove_old=True): hdr += "%15s " % "Hop prob." if not os.path.isfile(self.log_file) or remove_old: - with open(self.log_file, 'w') as f: + with open(self.log_file, "w") as f: f.write(hdr) template = "%-10.2f " - for i, state in enumerate(states): + for _ in states: template += "%15.6f" template += "%15.4f" template += "%15.4f" @@ -624,8 +555,7 @@ def setup_logging(self, remove_old=True): def clean_c_p(self): c_states = self.c.shape[-1] - c = (self.c[np.bitwise_not(np.isnan(self.c))] - .reshape(-1, c_states)) + c = self.c[np.bitwise_not(np.isnan(self.c))].reshape(-1, c_states) p_states = self.p_hop.shape[-1] p_nan_idx = np.isnan(self.p_hop).any(-1) @@ -636,67 +566,66 @@ def clean_c_p(self): def log(self): time = self.t / const.FS_TO_AU -# pcts = [] -# for i in range(self.num_states): -# num_surf = (self.surfs == i).sum() -# pct = num_surf / self.num_samples * 100 -# pcts.append(pct) - + # pcts = [] + # for i in range(self.num_states): + # num_surf = (self.surfs == i).sum() + # pct = num_surf / self.num_samples * 100 + # pcts.append(pct) + pcts_hmc = [] argmax = np.argmax(np.abs(self.c_hmc), axis=1) for i in range(self.num_states): num_surf = (argmax == i).sum() - pct = num_surf / self.num_samples #* 100 + pct = num_surf / self.num_samples # * 100 pcts_hmc.append(pct) - #c, p = self.clean_c_p() + # c, p = self.clean_c_p() norm_c = np.mean(np.linalg.norm(self.c_hmc, axis=1)) p_avg = np.mean(np.max(self.p_hop, axis=1)) - text = self.log_template % (time, #*pcts, - *pcts_hmc, - norm_c, p_avg) + text = self.log_template % ( + time, # *pcts, + *pcts_hmc, + norm_c, + p_avg, + ) - with open(self.log_file, 'a') as f: + with open(self.log_file, "a") as f: f.write("\n" + text) - + # sanity check if norm_c > 2.0: print("Norm of coefficients too large!!") exit(1) @classmethod - def from_pickle(cls, - file, - max_time=None): - + def from_pickle(cls, file, max_time=None): state_dicts = [] - with open(file, 'rb') as f: + with open(file, "rb") as f: while True: try: state_dict = pickle.load(f) state_dicts.append(state_dict) except EOFError: break - time = state_dict['t'] + time = state_dict["t"] if max_time is not None and time > max_time: break - sample_nxyz = state_dicts[0]['nxyz'] + sample_nxyz = state_dicts[0]["nxyz"] # whether this is a single trajectory or multiple single = len(sample_nxyz.shape) == 2 num_samples = 1 if single else sample_nxyz.shape[0] trjs = [[] for _ in range(num_samples)] for state_dict in state_dicts: - nxyz = state_dict['nxyz'] + nxyz = state_dict["nxyz"] if single: nxyz = [nxyz] - for i, nxyz in enumerate(nxyz): + for i, nxyz in enumerate(nxyz): # noqa if nxyz is None: trjs[i].append(None) continue - atoms = Atoms(nxyz[:, 0], - positions=nxyz[:, 1:]) + atoms = Atoms(nxyz[:, 0], positions=nxyz[:, 1:]) trjs[i].append(atoms) if single: @@ -704,196 +633,170 @@ def from_pickle(cls, return state_dicts, trjs - def update_props(self, - needs_nbrs): - - props = get_results(models=self.models, - nxyz=self.nxyz, - nbr_list=self.nbr_list, - num_atoms=self.num_atoms, - needs_nbrs=needs_nbrs, - cutoff=self.cutoff, - cutoff_skin=self.cutoff_skin, - device=self.device, - batch_size=self.batch_size,) + def update_props(self, needs_nbrs): + props = get_results( + models=self.models, + nxyz=self.nxyz, + nbr_list=self.nbr_list, + num_atoms=self.num_atoms, + needs_nbrs=needs_nbrs, + cutoff=self.cutoff, + cutoff_skin=self.cutoff_skin, + device=self.device, + batch_size=self.batch_size, + ) self.props = props self.update_selfs() - + def update_selfs(self): # simple reorganiation of NN outputs - self.energy = self.get_energy() - self.forces = self.get_forces() - self.nacv = self.get_nacv() - self.gap = self.get_gap() - self.force_nacv = self.get_force_nacv() - self.pot_V = self.get_pot_V() - + self.energy = self.get_energy() + self.forces = self.get_forces() + self.nacv = self.get_nacv() + self.gap = self.get_gap() + self.force_nacv = self.get_force_nacv() + self.pot_V = self.get_pot_V() + # assembly of complicated matrices - self.SOC_mat = self.get_SOC_mat() - self.H_hmc = self.get_H_hmc() + self.SOC_mat = self.get_SOC_mat() + self.H_hmc = self.get_H_hmc() self.H_plus_nacv = self.get_H_plus_nacv() - + # diagonalization of HMC representation - self.U, evals = self.get_U() - if type(self.U_old) != type(None): + self.U, evals = self.get_U() + + if self.U_old is not None: # the following is an implementation of Appendix B - # from Mai, Marquetand, Gonzalez + # from Mai, Marquetand, Gonzalez # Int.J. Quant. Chem. 2015, 115, 1215-1231 - V = np.einsum('ijk,ikl->ijl', - self.U.conj().transpose((0,2,1)), - self.U_old) - + V = np.einsum("ijk,ikl->ijl", self.U.conj().transpose((0, 2, 1)), self.U_old) + # attempt to make V diagonally dominant for replica in range(self.num_samples): - abs_v = np.abs(V[replica]) + abs_v = np.abs(V[replica]) arg_max = np.argmax(abs_v, axis=1) # sanity check print statement -# if len(np.unique(arg_max)) < len(arg_max): -# print("V could not be made diagonal dominant!") - + # if len(np.unique(arg_max)) < len(arg_max): + # print("V could not be made diagonal dominant!") + for column in range(self.num_states): curr_col = copy.deepcopy(V[replica][:, column]) - new_col = copy.deepcopy(V[replica][:, arg_max[column]]) + new_col = copy.deepcopy(V[replica][:, arg_max[column]]) # switch columns V[replica][:, column] = new_col V[replica][:, arg_max[column]] = curr_col - - # (CV)_{ab} = V_{ab} delta(Hdiag_aa - Hdiag_bb) - # setting everything to zero where + + # (CV)_{ab} = V_{ab} delta(Hdiag_aa - Hdiag_bb) + # setting everything to zero where # the difference in diagonal elements is NOT zero # for replica, hdiag in enumerate(evals): hdiag = evals.reshape((self.num_samples, self.num_states, 1)) - diff = hdiag - hdiag.transpose((0, 2, 1)) - preserved_idxs = np.isclose(diff, np.zeros(shape=diff.shape), - atol=1e-8, rtol=0.0) + diff = hdiag - hdiag.transpose((0, 2, 1)) + preserved_idxs = np.isclose(diff, np.zeros(shape=diff.shape), atol=1e-8, rtol=0.0) V[~preserved_idxs] = 0.0 - + # Loewding symmetric orthonormalization u, s, vh = np.linalg.svd(V) - Phi_adj = np.einsum('ijk, ikl->ijl', u, vh) - - corrected_U = np.einsum('ijk, ikl->ijl', self.U, Phi_adj) + Phi_adj = np.einsum("ijk, ikl->ijl", u, vh) + + corrected_U = np.einsum("ijk, ikl->ijl", self.U, Phi_adj) self.U = copy.deepcopy(corrected_U) - + # check eq B11 - epsilon = 0.1 # hardcoded for now - diagonals = np.einsum('ijj->ij', np.einsum('ijk,ikl->ijl', - self.U.conj().transpose((0,2,1)), - self.U_old)) - anti_hermitian = ((1 - epsilon) < diagonals).all() -# if not anti_hermitian: -# print("WARNING: Time step likely too large! At least one new unitary matrix ", -# "does not fulfill anti-hermicity!") -# print(f"epsilon = {epsilon}") -# print("diagonal elements:\n", diagonals) -# print("H_diag:\n", evals) -# print(V) -# print(preserved_idxs) - - self.U_old = copy.deepcopy(self.U) - - self.neg_G_hmc = self.get_neg_G_hmc() - self.neg_G_diag = self.get_neg_G_diag() + epsilon = 0.1 # hardcoded for now + diagonals = np.einsum("ijj->ij", np.einsum("ijk,ikl->ijl", self.U.conj().transpose((0, 2, 1)), self.U_old)) + ((1 - epsilon) < diagonals).all() + # if not anti_hermitian: + # print("WARNING: Time step likely too large! At least one new unitary matrix ", + # "does not fulfill anti-hermicity!") + # print(f"epsilon = {epsilon}") + # print("diagonal elements:\n", diagonals) + # print("H_diag:\n", evals) + # print(V) + # print(preserved_idxs) + + self.U_old = copy.deepcopy(self.U) + + self.neg_G_hmc = self.get_neg_G_hmc() + self.neg_G_diag = self.get_neg_G_diag() self.diag_energy = self.get_diag_energy() self.diag_forces = self.get_diag_forces() - - def do_hop(self, - old_c, - new_c, - P): - - self.p_hop = get_p_hop(hop_eqn=self.hop_eqn, - old_c=old_c, - new_c=new_c, - P=P, - surfs=self.surfs) - - new_surfs, new_vel = try_hop(p_hop=self.p_hop, - surfs=self.surfs, - vel=self.vel, - nacv=self.nacv, - mass=self.mass, - energy=self.energy, - max_gap_hop=self.max_gap_hop, - simple_scale=self.simple_vel_scale) + + def do_hop(self, old_c, new_c, P): + self.p_hop = get_p_hop(hop_eqn=self.hop_eqn, old_c=old_c, new_c=new_c, P=P, surfs=self.surfs) + + new_surfs, new_vel = try_hop( + p_hop=self.p_hop, + surfs=self.surfs, + vel=self.vel, + nacv=self.nacv, + mass=self.mass, + energy=self.energy, + max_gap_hop=self.max_gap_hop, + simple_scale=self.simple_vel_scale, + ) return new_surfs, new_vel def add_decoherence(self): - if not self.decoherence: return - - self.c_diag = self.decoherence(c=self.c_diag, - surfs=self.surfs, - energy=self.diag_energy, - vel=self.vel, - dt=self.dt, - mass=self.mass) - + + self.c_diag = self.decoherence( + c=self.c_diag, surfs=self.surfs, energy=self.diag_energy, vel=self.vel, dt=self.dt, mass=self.mass + ) + self.c_hmc = self.get_c_hmc() def step(self, needs_nbrs): - - old_V = copy.deepcopy(self.pot_V) - old_H_hmc = copy.deepcopy(self.H_hmc) + copy.deepcopy(self.pot_V) + copy.deepcopy(self.H_hmc) old_H_plus_nacv = copy.deepcopy(self.H_plus_nacv) - old_U = copy.deepcopy(self.U) + old_U = copy.deepcopy(self.U) - old_c_hmc = copy.deepcopy(self.c_hmc) - old_c_diag= copy.deepcopy(self.c_diag) + copy.deepcopy(self.c_hmc) + old_c_diag = copy.deepcopy(self.c_diag) # xyz converted to a.u. for the step and then # back to Angstrom after - new_xyz, new_vel = verlet_step_1(self.diag_forces, - self.surfs, - vel=self.vel, - xyz=self.xyz / const.BOHR_RADIUS, - mass=self.mass, - dt=self.dt) + new_xyz, new_vel = verlet_step_1( + self.diag_forces, self.surfs, vel=self.vel, xyz=self.xyz / const.BOHR_RADIUS, mass=self.mass, dt=self.dt + ) self.xyz = new_xyz * const.BOHR_RADIUS self.vel = new_vel # from here on everything is "new" self.update_props(needs_nbrs) - - new_vel = verlet_step_2(forces=self.diag_forces, - surfs=self.surfs, - vel=self.vel, - mass=self.mass, - dt=self.dt) + + new_vel = verlet_step_2(forces=self.diag_forces, surfs=self.surfs, vel=self.vel, mass=self.mass, dt=self.dt) self.vel = new_vel - self.c_hmc, self.P_hmc = adiabatic_c(c=self.c_hmc, - elec_substeps=self.elec_substeps, - old_H_plus_nacv=old_H_plus_nacv, - new_H_plus_nacv=self.H_plus_nacv, - dt=self.dt) - -# print("Norm before/after elec substeps (hop):\n", -# np.linalg.norm(old_c_hmc, axis=1), -# np.linalg.norm(self.c_hmc, axis=1)) - + self.c_hmc, self.P_hmc = adiabatic_c( + c=self.c_hmc, + elec_substeps=self.elec_substeps, + old_H_plus_nacv=old_H_plus_nacv, + new_H_plus_nacv=self.H_plus_nacv, + dt=self.dt, + ) + + # print("Norm before/after elec substeps (hop):\n", + # np.linalg.norm(old_c_hmc, axis=1), + # np.linalg.norm(self.c_hmc, axis=1)) + self.c_diag = self.get_c_diag() - self.P_diag = np.einsum('ijk,ikl,ilm->ijm', - self.U.conj().transpose((0,2,1)), - self.P_hmc, - old_U) - -# if self.nacv is not None: -# self.T, _ = compute_T(nacv=self.nacv, -# vel=self.vel, -# c=self.c_hmc) - - new_surfs, new_vel = self.do_hop(old_c=old_c_diag, - new_c=self.c_diag, - P=self.P_diag) + self.P_diag = np.einsum("ijk,ikl,ilm->ijm", self.U.conj().transpose((0, 2, 1)), self.P_hmc, old_U) + + # if self.nacv is not None: + # self.T, _ = compute_T(nacv=self.nacv, + # vel=self.vel, + # c=self.c_hmc) + + new_surfs, new_vel = self.do_hop(old_c=old_c_diag, new_c=self.c_diag, P=self.P_diag) self.just_hopped = (new_surfs != self.surfs).nonzero()[0] - #if self.just_hopped.any(): - - + # if self.just_hopped.any(): + self.surfs = new_surfs self.vel = new_vel self.t += self.dt @@ -904,60 +807,46 @@ def step(self, needs_nbrs): def run(self): steps = math.ceil((self.max_time - self.t) / self.dt) epochs = math.ceil(steps / self.nbr_update_period) - + self.save() self.log() counter = 0 -# self.model.to(self.device) -# if self.t == 0: -# self.update_props(needs_nbrs=True) + # self.model.to(self.device) + # if self.t == 0: + # self.update_props(needs_nbrs=True) for _ in tqdm(range(epochs)): for i in range(self.nbr_update_period): - needs_nbrs = (i == 0) + needs_nbrs = i == 0 self.step(needs_nbrs=needs_nbrs) counter += 1 - + if counter % self.save_period == 0: self.save() -# else: -# # save any geoms that just hopped -# self.save(idx=self.just_hopped) + # else: + # # save any geoms that just hopped + # self.save(idx=self.just_hopped) - + with open(self.log_file, "a") as f: + f.write("\nNeural Tully terminated normally.") - with open(self.log_file, 'a') as f: - f.write('\nNeural Tully terminated normally.') - - - - - class CombinedNeuralTully: - def __init__(self, - atoms, - ground_params, - tully_params): - + def __init__(self, atoms, ground_params, tully_params): self.reload_ground = tully_params.get("reload_ground", False) - self.ground_dynamics = self.init_ground(atoms=atoms, - ground_params=ground_params) + self.ground_dynamics = self.init_ground(atoms=atoms, ground_params=ground_params) self.ground_params = ground_params self.ground_savefile = ground_params.get("savefile") self.tully_params = tully_params - self.num_trj = tully_params['num_trj'] - - def init_ground(self, - atoms, - ground_params): + self.num_trj = tully_params["num_trj"] + def init_ground(self, atoms, ground_params): ase_ground_params = copy.deepcopy(ground_params) ase_ground_params["trajectory"] = ground_params.get("savefile") - logfile = ase_ground_params['logfile'] + logfile = ase_ground_params["logfile"] trj_file = ase_ground_params["trajectory"] if os.path.isfile(logfile): @@ -977,15 +866,12 @@ def init_ground(self, return ground_dynamics def sample_ground_geoms(self): - steps = math.ceil(self.ground_params["max_time"] / - self.ground_params["timestep"]) - equil_steps = math.ceil(self.ground_params["equil_time"] / - self.ground_params["timestep"]) + steps = math.ceil(self.ground_params["max_time"] / self.ground_params["timestep"]) + equil_steps = math.ceil(self.ground_params["equil_time"] / self.ground_params["timestep"]) loginterval = self.ground_params.get("loginterval", 1) if self.ground_dynamics is not None: - old_trj_file = (str(self.ground_savefile) - .replace(".trj", "_old.trj")) + old_trj_file = str(self.ground_savefile).replace(".trj", "_old.trj") if self.reload_ground and os.path.isfile(old_trj_file): trj = Trajectory(old_trj_file) atoms = next(iter(reversed(trj))) @@ -994,49 +880,39 @@ def sample_ground_geoms(self): # set positions and velocities. Don't overwrite atoms because # then you lose their calculator self.ground_dynamics.atoms.set_positions(atoms.get_positions()) - self.ground_dynamics.atoms.set_velocities( - atoms.get_velocities()) + self.ground_dynamics.atoms.set_velocities(atoms.get_velocities()) self.ground_dynamics.run(steps=steps) trj = Trajectory(self.ground_savefile) logged_equil = math.ceil(equil_steps / loginterval) - possible_states = [trj[index] for index in - range(logged_equil, len(trj))] - random_indices = random.sample(range(len(possible_states)), - self.num_trj) + possible_states = [trj[index] for index in range(logged_equil, len(trj))] + random_indices = random.sample(range(len(possible_states)), self.num_trj) actual_states = [possible_states[index] for index in random_indices] return actual_states def run(self): atoms_list = self.sample_ground_geoms() - tully = NeuralTully(atoms_list=atoms_list, - **self.tully_params) + tully = NeuralTully(atoms_list=atoms_list, **self.tully_params) tully.run() @classmethod - def from_file(cls, - file): - + def from_file(cls, file): all_params = load_json(file) - ground_params = all_params['ground_params'] - atomsbatch = get_atoms(all_params=all_params, - ground_params=ground_params) + ground_params = all_params["ground_params"] + atomsbatch = get_atoms(all_params=all_params, ground_params=ground_params) atomsbatch.calc.model_kwargs = MODEL_KWARGS - tully_params = all_params['tully_params'] - if 'weightpath' in all_params: - model_path = os.path.join(all_params['weightpath'], - str(all_params["nnid"])) + tully_params = all_params["tully_params"] + if "weightpath" in all_params: + model_path = os.path.join(all_params["weightpath"], str(all_params["nnid"])) else: - model_path = all_params['model_path'] + model_path = all_params["model_path"] tully_params.update({"model_path": model_path}) - instance = cls(atoms=atomsbatch, - ground_params=ground_params, - tully_params=tully_params) + instance = cls(atoms=atomsbatch, ground_params=ground_params, tully_params=tully_params) return instance diff --git a/nff/md/tully_multiplicity/io.py b/nff/md/tully_multiplicity/io.py index 37e21370..ec423cae 100644 --- a/nff/md/tully_multiplicity/io.py +++ b/nff/md/tully_multiplicity/io.py @@ -2,92 +2,69 @@ Link between Tully surface hopping and both NFF models and JSON parameter files. """ + import json import os +from typing import List, Union -from typing import * - -import torch -from torch.utils.data import DataLoader import numpy as np - -from rdkit import Chem +import torch from ase import Atoms +from rdkit import Chem +from torch.utils.data import DataLoader -from nff.train import batch_to, batch_detach -from nff.nn.utils import single_spec_nbrs from nff.data import Dataset, collate_dicts +from nff.io.ase_ax import AtomsBatch, NeuralFF +from nff.nn.utils import single_spec_nbrs +from nff.train import batch_detach, batch_to from nff.utils import constants as const -from nff.utils.scatter import compute_grad -from nff.io.ase_ax import NeuralFF, AtomsBatch PERIODICTABLE = Chem.GetPeriodicTable() ANGLE_MODELS = ["DimeNet", "DimeNetDiabat", "DimeNetDiabatDelta"] -def make_loader(nxyz, - nbr_list, - num_atoms, - needs_nbrs, - cutoff, - cutoff_skin, - device, - batch_size): - - props = {"nxyz": [torch.Tensor(i) - for i in nxyz]} +def make_loader(nxyz, nbr_list, num_atoms, needs_nbrs, cutoff, cutoff_skin, device, batch_size): + props = {"nxyz": [torch.Tensor(i) for i in nxyz]} - dataset = Dataset(props=props, - units='kcal/mol', - check_props=True) + dataset = Dataset(props=props, units="kcal/mol", check_props=True) if needs_nbrs or nbr_list is None: - nbrs = single_spec_nbrs(dset=dataset, - cutoff=(cutoff + - cutoff_skin), - device=device, - directed=True) - dataset.props['nbr_list'] = nbrs + nbrs = single_spec_nbrs(dset=dataset, cutoff=(cutoff + cutoff_skin), device=device, directed=True) + dataset.props["nbr_list"] = nbrs else: - dataset.props['nbr_list'] = nbr_list + dataset.props["nbr_list"] = nbr_list - loader = DataLoader(dataset, - batch_size=batch_size, - collate_fn=collate_dicts) + loader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_dicts) return loader -def run_models(models: List, - batch, - device: Union[str, int]): +def run_models(models: List, batch, device: Union[str, int]): """ - Gets a list of models, which contains X models that - collectively predict Energies/Forces, NACVs, SOCs - - Args: - models (list): list of torch models - batch: torch batch to do inference for - device: device on which all tensors are located + Gets a list of models, which contains X models that + collectively predict Energies/Forces, NACVs, SOCs + + Args: + models (list): list of torch models + batch: torch batch to do inference for + device: device on which all tensors are located """ batch = batch_to(batch, device) - + results = {} for model in models: - result = model(batch, - inference=True) + result = model(batch, inference=True) result = batch_detach(result) - + # merge dictionaries - for key in result.keys(): + for key in result: results[key] = result[key] return results -def concat_and_conv(results_list, - num_atoms): +def concat_and_conv(results_list, num_atoms): """ Concatenate results from separate batches and convert to atomic units @@ -102,19 +79,16 @@ def concat_and_conv(results_list, for key in keys: val = torch.cat([i[key] for i in results_list]) - if 'energy_grad' in key or 'force_nacv' in key: - val *= conv['energy'] * conv['_grad'] - val = val.reshape(*grad_shape) - elif 'energy' in key: - val *= conv['energy'] - elif ('nacv' in key or 'NACV' in key) and 'grad' in key: - val *= conv['_grad'] + if "energy_grad" in key or "force_nacv" in key: + val *= conv["energy"] * conv["_grad"] val = val.reshape(*grad_shape) - elif 'NACP' in key and 'grad' in key: - val *= conv['_grad'] + elif "energy" in key: + val *= conv["energy"] + elif (("nacv" in key or "NACV" in key) and "grad" in key) or ("NACP" in key and "grad" in key): + val *= conv["_grad"] val = val.reshape(*grad_shape) - elif 'soc' in key or 'SOC' in key: - val *= 0.0000045563353 # cm-1 to Ha + elif "soc" in key or "SOC" in key: + val *= 0.0000045563353 # cm-1 to Ha # else: # msg = f"{key} has no known conversion" # raise NotImplementedError(msg) @@ -124,36 +98,37 @@ def concat_and_conv(results_list, return all_results -def get_results(models, - nxyz, - nbr_list, - num_atoms, - needs_nbrs, - cutoff, - cutoff_skin, - device, - batch_size,): +def get_results( + models, + nxyz, + nbr_list, + num_atoms, + needs_nbrs, + cutoff, + cutoff_skin, + device, + batch_size, +): """ `nxyz_list` assumed to be in Angstroms """ - loader = make_loader(nxyz=nxyz, - nbr_list=nbr_list, - num_atoms=num_atoms, - needs_nbrs=needs_nbrs, - cutoff=cutoff, - cutoff_skin=cutoff_skin, - device=device, - batch_size=batch_size) + loader = make_loader( + nxyz=nxyz, + nbr_list=nbr_list, + num_atoms=num_atoms, + needs_nbrs=needs_nbrs, + cutoff=cutoff, + cutoff_skin=cutoff_skin, + device=device, + batch_size=batch_size, + ) results_list = [] for batch in loader: - results = run_models(models=models, - batch=batch, - device=device) + results = run_models(models=models, batch=batch, device=device) results_list.append(results) - all_results = concat_and_conv(results_list=results_list, - num_atoms=num_atoms) + all_results = concat_and_conv(results_list=results_list, num_atoms=num_atoms) return all_results @@ -161,7 +136,7 @@ def get_results(models, def coords_to_nxyz(coords): nxyz = [] for dic in coords: - directions = ['x', 'y', 'z'] + directions = ["x", "y", "z"] n = float(PERIODICTABLE.GetAtomicNumber(dic["element"])) xyz = [dic[i] for i in directions] nxyz.append([n, *xyz]) @@ -169,36 +144,27 @@ def coords_to_nxyz(coords): def load_json(file): - - with open(file, 'r') as f: + with open(file, "r") as f: info = json.load(f) - if 'details' in info: - details = info['details'] - else: - details = {} - all_params = {key: val for key, val in info.items() - if key != "details"} + details = info.get("details", {}) + all_params = {key: val for key, val in info.items() if key != "details"} all_params.update(details) return all_params -def make_dataset(nxyz, - ground_params): - props = { - 'nxyz': [torch.Tensor(nxyz)] - } +def make_dataset(nxyz, ground_params): + props = {"nxyz": [torch.Tensor(nxyz)]} cutoff = ground_params["cutoff"] cutoff_skin = ground_params["cutoff_skin"] - dataset = Dataset(props.copy(), units='kcal/mol') - dataset.generate_neighbor_list(cutoff=(cutoff + cutoff_skin), - undirected=False) + dataset = Dataset(props.copy(), units="kcal/mol") + dataset.generate_neighbor_list(cutoff=(cutoff + cutoff_skin), undirected=False) model_type = ground_params["model_type"] - needs_angles = (model_type in ANGLE_MODELS) + needs_angles = model_type in ANGLE_MODELS if needs_angles: dataset.generate_angle_list() @@ -215,14 +181,8 @@ def get_batched_props(dataset): return batched_props -def add_calculator(atomsbatch, - model_path, - model_type, - device, - batched_props, - output_keys = ["energy_0"]): - - needs_angles = (model_type in ANGLE_MODELS) +def add_calculator(atomsbatch, model_path, model_type, device, batched_props, output_keys=["energy_0"]): + needs_angles = model_type in ANGLE_MODELS nff_ase = NeuralFF.from_file( model_path=model_path, @@ -232,45 +192,41 @@ def add_calculator(atomsbatch, params=None, model_type=model_type, needs_angles=needs_angles, - dataset_props=batched_props + dataset_props=batched_props, ) atomsbatch.set_calculator(nff_ase) -def get_atoms(ground_params, - all_params): - +def get_atoms(ground_params, all_params): coords = all_params["coords"] nxyz = coords_to_nxyz(coords) - atoms = Atoms(nxyz[:, 0], - positions=nxyz[:, 1:]) + atoms = Atoms(nxyz[:, 0], positions=nxyz[:, 1:]) - dataset, needs_angles = make_dataset(nxyz=nxyz, - ground_params=ground_params) + dataset, needs_angles = make_dataset(nxyz=nxyz, ground_params=ground_params) batched_props = get_batched_props(dataset) - device = ground_params.get('device', 'cuda') + device = ground_params.get("device", "cuda") - atomsbatch = AtomsBatch.from_atoms(atoms=atoms, - props=batched_props, - needs_angles=needs_angles, - device=device, - undirected=False, - cutoff_skin=ground_params['cutoff_skin']) + atomsbatch = AtomsBatch.from_atoms( + atoms=atoms, + props=batched_props, + needs_angles=needs_angles, + device=device, + undirected=False, + cutoff_skin=ground_params["cutoff_skin"], + ) - if 'model_path' in all_params: - model_path = all_params['model_path'] + if "model_path" in all_params: + model_path = all_params["model_path"] else: - model_path = os.path.join(all_params['weightpath'], - str(all_params["nnid"])) - add_calculator(atomsbatch=atomsbatch, - model_path=model_path, - model_type=ground_params["model_type"], - device=device, - batched_props=batched_props, - output_keys = [ground_params['energy_key']]) + model_path = os.path.join(all_params["weightpath"], str(all_params["nnid"])) + add_calculator( + atomsbatch=atomsbatch, + model_path=model_path, + model_type=ground_params["model_type"], + device=device, + batched_props=batched_props, + output_keys=[ground_params["energy_key"]], + ) return atomsbatch - - - diff --git a/nff/md/tully_multiplicity/step.py b/nff/md/tully_multiplicity/step.py index 87bb9eac..45469f75 100644 --- a/nff/md/tully_multiplicity/step.py +++ b/nff/md/tully_multiplicity/step.py @@ -3,28 +3,19 @@ """ import copy -from functools import partial import numpy as np import torch +from nff.md.tully.step import solve_quadratic -def verlet_step_1(forces, - surfs, - vel, - xyz, - mass, - dt): - +def verlet_step_1(forces, surfs, vel, xyz, mass, dt): # `forces` has dimension (num_samples x num_states # x num_atoms x 3) # `surfs` has dimension `num_samples` - surf_forces = np.take_along_axis( - forces, surfs.reshape(-1, 1, 1, 1), - axis=1 - ).squeeze(1) + surf_forces = np.take_along_axis(forces, surfs.reshape(-1, 1, 1, 1), axis=1).squeeze(1) # `surf_forces` has dimension (num_samples x # num_atoms x 3) @@ -34,22 +25,14 @@ def verlet_step_1(forces, # `vel` and `xyz` each have dimension # (num_samples x num_atoms x 3) - new_xyz = xyz + vel * dt + 0.5 * accel * dt ** 2 + new_xyz = xyz + vel * dt + 0.5 * accel * dt**2 new_vel = vel + 0.5 * dt * accel return new_xyz, new_vel -def verlet_step_2(forces, - surfs, - vel, - mass, - dt): - - surf_forces = np.take_along_axis( - forces, surfs.reshape(-1, 1, 1, 1), - axis=1 - ).squeeze(1) +def verlet_step_2(forces, surfs, vel, mass, dt): + surf_forces = np.take_along_axis(forces, surfs.reshape(-1, 1, 1, 1), axis=1).squeeze(1) accel = surf_forces / mass.reshape(1, -1, 1) new_vel = vel + 0.5 * dt * accel @@ -57,65 +40,49 @@ def verlet_step_2(forces, return new_vel -def adiabatic_c(c, - elec_substeps, - old_H_plus_nacv, - new_H_plus_nacv, - dt, - **kwargs): - +def adiabatic_c(c, elec_substeps, old_H_plus_nacv, new_H_plus_nacv, dt, **kwargs): num_samples = old_H_plus_nacv.shape[0] num_states = old_H_plus_nacv.shape[1] n = elec_substeps - exp = (np.eye(num_states, num_states) - .reshape(1, num_states, num_states) - .repeat(num_samples, axis=0)) + exp = np.eye(num_states, num_states).reshape(1, num_states, num_states).repeat(num_samples, axis=0) delta_tau = dt / n for i in range(1, n + 1): - new_exp = torch.tensor( - -1j * ( - old_H_plus_nacv + i / n * - (new_H_plus_nacv - old_H_plus_nacv) - ) - * delta_tau - ).matrix_exp().numpy() - exp = np.einsum('ijk, ikl -> ijl', exp, new_exp) + new_exp = ( + torch.tensor(-1j * (old_H_plus_nacv + i / n * (new_H_plus_nacv - old_H_plus_nacv)) * delta_tau) + .matrix_exp() + .numpy() + ) + exp = np.einsum("ijk, ikl -> ijl", exp, new_exp) P = exp - c_new = np.einsum('ijk, ik -> ij', P, c) + c_new = np.einsum("ijk, ik -> ij", P, c) return c_new, P -def compute_T(nacv, - vel, - c): - +def compute_T(nacv, vel, c): # vel has shape num_samples x num_atoms x 3 # nacv has shape num_samples x num_states x num_states # x num_atoms x 3 # T has shape num_samples x (num_states x num_states) - T = (vel.reshape(vel.shape[0], 1, 1, -1, 3) - * nacv).sum((-1, -2)) + T = (vel.reshape(vel.shape[0], 1, 1, -1, 3) * nacv).sum((-1, -2)) # anything that's nan has too big a gap # for hopping and should therefore have T=0 T[np.isnan(T)] = 0 num_states = nacv.shape[1] - coupling = np.einsum('nij, nj-> ni', T, c[:, :num_states]) + coupling = np.einsum("nij, nj-> ni", T, c[:, :num_states]) return T, coupling -def get_p_hop(hop_eqn='sharc', - **kwargs): - - if hop_eqn == 'sharc': +def get_p_hop(hop_eqn="sharc", **kwargs): + if hop_eqn == "sharc": p = get_sharc_p(**kwargs) else: raise NotImplementedError @@ -123,12 +90,7 @@ def get_p_hop(hop_eqn='sharc', return p - -def get_sharc_p(old_c, - new_c, - P, - surfs, - **kwargs): +def get_sharc_p(old_c, new_c, P, surfs, **kwargs): """ P is the propagator. """ @@ -136,92 +98,55 @@ def get_sharc_p(old_c, num_samples = old_c.shape[0] num_states = old_c.shape[1] - other_surfs = get_other_surfs(surfs=surfs, - num_states=num_states, - num_samples=num_samples) + other_surfs = get_other_surfs(surfs=surfs, num_states=num_states, num_samples=num_samples) - c_beta_t = np.take_along_axis(old_c, - surfs.reshape(-1, 1), - axis=-1) - c_beta_dt = np.take_along_axis(new_c, - surfs.reshape(-1, 1), - axis=-1) + c_beta_t = np.take_along_axis(old_c, surfs.reshape(-1, 1), axis=-1) + c_beta_dt = np.take_along_axis(new_c, surfs.reshape(-1, 1), axis=-1) - c_alpha_dt = np.take_along_axis(new_c, - other_surfs, - axis=-1) + c_alpha_dt = np.take_along_axis(new_c, other_surfs, axis=-1) # `P` has dimension num_samples x num_states x num_states - P_alpha_beta = np.take_along_axis(np.take_along_axis( - P, - surfs.reshape(-1, 1, 1), - axis=-1).squeeze(-1), - other_surfs, - axis=-1 + P_alpha_beta = np.take_along_axis( + np.take_along_axis(P, surfs.reshape(-1, 1, 1), axis=-1).squeeze(-1), other_surfs, axis=-1 ) - P_beta_beta = np.take_along_axis(np.take_along_axis( - P, - surfs.reshape(-1, 1, 1), - axis=-1).squeeze(-1), - surfs.reshape(-1, 1), - axis=-1 + P_beta_beta = np.take_along_axis( + np.take_along_axis(P, surfs.reshape(-1, 1, 1), axis=-1).squeeze(-1), surfs.reshape(-1, 1), axis=-1 ) # h_alpha is the transition probability from the current state # to alpha num = np.real(c_alpha_dt * np.conj(P_alpha_beta) * np.conj(c_beta_t)) - denom = np.power(np.abs(c_beta_t), 2) - np.real(c_beta_dt * np.conj(P_beta_beta) - * np.conj(c_beta_t)) - pref = 1. - np.power(np.abs(c_beta_dt), 2) / (np.power(np.abs(c_beta_t), 2) + 1.e-8) + denom = np.power(np.abs(c_beta_t), 2) - np.real(c_beta_dt * np.conj(P_beta_beta) * np.conj(c_beta_t)) + pref = 1.0 - np.power(np.abs(c_beta_dt), 2) / (np.power(np.abs(c_beta_t), 2) + 1.0e-8) h = np.zeros((num_samples, num_states)) - np.put_along_axis(h, - other_surfs, - pref * num / (denom + 1.e-8), - axis=-1) + np.put_along_axis(h, other_surfs, pref * num / (denom + 1.0e-8), axis=-1) h[h < 0] = 0 return h -def get_other_surfs(surfs, - num_states, - num_samples): - all_surfs = (np.arange(num_states).reshape(-1, 1) - .repeat(num_samples, 1).transpose()) +def get_other_surfs(surfs, num_states, num_samples): + all_surfs = np.arange(num_states).reshape(-1, 1).repeat(num_samples, 1).transpose() other_idx = all_surfs != surfs.reshape(-1, 1) other_surfs = all_surfs[other_idx].reshape(num_samples, -1) return other_surfs -def try_hop(p_hop, - surfs, - vel, - nacv, - mass, - energy, - max_gap_hop, - simple_scale): +def try_hop(p_hop, surfs, vel, nacv, mass, energy, max_gap_hop, simple_scale): """ `energy` has dimension num_samples x num_states """ - new_surfs = get_new_surf(p_hop=p_hop, - surfs=surfs, - max_gap_hop=max_gap_hop, - energy=energy) + new_surfs = get_new_surf(p_hop=p_hop, surfs=surfs, max_gap_hop=max_gap_hop, energy=energy) - new_vel = rescale(energy=energy, - vel=vel, - nacv=nacv, - mass=mass, - surfs=surfs, - new_surfs=new_surfs, - simple_scale=simple_scale) + new_vel = rescale( + energy=energy, vel=vel, nacv=nacv, mass=mass, surfs=surfs, new_surfs=new_surfs, simple_scale=simple_scale + ) # reset any frustrated hops or things that didn't hop frustrated = np.isnan(new_vel).any((-1, -2)).nonzero()[0] @@ -231,15 +156,9 @@ def try_hop(p_hop, return new_surfs, new_vel -def get_new_surf(p_hop, - surfs, - max_gap_hop, - energy): - +def get_new_surf(p_hop, surfs, max_gap_hop, energy): num_samples = p_hop.shape[0] - lhs = np.concatenate([np.zeros(num_samples).reshape(-1, 1), - p_hop.cumsum(axis=-1)], - axis=-1)[:, :-1] + lhs = np.concatenate([np.zeros(num_samples).reshape(-1, 1), p_hop.cumsum(axis=-1)], axis=-1)[:, :-1] rhs = lhs + p_hop r = np.random.rand(num_samples).reshape(-1, 1) hop = (lhs < r) * (r <= rhs) @@ -251,12 +170,8 @@ def get_new_surf(p_hop, if max_gap_hop is None: return new_surfs - old_en = np.take_along_axis(energy, - surfs.reshape(-1, 1), - axis=-1).squeeze(-1) - new_en = np.take_along_axis(energy, - new_surfs.reshape(-1, 1), - axis=-1).squeeze(-1) + old_en = np.take_along_axis(energy, surfs.reshape(-1, 1), axis=-1).squeeze(-1) + new_en = np.take_along_axis(energy, new_surfs.reshape(-1, 1), axis=-1).squeeze(-1) gaps = abs(old_en - new_en) bad_idx = gaps >= max_gap_hop new_surfs[bad_idx] = surfs[bad_idx] @@ -264,18 +179,12 @@ def get_new_surf(p_hop, return new_surfs -def rescale(energy, - vel, - nacv, - mass, - surfs, - new_surfs, - simple_scale): +def rescale(energy, vel, nacv, mass, surfs, new_surfs, simple_scale): """ Velocity re-scaling, from: - Landry, B.R. and Subotnik, J.E., 2012. How to recover Marcus theory with - fewest switches surface hopping: Add just a touch of decoherence. The + Landry, B.R. and Subotnik, J.E., 2012. How to recover Marcus theory with + fewest switches surface hopping: Add just a touch of decoherence. The Journal of chemical physics, 137(22), p.22A513. If no NACV is available, the KE is simply rescaled to conserve energy. @@ -283,28 +192,18 @@ def rescale(energy, """ # old and new energies - old_en = np.take_along_axis(energy, surfs.reshape(-1, 1), - -1).reshape(-1) - new_en = np.take_along_axis(energy, new_surfs.reshape(-1, 1), - -1).reshape(-1) + old_en = np.take_along_axis(energy, surfs.reshape(-1, 1), -1).reshape(-1) + new_en = np.take_along_axis(energy, new_surfs.reshape(-1, 1), -1).reshape(-1) if simple_scale or nacv is None: - v_scale = get_simple_scale(mass=mass, - new_en=new_en, - old_en=old_en, - vel=vel) + v_scale = get_simple_scale(mass=mass, new_en=new_en, old_en=old_en, vel=vel) new_vel = v_scale.reshape(-1, 1, 1) * vel return new_vel # nacvs connecting old to new surfaces ones = [1] * 4 - start_nacv = np.take_along_axis(nacv, surfs - .reshape(-1, *ones), - axis=1) - pair_nacv = np.take_along_axis(start_nacv, new_surfs - .reshape(-1, *ones), - axis=2 - ).squeeze(1).squeeze(1) + start_nacv = np.take_along_axis(nacv, surfs.reshape(-1, *ones), axis=1) + pair_nacv = np.take_along_axis(start_nacv, new_surfs.reshape(-1, *ones), axis=2).squeeze(1).squeeze(1) # nacv unit vector norm = np.linalg.norm(pair_nacv, axis=-1) @@ -313,31 +212,20 @@ def rescale(energy, nac_dir = pair_nacv / norm.reshape(*pair_nacv.shape[:-1], 1) # solve quadratic equation for momentum rescaling - scale = solve_quadratic(vel=vel, - nac_dir=nac_dir, - old_en=old_en, - new_en=new_en, - mass=mass) + scale = solve_quadratic(vel=vel, nac_dir=nac_dir, old_en=old_en, new_en=new_en, mass=mass) # scale the velocity - new_vel = (scale.reshape(-1, 1, 1) * nac_dir - / mass.reshape(1, -1, 1) - + vel) + new_vel = scale.reshape(-1, 1, 1) * nac_dir / mass.reshape(1, -1, 1) + vel return new_vel -def get_simple_scale(mass, - new_en, - old_en, - vel): - +def get_simple_scale(mass, new_en, old_en, vel): m = mass.reshape(1, -1, 1) gap = old_en - new_en - arg = ((2 * gap + (m * vel ** 2).sum((-1, -2))) - .astype('complex128')) + arg = (2 * gap + (m * vel**2).sum((-1, -2))).astype("complex128") num = np.sqrt(arg) - denom = np.sqrt((m * vel ** 2).sum((-1, -2))) + denom = np.sqrt((m * vel**2).sum((-1, -2))) v_scale = num / denom @@ -348,16 +236,7 @@ def get_simple_scale(mass, return v_scale -def truhlar_decoherence(c, - surfs, - energy, - vel, - dt, - mass, - hbar=1, - C=0.1, - eps=1.e-12, - **kwargs): +def truhlar_decoherence(c, surfs, energy, vel, dt, mass, hbar=1, C=0.1, eps=1.0e-12, **kwargs): """ Originally attributed to Truhlar, cited from G. Granucci and M. Persico. "Critical appraisal of the @@ -368,25 +247,15 @@ def truhlar_decoherence(c, num_samples = c.shape[0] num_states = c.shape[1] - other_surfs = get_other_surfs(surfs=surfs, - num_states=num_states, - num_samples=num_samples) + other_surfs = get_other_surfs(surfs=surfs, num_states=num_states, num_samples=num_samples) - c_m = np.take_along_axis(c, - surfs.reshape(-1, 1), - axis=-1) + c_m = np.take_along_axis(c, surfs.reshape(-1, 1), axis=-1) - E_m = np.take_along_axis(energy, - surfs.reshape(-1, 1), - axis=-1) + E_m = np.take_along_axis(energy, surfs.reshape(-1, 1), axis=-1) - c_k = np.take_along_axis(c, - other_surfs, - axis=-1) + c_k = np.take_along_axis(c, other_surfs, axis=-1) - E_k = np.take_along_axis(energy, - other_surfs, - axis=-1) + E_k = np.take_along_axis(energy, other_surfs, axis=-1) # vel has shape num_samples x num_atoms x 3 E_kin = (1 / 2 * mass.reshape(1, -1, 1) * np.power(vel, 2)).sum((-1, -2)) @@ -401,20 +270,11 @@ def truhlar_decoherence(c, num[num < 0] = 0 - c_m_prime = c_m * np.sqrt( - num.reshape(-1, 1) - / np.power(np.abs(c_m), 2) - ) + c_m_prime = c_m * np.sqrt(num.reshape(-1, 1) / np.power(np.abs(c_m), 2)) new_c = np.zeros_like(c) - np.put_along_axis(new_c, - surfs.reshape(-1, 1), - c_m_prime, - axis=-1) - - np.put_along_axis(new_c, - other_surfs, - c_k_prime, - axis=-1) + np.put_along_axis(new_c, surfs.reshape(-1, 1), c_m_prime, axis=-1) + + np.put_along_axis(new_c, other_surfs, c_k_prime, axis=-1) return new_c diff --git a/nff/md/utils.py b/nff/md/utils.py index 26b3b46a..a3ff34da 100644 --- a/nff/md/utils.py +++ b/nff/md/utils.py @@ -157,7 +157,7 @@ def __init__(self, dyn, atoms, logfile, header=True, peratom=False, verbose=Fals self.hdr += "%12s %12s %12s " % ("U0+bias[eV]", "U0[eV]", "AbsGradPot") self.fmt += "%12.5f %12.5f %12.4f " - for i in range(self.num_cv): + for _i in range(self.num_cv): self.hdr += "%12s %12s %12s %12s %12s " % ( "CV", "Lambda", @@ -167,7 +167,7 @@ def __init__(self, dyn, atoms, logfile, header=True, peratom=False, verbose=Fals ) self.fmt += "%12.4f %12.4f %12.4f %12.4f %12.4f " - for i in range(self.n_const): + for _i in range(self.n_const): self.hdr += "%12s " % ("Const") self.fmt += "%12.5f " @@ -249,22 +249,22 @@ def write_traj(filename, frames): traj2write = trajconv(n_mol, n_atom, box_len, path) write_traj(path, traj2write) """ - file = open(filename, "w") - atom_no = frames.shape[1] - for i, frame in enumerate(frames): - file.write(str(atom_no) + "\n") - file.write("Atoms. Timestep: " + str(i) + "\n") - for atom in frame: - if atom.shape[0] == 4: - try: - file.write(str(int(atom[0])) + " " + str(atom[1]) + " " + str(atom[2]) + " " + str(atom[3]) + "\n") - except: - file.write(str(atom[0]) + " " + str(atom[1]) + " " + str(atom[2]) + " " + str(atom[3]) + "\n") - elif atom.shape[0] == 3: - file.write("1" + " " + str(atom[0]) + " " + str(atom[1]) + " " + str(atom[2]) + "\n") - else: - raise ValueError("wrong format") - file.close() + with open(filename, "w") as file: + atom_no = frames.shape[1] + for i, frame in enumerate(frames): + file.write(str(atom_no) + "\n") + file.write("Atoms. Timestep: " + str(i) + "\n") + for atom in frame: + if atom.shape[0] == 4: + try: + file.write(str(int(atom[0])) + " " + str(atom[1]) + " " + str(atom[2]) + " " + str(atom[3]) + + "\n") + except BaseException: + file.write(str(atom[0]) + " " + str(atom[1]) + " " + str(atom[2]) + " " + str(atom[3]) + "\n") + elif atom.shape[0] == 3: + file.write("1" + " " + str(atom[0]) + " " + str(atom[1]) + " " + str(atom[2]) + "\n") + else: + raise ValueError("wrong format") def csv_read(out_file): @@ -295,7 +295,7 @@ def csv_read(out_file): new_dic_list = [] for regular_dic, key_dic in zip(dic_list, dic_keys): new_dic = copy.deepcopy(regular_dic) - for key in regular_dic.keys(): + for key in regular_dic: new_dic[key_dic[key]] = regular_dic[key] new_dic_list.append(new_dic) diff --git a/nff/md/utils_ax.py b/nff/md/utils_ax.py index e94cd959..19e58b63 100644 --- a/nff/md/utils_ax.py +++ b/nff/md/utils_ax.py @@ -1,18 +1,12 @@ -import os -import numpy as np +import copy import csv import json -import logging -import copy -import pdb - +import os -import ase -from ase import Atoms, units +import numpy as np +from ase import units from ase.md import MDLogger -from nff.utils.scatter import compute_grad -from nff.data.graphs import * import nff.utils.constants as const @@ -29,68 +23,64 @@ def get_energy(atoms): # ekin = (0.5 * (vel * 1e-10 * fs * 1e15).pow(2).sum(1) * (mass * 1.66053904e-27) * 6.241509e+18).sum() # ekin = ekin.item() #* ev_to_kcal - #ekin = ekin.detach().numpy() + # ekin = ekin.detach().numpy() - print(('Energy per atom: Epot = %.2fkcal/mol ' - 'Ekin = %.2fkcal/mol (T=%3.0fK) ' - 'Etot = %.2fkcal/mol' - % (epot, ekin, Temperature, epot + ekin))) + print( + "Energy per atom: Epot = %.2fkcal/mol " + "Ekin = %.2fkcal/mol (T=%3.0fK) " + "Etot = %.2fkcal/mol" % (epot, ekin, Temperature, epot + ekin) + ) # print('Energy per atom: Epot = %.5feV Ekin = %.5feV (T=%3.0fK) ' # 'Etot = %.5feV' % (epot, ekin, Temperature, (epot + ekin))) return epot, ekin, Temperature def write_traj(filename, frames): - ''' - Write trajectory dataframes into .xyz format for VMD visualization - to do: include multiple atom types - - example: - path = "../../sim/topotools_ethane/ethane-nvt_unwrap.xyz" - traj2write = trajconv(n_mol, n_atom, box_len, path) - write_traj(path, traj2write) - ''' - file = open(filename, 'w') - atom_no = frames.shape[1] - for i, frame in enumerate(frames): - file.write(str(atom_no) + '\n') - file.write('Atoms. Timestep: ' + str(i)+'\n') - for atom in frame: - if atom.shape[0] == 4: - try: - file.write(str(int(atom[0])) + " " + str(atom[1]) + - " " + str(atom[2]) + " " + str(atom[3]) + "\n") - except: - file.write(str(atom[0]) + " " + str(atom[1]) + - " " + str(atom[2]) + " " + str(atom[3]) + "\n") - elif atom.shape[0] == 3: - file.write(("1" + " " + str(atom[0]) + " " - + str(atom[1]) + " " + str(atom[2]) + "\n")) - else: - raise ValueError("wrong format") - file.close() + """ + Write trajectory dataframes into .xyz format for VMD visualization + to do: include multiple atom types + + example: + path = "../../sim/topotools_ethane/ethane-nvt_unwrap.xyz" + traj2write = trajconv(n_mol, n_atom, box_len, path) + write_traj(path, traj2write) + """ + with open(filename, "w") as file: + atom_no = frames.shape[1] + for i, frame in enumerate(frames): + file.write(str(atom_no) + "\n") + file.write("Atoms. Timestep: " + str(i) + "\n") + for atom in frame: + if atom.shape[0] == 4: + try: + file.write(str(int(atom[0])) + " " + str(atom[1]) + " " + str(atom[2]) + " " + str(atom[3]) + + "\n") + except BaseException: + file.write(str(atom[0]) + " " + str(atom[1]) + " " + str(atom[2]) + " " + str(atom[3]) + "\n") + elif atom.shape[0] == 3: + file.write("1" + " " + str(atom[0]) + " " + str(atom[1]) + " " + str(atom[2]) + "\n") + else: + raise ValueError("wrong format") def mol_dot(vec1, vec2): - """ Say we have two vectors, each of which has the form + """Say we have two vectors, each of which has the form [[fx1, fy1, fz1], [fx2, fy2, fz2], ...]. - mol_dot returns an array of dot products between each - element of the two vectors. """ + mol_dot returns an array of dot products between each + element of the two vectors.""" v1 = np.array(vec1) v2 = np.array(vec2) - out = np.transpose([np.dot(element1, element2) for - element1, element2 in zip(v1, v2)]) + out = np.transpose([np.dot(element1, element2) for element1, element2 in zip(v1, v2)]) return out def mol_norm(vec): """Square root of mol_dot(vec, vec).""" - return mol_dot(vec, vec)**0.5 + return mol_dot(vec, vec) ** 0.5 def atoms_to_nxyz(atoms, positions=None): - atomic_numbers = atoms.get_atomic_numbers() if positions is None: positions = atoms.get_positions() @@ -98,8 +88,7 @@ def atoms_to_nxyz(atoms, positions=None): # don't make this a numpy array or it'll become type float64, # which will mess up the tensor computation. Need it to be # type float32. - nxyz = [[symbol, *position] for - symbol, position in zip(atomic_numbers, positions)] + nxyz = [[symbol, *position] for symbol, position in zip(atomic_numbers, positions)] return nxyz @@ -115,12 +104,12 @@ def zhu_dic_to_list(dic): """ lst = [] - first_key = list(dic.keys())[0] + first_key = next(iter(dic.keys())) for i in range(len(dic[first_key])): sub_dic = dict() - for key in dic.keys(): + for key in dic: sub_dic[key.split("_list")[0]] = dic[key][i] - if (key == "time_list"): + if key == "time_list": sub_dic[key.split("_list")[0]] /= const.FS_TO_AU lst.append(sub_dic) @@ -135,9 +124,9 @@ def append_to_csv(lst, out_file): Returns: None """ - with open(out_file, 'a+') as csvfile: + with open(out_file, "a+") as csvfile: for item in lst: - fieldnames = sorted(list(item.keys())) + fieldnames = sorted(item.keys()) writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() writer.writerow({key: item[key] for key in fieldnames}) @@ -147,9 +136,9 @@ def write_to_new_csv(lst, out_file): """ Same as `append_to_csv`, but writes a new file. """ - with open(out_file, 'w') as csvfile: + with open(out_file, "w") as csvfile: for item in lst: - fieldnames = sorted(list(item.keys())) + fieldnames = sorted(item.keys()) writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() writer.writerow({key: item[key] for key in fieldnames}) @@ -176,7 +165,7 @@ def csv_read(out_file): dic_list (list): list of dictionaries """ - with open(out_file, newline='') as csvfile: + with open(out_file, newline="") as csvfile: # get the keys and the corresponding dictionaries # being outputted dic_list = list(csv.DictReader(csvfile))[0::2] @@ -186,7 +175,7 @@ def csv_read(out_file): # to the key order on every other line. # (Also, weird things happen if you define `dic_keys` and # `dic_list` within the same context manager, so must do it separately) - with open(out_file, newline='') as csvfile: + with open(out_file, newline="") as csvfile: # this dictionary gives you a key: value pair # of the form supposed key: actual key dic_keys = list(csv.DictReader(csvfile))[1::2] @@ -195,25 +184,21 @@ def csv_read(out_file): new_dic_list = [] for regular_dic, key_dic in zip(dic_list, dic_keys): new_dic = copy.deepcopy(regular_dic) - for key in regular_dic.keys(): + for key in regular_dic: new_dic[key_dic[key]] = regular_dic[key] new_dic_list.append(new_dic) for dic in new_dic_list: for key, value in dic.items(): - if 'nan' in value: - value = value.replace('nan', "float('nan')") + if "nan" in value: + value = value.replace("nan", "float('nan')") dic[key] = eval(value) return new_dic_list class NeuralMDLogger(MDLogger): - def __init__(self, - *args, - verbose=True, - **kwargs): - + def __init__(self, *args, verbose=True, **kwargs): super().__init__(*args, **kwargs) self.natoms = len(self.atoms) self.verbose = verbose @@ -229,11 +214,11 @@ def __call__(self): epot /= self.natoms ekin /= self.natoms if self.dyn is not None: - t = self.dyn.get_time() / (1000*units.fs) + t = self.dyn.get_time() / (1000 * units.fs) dat = (t,) else: dat = () - dat += (epot+ekin, epot, ekin, temp) + dat += (epot + ekin, epot, ekin, temp) if self.stress: dat += tuple(self.atoms.get_stress() / units.GPa) self.logfile.write(self.fmt % dat) @@ -244,7 +229,6 @@ def __call__(self): class ZhuNakamuraLogger: - """ Base class for Zhu Nakamura dynamics. Properties: @@ -254,7 +238,6 @@ class ZhuNakamuraLogger: """ def __init__(self, out_file, log_file, save_keys, **kwargs): - self.out_file = out_file self.log_file = log_file self.save_keys = save_keys @@ -307,8 +290,7 @@ def create_save_list(self): else: save_dic[key] = val - save_dic["nxyz_list"] = [atoms_to_nxyz(self.atoms, positions) for - positions in save_dic["position_list"]] + save_dic["nxyz_list"] = [atoms_to_nxyz(self.atoms, positions) for positions in save_dic["position_list"]] save_list = zhu_dic_to_list(save_dic) return save_list @@ -319,16 +301,12 @@ def save(self): """ save_list = self.create_save_list() - csv_write(out_file=self.out_file, - lst=save_list[-1:], - method="append") + csv_write(out_file=self.out_file, lst=save_list[-1:], method="append") for key in self.save_keys: setattr(self, key, getattr(self, key)[-5:]) - def ac_present(self, - old_list, - new_list): + def ac_present(self, old_list, new_list): """ Check if the previous AC step, whose properties you're updating, was actually saved. @@ -360,8 +338,7 @@ def ac_present(self, return present, freq_gt_2 def modify_hop(self, new_list): - - key = 'hopping_probability' + key = "hopping_probability" new_list[-3][key] = copy.deepcopy(new_list[-2][key]) new_list[-2][key] = [] @@ -402,9 +379,7 @@ def modify_save(self): # check to see if [AC on old surface] was # actually saved - may not be if we're not # saving every frame - ac_present, freq_gt_2 = self.ac_present( - old_list=old_list, - new_list=new_list) + ac_present, freq_gt_2 = self.ac_present(old_list=old_list, new_list=new_list) # update [AC on old surface] with the # fact that it's not in the trj and that @@ -434,9 +409,7 @@ def modify_save(self): if not freq_gt_2: save_list.append(new_list[-1]) - csv_write(out_file=self.out_file, - lst=save_list, - method="new") + csv_write(out_file=self.out_file, lst=save_list, method="new") def output_to_json(self): """ @@ -456,7 +429,7 @@ def log(self, msg): Args: msg (str) """ - output = '{:>12}: {}'.format("Zhu-Nakamura dynamics".upper(), msg) + output = "{:>12}: {}".format("Zhu-Nakamura dynamics".upper(), msg) with open(self.log_file, "a+") as f: f.write(output) f.write("\n") diff --git a/nff/md/zhu_nakamura/dynamics.py b/nff/md/zhu_nakamura/dynamics.py index 8ffef808..2c3a4c5a 100644 --- a/nff/md/zhu_nakamura/dynamics.py +++ b/nff/md/zhu_nakamura/dynamics.py @@ -560,6 +560,7 @@ def rescale_v(self, old_surf, new_surf): if np.isnan(velocities).any(): return "err" self.velocities = velocities + return None def update_probabilities(self): """ @@ -677,6 +678,7 @@ def hop(self, new_surf): self.hopping_probabilities = [] self.time = self.time - self.dt self.modify_save() + return None def full_step(self, compute_internal_forces=True, do_log=True): """ @@ -1059,18 +1061,18 @@ def add_diabat_forces(self): # only store the diagonal diabatic forces diabat_forces = np.zeros((num_states, N[j], 3)) - for l in range(num_states): + for k in range(num_states): for m in range(num_states): - d_key = diabat_keys[l, m] + d_key = diabat_keys[k, m] diabat_en_kcal = results[d_key][j].item() diabat_en_au = diabat_en_kcal * KCAL_TO_AU["energy"] - diabat_ens[l, m] = diabat_en_au + diabat_ens[k, m] = diabat_en_au - if l == m: + if k == m: diabat_force_kcal = -(results[f"{d_key}_grad"][j].detach().cpu().numpy()) diabat_force_au = diabat_force_kcal * KCAL_TO_AU["energy"] * KCAL_TO_AU["_grad"] - diabat_forces[l, :] = diabat_force_au + diabat_forces[k, :] = diabat_force_au trj.diabat_ens = diabat_ens trj.diabat_forces = diabat_forces @@ -1161,7 +1163,7 @@ def run(self): if do_save: print(f"Completed step {num_steps}") - complete = all([trj.time >= self.max_time for trj in self.zhu_trjs]) + complete = all(trj.time >= self.max_time for trj in self.zhu_trjs) num_steps += 1 print("Neural ZN terminated normally.") diff --git a/nff/nn/__init__.py b/nff/nn/__init__.py index 64d1de05..12e3652c 100644 --- a/nff/nn/__init__.py +++ b/nff/nn/__init__.py @@ -3,4 +3,3 @@ from .utils import * from .models import * from .tensorgrad import * - diff --git a/nff/nn/activations.py b/nff/nn/activations.py index d4139c40..ebb90eb4 100644 --- a/nff/nn/activations.py +++ b/nff/nn/activations.py @@ -5,9 +5,8 @@ class shifted_softplus(torch.nn.Module): - def __init__(self): - super(shifted_softplus, self).__init__() + super().__init__() def forward(self, input): return F.softplus(input) - np.log(2.0) @@ -22,20 +21,10 @@ def forward(self, x): class LearnableSwish(torch.nn.Module): - def __init__(self, - alpha=1.0, - beta=1.702): + def __init__(self, alpha=1.0, beta=1.702): super().__init__() - self.alpha_inv = nn.Parameter( - torch.log( - torch.exp(torch.Tensor([alpha])) - 1 - ) - ) - self.beta_inv = nn.Parameter( - torch.log( - torch.exp(torch.Tensor([beta])) - 1 - ) - ) + self.alpha_inv = nn.Parameter(torch.log(torch.exp(torch.Tensor([alpha])) - 1)) + self.beta_inv = nn.Parameter(torch.log(torch.exp(torch.Tensor([beta])) - 1)) @property def alpha(self): diff --git a/nff/nn/glue.py b/nff/nn/glue.py index 484f17fb..0636e84a 100644 --- a/nff/nn/glue.py +++ b/nff/nn/glue.py @@ -1,33 +1,29 @@ import copy + +import numpy as np from torch import nn from torch.nn import ModuleDict, ModuleList -import numpy as np + from nff.train import batch_detach -IMPLEMENTED_MODES = ['sum', 'mean'] +IMPLEMENTED_MODES = ["sum", "mean"] class Stack(nn.Module): - def __init__(self, model_dict, mode='sum'): + def __init__(self, model_dict, mode="sum"): super().__init__() if mode not in IMPLEMENTED_MODES: - raise NotImplementedError( - f'{mode} mode is not implemented for Stack') + raise NotImplementedError(f"{mode} mode is not implemented for Stack") # to implement a check for readout keys self.models = ModuleDict(model_dict) self.mode = mode - def forward(self, - batch, - keys_to_combine=['energy', 'energy_grad'], - **kwargs): - + def forward(self, batch, keys_to_combine=["energy", "energy_grad"], **kwargs): # run models - result_list = [self.models[key](batch, **kwargs) - for key in self.models.keys()] + result_list = [self.models[key](batch, **kwargs) for key in self.models] # perform further operations combine_results = dict() @@ -38,20 +34,14 @@ def forward(self, combine_results[key] += result[key] else: combine_results[key] = result[key] - if self.mode == 'mean': + if self.mode == "mean": for key in keys_to_combine: combine_results[key] /= len(result_list) return combine_results class DiabatStack(nn.Module): - def __init__(self, - models, - diabat_keys, - energy_keys, - adiabat_mean, - extra_keys=None): - + def __init__(self, models, diabat_keys, energy_keys, adiabat_mean, extra_keys=None): super().__init__() self.models = ModuleList(models) @@ -61,13 +51,9 @@ def __init__(self, self.adiabat_mean = adiabat_mean # any extra keys you want to be averaged - self.extra_keys = (extra_keys if (extra_keys is not None) - else extra_keys) - - def forward(self, - batch, - **kwargs): + self.extra_keys = extra_keys if (extra_keys is not None) else extra_keys + def forward(self, batch, **kwargs): # use the same xyz for all the models so you can # compute the gradients @@ -83,14 +69,12 @@ def forward(self, if this_key not in results: continue if i != 0: - combined_results[this_key] += (results[this_key] - / num_models) + combined_results[this_key] += results[this_key] / num_models else: - combined_results[this_key] = (results[this_key] - / num_models) + combined_results[this_key] = results[this_key] / num_models return combined_results - xyz = batch['nxyz'][:, 1:] + xyz = batch["nxyz"][:, 1:] xyz.requires_grad = True # don't compute any gradients in the initial forward @@ -98,8 +82,7 @@ def forward(self, # average adiabatic energies and waste time/memory init_kwargs = copy.deepcopy(kwargs) - init_kwargs.update({"add_grad": False, - "add_nacv": False}) + init_kwargs.update({"add_grad": False, "add_nacv": False}) # get diabatic predictions from each model result_list = [] @@ -107,28 +90,20 @@ def forward(self, # use the initial run before computing all the # adiabatic quantities, if possible if hasattr(model, "run"): - result_list.append(model.run(batch=batch, - xyz=xyz)[0]) + result_list.append(model.run(batch=batch, xyz=xyz)[0]) else: - result_list.append(model(batch=batch, - xyz=xyz, - **init_kwargs)) + result_list.append(model(batch=batch, xyz=xyz, **init_kwargs)) combined_results = {} - unique_diabat_keys = list(set((np.array(self.diabat_keys) - .reshape(-1).tolist()))) + unique_diabat_keys = list(set(np.array(self.diabat_keys).reshape(-1).tolist())) num_models = len(self.models) for key in unique_diabat_keys: for i, result in enumerate(result_list): if i == 0: combined_results[key] = result[key] / num_models else: - combined_results[key] = (combined_results[key] - + result[key] / num_models) + combined_results[key] = combined_results[key] + result[key] / num_models - combined_results = self.diabatic_readout(batch=batch, - xyz=xyz, - results=combined_results, - **kwargs) + combined_results = self.diabatic_readout(batch=batch, xyz=xyz, results=combined_results, **kwargs) return combined_results diff --git a/nff/nn/graphconv.py b/nff/nn/graphconv.py index d3e97d10..f2d4cdc6 100644 --- a/nff/nn/graphconv.py +++ b/nff/nn/graphconv.py @@ -1,14 +1,13 @@ import torch.nn as nn + from nff.utils.scatter import scatter_add class MessagePassingModule(nn.Module): - - """Convolution constructed as MessagePassing. - """ + """Convolution constructed as MessagePassing.""" def __init__(self): - super(MessagePassingModule, self).__init__() + super().__init__() def message(self, r, e, a, aggr_wgt): # Basic message case @@ -25,17 +24,13 @@ def message(self, r, e, a, aggr_wgt): def aggregate(self, message, index, size): # pdb.set_trace() - new_r = scatter_add(src=message, - index=index, - dim=0, - dim_size=size) + new_r = scatter_add(src=message, index=index, dim=0, dim_size=size) return new_r def update(self, r): return r def forward(self, r, e, a, aggr_wgt=None): - graph_size = r.shape[0] rij, rji = self.message(r, e, a, aggr_wgt) @@ -48,11 +43,10 @@ def forward(self, r, e, a, aggr_wgt=None): class EdgeUpdateModule(nn.Module): - """Update Edge State Based on information from connected nodes - """ + """Update Edge State Based on information from connected nodes""" def __init__(self): - super(EdgeUpdateModule, self).__init__() + super().__init__() def message(self, r, e, a): """Summary @@ -79,8 +73,7 @@ def aggregate(self, message, neighborlist): Returns: TYPE: Description """ - aggregated_edge_feature = message[neighborlist[:, 0] - ] + message[neighborlist[:, 1]] + aggregated_edge_feature = message[neighborlist[:, 0]] + message[neighborlist[:, 1]] return aggregated_edge_feature def update(self, e): @@ -95,18 +88,14 @@ def forward(self, r, e, a): class GeometricOperations(nn.Module): - - """Compute geomtrical properties based on XYZ coordinates - """ + """Compute geomtrical properties based on XYZ coordinates""" def __init__(self): - super(GeometricOperations, self).__init__() + super().__init__() class TopologyOperations(nn.Module): - - """Change the topology index given geomtrical properties - """ + """Change the topology index given geomtrical properties""" def __init__(self): - super(TopologyOperations, self).__init__() + super().__init__() diff --git a/nff/nn/graphop.py b/nff/nn/graphop.py index f15dddc8..9e6a8e72 100644 --- a/nff/nn/graphop.py +++ b/nff/nn/graphop.py @@ -1,6 +1,8 @@ import torch -from nff.utils.scatter import compute_grad + from nff.nn.modules import ConfAttention +from nff.utils.scatter import compute_grad + EPS = 1e-15 @@ -31,12 +33,7 @@ def update_boltz(conf_fp, weight, boltz_nn): return boltzmann_fp -def conf_pool(mol_size, - boltzmann_weights, - mol_fp_nn, - boltz_nns, - conf_fps, - head_pool="concatenate"): +def conf_pool(mol_size, boltzmann_weights, mol_fp_nn, boltz_nns, conf_fps, head_pool="concatenate"): """ Pool atomic representations of conformers into molecular fingerprint, and then add those fingerprints together with Boltzmann weights. @@ -65,9 +62,7 @@ def conf_pool(mol_size, # the attention pooler and return if isinstance(boltz_nn, ConfAttention): - final_fp, learned_weights = boltz_nn( - conf_fps=conf_fps, - boltzmann_weights=boltzmann_weights) + final_fp, learned_weights = boltz_nn(conf_fps=conf_fps, boltzmann_weights=boltzmann_weights) else: # otherwise get a new fingerprint for each conformer # based on its Boltzmann weight @@ -75,10 +70,7 @@ def conf_pool(mol_size, boltzmann_fps = [] for i, conf_fp in enumerate(conf_fps): weight = boltzmann_weights[i] - boltzmann_fp = update_boltz( - conf_fp=conf_fp, - weight=weight, - boltz_nn=boltz_nn) + boltzmann_fp = update_boltz(conf_fp=conf_fp, weight=weight, boltz_nn=boltz_nn) boltzmann_fps.append(boltzmann_fp) boltzmann_fps = torch.stack(boltzmann_fps) @@ -148,12 +140,9 @@ def batch_and_sum(dict_input, N, predict_keys, xyz): # split if key in predict_keys and key + "_grad" not in predict_keys: results[key] = split_and_sum(val, N) - elif key in predict_keys and key + "_grad" in predict_keys: - results[key] = split_and_sum(val, N) - grad = compute_grad(inputs=xyz, output=results[key]) - results[key + "_grad"] = grad - # For the case only predicting gradient - elif key not in predict_keys and key + "_grad" in predict_keys: + elif (key in predict_keys and key + "_grad" in predict_keys) or ( + key not in predict_keys and key + "_grad" in predict_keys + ): results[key] = split_and_sum(val, N) grad = compute_grad(inputs=xyz, output=results[key]) results[key + "_grad"] = grad @@ -177,21 +166,15 @@ def get_atoms_inside_cell(r, N, pbc): N = N.to(torch.long).tolist() # make N a list if it is a int - if type(N) == int: + if isinstance(N, int): N = [N] # selecting only the atoms inside the unit cell - atoms_in_cell = [ - set(x.cpu().data.numpy()) - for x in torch.split(pbc, N) - ] + atoms_in_cell = [set(x.cpu().data.numpy()) for x in torch.split(pbc, N)] N = [len(n) for n in atoms_in_cell] - atoms_in_cell = torch.cat([ - torch.LongTensor(list(x)) - for x in atoms_in_cell - ]) + atoms_in_cell = torch.cat([torch.LongTensor(list(x)) for x in atoms_in_cell]) r = r[atoms_in_cell] diff --git a/nff/nn/layers.py b/nff/nn/layers.py index c55c960b..fab23e3e 100644 --- a/nff/nn/layers.py +++ b/nff/nn/layers.py @@ -357,7 +357,8 @@ def __init__(self, l_spher, n_spher, cutoff, envelope_p): for i in range(l_spher): if i == 0: first_sph = sym.lambdify([theta], self.sph_harm_formulas[i][0], modules)(0) - self.sph_funcs.append(lambda tensor: torch.zeros_like(tensor) + first_sph) + self.sph_funcs.append(partial(lambda tensor, fsph: torch.zeros_like(tensor) + fsph, + fsph=first_sph)) else: self.sph_funcs.append(sym.lambdify([theta], self.sph_harm_formulas[i][0], modules)) for j in range(n_spher): @@ -491,7 +492,7 @@ def __init__(self, cutoff): self.cutoff = cutoff def forward(self, d): - output = 0.5 * (torch.cos((np.pi * d / self.cutoff)) + 1) + output = 0.5 * (torch.cos(np.pi * d / self.cutoff) + 1) exclude = d >= self.cutoff output[exclude] = 0 diff --git a/nff/nn/models/__init__.py b/nff/nn/models/__init__.py index b5549239..f9d1e770 100644 --- a/nff/nn/models/__init__.py +++ b/nff/nn/models/__init__.py @@ -1,2 +1,2 @@ from .schnet import * -from .dimenet import * \ No newline at end of file +from .dimenet import * diff --git a/nff/nn/models/chgnet.py b/nff/nn/models/chgnet.py index 2786212e..cdb955c3 100644 --- a/nff/nn/models/chgnet.py +++ b/nff/nn/models/chgnet.py @@ -1,10 +1,9 @@ from __future__ import annotations import os -from collections.abc import Sequence from dataclasses import dataclass from pathlib import Path -from typing import Dict, List, Union +from typing import TYPE_CHECKING, Dict, List import chgnet import torch @@ -21,6 +20,10 @@ from nff.io.chgnet import convert_data_batch from nff.utils.misc import cat_props +if TYPE_CHECKING: + from collections.abc import Sequence + + module_dir = os.path.dirname(os.path.abspath(__file__)) @@ -33,7 +36,7 @@ def __init__( units: str = "eV/atom", is_intensive: bool = True, cutoff: float = 5.0, - key_mappings: Dict[str, str] = None, + key_mappings: Dict[str, str] | None = None, device: str = "cpu", requires_embedding: bool = False, **kwargs, @@ -73,7 +76,7 @@ def __init__( for param in self.composition_model.parameters(): param.requires_grad = True - def forward(self, data_batch: Dict[str, List], **kwargs) -> Dict[str, Union[Tensor, List]]: + def forward(self, data_batch: Dict[str, List], **kwargs) -> Dict[str, Tensor | List]: """Convert data_batch to CHGNet format and run forward pass. Args: @@ -112,7 +115,7 @@ def forward(self, data_batch: Dict[str, List], **kwargs) -> Dict[str, Union[Tens # convert to NFF keys and negate energy_grad return cat_props({self.key_mappings[k]: self.negate_value(k, v) for k, v in output.items()}) - def negate_value(self, key: str, value: Union[list, Tensor]) -> Union[list, Tensor]: + def negate_value(self, key: str, value: list | Tensor) -> list | Tensor: """Negate the value if key is in negate_keys. Args: @@ -192,7 +195,7 @@ def load(cls, model_name: str = "0.3.0", **kwargs) -> CHGNetNFF: if Path(checkpoint_path).is_file(): checkpoint_path = model_name elif checkpoint_path is None: - raise ValueError(f"Unknown {model_name=}") from e + raise ValueError(f"Unknown model name {model_name}") from e return cls.from_file( os.path.join(module_dir, checkpoint_path), @@ -210,7 +213,7 @@ def to(self, device: str, **kwargs) -> CHGNetNFF: Returns: CHGNetNFF: Model moved to the specified device. """ - self = super().to(device, **kwargs) + super().to(device, **kwargs) self.device = device if hasattr(self, "composition_model"): self.composition_model = self.composition_model.to(device, **kwargs) diff --git a/nff/nn/models/conformers.py b/nff/nn/models/conformers.py index a7aabeea..51bdd3fa 100644 --- a/nff/nn/models/conformers.py +++ b/nff/nn/models/conformers.py @@ -1,23 +1,18 @@ import torch import torch.nn as nn -from nff.nn.layers import DEFAULT_DROPOUT_RATE -from nff.nn.modules import ( - SchNetConv, - NodeMultiTaskReadOut, - ConfAttention, - LinearConfAttention -) from nff.nn.graphop import conf_pool +from nff.nn.layers import DEFAULT_DROPOUT_RATE +from nff.nn.modules import ConfAttention, LinearConfAttention, NodeMultiTaskReadOut, SchNetConv from nff.nn.utils import construct_sequential -from nff.utils.scatter import compute_grad from nff.utils.confs import split_batch +from nff.utils.scatter import compute_grad class WeightedConformers(nn.Module): """ Model that uses a representation of a molecule in terms of different 3D - conformers to predict properties. The fingerprints of each conformer are + conformers to predict properties. The fingerprints of each conformer are generated using the SchNet model. """ @@ -37,21 +32,21 @@ def __init__(self, modelparams): # all the atomic fingerprints get added together, then go through the network created # by `mol_fp_layers` to turn into a molecular fingerprint - mol_fp_layers = [{'name': 'linear', 'param' : { 'in_features': n_atom_basis, - 'out_features': int((n_atom_basis + mol_basis)/2)}}, - {'name': 'shifted_softplus', 'param': {}}, - {'name': 'linear', 'param' : { 'in_features': int((n_atom_basis + mol_basis)/2), - 'out_features': mol_basis}}] - - - readoutdict = { - "covid": [{'name': 'linear', 'param' : { 'in_features': mol_basis, - 'out_features': int(mol_basis / 2)}}, - {'name': 'shifted_softplus', 'param': {}}, - {'name': 'linear', 'param' : { 'in_features': int(mol_basis / 2), - 'out_features': 1}}, - {'name': 'sigmoid', 'param': {}}], - } + mol_fp_layers = [{'name': 'linear', + 'param' : { 'in_features': n_atom_basis, + 'out_features': int((n_atom_basis + mol_basis)/2)}}, + {'name': 'shifted_softplus', 'param': {}}, + {'name': 'linear', 'param' : { 'in_features': int((n_atom_basis + mol_basis)/2), + 'out_features': mol_basis}} + ] + + readout_dict = {"covid": [{'name': 'linear', 'param' : { 'in_features': mol_basis, + 'out_features': int(mol_basis / 2)}}, + {'name': 'shifted_softplus', 'param': {}}, + {'name': 'linear', 'param' : { 'in_features': int(mol_basis / 2), + 'out_features': 1}}, + {'name': 'sigmoid', 'param': {}}], + } # dictionary to tell you what to do with the Boltzmann factors # ex. 1: @@ -168,7 +163,7 @@ def make_boltz_nn(self, boltzmann_dict): # under the key `layers` will be used to create the corresponding # network - elif boltzmann_dict["type"] == "layers": + if boltzmann_dict["type"] == "layers": layers = boltzmann_dict["layers"] networks.append(construct_sequential(layers)) @@ -176,7 +171,6 @@ def make_boltz_nn(self, boltzmann_dict): # network for each of the number of heads elif "attention" in boltzmann_dict["type"]: - if boltzmann_dict["type"] == "attention": module = ConfAttention elif boltzmann_dict["type"] == "linear_attention": @@ -190,27 +184,30 @@ def make_boltz_nn(self, boltzmann_dict): # (useful for ablation studies) equal_weights = boltzmann_dict.get("equal_weights", False) # what function to use to convert the alpha_ij to probabilities - prob_func = boltzmann_dict.get("prob_func", 'softmax') + prob_func = boltzmann_dict.get("prob_func", "softmax") # add a network for each head for _ in range(num_heads): - mol_basis = boltzmann_dict["mol_basis"] boltz_basis = boltzmann_dict["boltz_basis"] final_act = boltzmann_dict["final_act"] - networks.append(module(mol_basis=mol_basis, - boltz_basis=boltz_basis, - final_act=final_act, - equal_weights=equal_weights, - prob_func=prob_func)) + networks.append( + module( + mol_basis=mol_basis, + boltz_basis=boltz_basis, + final_act=final_act, + equal_weights=equal_weights, + prob_func=prob_func, + ) + ) return networks def add_features(self, batch, **kwargs): """ - Get any extra per-species features that were requested for - the dataset. + Get any extra per-species features that were requested for + the dataset. Args: batch (dict): batched sample of species Returns: @@ -226,13 +223,12 @@ def add_features(self, batch, **kwargs): if self.extra_feats is None or "species" not in self.ext_feat_types: return [torch.tensor([]) for _ in range(num_mols)] - assert all([feat in batch.keys() for feat in self.extra_feats]) + assert all(feat in batch for feat in self.extra_feats) feats = [] # go through each extra per-species feature for feat_name, feat_type in zip(self.extra_feats, self.ext_feat_types): - if feat_type == "conformer": continue @@ -241,8 +237,7 @@ def add_features(self, batch, **kwargs): # split the batched features up by species and add them # to the list splits = [feat_len] * num_mols - feat = torch.stack(list( - torch.split(batch[feat_name], splits))) + feat = torch.stack(list(torch.split(batch[feat_name], splits))) feats.append(feat) # concatenate the features @@ -250,11 +245,7 @@ def add_features(self, batch, **kwargs): return feats - def convolve_sub_batch(self, - batch, - xyz=None, - xyz_grad=False, - **kwargs): + def convolve_sub_batch(self, batch, xyz=None, xyz_grad=False, **kwargs): """ Apply the convolutional layers to a sub-batch. @@ -284,31 +275,26 @@ def convolve_sub_batch(self, # offsets take care of periodic boundary conditions offsets = batch.get("offsets", 0) # to deal with any shape mismatches - if hasattr(offsets, 'max') and offsets.max() == 0: + if hasattr(offsets, "max") and offsets.max() == 0: offsets = 0 if "distances" in batch: e = batch["distances"][:, None] else: - e = (xyz[a[:, 0]] - xyz[a[:, 1]] - - offsets).pow(2).sum(1).sqrt()[:, None] + e = (xyz[a[:, 0]] - xyz[a[:, 1]] - offsets).pow(2).sum(1).sqrt()[:, None] # ensuring image atoms have the same vectors of their corresponding # atom inside the unit cell r = self.atom_embed(r.long()).squeeze() # update function includes periodic boundary conditions - for i, conv in enumerate(self.convolutions): + for conv in self.convolutions: dr = conv(r=r, e=e, a=a) r = r + dr return r, xyz - def convolve(self, - batch, - sub_batch_size=None, - xyz=None, - xyz_grad=False): + def convolve(self, batch, sub_batch_size=None, xyz=None, xyz_grad=False): """ Apply the convolution layers to the batch. Args: @@ -328,10 +314,7 @@ def convolve(self, self.classifier = True # split batches as necessary - if sub_batch_size is None: - sub_batches = [batch] - else: - sub_batches = split_batch(batch, sub_batch_size) + sub_batches = [batch] if sub_batch_size is None else split_batch(batch, sub_batch_size) # go through each sub-batch, get the xyz and node features, # and concatenate them when done @@ -340,9 +323,7 @@ def convolve(self, xyz_list = [] for sub_batch in sub_batches: - - new_node_feats, xyz = self.convolve_sub_batch( - sub_batch, xyz, xyz_grad) + new_node_feats, xyz = self.convolve_sub_batch(sub_batch, xyz, xyz_grad) new_node_feat_list.append(new_node_feats) xyz_list.append(xyz) @@ -351,33 +332,29 @@ def convolve(self, return new_node_feats, xyz - def get_external_3d(self, - batch, - n_conf_list): + def get_external_3d(self, batch, n_conf_list): """ - Get any extra 3D per-conformer features that were requested for - the dataset. + Get any extra 3D per-conformer features that were requested for + the dataset. Args: batch (dict): batched sample of species - n_conf_list (list[int]): list of number of conformers in each + n_conf_list (list[int]): list of number of conformers in each species. Returns: - split_extra (list): list of stacked per-cofnormer feature tensors + split_extra (list): list of stacked per-cofnormer feature tensors for each species. """ # if you didn't ask for any extra features, or none of the requested # features are per-conformer features, return empty tensors - if (self.extra_feats is None or - "conformer" not in self.ext_feat_types): - return + if self.extra_feats is None or "conformer" not in self.ext_feat_types: + return None # get all the features and split them up by species extra_conf_fps = [] - for feat_name, feat_type in zip(self.extra_feats, - self.ext_feat_types): + for feat_name, feat_type in zip(self.extra_feats, self.ext_feat_types): if feat_type == "conformer": extra_conf_fps.append(batch[feat_name]) @@ -386,12 +363,7 @@ def get_external_3d(self, return split_extra - def get_conf_fps(self, - smiles_fp, - mol_size, - batch, - split_extra, - idx): + def get_conf_fps(self, smiles_fp, mol_size, batch, split_extra, idx): """ Get per-conformer fingerprints. Args: @@ -403,7 +375,7 @@ def get_conf_fps(self, mol_size (int): Number of atoms in the molecule batch (dict): batched sample of species split_extra (list): extra 3D fingerprints split by - species + species idx (int): index of the current species in the batch. """ @@ -420,9 +392,9 @@ def get_conf_fps(self, # split the atomic fingerprints up by conformer for atomic_fps in torch.split(smiles_fp, N): # sum them and then convert to molecular fp - if self.pool_type == 'sum': + if self.pool_type == "sum": summed_atomic_fps = atomic_fps.sum(dim=0) - elif self.pool_type == 'mean': + elif self.pool_type == "mean": summed_atomic_fps = atomic_fps.mean(dim=0) else: raise NotImplementedError @@ -443,11 +415,7 @@ def get_conf_fps(self, return conf_fps - def post_process(self, - batch, - r, - xyz, - **kwargs): + def post_process(self, batch, r, xyz, **kwargs): """ Split various items up by species, convert atomic fingerprints to molecular fingerprints, and incorporate non-learnable features. @@ -465,17 +433,14 @@ def post_process(self, # split the fingerprints by species fps_by_smiles = torch.split(r, N) # get extra 3D fingerprints - split_extra = self.get_external_3d(batch, - num_confs) + split_extra = self.get_external_3d(batch, num_confs) # get all the conformer fingerprints for each species conf_fps_by_smiles = [] for i, smiles_fp in enumerate(fps_by_smiles): - conf_fps = self.get_conf_fps(smiles_fp=smiles_fp, - mol_size=mol_sizes[i], - batch=batch, - split_extra=split_extra, - idx=i) + conf_fps = self.get_conf_fps( + smiles_fp=smiles_fp, mol_size=mol_sizes[i], batch=batch, split_extra=split_extra, idx=i + ) conf_fps_by_smiles.append(conf_fps) @@ -486,13 +451,15 @@ def post_process(self, extra_feats = self.add_features(batch=batch, **kwargs) # return everything in a dictionary - outputs = dict(r=r, - N=N, - xyz=xyz, - conf_fps_by_smiles=conf_fps_by_smiles, - boltzmann_weights=boltzmann_weights, - mol_sizes=mol_sizes, - extra_feats=extra_feats) + outputs = dict( + r=r, + N=N, + xyz=xyz, + conf_fps_by_smiles=conf_fps_by_smiles, + boltzmann_weights=boltzmann_weights, + mol_sizes=mol_sizes, + extra_feats=extra_feats, + ) return outputs @@ -514,24 +481,22 @@ def fps_no_mpnn(self, batch, **kwargs): n_conf_list = (torch.tensor(N) / torch.tensor(mol_sizes)).tolist() # get the conformer fps for each smiles - conf_fps_by_smiles = self.get_external_3d(batch, - n_conf_list) + conf_fps_by_smiles = self.get_external_3d(batch, n_conf_list) # add any per-species fingerprints boltzmann_weights = torch.split(batch["weights"], n_conf_list) extra_feats = self.add_features(batch=batch, **kwargs) - outputs = {"conf_fps_by_smiles": conf_fps_by_smiles, - "boltzmann_weights": boltzmann_weights, - "mol_sizes": mol_sizes, - "extra_feats": extra_feats} + outputs = { + "conf_fps_by_smiles": conf_fps_by_smiles, + "boltzmann_weights": boltzmann_weights, + "mol_sizes": mol_sizes, + "extra_feats": extra_feats, + } return outputs - def make_embeddings(self, - batch, - xyz=None, - **kwargs): + def make_embeddings(self, batch, xyz=None, **kwargs): """ Make all conformer fingerprints. Args: @@ -549,13 +514,8 @@ def make_embeddings(self, # if using an MPNN, apply the convolution layers # and then post-process if self.use_mpnn: - r, xyz = self.convolve(batch=batch, - xyz=xyz, - **kwargs) - outputs = self.post_process(batch=batch, - r=r, - xyz=xyz, - **kwargs) + r, xyz = self.convolve(batch=batch, xyz=xyz, **kwargs) + outputs = self.post_process(batch=batch, r=r, xyz=xyz, **kwargs) # otherwise just use the non-learnable features else: @@ -571,8 +531,8 @@ def pool(self, outputs): Here, the atomic fingerprints for each geometry get converted into a molecular fingerprint. Then, the molecular fingerprints for the different conformers of a given species - get multiplied by the Boltzmann weights or learned weights of - those conformers and added together to make a final fingerprint + get multiplied by the Boltzmann weights or learned weights of + those conformers and added together to make a final fingerprint for the species. Args: @@ -614,7 +574,8 @@ def pool(self, outputs): mol_fp_nn=self.mol_fp_nn, boltz_nns=self.boltz_nns, conf_fps=conf_fps, - head_pool=self.head_pool) + head_pool=self.head_pool, + ) # add extra features if there are any if extra_feats is not None: @@ -642,21 +603,17 @@ def add_grad(self, batch, results, xyz): batch_keys = batch.keys() # names of the gradients of each property - result_grad_keys = [key + "_grad" for key in results.keys()] + result_grad_keys = [key + "_grad" for key in results] for key in batch_keys: # if the batch with the ground truth contains one of # these keys, then compute its predicted value if key in result_grad_keys: base_result = results[key.replace("_grad", "")] - results[key] = compute_grad(inputs=xyz, - output=base_result) + results[key] = compute_grad(inputs=xyz, output=base_result) return results - def forward(self, - batch, - xyz=None, - **kwargs): + def forward(self, batch, xyz=None, **kwargs): """ Call the model. Args: diff --git a/nff/nn/models/cp3d.py b/nff/nn/models/cp3d.py index acf44902..9f0977d7 100644 --- a/nff/nn/models/cp3d.py +++ b/nff/nn/models/cp3d.py @@ -1,15 +1,10 @@ -from torch import nn import torch -import numpy as np -import math -from nff.data.graphs import get_bond_idx +from torch import nn +from nff.data.graphs import get_bond_idx from nff.nn.models.conformers import WeightedConformers -from nff.nn.modules import (ChemPropConv, ChemPropMsgToNode, - ChemPropInit, SchNetEdgeFilter, - CpSchNetConv) +from nff.nn.modules import ChemPropConv, ChemPropInit, ChemPropMsgToNode, CpSchNetConv, SchNetEdgeFilter from nff.utils.tools import make_directed -from nff.utils.confs import split_batch REINDEX_KEYS = ["nbr_list", "bonded_nbr_list"] @@ -60,7 +55,8 @@ def __init__(self, modelparams): trainable_gauss=modelparams["trainable_gauss"], n_filters=modelparams["n_filters"], dropout_rate=modelparams["dropout_rate"], - activation=modelparams["activation"]) + activation=modelparams["activation"], + ) def make_convs(self, modelparams): """ @@ -75,16 +71,11 @@ def make_convs(self, modelparams): modelparams.update({"n_edge_hidden": modelparams["mol_basis"]}) # call `CpSchNetConv` to make the convolution layers - convs = nn.ModuleList([ChemPropConv(**modelparams) - for _ in range(num_conv)]) + convs = nn.ModuleList([ChemPropConv(**modelparams) for _ in range(num_conv)]) return convs - def get_distance_feats(self, - batch, - xyz, - offsets, - bond_nbrs): + def get_distance_feats(self, batch, xyz, offsets, bond_nbrs): """ Get distance features. Args: @@ -102,8 +93,7 @@ def get_distance_feats(self, # get directed neighbor list nbr_list, nbr_was_directed = make_directed(batch["nbr_list"]) # distances - distances = (xyz[nbr_list[:, 0]] - xyz[nbr_list[:, 1]] - - offsets).pow(2).sum(1).sqrt()[:, None] + distances = (xyz[nbr_list[:, 0]] - xyz[nbr_list[:, 1]] - offsets).pow(2).sum(1).sqrt()[:, None] # put through Gaussian filter and dense layer to get features distance_feats = self.edge_filter(distances) @@ -114,18 +104,13 @@ def get_distance_feats(self, bond_idx = batch["bond_idx"] if not nbr_was_directed: nbr_dim = nbr_list.shape[0] - bond_idx = torch.cat([bond_idx, - bond_idx + nbr_dim // 2]) + bond_idx = torch.cat([bond_idx, bond_idx + nbr_dim // 2]) else: bond_idx = get_bond_idx(bond_nbrs, nbr_list) return nbr_list, distance_feats, bond_idx - def make_h(self, - batch, - r, - xyz, - offsets): + def make_h(self, batch, r, xyz, offsets): """ Initialize the hidden edge features. Args: @@ -149,17 +134,13 @@ def make_h(self, # get the distance-based edge features nbr_list, distance_feats, bond_idx = self.get_distance_feats( - batch=batch, - xyz=xyz, - offsets=offsets, - bond_nbrs=bond_nbrs) + batch=batch, xyz=xyz, offsets=offsets, bond_nbrs=bond_nbrs + ) # combine node and bonded edge features to get the bond component # of h_0 - cp_bond_feats = self.W_i_cp(r=r, - bond_feats=bond_feats, - bond_nbrs=bond_nbrs) + cp_bond_feats = self.W_i_cp(r=r, bond_feats=bond_feats, bond_nbrs=bond_nbrs) h_0_bond = torch.zeros((nbr_list.shape[0], cp_bond_feats.shape[1])) h_0_bond = h_0_bond.to(device) h_0_bond[bond_idx] = cp_bond_feats @@ -167,9 +148,7 @@ def make_h(self, # combine node and distance edge features to get the schnet component # of h_0 - h_0_distance = self.W_i_schnet(r=r, - bond_feats=distance_feats, - bond_nbrs=nbr_list) + h_0_distance = self.W_i_schnet(r=r, bond_feats=distance_feats, bond_nbrs=nbr_list) # concatenate the two together @@ -177,10 +156,7 @@ def make_h(self, return h_0 - def convolve_sub_batch(self, - batch, - xyz=None, - xyz_grad=False): + def convolve_sub_batch(self, batch, xyz=None, xyz_grad=False): """ Apply the convolution layers to a sub-batch. Args: @@ -206,34 +182,24 @@ def convolve_sub_batch(self, # offsets for periodic boundary conditions offsets = batch.get("offsets", 0) # to deal with any shape mismatches - if hasattr(offsets, 'max') and offsets.max() == 0: + if hasattr(offsets, "max") and offsets.max() == 0: offsets = 0 # initialize hidden bond features - h_0 = self.make_h(batch=batch, - r=r, - xyz=xyz, - offsets=offsets) + h_0 = self.make_h(batch=batch, r=r, xyz=xyz, offsets=offsets) h_new = h_0.clone() # update edge features for conv in self.convolutions: - h_new = conv(h_0=h_0, - h_new=h_new, - nbrs=a, - kj_idx=batch.get("kj_idx"), - ji_idx=batch.get("ji_idx")) + h_new = conv(h_0=h_0, h_new=h_new, nbrs=a, kj_idx=batch.get("kj_idx"), ji_idx=batch.get("ji_idx")) # convert back to node features - new_node_feats = self.W_o(r=r, - h=h_new, - nbrs=a) + new_node_feats = self.W_o(r=r, h=h_new, nbrs=a) return new_node_feats, xyz class OnlyBondUpdateCP3D(ChemProp3D): - def __init__(self, modelparams): """ Initialize model. @@ -253,8 +219,7 @@ def __init__(self, modelparams): self.W_i = ChemPropInit(input_layers=input_layers) self.convolutions = self.make_convs(modelparams) - self.W_o = ChemPropMsgToNode( - output_layers=output_layers) + self.W_o = ChemPropMsgToNode(output_layers=output_layers) # dimension of the hidden bond vector self.n_bond_hidden = modelparams["n_bond_hidden"] @@ -272,8 +237,7 @@ def make_convs(self, modelparams): same_filters = modelparams["same_filters"] # call `CpSchNetConv` to make the convolution layers - convs = nn.ModuleList([CpSchNetConv(**modelparams) - for _ in range(num_conv)]) + convs = nn.ModuleList([CpSchNetConv(**modelparams) for _ in range(num_conv)]) # if you want to use the same filters for every convolution, repeat # the initial network and delete all the others @@ -282,11 +246,7 @@ def make_convs(self, modelparams): return convs - def make_h(self, - batch, - nbr_list, - r, - nbr_was_directed): + def make_h(self, batch, nbr_list, r, nbr_was_directed): """ Initialize the hidden bond features. Args: @@ -300,7 +260,7 @@ def make_h(self, bond_nbrs (torch.LongTensor): bonded neighbor list bond_idx (torch.LongTensor): indices that map an element of `bond_nbrs` to the corresponding - element in `nbr_list`. + element in `nbr_list`. """ # get the directed bond list and bond features @@ -315,15 +275,13 @@ def make_h(self, # initialize hidden bond features - h_0_bond = self.W_i(r=r, - bond_feats=bond_feats, - bond_nbrs=bond_nbrs) + h_0_bond = self.W_i(r=r, bond_feats=bond_feats, bond_nbrs=bond_nbrs) # initialize `h_0`, the features of all edges # (including bonded ones), to zero nbr_dim = nbr_list.shape[0] - h_0 = torch.zeros((nbr_dim, self.n_bond_hidden)) + h_0 = torch.zeros((nbr_dim, self.n_bond_hidden)) h_0 = h_0.to(device) # set the features of bonded edges equal to the bond @@ -333,8 +291,7 @@ def make_h(self, bond_idx = batch["bond_idx"] if not nbr_was_directed: nbr_dim = nbr_list.shape[0] - bond_idx = torch.cat([bond_idx, - bond_idx + nbr_dim // 2]) + bond_idx = torch.cat([bond_idx, bond_idx + nbr_dim // 2]) else: bond_idx = get_bond_idx(bond_nbrs, nbr_list) bond_idx = bond_idx.to(device) @@ -343,10 +300,7 @@ def make_h(self, return h_0, bond_nbrs, bond_idx - def convolve_sub_batch(self, - batch, - xyz=None, - xyz_grad=False): + def convolve_sub_batch(self, batch, xyz=None, xyz_grad=False): """ Apply the convolution layers to a sub-batch. Args: @@ -374,41 +328,28 @@ def convolve_sub_batch(self, offsets = 0 # get the distances between neighbors - e = (xyz[a[:, 0]] - xyz[a[:, 1]] - - offsets).pow(2).sum(1).sqrt()[:, None] + e = (xyz[a[:, 0]] - xyz[a[:, 1]] - offsets).pow(2).sum(1).sqrt()[:, None] # initialize hidden bond features - h_0, bond_nbrs, bond_idx = self.make_h( - batch=batch, - nbr_list=a, - r=r, - nbr_was_directed=nbr_was_directed) + h_0, bond_nbrs, bond_idx = self.make_h(batch=batch, nbr_list=a, r=r, nbr_was_directed=nbr_was_directed) h_new = h_0.clone() # update edge features for conv in self.convolutions: - # don't use any kj_idx or ji_idx # because they are only relevant when - # you're doing updates with all neighbors, + # you're doing updates with all neighbors, # not with just the bonded neighbors like # we do here - - h_new = conv(h_0=h_0, - h_new=h_new, - all_nbrs=a, - bond_nbrs=bond_nbrs, - bond_idx=bond_idx, - e=e, - kj_idx=None, - ji_idx=None) + + h_new = conv( + h_0=h_0, h_new=h_new, all_nbrs=a, bond_nbrs=bond_nbrs, bond_idx=bond_idx, e=e, kj_idx=None, ji_idx=None + ) # convert back to node features - new_node_feats = self.W_o(r=r, - h=h_new, - nbrs=a) + new_node_feats = self.W_o(r=r, h=h_new, nbrs=a) return new_node_feats, xyz diff --git a/nff/nn/models/dimenet.py b/nff/nn/models/dimenet.py index c1cc36e8..1ebbfe01 100644 --- a/nff/nn/models/dimenet.py +++ b/nff/nn/models/dimenet.py @@ -253,7 +253,7 @@ def forward(self, batch, xyz=None): else: periodic = bool(offsets.abs().max() != 0) else: - raise Exception("Don't know how to interpret offsets of type {}".format(type(offsets))) + raise Exception(f"Don't know how to interpret offsets of type {type(offsets)}") if periodic: raise NotImplementedError("DimeNet not implemented for PBC.") diff --git a/nff/nn/models/dispersion_models.py b/nff/nn/models/dispersion_models.py index 2989f51a..53d436b2 100644 --- a/nff/nn/models/dispersion_models.py +++ b/nff/nn/models/dispersion_models.py @@ -2,23 +2,19 @@ Models with added empirical dispersion on top """ -import torch import numpy as np +import torch from torch import nn -from nff.nn.models.painn import add_stress -from nff.utils.scatter import compute_grad +from nff.nn.models.painn import Painn, PainnDiabat, add_stress from nff.utils import constants as const -from nff.utils.dispersion import get_dispersion as base_dispersion, grimme_dispersion -from nff.nn.models.painn import Painn -from nff.nn.models.painn import PainnDiabat +from nff.utils.dispersion import get_dispersion as base_dispersion +from nff.utils.dispersion import grimme_dispersion +from nff.utils.scatter import compute_grad class PainnDispersion(nn.Module): - - def __init__(self, - modelparams, - painn_model=None): + def __init__(self, modelparams, painn_model=None): """ `modelparams` has the same keys as in a regular PaiNN model, plus the required keys "functional" and "disp_type" for the added dispersion. @@ -38,44 +34,32 @@ def __init__(self, else: self.painn_model = Painn(modelparams=modelparams) - def get_dispersion(self, - batch, - xyz): - - e_disp, r_ij_T, nbrs_T = base_dispersion(batch=batch, - xyz=xyz, - disp_type=self.disp_type, - functional=self.functional, - nbrs=batch.get('mol_nbrs'), - mol_idx=batch.get('mol_idx')) + def get_dispersion(self, batch, xyz): + e_disp, r_ij_T, nbrs_T = base_dispersion( + batch=batch, + xyz=xyz, + disp_type=self.disp_type, + functional=self.functional, + nbrs=batch.get("mol_nbrs"), + mol_idx=batch.get("mol_idx"), + ) # convert to kcal / mol e_disp = e_disp * const.HARTREE_TO_KCAL_MOL return e_disp, r_ij_T, nbrs_T - def get_grimme_dispersion(self, - batch, - xyz): - + def get_grimme_dispersion(self, batch, xyz): # all units are output in ASE units (eV and Angs) - e_disp, stress_disp, forces_disp = grimme_dispersion(batch=batch, - xyz=xyz, - disp_type=self.disp_type, - functional=self.functional) + e_disp, stress_disp, forces_disp = grimme_dispersion( + batch=batch, xyz=xyz, disp_type=self.disp_type, functional=self.functional + ) return e_disp, stress_disp, forces_disp - def run(self, - batch, - xyz=None, - requires_stress=False, - grimme_disp=False, - inference=False): - + def run(self, batch, xyz=None, requires_stress=False, grimme_disp=False, inference=False): # Normal painn stuff, part 1 - atomwise_out, xyz, r_ij, nbrs = self.painn_model.atomwise(batch=batch, - xyz=xyz) + atomwise_out, xyz, r_ij, nbrs = self.painn_model.atomwise(batch=batch, xyz=xyz) if getattr(self.painn_model, "excl_vol", None): # Excluded Volume interactions @@ -83,18 +67,12 @@ def run(self, for key in self.output_keys: atomwise_out[key] += r_ex - all_results, xyz = self.painn_model.pool(batch=batch, - atomwise_out=atomwise_out, - xyz=xyz, - r_ij=r_ij, - nbrs=nbrs, - inference=inference) + all_results, xyz = self.painn_model.pool( + batch=batch, atomwise_out=atomwise_out, xyz=xyz, r_ij=r_ij, nbrs=nbrs, inference=inference + ) if requires_stress: - all_results = add_stress(batch=batch, - all_results=all_results, - nbrs=nbrs, - r_ij=r_ij) + all_results = add_stress(batch=batch, all_results=all_results, nbrs=nbrs, r_ij=r_ij) # add dispersion and gradients associated with it @@ -102,24 +80,19 @@ def run(self, fallback_to_grimme = getattr(self, "fallback_to_grimme", True) if grimme_disp: - pass + e_disp, r_ij_T, nbrs_T = None else: - e_disp, r_ij_T, nbrs_T = self.get_dispersion(batch=batch, - xyz=xyz) + e_disp, r_ij_T, nbrs_T = self.get_dispersion(batch=batch, xyz=xyz) - for key in self.painn_model.pool_dic.keys(): + for key in self.painn_model.pool_dic: # add dispersion energy - if inference: - add_e = e_disp.detach().cpu() - else: - add_e = e_disp + add_e = e_disp.detach().cpu() if inference else e_disp # add gradient for forces grad_key = "%s_grad" % key if grad_key in self.painn_model.grad_keys: if disp_grad is None: - disp_grad = compute_grad(inputs=xyz, - output=e_disp) + disp_grad = compute_grad(inputs=xyz, output=e_disp) if inference: disp_grad = disp_grad.detach().cpu() @@ -131,26 +104,23 @@ def run(self, all_results[grad_key] = all_results[grad_key] + disp_grad if requires_stress and not grimme_disp: + if e_disp is None or r_ij_T is None or nbrs_T is None: + raise RuntimeError("Should not be reached, something went wrong") # add gradient for stress - disp_rij_grad = compute_grad(inputs=r_ij_T, - output=e_disp) + disp_rij_grad = compute_grad(inputs=r_ij_T, output=e_disp) - if batch['num_atoms'].shape[0] == 1: + if batch["num_atoms"].shape[0] == 1: disp_stress_volume = torch.matmul(disp_rij_grad.t(), r_ij_T) else: - allstress = [] - for j in range(batch['nxyz'].shape[0]): - allstress.append( - torch.matmul( - disp_rij_grad[torch.where(nbrs_T[:, 0] == j)].t(), - r_ij_T[torch.where(nbrs_T[:, 0] == j)] - ) + allstress = torch.stack([ + torch.matmul( + disp_rij_grad[torch.where(nbrs_T[:, 0] == j)].t(), r_ij_T[torch.where(nbrs_T[:, 0] == j)] ) - allstress = torch.stack(allstress) + for j in range(batch["nxyz"].shape[0]) + ]) N = batch["num_atoms"].detach().cpu().tolist() split_val = torch.split(allstress, N) - disp_stress_volume = torch.stack([i.sum(0) - for i in split_val]) + disp_stress_volume = torch.stack([i.sum(0) for i in split_val]) if inference: disp_stress_volume = disp_stress_volume.detach().cpu() @@ -158,18 +128,16 @@ def run(self, if disp_stress_volume.isnan().any() and fallback_to_grimme: grimme_disp = True else: - all_results['stress_volume'] = all_results['stress_volume'] + \ - disp_stress_volume + all_results["stress_volume"] = all_results["stress_volume"] + disp_stress_volume # if there was numerical instability with disp_grad pytorch # re-calculate everything with Grimme dispersion instead # requires dftd3 executable if grimme_disp: - e_disp, stress_disp, forces_disp = self.get_grimme_dispersion(batch=batch, - xyz=xyz) - all_results['e_disp'] = e_disp - all_results['stress_disp'] = stress_disp - all_results['forces_disp'] = forces_disp + e_disp, stress_disp, forces_disp = self.get_grimme_dispersion(batch=batch, xyz=xyz) + all_results["e_disp"] = e_disp + all_results["stress_disp"] = stress_disp + all_results["forces_disp"] = forces_disp # Normal painn stuff, part 2 @@ -178,13 +146,7 @@ def run(self, return all_results, xyz - def forward(self, - batch, - xyz=None, - requires_stress=False, - grimme_disp=False, - inference=False, - **kwargs): + def forward(self, batch, xyz=None, requires_stress=False, grimme_disp=False, inference=False, **kwargs): """ Call the model Args: @@ -193,11 +155,9 @@ def forward(self, results (dict): dictionary of predictions """ - results, _ = self.run(batch=batch, - xyz=xyz, - requires_stress=requires_stress, - grimme_disp=grimme_disp, - inference=inference) + results, _ = self.run( + batch=batch, xyz=xyz, requires_stress=requires_stress, grimme_disp=grimme_disp, inference=inference + ) return results @@ -208,39 +168,44 @@ def __init__(self, modelparams): self.functional = modelparams["functional"] self.disp_type = modelparams["disp_type"] - def forward(self, - batch, - xyz=None, - add_nacv=True, - add_grad=True, - add_gap=True, - add_u=False, - inference=False, - do_nan=True, - en_keys_for_grad=None): - + def forward( + self, + batch, + xyz=None, + add_nacv=True, + add_grad=True, + add_gap=True, + add_u=False, + inference=False, + do_nan=True, + en_keys_for_grad=None, + ): # get diabatic results - results = super().forward(batch=batch, - xyz=xyz, - add_nacv=add_nacv, - add_grad=add_grad, - add_gap=add_gap, - add_u=add_u, - inference=inference, - do_nan=do_nan, - en_keys_for_grad=en_keys_for_grad) + results = super().forward( + batch=batch, + xyz=xyz, + add_nacv=add_nacv, + add_grad=add_grad, + add_gap=add_gap, + add_u=add_u, + inference=inference, + do_nan=do_nan, + en_keys_for_grad=en_keys_for_grad, + ) xyz = results["xyz"] # get dispersion energy (I couldn't figure out how to sub-class # PainnDiabatDispersion with PainnDispersion without getting errors, # unless I put it before PainnDiabat, which isn't what I want. So # instead I just copied the logic for getting the disperson energy) - e_disp, _, _ = base_dispersion(batch=batch, - xyz=xyz, - disp_type=self.disp_type, - functional=self.functional, - nbrs=batch.get('mol_nbrs'), - mol_idx=batch.get('mol_idx')) + e_disp, _, _ = base_dispersion( + batch=batch, + xyz=xyz, + disp_type=self.disp_type, + functional=self.functional, + nbrs=batch.get("mol_nbrs"), + mol_idx=batch.get("mol_idx"), + ) # convert to kcal / mol e_disp = e_disp * const.HARTREE_TO_KCAL_MOL @@ -253,20 +218,17 @@ def forward(self, # "energy_1"], we won't have updated "energy_2" properly energy_keys = ["energy_%d" % i for i in range(len(diabat_keys))] - for key in (diagonal_diabat_keys + energy_keys): + for key in diagonal_diabat_keys + energy_keys: results[key] = results[key] + e_disp.reshape(results[key].shape) # add dispersion grads to diabatic diagonal gradients and # adiabatic gradients - disp_grad = compute_grad(inputs=xyz, - output=e_disp) + disp_grad = compute_grad(inputs=xyz, output=e_disp) - grad_keys = [key + "_grad" for key in - (diagonal_diabat_keys + energy_keys)] + grad_keys = [key + "_grad" for key in (diagonal_diabat_keys + energy_keys)] for key in grad_keys: if key in results: - results[key] = (results[key] + - disp_grad.reshape(results[key].shape)) + results[key] = results[key] + disp_grad.reshape(results[key].shape) return results diff --git a/nff/nn/models/graphconvintegration.py b/nff/nn/models/graphconvintegration.py index b920d4e1..b1cb6f65 100644 --- a/nff/nn/models/graphconvintegration.py +++ b/nff/nn/models/graphconvintegration.py @@ -1,19 +1,13 @@ -import torch import torch.nn as nn -import copy -import torch.nn.functional as F -from nff.nn.layers import Dense, GaussianSmearing -from nff.nn.modules import GraphDis, SchNetConv, BondEnergyModule, SchNetEdgeUpdate, NodeMultiTaskReadOut -from nff.nn.activations import shifted_softplus -from nff.nn.graphop import batch_and_sum, get_atoms_inside_cell +from nff.nn.graphop import batch_and_sum +from nff.nn.modules import NodeMultiTaskReadOut, SchNetConv from nff.nn.utils import get_default_readout class GraphConvIntegration(nn.Module): - """SchNet with optional aggr_weight for thermodynamic intergration - + Attributes: atom_embed (torch.nn.Embedding): Convert atomic number into an embedding vector of size n_atom_basis @@ -27,60 +21,63 @@ class GraphConvIntegration(nn.Module): {name: mod_list}, where name is the name of a property object and mod_list is a ModuleList of layers to predict that property. """ - + def __init__(self, modelparams): """Constructs a SchNet model. - + Args: modelparams (TYPE): Description """ super().__init__() - n_atom_basis = modelparams['n_atom_basis'] - n_filters = modelparams['n_filters'] - n_gaussians = modelparams['n_gaussians'] - n_convolutions = modelparams['n_convolutions'] - cutoff = modelparams['cutoff'] - trainable_gauss = modelparams.get('trainable_gauss', False) + n_atom_basis = modelparams["n_atom_basis"] + n_filters = modelparams["n_filters"] + n_gaussians = modelparams["n_gaussians"] + n_convolutions = modelparams["n_convolutions"] + cutoff = modelparams["cutoff"] + trainable_gauss = modelparams.get("trainable_gauss", False) # default predict var - readoutdict = modelparams.get('readoutdict', get_default_readout(n_atom_basis)) - post_readout = modelparams.get('post_readout', None) + readoutdict = modelparams.get("readoutdict", get_default_readout(n_atom_basis)) + post_readout = modelparams.get("post_readout", None) self.atom_embed = nn.Embedding(100, n_atom_basis, padding_idx=0) - self.convolutions = nn.ModuleList([ - SchNetConv(n_atom_basis=n_atom_basis, - n_filters=n_filters, - n_gaussians=n_gaussians, - cutoff=cutoff, - trainable_gauss=trainable_gauss) - for _ in range(n_convolutions) - ]) + self.convolutions = nn.ModuleList( + [ + SchNetConv( + n_atom_basis=n_atom_basis, + n_filters=n_filters, + n_gaussians=n_gaussians, + cutoff=cutoff, + trainable_gauss=trainable_gauss, + ) + for _ in range(n_convolutions) + ] + ) # ReadOut - self.atomwisereadout = NodeMultiTaskReadOut(multitaskdict=readoutdict, post_readout=post_readout) + self.atomwisereadout = NodeMultiTaskReadOut(multitaskdict=readoutdict, post_readout=post_readout) self.device = None def forward(self, batch, **kwargs): - """Summary - + Args: batch (dict): dictionary of props - + Returns: - dict: dionary of results + dict: dionary of results """ - r = batch['nxyz'][:, 0] - xyz = batch['nxyz'][:, 1:4] - N = batch['num_atoms'].reshape(-1).tolist() - a = batch['nbr_list'] - aggr_wgt = batch['aggr_wgt'] + r = batch["nxyz"][:, 0] + xyz = batch["nxyz"][:, 1:4] + N = batch["num_atoms"].reshape(-1).tolist() + a = batch["nbr_list"] + aggr_wgt = batch["aggr_wgt"] # offsets take care of periodic boundary conditions - offsets = batch.get('offsets', 0) + offsets = batch.get("offsets", 0) xyz.requires_grad = True @@ -92,11 +89,11 @@ def forward(self, batch, **kwargs): r = self.atom_embed(r.long()).squeeze() # update function includes periodic boundary conditions - for i, conv in enumerate(self.convolutions): + for conv in self.convolutions: dr = conv(r=r, e=e, a=a, aggr_wgt=aggr_wgt) r = r + dr r = self.atomwisereadout(r) results = batch_and_sum(r, N, list(batch.keys()), xyz) - - return results + + return results diff --git a/nff/nn/models/hybridgraph.py b/nff/nn/models/hybridgraph.py index c6cfbb67..f835486e 100644 --- a/nff/nn/models/hybridgraph.py +++ b/nff/nn/models/hybridgraph.py @@ -1,96 +1,99 @@ import torch import torch.nn as nn -import copy -import torch.nn.functional as F -from nff.nn.layers import Dense, GaussianSmearing -from nff.nn.modules import SchNetConv, SchNetEdgeUpdate, NodeMultiTaskReadOut -from nff.nn.activations import shifted_softplus from nff.nn.graphop import batch_and_sum +from nff.nn.modules import NodeMultiTaskReadOut, SchNetConv from nff.nn.utils import get_default_readout - from nff.utils.scatter import scatter_add -class HybridGraphConv(nn.Module): +class HybridGraphConv(nn.Module): def __init__(self, modelparams): super().__init__() - n_atom_basis = modelparams['n_atom_basis'] - n_filters = modelparams['n_filters'] - n_gaussians = modelparams['n_gaussians'] - trainable_gauss = modelparams.get('trainable_gauss', False) - mol_n_convolutions = modelparams['mol_n_convolutions'] - mol_cutoff = modelparams['mol_cutoff'] - sys_n_convolutions = modelparams['sys_n_convolutions'] - sys_cutoff = modelparams['sys_cutoff'] - + n_atom_basis = modelparams["n_atom_basis"] + n_filters = modelparams["n_filters"] + n_gaussians = modelparams["n_gaussians"] + trainable_gauss = modelparams.get("trainable_gauss", False) + mol_n_convolutions = modelparams["mol_n_convolutions"] + mol_cutoff = modelparams["mol_cutoff"] + sys_n_convolutions = modelparams["sys_n_convolutions"] + sys_cutoff = modelparams["sys_cutoff"] + self.power = modelparams["V_ex_power"] self.sigma = torch.nn.Parameter(torch.Tensor([modelparams["V_ex_sigma"]])) # default predict var - readoutdict = modelparams.get('readoutdict', get_default_readout(n_atom_basis)) - post_readout = modelparams.get('post_readout', None) + readoutdict = modelparams.get("readoutdict", get_default_readout(n_atom_basis)) + post_readout = modelparams.get("post_readout", None) self.atom_embed = nn.Embedding(100, n_atom_basis, padding_idx=0) - self.molecule_convolutions = nn.ModuleList([ - SchNetConv(n_atom_basis=n_atom_basis, - n_filters=n_filters, - n_gaussians=n_gaussians, - cutoff=mol_cutoff, - trainable_gauss=trainable_gauss, - dropout_rate=0.0) - for _ in range(mol_n_convolutions) - ]) - - self.system_convolutions = nn.ModuleList([ - SchNetConv(n_atom_basis=n_atom_basis, - n_filters=n_filters, - n_gaussians=n_gaussians, - cutoff=sys_cutoff, - trainable_gauss=trainable_gauss, - dropout_rate=0.0) - for _ in range(sys_n_convolutions) - ]) + self.molecule_convolutions = nn.ModuleList( + [ + SchNetConv( + n_atom_basis=n_atom_basis, + n_filters=n_filters, + n_gaussians=n_gaussians, + cutoff=mol_cutoff, + trainable_gauss=trainable_gauss, + dropout_rate=0.0, + ) + for _ in range(mol_n_convolutions) + ] + ) + + self.system_convolutions = nn.ModuleList( + [ + SchNetConv( + n_atom_basis=n_atom_basis, + n_filters=n_filters, + n_gaussians=n_gaussians, + cutoff=sys_cutoff, + trainable_gauss=trainable_gauss, + dropout_rate=0.0, + ) + for _ in range(sys_n_convolutions) + ] + ) # ReadOut - self.atomwisereadout = NodeMultiTaskReadOut(multitaskdict=readoutdict, post_readout=post_readout) + self.atomwisereadout = NodeMultiTaskReadOut(multitaskdict=readoutdict, post_readout=post_readout) self.device = None - + def SeqConv(self, node, xyz, nbr_list, conv_module, pbc_offsets=None): if pbc_offsets is None: pbc_offsets = 0 e = (xyz[nbr_list[:, 1]] - xyz[nbr_list[:, 0]] + pbc_offsets).pow(2).sum(1).sqrt()[:, None] - for i, conv in enumerate(conv_module): + for _i, conv in enumerate(conv_module): dr = conv(r=node, e=e, a=nbr_list) node = node + dr return node - + def V_ex(self, xyz, nbr_list, pbc_offsets): dist = (xyz[nbr_list[:, 1]] - xyz[nbr_list[:, 0]] + pbc_offsets).pow(2).sum(1).sqrt() - potential = ((dist.reciprocal() * self.sigma).pow(self.power)) + potential = (dist.reciprocal() * self.sigma).pow(self.power) return scatter_add(potential, nbr_list[:, 0], dim_size=xyz.shape[0])[:, None] - + def forward(self, batch, **kwargs): - r = batch['nxyz'][:, 0] - xyz = batch['nxyz'][:, 1:4] - N = batch['num_atoms'].reshape(-1).tolist() - a_mol = batch['atoms_nbr_list'] - a_sys = batch['nbr_list'] + r = batch["nxyz"][:, 0] + xyz = batch["nxyz"][:, 1:4] + N = batch["num_atoms"].reshape(-1).tolist() + a_mol = batch["atoms_nbr_list"] + a_sys = batch["nbr_list"] # offsets take care of periodic boundary conditions - offsets = batch.get('offsets', 0) # offsets only affect nbr_list + offsets = batch.get("offsets", 0) # offsets only affect nbr_list xyz.requires_grad = True node_input = self.atom_embed(r.long()).squeeze() - - # system convolution + + # system convolution r_sys = self.SeqConv(node_input, xyz, a_sys, self.system_convolutions, offsets) r_mol = self.SeqConv(node_input, xyz, a_mol, self.molecule_convolutions) - # Excluded Volume interactions - #r_ex = self.V_ex(xyz, a_sys, offsets) + # Excluded Volume interactions + # r_ex = self.V_ex(xyz, a_sys, offsets) results = self.atomwisereadout(r_sys + r_mol) - # add excluded volume interactions - #results['energy'] += r_ex + # add excluded volume interactions + # results['energy'] += r_ex results = batch_and_sum(results, N, list(batch.keys()), xyz) return results diff --git a/nff/nn/models/mace.py b/nff/nn/models/mace.py index 329a54f2..2a332c1d 100644 --- a/nff/nn/models/mace.py +++ b/nff/nn/models/mace.py @@ -8,7 +8,7 @@ from __future__ import annotations from pathlib import Path -from typing import List, Literal, Union +from typing import List, Literal import torch from e3nn import o3 @@ -70,7 +70,7 @@ def forward( compute_virials: bool = False, compute_displacement: bool = False, **kwargs, - ) -> dict: # noqa: W0221 + ) -> dict: """Forward pass through the model and ouput in NFF format Args: @@ -85,10 +85,7 @@ def forward( Returns: dict: dict of output from the forward pass in NFF format """ - if isinstance(batch, dict): - data = self.convert_batch_to_data(batch) - else: - data = batch + data = self.convert_batch_to_data(batch) if isinstance(batch, dict) else batch output = super().forward( data, training=training, # set the training mode to the value of the wrapper @@ -123,11 +120,8 @@ def convert_batch_to_data(self, batch: dict) -> torch_geometric.data.Data: props = batch else: raise ValueError("Batch must be a dictionary") - if props["num_atoms"].dim() == 0: - num_atoms = props["num_atoms"].unsqueeze(0) - else: - num_atoms = props["num_atoms"] - cum_idx_list = [0] + torch.cumsum(num_atoms, 0).tolist() + num_atoms = props["num_atoms"].unsqueeze(0) if props["num_atoms"].dim() == 0 else props["num_atoms"] + cum_idx_list = [0, *torch.cumsum(num_atoms, 0).tolist()] z_table = AtomicNumberTable([int(z) for z in self.atomic_numbers]) dataset = [] @@ -137,9 +131,9 @@ def convert_batch_to_data(self, batch: dict) -> torch_geometric.data.Data: positions = props.get("nxyz")[node_idx, 1:].detach().cpu().numpy() numbers = props.get("nxyz")[node_idx, 0].long().detach().cpu().numpy() - if "cell" in props.keys(): + if "cell" in props: cell = props["cell"][3 * i : 3 * i + 3].detach().cpu().numpy() - elif "lattice" in props.keys(): + elif "lattice" in props: cell = props["lattice"][3 * i : 3 * i + 3].detach().cpu().numpy() else: raise ValueError("No cell or lattice found in batch") @@ -235,7 +229,7 @@ def from_dict(cls, state_dict: dict, **hparams) -> NffScaleMACE: return model @classmethod - def from_file(cls, path: str, map_location: str = None, **kwargs) -> NffScaleMACE: + def from_file(cls, path: str, map_location: str | None = None, **kwargs) -> NffScaleMACE: """Load the model from checkpoint created by pytorch lightning. Args: @@ -279,7 +273,8 @@ def load_foundations( default_dtype = model_dtype if model_dtype != default_dtype: print( - f"Default dtype {default_dtype} does not match model dtype {model_dtype}, converting models to {default_dtype}." + f"Default dtype {default_dtype} does not match model dtype {model_dtype}, " + f"converting models to {default_dtype}." ) if default_dtype == "float64": mace_model.double() @@ -293,7 +288,7 @@ def load_foundations( def load( cls, model_name: str = "medium", - map_location: str = None, + map_location: str | None = None, **kwargs, ) -> NffScaleMACE: """Load the model from checkpoint created by pytorch lightning. @@ -316,7 +311,7 @@ def load( def reduce_foundations( model_foundations: NffScaleMACE, - table: Union[List, AtomicNumberTable], + table: List | AtomicNumberTable, load_readout=False, use_shift=True, use_scale=True, @@ -324,7 +319,7 @@ def reduce_foundations( num_conv_tp_weights=4, num_products=2, num_contraction=2, -) -> "NffScaleMACE": +) -> NffScaleMACE: """Reducing the model by extracting elements of interests Refer to the original paper to understand the architecture: "https://openreview.net/forum?id=YPpSngE-ZU" @@ -349,7 +344,7 @@ def reduce_foundations( reduced_atomic_numbers = table table = get_atomic_number_table_from_zs(table) elif isinstance(AtomicNumberTable): - reduced_atomic_numbers = [n for n in table.zs] + reduced_atomic_numbers = list(table.zs) z_table = AtomicNumberTable([int(z) for z in model_foundations.atomic_numbers]) new_z_table = table num_species_foundations = len(z_table.zs) @@ -459,7 +454,7 @@ def restore_foundations( num_conv_tp_weights=4, num_products=2, num_contraction=2, -) -> "NffScaleMACE": +) -> NffScaleMACE: """Restore back to foundational model from reduced model Refer to the original paper to understand the architecture: "https://openreview.net/forum?id=YPpSngE-ZU" @@ -510,7 +505,8 @@ def restore_foundations( model.interactions[i].linear.weight.clone() ) # Assuming 'model' and 'model_foundations' are instances of some torch.nn.Module - # And assuming the other variables (num_channels_foundation, num_species_foundations, etc.) are correctly defined + # And assuming the other variables (num_channels_foundation, + # num_species_foundations, etc.) are correctly defined if model.interactions[i].__class__.__name__ == "RealAgnosticResidualInteractionBlock": for k, index in enumerate(indices_weights): @@ -551,10 +547,10 @@ def restore_foundations( original_weights_max.data[index, :, :] = torch.nn.Parameter(new_weights_max) original_weights_list = model_foundations.products[i].symmetric_contractions.contractions[j].weights - for l in range(num_contraction): # Assuming 2 weights in each contractions - original_weights = original_weights_list[l] + for n in range(num_contraction): # Assuming 2 weights in each contractions + original_weights = original_weights_list[n] for k, index in enumerate(indices_weights): - new_weights = model.products[i].symmetric_contractions.contractions[j].weights[l][k, :, :].clone() + new_weights = model.products[i].symmetric_contractions.contractions[j].weights[n][k, :, :].clone() original_weights.data[index, :, :] = torch.nn.Parameter(new_weights) model_foundations.products[i].linear.weight = torch.nn.Parameter(model.products[i].linear.weight.clone()) diff --git a/nff/nn/models/painn.py b/nff/nn/models/painn.py index 66ad8711..bccee0c6 100644 --- a/nff/nn/models/painn.py +++ b/nff/nn/models/painn.py @@ -1,34 +1,33 @@ -from torch import nn -import numpy as np import copy -from nff.utils.tools import make_directed + +import numpy as np +import torch +from torch import nn + +from nff.nn.layers import Diagonalize, ExpNormalBasis +from nff.nn.modules.diabat import AdiabaticReadout, DiabaticReadout from nff.nn.modules.painn import ( - MessageBlock, - UpdateBlock, EmbeddingBlock, + MessageBlock, + NbrEmbeddingBlock, ReadoutBlock, - ReadoutBlock_Vec, - ReadoutBlock_Tuple, ReadoutBlock_Complex, + ReadoutBlock_Tuple, + ReadoutBlock_Vec, TransformerMessageBlock, - NbrEmbeddingBlock, + UpdateBlock, ) from nff.nn.modules.schnet import ( AttentionPool, - SumPool, - MolFpPool, MeanPool, - get_rij, + MolFpPool, + SumPool, add_embedding, add_stress, + get_rij, ) - -from nff.nn.modules.diabat import DiabaticReadout, AdiabaticReadout -from nff.nn.layers import Diagonalize, ExpNormalBasis -from nff.utils.scatter import scatter_add, compute_grad -import torch - -import pdb +from nff.utils.scatter import compute_grad, scatter_add +from nff.utils.tools import make_directed POOL_DIC = { "sum": SumPool, @@ -84,19 +83,12 @@ def __init__(self, modelparams): ] ) self.update_blocks = nn.ModuleList( - [ - UpdateBlock( - feat_dim=feat_dim, activation=activation, dropout=conv_dropout - ) - for _ in range(num_conv) - ] + [UpdateBlock(feat_dim=feat_dim, activation=activation, dropout=conv_dropout) for _ in range(num_conv)] ) self.output_keys = output_keys # no skip connection in original paper - self.skip = modelparams.get( - "skip_connection", {key: False for key in self.output_keys} - ) + self.skip = modelparams.get("skip_connection", {key: False for key in self.output_keys}) self.num_readouts = num_conv if any(self.skip.values()) else 1 self.readout_blocks = nn.ModuleList( @@ -159,9 +151,7 @@ def atomwise(self, batch, xyz=None): for i, message_block in enumerate(self.message_blocks): update_block = self.update_blocks[i] - ds_message, dv_message = message_block( - s_j=s_i, v_j=v_i, r_ij=r_ij, nbrs=nbrs - ) + ds_message, dv_message = message_block(s_j=s_i, v_j=v_i, r_ij=r_ij, nbrs=nbrs) s_i = s_i + ds_message v_i = v_i + dv_message @@ -214,7 +204,7 @@ def pool(self, batch, atomwise_out, xyz, r_ij, nbrs, inference=False): for key, pool_obj in self.pool_dic.items(): grad_key = f"{key}_grad" grad_keys = [grad_key] if (grad_key in self.grad_keys) else [] - if "stress" in self.grad_keys and not "stress" in all_results: + if "stress" in self.grad_keys and "stress" not in all_results: grad_keys.append("stress") results = pool_obj( batch=batch, @@ -233,8 +223,8 @@ def pool(self, batch, atomwise_out, xyz, r_ij, nbrs, inference=False): # transfer those results that don't get pooled if inference: atomwise_out = batch_detach(atomwise_out) - for key in atomwise_out.keys(): - if key not in all_results.keys(): + for key in atomwise_out: + if key not in all_results: all_results[key] = atomwise_out[key] return all_results, xyz @@ -248,10 +238,8 @@ def add_delta(self, all_results): all_results[key] = all_results[e_i] - all_results[e_j] grad_keys = [e_i + "_grad", e_j + "_grad"] delta_grad_key = "_".join(grad_keys) + "_delta" - if all([grad_key in all_results for grad_key in grad_keys]): - all_results[delta_grad_key] = ( - all_results[grad_keys[0]] - all_results[grad_keys[1]] - ) + if all(grad_key in all_results for grad_key in grad_keys): + all_results[delta_grad_key] = all_results[grad_keys[0]] - all_results[grad_keys[1]] return all_results def V_ex(self, r_ij, nbr_list, xyz): @@ -286,14 +274,10 @@ def run( ) if requires_embedding: - all_results = add_embedding( - atomwise_out=atomwise_out, all_results=all_results - ) + all_results = add_embedding(atomwise_out=atomwise_out, all_results=all_results) if requires_stress: - all_results = add_stress( - batch=batch, all_results=all_results, nbrs=nbrs, r_ij=r_ij - ) + all_results = add_stress(batch=batch, all_results=all_results, nbrs=nbrs, r_ij=r_ij) if getattr(self, "compute_delta", False): all_results = self.add_delta(all_results) @@ -359,13 +343,9 @@ def __init__(self, modelparams): ) if same_message_blocks: - self.message_blocks = nn.ModuleList( - [self.message_blocks[0]] * len(self.message_blocks) - ) + self.message_blocks = nn.ModuleList([self.message_blocks[0]] * len(self.message_blocks)) - self.embed_block = NbrEmbeddingBlock( - feat_dim=feat_dim, dropout=conv_dropout, rbf=rbf - ) + self.embed_block = NbrEmbeddingBlock(feat_dim=feat_dim, dropout=conv_dropout, rbf=rbf) class PainnDiabat(Painn): @@ -471,7 +451,7 @@ class PainnGapToAbs(nn.Module): """ def __init__(self, ground_model, gap_model, subtract_gap): - super(PainnGapToAbs, self).__init__() + super().__init__() self.ground_model = ground_model self.gap_model = gap_model @@ -484,18 +464,12 @@ def get_model_attr(self, model, key): return getattr(model, key) def set_model_attr(self, model, key, val): - if hasattr(model, "painn_model"): - sub_model = model.painn_model - else: - sub_model = model + sub_model = model.painn_model if hasattr(model, "painn_model") else model setattr(sub_model, key, val) def get_grad_keys(self, model): - if hasattr(model, "painn_model"): - grad_keys = model.painn_model.grad_keys - else: - grad_keys = model.grad_keys + grad_keys = model.painn_model.grad_keys if hasattr(model, "painn_model") else model.grad_keys return set(grad_keys) @property @@ -521,17 +495,10 @@ def forward(self, *args, **kwargs): combined_results = {} for key in common_keys: - pool_dics = [ - self.get_model_attr(model, "pool_dic") for model in self.models - ] + pool_dics = [self.get_model_attr(model, "pool_dic") for model in self.models] - in_pool = all([key in dic for dic in pool_dics]) - in_grad = all( - [ - key in self.get_model_attr(model, "grad_keys") - for model in self.models - ] - ) + in_pool = all(key in dic for dic in pool_dics) + in_grad = all(key in self.get_model_attr(model, "grad_keys") for model in self.models) common = in_pool or in_grad @@ -565,9 +532,7 @@ def __init__(self, modelparams): self.output_vec_keys = output_vec_keys # no skip connection in original paper - self.skip_vec = modelparams.get( - "skip_vec_connection", {key: False for key in self.output_vec_keys} - ) + self.skip_vec = modelparams.get("skip_vec_connection", {key: False for key in self.output_vec_keys}) num_vec_readouts = modelparams["num_conv"] if any(self.skip.values()) else 1 self.readout_vec_blocks = nn.ModuleList( @@ -608,9 +573,7 @@ def atomwise(self, batch, xyz=None): for i, message_block in enumerate(self.message_blocks): update_block = self.update_blocks[i] - ds_message, dv_message = message_block( - s_j=s_i, v_j=v_i, r_ij=r_ij, nbrs=nbrs - ) + ds_message, dv_message = message_block(s_j=s_i, v_j=v_i, r_ij=r_ij, nbrs=nbrs) s_i = s_i + ds_message v_i = v_i + dv_message @@ -659,11 +622,10 @@ def atomwise(self, batch, xyz=None): return results, xyz, r_ij, nbrs - + class Painn_VecOut2(Painn_VecOut): # unlike Painn_VecOut this uses 2 equivariant blocks for each output - def __init__(self, - modelparams): + def __init__(self, modelparams): """ Args: modelparams (dict): dictionary of model parameters @@ -678,32 +640,32 @@ def __init__(self, readout_dropout = modelparams.get("readout_dropout", 0) means = modelparams.get("means") stddevs = modelparams.get("stddevs") - + self.output_vec_keys = output_vec_keys # no skip connection in original paper - self.skip_vec = modelparams.get("skip_vec_connection", - {key: False for key - in self.output_vec_keys}) + self.skip_vec = modelparams.get("skip_vec_connection", {key: False for key in self.output_vec_keys}) - num_vec_readouts = (modelparams["num_conv"] if any(self.skip.values()) - else 1) + num_vec_readouts = modelparams["num_conv"] if any(self.skip.values()) else 1 self.readout_vec_blocks = nn.ModuleList( - [ReadoutBlock_Vec2(feat_dim=feat_dim, - output_keys=output_vec_keys, - activation=activation, - dropout=readout_dropout, - means=means, - stddevs=stddevs) - for _ in range(num_vec_readouts)] + [ + ReadoutBlock_Vec2( + feat_dim=feat_dim, + output_keys=output_vec_keys, + activation=activation, + dropout=readout_dropout, + means=means, + stddevs=stddevs, + ) + for _ in range(num_vec_readouts) + ] ) - - + + class Painn_NAC_OuterProd(Painn_VecOut2): # This model attempts to learn non-adiabatic coupling vectors # as suggested by Jeremy Richardson, as eigenvector # of an outer product matrix - def __init__(self, - modelparams): + def __init__(self, modelparams): """ Args: modelparams (dict): dictionary of model parameters @@ -711,25 +673,19 @@ def __init__(self, """ super().__init__(modelparams) - - def get_nac(self, - all_results, - batch, - xyz): - + + def get_nac(self, all_results, batch, xyz): N = batch["num_atoms"].detach().cpu().tolist() xyz_s = torch.split(xyz, N) - + for key in self.output_vec_keys: - mats = [] nacs = [] nu_s = torch.split(all_results[key], N) for nu, r in zip(nu_s, xyz_s): - mat = (torch.outer(r.reshape(-1), nu.reshape(-1)) - + torch.outer(nu.reshape(-1), r.reshape(-1))) + mat = torch.outer(r.reshape(-1), nu.reshape(-1)) + torch.outer(nu.reshape(-1), r.reshape(-1)) mats.append(mat) - + eigvals, eigvecs = torch.linalg.eigh(mat) real_vals = torch.abs(eigvals) phase = eigvals[0] / real_vals[0] @@ -738,23 +694,16 @@ def get_nac(self, max_idx = torch.argmax(real_vals) nac = real_vecs[:, max_idx] * torch.sqrt(real_vals[max_idx]) nacs.append(nac.reshape(-1, 3)) - + all_results[key] = torch.cat(nacs) - all_results[key+"_mat"] = tuple(mats) - + all_results[key + "_mat"] = tuple(mats) + return all_results, xyz - - def run(self, - batch, - xyz=None, - requires_embedding=False, - requires_stress=False, - inference=False): + def run(self, batch, xyz=None, requires_embedding=False, requires_stress=False, inference=False): from nff.train import batch_detach - - atomwise_out, xyz, r_ij, nbrs = self.atomwise(batch=batch, - xyz=xyz) + + atomwise_out, xyz, r_ij, nbrs = self.atomwise(batch=batch, xyz=xyz) if getattr(self, "excl_vol", None): # Excluded Volume interactions @@ -762,30 +711,21 @@ def run(self, for key in self.output_keys: atomwise_out[key] += r_ex - pooled_results, xyz = self.pool(batch=batch, - atomwise_out=atomwise_out, - xyz=xyz, - r_ij=r_ij, - nbrs=nbrs, - inference=False) - - all_results, xyz = self.get_nac(all_results=pooled_results, - batch=batch, - xyz=xyz) + pooled_results, xyz = self.pool( + batch=batch, atomwise_out=atomwise_out, xyz=xyz, r_ij=r_ij, nbrs=nbrs, inference=False + ) + + all_results, xyz = self.get_nac(all_results=pooled_results, batch=batch, xyz=xyz) if requires_embedding: - all_results = add_embedding(atomwise_out=atomwise_out, - all_results=all_results) + all_results = add_embedding(atomwise_out=atomwise_out, all_results=all_results) if requires_stress: - all_results = add_stress(batch=batch, - all_results=all_results, - nbrs=nbrs, - r_ij=r_ij) + all_results = add_stress(batch=batch, all_results=all_results, nbrs=nbrs, r_ij=r_ij) if getattr(self, "compute_delta", False): all_results = self.add_delta(all_results) - + if inference: batch_detach(all_results) @@ -793,9 +733,7 @@ def run(self, class Painn_Complex(Painn): - - def __init__(self, - modelparams): + def __init__(self, modelparams): """ Args: modelparams (dict): dictionary of model parameters @@ -811,27 +749,26 @@ def __init__(self, activation = modelparams["activation"] readout_dropout = modelparams.get("readout_dropout", 0) - num_cmplx_readouts = (modelparams["num_conv"] if any(self.skip.values()) - else 1) + num_cmplx_readouts = modelparams["num_conv"] if any(self.skip.values()) else 1 self.readout_cmplx_blocks = nn.ModuleList( - [ReadoutBlock_Complex(feat_dim=feat_dim, - output_keys=self.output_cmplx_keys, - activation=activation, - dropout=readout_dropout) - for _ in range(num_cmplx_readouts)] + [ + ReadoutBlock_Complex( + feat_dim=feat_dim, + output_keys=self.output_cmplx_keys, + activation=activation, + dropout=readout_dropout, + ) + for _ in range(num_cmplx_readouts) + ] ) - def atomwise(self, - batch, - xyz=None): - + def atomwise(self, batch, xyz=None): # for backwards compatability if isinstance(self.skip, bool): - self.skip = {key: self.skip - for key in self.output_keys} + self.skip = {key: self.skip for key in self.output_keys} - nbrs, _ = make_directed(batch['nbr_list']) - nxyz = batch['nxyz'] + nbrs, _ = make_directed(batch["nbr_list"]) + nxyz = batch["nxyz"] if xyz is None: xyz = nxyz[:, 1:] @@ -842,28 +779,19 @@ def atomwise(self, # get r_ij including offsets and excluding # anything in the neighbor skin self.set_cutoff() - r_ij, nbrs = get_rij(xyz=xyz, - batch=batch, - nbrs=nbrs, - cutoff=self.cutoff) - - s_i, v_i = self.embed_block(z_numbers, - nbrs=nbrs, - r_ij=r_ij) + r_ij, nbrs = get_rij(xyz=xyz, batch=batch, nbrs=nbrs, cutoff=self.cutoff) + + s_i, v_i = self.embed_block(z_numbers, nbrs=nbrs, r_ij=r_ij) results = {} for i, message_block in enumerate(self.message_blocks): update_block = self.update_blocks[i] - ds_message, dv_message = message_block(s_j=s_i, - v_j=v_i, - r_ij=r_ij, - nbrs=nbrs) + ds_message, dv_message = message_block(s_j=s_i, v_j=v_i, r_ij=r_ij, nbrs=nbrs) s_i = s_i + ds_message v_i = v_i + dv_message - ds_update, dv_update = update_block(s_i=s_i, - v_i=v_i) + ds_update, dv_update = update_block(s_i=s_i, v_i=v_i) s_i = s_i + ds_update v_i = v_i + dv_update @@ -902,8 +830,8 @@ def atomwise(self, if not skip: results[key] = new_cmplx_results[key] - results['features'] = s_i - results['features_vec'] = v_i + results["features"] = s_i + results["features_vec"] = v_i return results, xyz, r_ij, nbrs @@ -927,13 +855,9 @@ def __init__(self, modelparams): readout_dropout = modelparams.get("readout_dropout", 0) # no skip connection in original paper - self.skip_tuple = modelparams.get( - "skip_tuple_connection", {key: False for key in self.output_tuple_keys} - ) + self.skip_tuple = modelparams.get("skip_tuple_connection", {key: False for key in self.output_tuple_keys}) - num_tuple_readouts = ( - modelparams["num_conv"] if any(self.skip_tuple.values()) else 1 - ) + num_tuple_readouts = modelparams["num_conv"] if any(self.skip_tuple.values()) else 1 self.readout_tuple_blocks = nn.ModuleList( [ ReadoutBlock_Tuple( @@ -974,9 +898,7 @@ def atomwise(self, batch, xyz=None): for i, message_block in enumerate(self.message_blocks): update_block = self.update_blocks[i] - ds_message, dv_message = message_block( - s_j=s_i, v_j=v_i, r_ij=r_ij, nbrs=nbrs - ) + ds_message, dv_message = message_block(s_j=s_i, v_j=v_i, r_ij=r_ij, nbrs=nbrs) s_i = s_i + ds_message v_i = v_i + dv_message @@ -1060,13 +982,9 @@ def __init__(self, modelparams): readout_dropout = modelparams.get("readout_dropout", 0) # no skip connection in original paper - self.skip_tuple = modelparams.get( - "skip_tuple_connection", {key: False for key in self.output_tuple_keys} - ) + self.skip_tuple = modelparams.get("skip_tuple_connection", {key: False for key in self.output_tuple_keys}) - num_tuple_readouts = ( - modelparams["num_conv"] if any(self.skip_tuple.values()) else 1 - ) + num_tuple_readouts = modelparams["num_conv"] if any(self.skip_tuple.values()) else 1 # overwrite what has been done before self.readout_tuple_blocks = nn.ModuleList( [ @@ -1102,15 +1020,11 @@ def adibatic_energies( omega = all_results[wCP_keys[0]] batch_size = len(omega) - C_mat0 = torch.zeros( - batch_size, num_states, num_states, device=omega.device - ) + C_mat0 = torch.zeros(batch_size, num_states, num_states, device=omega.device) for mat in C_mat0: mat.fill_diagonal_(1) - C_mat = C_mat0 * omega.reshape(-1, 1, 1) + torch.diag( - torch.ones(num_states - 1), -1 - ).to(omega.device) + C_mat = C_mat0 * omega.reshape(-1, 1, 1) + torch.diag(torch.ones(num_states - 1), -1).to(omega.device) for idx, coef in enumerate(wCP_keys[1:]): C_mat[:, idx, -1] = -all_results[coef] @@ -1128,7 +1042,7 @@ def adibatic_energies( output = all_results[key] grad = compute_grad(output=output, inputs=xyz) all_results[grad_key] = grad - + return all_results, xyz def run( @@ -1158,19 +1072,13 @@ def run( inference=False, ) - all_results, xyz = self.adibatic_energies( - all_results=intermediate_results, xyz=xyz - ) + all_results, xyz = self.adibatic_energies(all_results=intermediate_results, xyz=xyz) if requires_embedding: - all_results = add_embedding( - atomwise_out=atomwise_out, all_results=all_results - ) + all_results = add_embedding(atomwise_out=atomwise_out, all_results=all_results) if requires_stress: - all_results = add_stress( - batch=batch, all_results=all_results, nbrs=nbrs, r_ij=r_ij - ) + all_results = add_stress(batch=batch, all_results=all_results, nbrs=nbrs, r_ij=r_ij) if getattr(self, "compute_delta", False): all_results = self.add_delta(all_results) @@ -1384,9 +1292,7 @@ def forward( continue val = results[key] split_vals = torch.split(val, batch["num_atoms"].tolist()) - final_vals = torch.stack( - [split_val.sum(0).reshape(3) for split_val in split_vals] - ) + final_vals = torch.stack([split_val.sum(0).reshape(3) for split_val in split_vals]) results[key] = final_vals return results diff --git a/nff/nn/models/schnet.py b/nff/nn/models/schnet.py index d8d79ce2..913b241b 100644 --- a/nff/nn/models/schnet.py +++ b/nff/nn/models/schnet.py @@ -1,23 +1,14 @@ from torch import nn +from nff.nn.graphop import batch_and_sum from nff.nn.layers import DEFAULT_DROPOUT_RATE -from nff.nn.modules import ( - SchNetConv, - NodeMultiTaskReadOut, - get_rij, - add_stress -) - - +from nff.nn.modules import NodeMultiTaskReadOut, SchNetConv, add_stress, get_rij from nff.nn.modules.diabat import DiabaticReadout -from nff.nn.graphop import batch_and_sum from nff.nn.utils import get_default_readout - from nff.utils.scatter import scatter_add class SchNet(nn.Module): - """SchNet implementation with continous filter. Attributes: @@ -62,7 +53,7 @@ def __init__(self, modelparams): 'n_convolutions': 4, 'cutoff': 5.0, 'trainable_gauss': True, - 'readoutdict': readoutdict, + 'readoutdict': readoutdict, 'dropout_rate': 0.2 } @@ -87,8 +78,7 @@ def __init__(self, modelparams): self.atom_embed = nn.Embedding(100, n_atom_basis, padding_idx=0) - readoutdict = modelparams.get( - "readoutdict", get_default_readout(n_atom_basis)) + readoutdict = modelparams.get("readoutdict", get_default_readout(n_atom_basis)) post_readout = modelparams.get("post_readout", None) # convolutions @@ -107,22 +97,17 @@ def __init__(self, modelparams): ) # ReadOut - self.atomwisereadout = NodeMultiTaskReadOut( - multitaskdict=readoutdict, post_readout=post_readout - ) + self.atomwisereadout = NodeMultiTaskReadOut(multitaskdict=readoutdict, post_readout=post_readout) self.device = None self.cutoff = cutoff def set_cutoff(self): if hasattr(self, "cutoff"): return - gauss_centers = (self.convolutions[0].moduledict - ['message_edge_filter'][0].offsets) + gauss_centers = self.convolutions[0].moduledict["message_edge_filter"][0].offsets self.cutoff = gauss_centers[-1] - gauss_centers[0] - def convolve(self, - batch, - xyz=None): + def convolve(self, batch, xyz=None): """ Apply the convolutional layers to the batch. @@ -142,7 +127,7 @@ def convolve(self, if xyz is None: xyz = batch["nxyz"][:, 1:4] - if xyz.requires_grad == False: + if not xyz.requires_grad: xyz.requires_grad = True r = batch["nxyz"][:, 0] @@ -152,34 +137,26 @@ def convolve(self, # get r_ij including offsets and excluding # anything in the neighbor skin self.set_cutoff() - r_ij, a = get_rij(xyz=xyz, - batch=batch, - nbrs=a, - cutoff=self.cutoff) + r_ij, a = get_rij(xyz=xyz, batch=batch, nbrs=a, cutoff=self.cutoff) dist = r_ij.pow(2).sum(1).sqrt() e = dist[:, None] r = self.atom_embed(r.long()).squeeze() # update function includes periodic boundary conditions - for i, conv in enumerate(self.convolutions): + for conv in self.convolutions: dr = conv(r=r, e=e, a=a) r = r + dr return r, N, xyz, r_ij, a def V_ex(self, r_ij, nbr_list, xyz): - dist = (r_ij).pow(2).sum(1).sqrt() - potential = ((dist.reciprocal() * self.sigma).pow(self.power)) + potential = (dist.reciprocal() * self.sigma).pow(self.power) return scatter_add(potential, nbr_list[:, 0], dim_size=xyz.shape[0])[:, None] - def forward(self, - batch, - xyz=None, - requires_stress=False, - **kwargs): + def forward(self, batch, xyz=None, requires_stress=False, **kwargs): """Summary Args: batch (dict): dictionary of props @@ -196,47 +173,47 @@ def forward(self, if getattr(self, "excl_vol", None): # Excluded Volume interactions r_ex = self.V_ex(r_ij, a, xyz) - r['energy'] += r_ex + r["energy"] += r_ex results = batch_and_sum(r, N, list(batch.keys()), xyz) if requires_stress: - results = add_stress(batch=batch, - all_results=results, - nbrs=a, - r_ij=r_ij) + results = add_stress(batch=batch, all_results=results, nbrs=a, r_ij=r_ij) return results class SchNetDiabat(SchNet): def __init__(self, modelparams): - super().__init__(modelparams) self.diabatic_readout = DiabaticReadout( diabat_keys=modelparams["diabat_keys"], grad_keys=modelparams["grad_keys"], - energy_keys=modelparams["output_keys"]) - - def forward(self, - batch, - xyz=None, - add_nacv=False, - add_grad=True, - add_gap=True, - extra_grads=None, - try_speedup=False, - **kwargs): + energy_keys=modelparams["output_keys"], + ) + def forward( + self, + batch, + xyz=None, + add_nacv=False, + add_grad=True, + add_gap=True, + extra_grads=None, + try_speedup=False, + **kwargs, + ): r, N, xyz = self.convolve(batch, xyz) output = self.atomwisereadout(r) - results = self.diabatic_readout(batch=batch, - output=output, - xyz=xyz, - add_nacv=add_nacv, - add_grad=add_grad, - add_gap=add_gap, - extra_grads=extra_grads, - try_speedup=try_speedup) + results = self.diabatic_readout( + batch=batch, + output=output, + xyz=xyz, + add_nacv=add_nacv, + add_grad=add_grad, + add_gap=add_gap, + extra_grads=extra_grads, + try_speedup=try_speedup, + ) return results diff --git a/nff/nn/models/schnet_features.py b/nff/nn/models/schnet_features.py index d8fe6165..d6169873 100644 --- a/nff/nn/models/schnet_features.py +++ b/nff/nn/models/schnet_features.py @@ -3,9 +3,9 @@ from torch.nn import Sequential from nff.data.graphs import get_bond_idx -from nff.nn.models.conformers import WeightedConformers -from nff.nn.modules import SchNetEdgeFilter, MixedSchNetConv from nff.nn.layers import Dense +from nff.nn.models.conformers import WeightedConformers +from nff.nn.modules import MixedSchNetConv, SchNetEdgeFilter from nff.utils.tools import layer_types, make_directed @@ -52,7 +52,7 @@ def __init__(self, modelparams): n_filters=n_filters, dropout_rate=dropout_rate, n_bond_hidden=n_bond_hidden, - activation=activation + activation=activation, ) for _ in range(n_convolutions) ] @@ -66,35 +66,23 @@ def __init__(self, modelparams): trainable_gauss=trainable_gauss, n_filters=n_filters, dropout_rate=dropout_rate, - activation=activation) + activation=activation, + ) # for converting bond features to hidden feature vectors self.bond_filter = Sequential( - Dense( - in_features=n_bond_features, - out_features=n_bond_hidden, - dropout_rate=dropout_rate), + Dense(in_features=n_bond_features, out_features=n_bond_hidden, dropout_rate=dropout_rate), layer_types[activation](), - Dense( - in_features=n_bond_hidden, - out_features=n_bond_hidden, - dropout_rate=dropout_rate) + Dense(in_features=n_bond_hidden, out_features=n_bond_hidden, dropout_rate=dropout_rate), ) self.atom_filter = Sequential( - Dense( - in_features=n_atom_basis, - out_features=n_atom_hidden, - dropout_rate=dropout_rate), + Dense(in_features=n_atom_basis, out_features=n_atom_hidden, dropout_rate=dropout_rate), layer_types[activation](), - Dense( - in_features=n_atom_hidden, - out_features=n_atom_hidden, - dropout_rate=dropout_rate) + Dense(in_features=n_atom_hidden, out_features=n_atom_hidden, dropout_rate=dropout_rate), ) - def find_bond_idx(self, - batch): + def find_bond_idx(self, batch): """ Get `bond_idx`, which map bond indices to indices in the neighbor list. @@ -110,19 +98,14 @@ def find_bond_idx(self, bond_idx = batch["bond_idx"] if not was_directed: nbr_dim = nbr_list.shape[0] - bond_idx = torch.cat([bond_idx, - bond_idx + nbr_dim // 2]) + bond_idx = torch.cat([bond_idx, bond_idx + nbr_dim // 2]) else: bonded_nbr_list = batch["bonded_nbr_list"] bonded_nbr_list, _ = make_directed(bonded_nbr_list) bond_idx = get_bond_idx(bonded_nbr_list, nbr_list) return bond_idx - def convolve_sub_batch(self, - batch, - xyz=None, - xyz_grad=False, - **kwargs): + def convolve_sub_batch(self, batch, xyz=None, xyz_grad=False, **kwargs): """ Apply the convolutional layers to a sub-batch. @@ -152,8 +135,7 @@ def convolve_sub_batch(self, bond_dim = bond_features.shape[1] num_pairs = a.shape[0] - bond_edge_features = torch.zeros(num_pairs, bond_dim - ).to(a.device) + bond_edge_features = torch.zeros(num_pairs, bond_dim).to(a.device) bond_idx = self.find_bond_idx(batch) bond_edge_features[bond_idx] = bond_features @@ -161,24 +143,22 @@ def convolve_sub_batch(self, # offsets take care of periodic boundary conditions offsets = batch.get("offsets", 0) # to deal with any shape mismatches - if hasattr(offsets, 'max') and offsets.max() == 0: + if hasattr(offsets, "max") and offsets.max() == 0: offsets = 0 if "distances" in batch: distances = batch["distances"][:, None] else: - distances = (xyz[a[:, 0]] - xyz[a[:, 1]] - - offsets).pow(2).sum(1).sqrt()[:, None] + distances = (xyz[a[:, 0]] - xyz[a[:, 1]] - offsets).pow(2).sum(1).sqrt()[:, None] distance_feats = self.distance_filter(distances) - e = torch.cat([bond_edge_features, distance_feats], - dim=-1) + e = torch.cat([bond_edge_features, distance_feats], dim=-1) r = self.atom_filter(batch["atom_features"]) # update function includes periodic boundary conditions - for i, conv in enumerate(self.convolutions): + for conv in self.convolutions: dr = conv(r=r, e=e, a=a) r = r + dr diff --git a/nff/nn/models/spooky.py b/nff/nn/models/spooky.py index 445aea33..ea680ff1 100644 --- a/nff/nn/models/spooky.py +++ b/nff/nn/models/spooky.py @@ -1,17 +1,23 @@ import torch from torch import nn -from nff.utils.scatter import compute_grad -from nff.utils.tools import make_directed -from nff.utils import constants as const -from nff.nn.modules.spooky import (DEFAULT_DROPOUT, DEFAULT_ACTIVATION, - DEFAULT_MAX_Z, DEFAULT_RES_LAYERS, - CombinedEmbedding, InteractionBlock, - AtomwiseReadout, Electrostatics, - NuclearRepulsion, get_dipole) -from nff.nn.modules.schnet import get_rij, get_offsets - from nff.nn.models.spooky_net_source.spookynet import SpookyNet as SourceSpooky +from nff.nn.modules.schnet import get_offsets, get_rij +from nff.nn.modules.spooky import ( + DEFAULT_ACTIVATION, + DEFAULT_DROPOUT, + DEFAULT_MAX_Z, + DEFAULT_RES_LAYERS, + AtomwiseReadout, + CombinedEmbedding, + Electrostatics, + InteractionBlock, + NuclearRepulsion, + get_dipole, +) +from nff.utils import constants as const +from nff.utils.scatter import compute_grad +from nff.utils.tools import make_directed def default(val, def_val): @@ -20,23 +26,17 @@ def default(val, def_val): def parse_optional(modelparams): - dropout = default(modelparams.get('dropout'), - DEFAULT_DROPOUT) - activation = default(modelparams.get('activation'), - DEFAULT_ACTIVATION) - max_z = default(modelparams.get('max_z'), DEFAULT_MAX_Z) - residual_layers = default(modelparams.get('residual_layers'), - DEFAULT_RES_LAYERS) + dropout = default(modelparams.get("dropout"), DEFAULT_DROPOUT) + activation = default(modelparams.get("activation"), DEFAULT_ACTIVATION) + max_z = default(modelparams.get("max_z"), DEFAULT_MAX_Z) + residual_layers = default(modelparams.get("residual_layers"), DEFAULT_RES_LAYERS) return dropout, activation, max_z, residual_layers def parse_add_ons(modelparams): - add_nuc_keys = default(modelparams.get('add_nuc_keys'), - modelparams['output_keys']) - add_elec_keys = default(modelparams.get('add_elec_keys'), - modelparams['output_keys']) - add_disp_keys = default(modelparams.get('add_disp_keys'), - []) + add_nuc_keys = default(modelparams.get("add_nuc_keys"), modelparams["output_keys"]) + add_elec_keys = default(modelparams.get("add_elec_keys"), modelparams["output_keys"]) + add_disp_keys = default(modelparams.get("add_disp_keys"), []) return add_nuc_keys, add_elec_keys, add_disp_keys @@ -47,125 +47,95 @@ class SpookyNet(nn.Module): work properly """ - def __init__(self, - modelparams): - + def __init__(self, modelparams): super().__init__() - feat_dim = modelparams['feat_dim'] - r_cut = modelparams['r_cut'] + feat_dim = modelparams["feat_dim"] + r_cut = modelparams["r_cut"] optional = parse_optional(modelparams) dropout, activation, max_z, residual_layers = optional add_ons = parse_add_ons(modelparams) add_nuc_keys, add_elec_keys, add_disp_keys = add_ons - self.output_keys = modelparams['output_keys'] - self.grad_keys = modelparams['grad_keys'] - self.embedding = CombinedEmbedding(feat_dim=feat_dim, - activation=activation, - max_z=max_z, - residual_layers=residual_layers) - self.interactions = nn.ModuleList([ - InteractionBlock(feat_dim=feat_dim, - r_cut=r_cut, - gamma=modelparams['gamma'], - bern_k=modelparams['bern_k'], - activation=activation, - dropout=dropout, - max_z=max_z, - residual_layers=residual_layers, - l_max=default(modelparams.get("l_max"), 2), - fast_feats=modelparams.get("fast_feats")) - for _ in range(modelparams['num_conv']) - ]) - - self.atomwise_readout = nn.ModuleDict({ - key: AtomwiseReadout(feat_dim=feat_dim) - for key in self.output_keys - }) - - self.electrostatics = nn.ModuleDict({ - key: Electrostatics(feat_dim=feat_dim, - r_cut=r_cut, - max_z=max_z) - for key in add_elec_keys - }) - - self.nuc_repulsion = nn.ModuleDict({ - key: NuclearRepulsion(r_cut=r_cut) - for key in add_nuc_keys - }) + self.output_keys = modelparams["output_keys"] + self.grad_keys = modelparams["grad_keys"] + self.embedding = CombinedEmbedding( + feat_dim=feat_dim, activation=activation, max_z=max_z, residual_layers=residual_layers + ) + self.interactions = nn.ModuleList( + [ + InteractionBlock( + feat_dim=feat_dim, + r_cut=r_cut, + gamma=modelparams["gamma"], + bern_k=modelparams["bern_k"], + activation=activation, + dropout=dropout, + max_z=max_z, + residual_layers=residual_layers, + l_max=default(modelparams.get("l_max"), 2), + fast_feats=modelparams.get("fast_feats"), + ) + for _ in range(modelparams["num_conv"]) + ] + ) + + self.atomwise_readout = nn.ModuleDict({key: AtomwiseReadout(feat_dim=feat_dim) for key in self.output_keys}) + + self.electrostatics = nn.ModuleDict( + {key: Electrostatics(feat_dim=feat_dim, r_cut=r_cut, max_z=max_z) for key in add_elec_keys} + ) + + self.nuc_repulsion = nn.ModuleDict({key: NuclearRepulsion(r_cut=r_cut) for key in add_nuc_keys}) if add_disp_keys: raise NotImplementedError("Dispersion not implemented") self.r_cut = r_cut - def get_results(self, - z, - f, - num_atoms, - xyz, - charge, - nbrs, - offsets, - mol_offsets, - mol_nbrs): - + def get_results(self, z, f, num_atoms, xyz, charge, nbrs, offsets, mol_offsets, mol_nbrs): results = {} for key in self.output_keys: atomwise_readout = self.atomwise_readout[key] - energy = atomwise_readout(z=z, - f=f, - num_atoms=num_atoms) + energy = atomwise_readout(z=z, f=f, num_atoms=num_atoms) if key in self.electrostatics: electrostatics = self.electrostatics[key] - elec_e, q = electrostatics(f=f, - z=z, - xyz=xyz, - total_charge=charge, - num_atoms=num_atoms, - mol_nbrs=mol_nbrs, - mol_offsets=mol_offsets) + elec_e, q = electrostatics( + f=f, + z=z, + xyz=xyz, + total_charge=charge, + num_atoms=num_atoms, + mol_nbrs=mol_nbrs, + mol_offsets=mol_offsets, + ) energy += elec_e if key in self.nuc_repulsion: nuc_repulsion = self.nuc_repulsion[key] - nuc_e = nuc_repulsion(xyz=xyz, - z=z, - nbrs=nbrs, - num_atoms=num_atoms, - offsets=offsets) + nuc_e = nuc_repulsion(xyz=xyz, z=z, nbrs=nbrs, num_atoms=num_atoms, offsets=offsets) energy += nuc_e results.update({key: energy}) if key in self.electrostatics: - dipole = get_dipole(xyz=xyz, - q=q, - num_atoms=num_atoms) + dipole = get_dipole(xyz=xyz, q=q, num_atoms=num_atoms) suffix = "_" + key.split("_")[-1] - if not any([i.isdigit() for i in suffix]): + if not any(i.isdigit() for i in suffix): suffix = "" - results.update({f"dipole{suffix}": dipole, - f"q{suffix}": q}) + results.update({f"dipole{suffix}": dipole, f"q{suffix}": q}) return results - def add_grad(self, - xyz, - grad_keys, - results): - + def add_grad(self, xyz, grad_keys, results): if grad_keys is None: grad_keys = self.grad_keys for key in grad_keys: base_key = key.replace("_grad", "") - grad = compute_grad(inputs=xyz, - output=results[base_key]) + grad = compute_grad(inputs=xyz, output=results[base_key]) results[key] = grad return results @@ -176,60 +146,46 @@ def set_cutoff(self): interac = self.interactions[0] self.cutoff = interac.local.g_0.r_cut - def fwd(self, - batch, - xyz=None, - grad_keys=None): - - nxyz = batch['nxyz'] - nbrs, _ = make_directed(batch['nbr_list']) + def fwd(self, batch, xyz=None, grad_keys=None): + nxyz = batch["nxyz"] + nbrs, _ = make_directed(batch["nbr_list"]) z = nxyz[:, 0].long() if xyz is None: xyz = nxyz[:, 1:] xyz.requires_grad = True - charge = batch['charge'] - spin = batch['spin'] - num_atoms = batch['num_atoms'] - offsets = get_offsets(batch, 'offsets') - mol_offsets = get_offsets(batch, 'mol_offsets') - mol_nbrs = batch.get('mol_nbrs') + charge = batch["charge"] + spin = batch["spin"] + num_atoms = batch["num_atoms"] + offsets = get_offsets(batch, "offsets") + mol_offsets = get_offsets(batch, "mol_offsets") + mol_nbrs = batch.get("mol_nbrs") - x = self.embedding(charge=charge, - spin=spin, - z=z, - num_atoms=num_atoms) + x = self.embedding(charge=charge, spin=spin, z=z, num_atoms=num_atoms) # get r_ij including offsets and removing neighbor skin self.set_cutoff() - r_ij, nbrs = get_rij(xyz=xyz, - batch=batch, - nbrs=nbrs, - cutoff=self.cutoff) + r_ij, nbrs = get_rij(xyz=xyz, batch=batch, nbrs=nbrs, cutoff=self.cutoff) f = torch.zeros_like(x) - for i, interaction in enumerate(self.interactions): - x, y_t = interaction(x=x, - xyz=xyz, - nbrs=nbrs, - num_atoms=num_atoms, - r_ij=r_ij) + for interaction in self.interactions: + x, y_t = interaction(x=x, xyz=xyz, nbrs=nbrs, num_atoms=num_atoms, r_ij=r_ij) f = f + y_t - results = self.get_results(z=z, - f=f, - num_atoms=num_atoms, - xyz=xyz, - charge=charge, - nbrs=nbrs, - offsets=offsets, - mol_offsets=mol_offsets, - mol_nbrs=mol_nbrs) + results = self.get_results( + z=z, + f=f, + num_atoms=num_atoms, + xyz=xyz, + charge=charge, + nbrs=nbrs, + offsets=offsets, + mol_offsets=mol_offsets, + mol_nbrs=mol_nbrs, + ) - results = self.add_grad(xyz=xyz, - grad_keys=grad_keys, - results=results) + results = self.add_grad(xyz=xyz, grad_keys=grad_keys, results=results) return results @@ -239,6 +195,7 @@ def forward(self, *args, **kwargs): except Exception as e: print(e) import pdb + pdb.post_mortem() @@ -247,9 +204,7 @@ class RealSpookyNet(SourceSpooky): Wrapper around the real source code for SpookyNet, so we can use it in NFF """ - def __init__(self, - params): - + def __init__(self, params): super().__init__(**params) self.int_dtype = torch.long @@ -260,16 +215,14 @@ def __init__(self, self.dip_key = params["dip_key"] self.charge_key = params["charge_key"] - def get_full_nbrs(self, - batch): - - idx_i = batch['mol_nbrs'][:, 0] - idx_j = batch['mol_nbrs'][:, 1] + def get_full_nbrs(self, batch): + idx_i = batch["mol_nbrs"][:, 0] + idx_j = batch["mol_nbrs"][:, 1] return idx_i, idx_j def get_regular_nbrs(self, batch): - nbrs = batch['nbr_list'] + nbrs = batch["nbr_list"] idx_i = nbrs[:, 0] idx_j = nbrs[:, 1] @@ -284,39 +237,42 @@ def device(self, val): self.to(val) def forward(self, batch): - full_nbrs = any([self.use_d4_dispersion, - self.use_electrostatics]) + full_nbrs = any([self.use_d4_dispersion, self.use_electrostatics]) if full_nbrs: idx_i, idx_j = self.get_full_nbrs(batch) - cell_offsets = batch.get('mol_offsets') + cell_offsets = batch.get("mol_offsets") else: idx_i, idx_j = self.get_regular_nbrs(batch) - cell_offsets = batch.get('offsets') + cell_offsets = batch.get("offsets") - nxyz = batch['nxyz'] + nxyz = batch["nxyz"] xyz = nxyz[:, 1:].to(self.float_dtype) xyz.requires_grad = True device = xyz.device Z = nxyz[:, 0].to(self.int_dtype) - num_atoms = batch['num_atoms'] + num_atoms = batch["num_atoms"] num_batch = len(num_atoms) - batch_seg = torch.cat([torch.ones(int(num_atoms)) * i for i, num_atoms in - enumerate(num_atoms)]).to(self.int_dtype - ).to(device) - - out = super().forward(Z=Z, - Q=batch['charge'].to(self.float_dtype), - S=batch['spin'].to(self.float_dtype), - R=xyz, - idx_i=idx_i, - idx_j=idx_j, - num_batch=num_batch, - batch_seg=batch_seg, - cell=batch.get('cell'), - cell_offsets=cell_offsets) + batch_seg = ( + torch.cat([torch.ones(int(num_atoms)) * i for i, num_atoms in enumerate(num_atoms)]) + .to(self.int_dtype) + .to(device) + ) + + out = super().forward( + Z=Z, + Q=batch["charge"].to(self.float_dtype), + S=batch["spin"].to(self.float_dtype), + R=xyz, + idx_i=idx_i, + idx_j=idx_j, + num_batch=num_batch, + batch_seg=batch_seg, + cell=batch.get("cell"), + cell_offsets=cell_offsets, + ) energy, forces, dipole, f, ea, qa, ea_rep, ea_ele, ea_vdw, pa, c6 = out @@ -335,7 +291,7 @@ def forward(self, batch): "atomic_zbl": ea_rep * const.EV_TO_KCAL_MOL, "atom_vwd": ea_vdw * const.EV_TO_KCAL_MOL, "polarizabilities": pa, - "c6": c6 + "c6": c6, } return results diff --git a/nff/nn/models/spooky_fast.py b/nff/nn/models/spooky_fast.py index d281db4f..3a8565fe 100644 --- a/nff/nn/models/spooky_fast.py +++ b/nff/nn/models/spooky_fast.py @@ -1,12 +1,20 @@ import torch from torch import nn + +from nff.nn.modules.spooky_fast import ( + DEFAULT_ACTIVATION, + DEFAULT_DROPOUT, + DEFAULT_MAX_Z, + DEFAULT_RES_LAYERS, + AtomwiseReadout, + CombinedEmbedding, + Electrostatics, + InteractionBlock, + NuclearRepulsion, + get_dipole, +) from nff.utils.scatter import compute_grad from nff.utils.tools import make_directed, make_undirected -from nff.nn.modules.spooky_fast import (DEFAULT_DROPOUT, DEFAULT_ACTIVATION, - DEFAULT_MAX_Z, DEFAULT_RES_LAYERS, - CombinedEmbedding, InteractionBlock, - AtomwiseReadout, Electrostatics, - NuclearRepulsion, get_dipole) def default(val, def_val): @@ -15,182 +23,127 @@ def default(val, def_val): def parse_optional(modelparams): - dropout = default(modelparams.get('dropout'), - DEFAULT_DROPOUT) - activation = default(modelparams.get('activation'), - DEFAULT_ACTIVATION) - max_z = default(modelparams.get('max_z'), DEFAULT_MAX_Z) - residual_layers = default(modelparams.get('residual_layers'), - DEFAULT_RES_LAYERS) + dropout = default(modelparams.get("dropout"), DEFAULT_DROPOUT) + activation = default(modelparams.get("activation"), DEFAULT_ACTIVATION) + max_z = default(modelparams.get("max_z"), DEFAULT_MAX_Z) + residual_layers = default(modelparams.get("residual_layers"), DEFAULT_RES_LAYERS) return dropout, activation, max_z, residual_layers def parse_add_ons(modelparams): - add_nuc_keys = default(modelparams.get('add_nuc_keys'), - modelparams['output_keys']) - add_elec_keys = default(modelparams.get('add_elec_keys'), - modelparams['output_keys']) - add_disp_keys = default(modelparams.get('add_disp_keys'), - []) + add_nuc_keys = default(modelparams.get("add_nuc_keys"), modelparams["output_keys"]) + add_elec_keys = default(modelparams.get("add_elec_keys"), modelparams["output_keys"]) + add_disp_keys = default(modelparams.get("add_disp_keys"), []) return add_nuc_keys, add_elec_keys, add_disp_keys class SpookyNet(nn.Module): - def __init__(self, - modelparams): - + def __init__(self, modelparams): super().__init__() - feat_dim = modelparams['feat_dim'] - r_cut = modelparams['r_cut'] + feat_dim = modelparams["feat_dim"] + r_cut = modelparams["r_cut"] optional = parse_optional(modelparams) dropout, activation, max_z, residual_layers = optional add_ons = parse_add_ons(modelparams) add_nuc_keys, add_elec_keys, add_disp_keys = add_ons - self.output_keys = modelparams['output_keys'] - self.grad_keys = modelparams['grad_keys'] - self.embedding = CombinedEmbedding(feat_dim=feat_dim, - activation=activation, - max_z=max_z, - residual_layers=residual_layers) - self.interactions = nn.ModuleList([ - InteractionBlock(feat_dim=feat_dim, - r_cut=r_cut, - gamma=modelparams['gamma'], - bern_k=modelparams['bern_k'], - activation=activation, - dropout=dropout, - max_z=max_z, - residual_layers=residual_layers) - for _ in range(modelparams['num_conv']) - ]) - self.atomwise_readout = nn.ModuleDict({ - key: AtomwiseReadout(feat_dim=feat_dim) - for key in self.output_keys - }) - - self.electrostatics = nn.ModuleDict({ - key: Electrostatics(feat_dim=feat_dim, - r_cut=r_cut, - max_z=max_z) - for key in add_elec_keys - }) - - self.nuc_repulsion = nn.ModuleDict({ - key: NuclearRepulsion(r_cut=r_cut) - for key in add_nuc_keys - }) + self.output_keys = modelparams["output_keys"] + self.grad_keys = modelparams["grad_keys"] + self.embedding = CombinedEmbedding( + feat_dim=feat_dim, activation=activation, max_z=max_z, residual_layers=residual_layers + ) + self.interactions = nn.ModuleList( + [ + InteractionBlock( + feat_dim=feat_dim, + r_cut=r_cut, + gamma=modelparams["gamma"], + bern_k=modelparams["bern_k"], + activation=activation, + dropout=dropout, + max_z=max_z, + residual_layers=residual_layers, + ) + for _ in range(modelparams["num_conv"]) + ] + ) + self.atomwise_readout = nn.ModuleDict({key: AtomwiseReadout(feat_dim=feat_dim) for key in self.output_keys}) + + self.electrostatics = nn.ModuleDict( + {key: Electrostatics(feat_dim=feat_dim, r_cut=r_cut, max_z=max_z) for key in add_elec_keys} + ) + + self.nuc_repulsion = nn.ModuleDict({key: NuclearRepulsion(r_cut=r_cut) for key in add_nuc_keys}) if add_disp_keys: raise NotImplementedError("Dispersion not implemented") - def get_results(self, - z, - f, - num_atoms, - xyz, - charge, - mol_nbrs, - nbrs): - + def get_results(self, z, f, num_atoms, xyz, charge, mol_nbrs, nbrs): results = {} for key in self.output_keys: atomwise_readout = self.atomwise_readout[key] - energy = atomwise_readout(z=z, - f=f, - num_atoms=num_atoms) + energy = atomwise_readout(z=z, f=f, num_atoms=num_atoms) if key in self.electrostatics: electrostatics = self.electrostatics[key] - elec_e, q = electrostatics(f=f, - z=z, - xyz=xyz, - total_charge=charge, - num_atoms=num_atoms, - mol_nbrs=mol_nbrs) + elec_e, q = electrostatics( + f=f, z=z, xyz=xyz, total_charge=charge, num_atoms=num_atoms, mol_nbrs=mol_nbrs + ) energy += elec_e if key in self.nuc_repulsion: nuc_repulsion = self.nuc_repulsion[key] - nuc_e = nuc_repulsion(xyz=xyz, - z=z, - nbrs=nbrs, - num_atoms=num_atoms) + nuc_e = nuc_repulsion(xyz=xyz, z=z, nbrs=nbrs, num_atoms=num_atoms) energy += nuc_e results.update({key: energy}) if key in self.electrostatics: - dipole = get_dipole(xyz=xyz, - q=q, - num_atoms=num_atoms) + dipole = get_dipole(xyz=xyz, q=q, num_atoms=num_atoms) suffix = "_" + key.split("_")[-1] - if not any([i.isdigit() for i in suffix]): + if not any(i.isdigit() for i in suffix): suffix = "" - results.update({f"dipole{suffix}": dipole, - f"q{suffix}": q}) + results.update({f"dipole{suffix}": dipole, f"q{suffix}": q}) return results - def add_grad(self, - xyz, - grad_keys, - results): + def add_grad(self, xyz, grad_keys, results): if grad_keys is None: grad_keys = self.grad_keys for key in grad_keys: base_key = key.replace("_grad", "") - grad = compute_grad(inputs=xyz, - output=results[base_key]) + grad = compute_grad(inputs=xyz, output=results[base_key]) results[key] = grad return results - def fwd(self, - batch, - xyz=None, - grad_keys=None): - - nxyz = batch['nxyz'] - nbrs, _ = make_directed(batch['nbr_list']) - mol_nbrs = make_undirected(batch['mol_nbrs']) + def fwd(self, batch, xyz=None, grad_keys=None): + nxyz = batch["nxyz"] + nbrs, _ = make_directed(batch["nbr_list"]) + mol_nbrs = make_undirected(batch["mol_nbrs"]) z = nxyz[:, 0].long() if xyz is None: xyz = nxyz[:, 1:] xyz.requires_grad = True - charge = batch['charge'] - spin = batch['spin'] - num_atoms = batch['num_atoms'] + charge = batch["charge"] + spin = batch["spin"] + num_atoms = batch["num_atoms"] - x = self.embedding(charge=charge, - spin=spin, - z=z, - num_atoms=num_atoms) + x = self.embedding(charge=charge, spin=spin, z=z, num_atoms=num_atoms) f = torch.zeros_like(x) - for i, interaction in enumerate(self.interactions): - x, y_t = interaction(x=x, - xyz=xyz, - nbrs=nbrs, - num_atoms=num_atoms) + for interaction in self.interactions: + x, y_t = interaction(x=x, xyz=xyz, nbrs=nbrs, num_atoms=num_atoms) f += y_t - results = self.get_results(z=z, - f=f, - num_atoms=num_atoms, - xyz=xyz, - charge=charge, - mol_nbrs=mol_nbrs, - nbrs=nbrs) + results = self.get_results(z=z, f=f, num_atoms=num_atoms, xyz=xyz, charge=charge, mol_nbrs=mol_nbrs, nbrs=nbrs) - results = self.add_grad(xyz=xyz, - grad_keys=grad_keys, - results=results) + results = self.add_grad(xyz=xyz, grad_keys=grad_keys, results=results) return results @@ -200,4 +153,5 @@ def forward(self, *args, **kwargs): except Exception as e: print(e) import pdb + pdb.post_mortem() diff --git a/nff/nn/models/spooky_net_source/functional.py b/nff/nn/models/spooky_net_source/functional.py index dcac75f8..767ac708 100644 --- a/nff/nn/models/spooky_net_source/functional.py +++ b/nff/nn/models/spooky_net_source/functional.py @@ -1,4 +1,5 @@ import math + import torch """ @@ -11,8 +12,9 @@ unnecessary, but they are crucial for autograd to work properly. """ + def shifted_softplus(x: torch.Tensor) -> torch.Tensor: - """ Shifted softplus activation function. """ + """Shifted softplus activation function.""" return torch.nn.functional.softplus(x) - math.log(2) @@ -24,15 +26,11 @@ def cutoff_function(x: torch.Tensor, cutoff: float) -> torch.Tensor: """ zeros = torch.zeros_like(x) x_ = torch.where(x < cutoff, x, zeros) # prevent nan in backprop - return torch.where( - x < cutoff, torch.exp(-(x_ ** 2) / ((cutoff - x_) * (cutoff + x_))), zeros - ) + return torch.where(x < cutoff, torch.exp(-(x_**2) / ((cutoff - x_) * (cutoff + x_))), zeros) -def _switch_component( - x: torch.Tensor, ones: torch.Tensor, zeros: torch.Tensor -) -> torch.Tensor: - """ Component of the switch function, only for internal use. """ +def _switch_component(x: torch.Tensor, ones: torch.Tensor, zeros: torch.Tensor) -> torch.Tensor: + """Component of the switch function, only for internal use.""" x_ = torch.where(x <= 0, ones, x) # prevent nan in backprop return torch.where(x <= 0, zeros, torch.exp(-ones / x_)) @@ -44,7 +42,7 @@ def switch_function(x: torch.Tensor, cuton: float, cutoff: float) -> torch.Tenso f(x) = 1 and for x >= cutoff, f(x) = 0. This switch function has infinitely many smooth derivatives. NOTE: The implementation with the "_switch_component" function is - numerically more stable than a simplified version, it is not recommended + numerically more stable than a simplified version, it is not recommended to change this! """ x = (x - cuton) / (cutoff - cuton) diff --git a/nff/nn/models/spooky_net_source/modules/attention.py b/nff/nn/models/spooky_net_source/modules/attention.py index 007629f4..e0efcfb5 100644 --- a/nff/nn/models/spooky_net_source/modules/attention.py +++ b/nff/nn/models/spooky_net_source/modules/attention.py @@ -1,9 +1,10 @@ -import torch -import torch.nn as nn import math -import numpy as np from typing import Optional +import numpy as np +import torch +import torch.nn as nn + class Attention(nn.Module): """ @@ -20,28 +21,22 @@ class Attention(nn.Module): this is 0, the exact attention matrix is computed. """ - def __init__( - self, dim_qk: int, dim_v: int, num_random_features: Optional[int] = None - ) -> None: - """ Initializes the Attention class. """ - super(Attention, self).__init__() + def __init__(self, dim_qk: int, dim_v: int, num_random_features: Optional[int] = None) -> None: + """Initializes the Attention class.""" + super().__init__() self.num_random_features = num_random_features - if self.num_random_features is not None: - omega = self._omega(num_random_features, dim_qk) - else: - omega = [] + omega = self._omega(num_random_features, dim_qk) if self.num_random_features is not None else [] self.register_buffer("omega", torch.tensor(omega, dtype=torch.float64)) self.reset_parameters() def reset_parameters(self) -> None: - """ For compatibility with other modules. """ - pass + """For compatibility with other modules.""" def _omega(self, nrows: int, ncols: int) -> np.ndarray: - """ Return a (nrows x ncols) random feature matrix. """ + """Return a (nrows x ncols) random feature matrix.""" nblocks = int(nrows / ncols) blocks = [] - for i in range(nblocks): + for _ in range(nblocks): block = np.random.normal(size=(ncols, ncols)) q, _ = np.linalg.qr(block) blocks.append(np.transpose(q)) @@ -63,11 +58,11 @@ def _phi( batch_seg: torch.Tensor, eps: float = 1e-4, ) -> torch.Tensor: - """ Normalize X and project into random feature space. """ + """Normalize X and project into random feature space.""" d = X.shape[-1] m = self.omega.shape[-1] - U = torch.matmul(X / d ** 0.25, self.omega) - h = torch.sum(X ** 2, dim=-1, keepdim=True) / (2 * d ** 0.5) # OLD + U = torch.matmul(X / d**0.25, self.omega) + h = torch.sum(X**2, dim=-1, keepdim=True) / (2 * d**0.5) # OLD # determine maximum (is subtracted to prevent numerical overflow) if is_query: maximum, _ = torch.max(U, dim=-1, keepdim=True) @@ -75,15 +70,11 @@ def _phi( if num_batch > 1: brow = batch_seg.view(1, -1, 1).expand(num_batch, -1, U.shape[-1]) bcol = ( - torch.arange( - num_batch, dtype=batch_seg.dtype, device=batch_seg.device - ) + torch.arange(num_batch, dtype=batch_seg.dtype, device=batch_seg.device) .view(-1, 1, 1) .expand(-1, U.shape[-2], U.shape[-1]) ) - mask = torch.where( - brow == bcol, torch.ones_like(U), torch.zeros_like(U) - ) + mask = torch.where(brow == bcol, torch.ones_like(U), torch.zeros_like(U)) tmp = U.unsqueeze(0).expand(num_batch, -1, -1) tmp, _ = torch.max(mask * tmp, dim=-1) tmp, _ = torch.max(tmp, dim=-1) @@ -104,10 +95,10 @@ def _exact_attention( batch_seg: torch.Tensor, eps: float = 1e-8, ): - """ Compute exact attention. """ + """Compute exact attention.""" d = Q.shape[-1] dot = Q @ K.T # dot product - A = torch.exp((dot - torch.max(dot)) / d ** 0.5) # attention matrix + A = torch.exp((dot - torch.max(dot)) / d**0.5) # attention matrix if num_batch > 1: # mask out entries of different batches brow = batch_seg.view(1, -1).expand(A.shape[-2], -1) bcol = batch_seg.view(-1, 1).expand(-1, A.shape[-1]) @@ -126,12 +117,12 @@ def _approximate_attention( mask: Optional[torch.Tensor] = None, eps: float = 1e-8, ) -> torch.Tensor: - """ Compute approximate attention. """ + """Compute approximate attention.""" Q = self._phi(Q, True, num_batch, batch_seg) # random projection of Q K = self._phi(K, False, num_batch, batch_seg) # random projection of K if num_batch > 1: d = Q.shape[-1] - n = batch_seg.shape[0] + batch_seg.shape[0] # compute norm idx = batch_seg.unsqueeze(-1).expand(-1, d) @@ -143,14 +134,11 @@ def _approximate_attention( # K[b==batch_seg].transpose(-1,-2)@V[b==batch_seg]) # for b in range(num_batch)])/norm if mask is None: # mask can be shared across multiple attentions - one_hot = nn.functional.one_hot(batch_seg).to( - dtype=V.dtype, device=V.device - ) + one_hot = nn.functional.one_hot(batch_seg).to(dtype=V.dtype, device=V.device) mask = one_hot @ one_hot.transpose(-1, -2) return ((mask * (K @ Q.transpose(-1, -2))).transpose(-1, -2) @ V) / norm - else: - norm = Q @ torch.sum(K, 0, keepdim=True).T + eps - return (Q @ (K.T @ V)) / norm + norm = Q @ torch.sum(K, 0, keepdim=True).T + eps + return (Q @ (K.T @ V)) / norm def forward( self, @@ -187,5 +175,4 @@ def forward( """ if self.num_random_features is None: return self._exact_attention(Q, K, V, num_batch, batch_seg) - else: - return self._approximate_attention(Q, K, V, num_batch, batch_seg, mask) + return self._approximate_attention(Q, K, V, num_batch, batch_seg, mask) diff --git a/nff/nn/models/spooky_net_source/modules/bernstein_polynomials.py b/nff/nn/models/spooky_net_source/modules/bernstein_polynomials.py index 192205a3..37d81c90 100644 --- a/nff/nn/models/spooky_net_source/modules/bernstein_polynomials.py +++ b/nff/nn/models/spooky_net_source/modules/bernstein_polynomials.py @@ -1,8 +1,6 @@ +import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F -import numpy as np -from ..functional import softplus_inverse class BernsteinPolynomials(nn.Module): @@ -21,10 +19,10 @@ class BernsteinPolynomials(nn.Module): """ def __init__(self, num_basis_functions: int, cutoff: float) -> None: - """ Initializes the BernsteinPolynomials class. """ - super(BernsteinPolynomials, self).__init__() + """Initializes the BernsteinPolynomials class.""" + super().__init__() # compute values to initialize buffers - logfactorial = np.zeros((num_basis_functions)) + logfactorial = np.zeros(num_basis_functions) for i in range(2, num_basis_functions): logfactorial[i] = logfactorial[i - 1] + np.log(i) v = np.arange(0, num_basis_functions) @@ -38,8 +36,7 @@ def __init__(self, num_basis_functions: int, cutoff: float) -> None: self.reset_parameters() def reset_parameters(self) -> None: - """ For compatibility with other modules. """ - pass + """For compatibility with other modules.""" def forward(self, r: torch.Tensor, cutoff_values: torch.Tensor) -> torch.Tensor: """ diff --git a/nff/nn/models/spooky_net_source/modules/d4_dispersion_energy.py b/nff/nn/models/spooky_net_source/modules/d4_dispersion_energy.py index f2a57b66..8618a13c 100644 --- a/nff/nn/models/spooky_net_source/modules/d4_dispersion_energy.py +++ b/nff/nn/models/spooky_net_source/modules/d4_dispersion_energy.py @@ -1,10 +1,12 @@ -import os import math +import os +from typing import Optional, Tuple + import torch import torch.nn as nn import torch.nn.functional as F -from ..functional import softplus_inverse, switch_function -from typing import Tuple, Optional + +from nff.nn.models.spooky_net_source.functional import softplus_inverse, switch_function """ computes D4 dispersion energy @@ -33,23 +35,17 @@ def __init__( Hartree: float = 27.211386024367243, # conversion to Hartree dtype: torch.dtype = torch.float32, ) -> None: - """ Initializes the D4DispersionEnergy class. """ - super(D4DispersionEnergy, self).__init__() + """Initializes the D4DispersionEnergy class.""" + super().__init__() # Grimme's D4 dispersion is only parametrized up to Rn (Z=86) assert Zmax <= 87 # trainable parameters self.register_parameter( "_s6", nn.Parameter(softplus_inverse(s6), requires_grad=False) ) # s6 is usually not fitted (correct long-range) - self.register_parameter( - "_s8", nn.Parameter(softplus_inverse(s8), requires_grad=True) - ) - self.register_parameter( - "_a1", nn.Parameter(softplus_inverse(a1), requires_grad=True) - ) - self.register_parameter( - "_a2", nn.Parameter(softplus_inverse(a2), requires_grad=True) - ) + self.register_parameter("_s8", nn.Parameter(softplus_inverse(s8), requires_grad=True)) + self.register_parameter("_a1", nn.Parameter(softplus_inverse(a1), requires_grad=True)) + self.register_parameter("_a2", nn.Parameter(softplus_inverse(a2), requires_grad=True)) self.register_parameter( "_scaleq", nn.Parameter(softplus_inverse(1.0), requires_grad=True) ) # for scaling charges of reference systems @@ -57,8 +53,8 @@ def __init__( self.Zmax = Zmax self.convert2Bohr = 1 / Bohr self.convert2eV = 0.5 * Hartree # factor of 0.5 prevents double counting - self.convert2Angstrom3 = Bohr ** 3 - self.convert2eVAngstrom6 = Hartree * Bohr ** 6 + self.convert2Angstrom3 = Bohr**3 + self.convert2eVAngstrom6 = Hartree * Bohr**6 self.set_cutoff(cutoff) self.g_a = g_a self.g_c = g_c @@ -75,20 +71,24 @@ def __init__( torch.load(os.path.join(directory, "refsys.pth"))[:Zmax], ) self.register_buffer( - "zeff", torch.load(os.path.join(directory, "zeff.pth"))[:Zmax] # [Zmax] + "zeff", + torch.load(os.path.join(directory, "zeff.pth"))[:Zmax], # [Zmax] ) self.register_buffer( "refh", # [Zmax,max_nref] torch.load(os.path.join(directory, "refh.pth"))[:Zmax], ) self.register_buffer( - "sscale", torch.load(os.path.join(directory, "sscale.pth")) # [18] + "sscale", + torch.load(os.path.join(directory, "sscale.pth")), # [18] ) self.register_buffer( - "secaiw", torch.load(os.path.join(directory, "secaiw.pth")) # [18,23] + "secaiw", + torch.load(os.path.join(directory, "secaiw.pth")), # [18,23] ) self.register_buffer( - "gam", torch.load(os.path.join(directory, "gam.pth"))[:Zmax] # [Zmax] + "gam", + torch.load(os.path.join(directory, "gam.pth"))[:Zmax], # [Zmax] ) self.register_buffer( "ascale", # [Zmax,max_nref] @@ -107,10 +107,12 @@ def __init__( torch.load(os.path.join(directory, "casimir_polder_weights.pth"))[:Zmax], ) self.register_buffer( - "rcov", torch.load(os.path.join(directory, "rcov.pth"))[:Zmax] # [Zmax] + "rcov", + torch.load(os.path.join(directory, "rcov.pth"))[:Zmax], # [Zmax] ) self.register_buffer( - "en", torch.load(os.path.join(directory, "en.pth"))[:Zmax] # [Zmax] + "en", + torch.load(os.path.join(directory, "en.pth"))[:Zmax], # [Zmax] ) self.register_buffer( "ncount_mask", # [Zmax,max_nref,max_ncount] @@ -145,11 +147,10 @@ def __init__( self.reset_parameters() def reset_parameters(self) -> None: - """ For compatibility with other modules. """ - pass + """For compatibility with other modules.""" def set_cutoff(self, cutoff: Optional[float] = None) -> None: - """ Can be used to change the cutoff. """ + """Can be used to change the cutoff.""" if cutoff is None: self.cutoff = None self.cuton = None @@ -178,24 +179,18 @@ def _compute_refc6(self) -> None: * self.secaiw[is_] * torch.where( qmod > 1e-8, - torch.exp( - self.g_a - * (1 - torch.exp(self.gam[is_] * self.g_c * (1 - qref / qmod_))) - ), + torch.exp(self.g_a * (1 - torch.exp(self.gam[is_] * self.g_c * (1 - qref / qmod_)))), math.exp(self.g_a) * ones_like_qmod, ).view(-1, self.max_nref, 1) ) alpha = torch.max( self.ascale[allZ, :].view(-1, self.max_nref, 1) - * ( - self.alphaiw[allZ, :, :] - - self.hcount[allZ, :].view(-1, self.max_nref, 1) * alpha - ), + * (self.alphaiw[allZ, :, :] - self.hcount[allZ, :].view(-1, self.max_nref, 1) * alpha), torch.zeros_like(alpha), ) - alpha_expanded = alpha.view( - alpha.size(0), 1, alpha.size(1), 1, -1 - ) * alpha.view(1, alpha.size(0), 1, alpha.size(1), -1) + alpha_expanded = alpha.view(alpha.size(0), 1, alpha.size(1), 1, -1) * alpha.view( + 1, alpha.size(0), 1, alpha.size(1), -1 + ) self.register_buffer( "refc6", 3.0 @@ -227,9 +222,7 @@ def forward( # calculate coordination numbers rco = self.k2 * (self.rcov[Zi] + self.rcov[Zj]) - den = self.k4 * torch.exp( - -((torch.abs(self.en[Zi] - self.en[Zj]) + self.k5) ** 2) / self.k6 - ) + den = self.k4 * torch.exp(-((torch.abs(self.en[Zi] - self.en[Zj]) + self.k5) ** 2) / self.k6) tmp = den * 0.5 * (1.0 + torch.erf(-self.kn * (rij - rco) / rco)) if self.cutoff is not None: tmp = tmp * switch_function(rij, self.cuton, self.cutoff) @@ -239,11 +232,7 @@ def forward( # calculate gaussian weights gweights = torch.sum( self.ncount_mask[Z] - * torch.exp( - -self.wf - * self.ncount_weight[Z] - * (covcn.view(-1, 1, 1) - self.cn[Z]) ** 2 - ), + * torch.exp(-self.wf * self.ncount_weight[Z] * (covcn.view(-1, 1, 1) - self.cn[Z]) ** 2), -1, ) norm = torch.sum(gweights, -1, True) @@ -261,15 +250,7 @@ def forward( zeta = ( torch.where( qmod > 1e-8, - torch.exp( - self.g_a - * ( - 1 - - torch.exp( - self.gam[Z].view(-1, 1) * self.g_c * (1 - qref / qmod_) - ) - ) - ), + torch.exp(self.g_a * (1 - torch.exp(self.gam[Z].view(-1, 1) * self.g_c * (1 - qref / qmod_)))), math.exp(self.g_a) * ones_like_qmod, ) * gweights @@ -281,49 +262,38 @@ def forward( zetai = torch.gather(zeta, 0, idx_i.view(-1, 1).expand(-1, zeta.size(1))) zetaj = torch.gather(zeta, 0, idx_j.view(-1, 1).expand(-1, zeta.size(1))) refc6ij = self.refc6[Zi, Zj, :, :] - zetaij = zetai.view(zetai.size(0), zetai.size(1), 1) * zetaj.view( - zetaj.size(0), 1, zetaj.size(1) - ) + zetaij = zetai.view(zetai.size(0), zetai.size(1), 1) * zetaj.view(zetaj.size(0), 1, zetaj.size(1)) c6ij = torch.sum((refc6ij * zetaij).view(refc6ij.size(0), -1), -1) sqrt_r4r2ij = math.sqrt(3) * self.sqrt_r4r2[Zi] * self.sqrt_r4r2[Zj] a1 = F.softplus(self._a1) a2 = F.softplus(self._a2) r0 = a1 * sqrt_r4r2ij + a2 if self.cutoff is None: - oor6 = 1 / (rij ** 6 + r0 ** 6) - oor8 = 1 / (rij ** 8 + r0 ** 8) + oor6 = 1 / (rij**6 + r0**6) + oor8 = 1 / (rij**8 + r0**8) else: - cut2 = self.cutoff ** 2 - cut6 = cut2 ** 3 + cut2 = self.cutoff**2 + cut6 = cut2**3 cut8 = cut2 * cut6 - tmp6 = r0 ** 6 - tmp8 = r0 ** 8 + tmp6 = r0**6 + tmp8 = r0**8 cut6tmp6 = cut6 + tmp6 cut8tmp8 = cut8 + tmp8 tmpc = rij / self.cutoff - 1 - oor6 = ( - 1 / (rij ** 6 + tmp6) - 1 / cut6tmp6 + 6 * cut6 / cut6tmp6 ** 2 * tmpc - ) - oor8 = ( - 1 / (rij ** 8 + tmp8) - 1 / cut8tmp8 + 8 * cut8 / cut8tmp8 ** 2 * tmpc - ) + oor6 = 1 / (rij**6 + tmp6) - 1 / cut6tmp6 + 6 * cut6 / cut6tmp6**2 * tmpc + oor8 = 1 / (rij**8 + tmp8) - 1 / cut8tmp8 + 8 * cut8 / cut8tmp8**2 * tmpc oor6 = torch.where(rij < self.cutoff, oor6, torch.zeros_like(oor6)) oor8 = torch.where(rij < self.cutoff, oor8, torch.zeros_like(oor8)) s6 = F.softplus(self._s6) s8 = F.softplus(self._s8) - pairwise = -c6ij * (s6 * oor6 + s8 * sqrt_r4r2ij ** 2 * oor8) * self.convert2eV + pairwise = -c6ij * (s6 * oor6 + s8 * sqrt_r4r2ij**2 * oor8) * self.convert2eV edisp = rij.new_zeros(N).index_add_(0, idx_i, pairwise) if compute_atomic_quantities: alpha = self.alpha[Z, :, 0] polarizabilities = torch.sum(zeta * alpha, -1) * self.convert2Angstrom3 refc6ii = self.refc6[Z, Z, :, :] - zetaii = zeta.view(zeta.size(0), zeta.size(1), 1) * zeta.view( - zeta.size(0), 1, zeta.size(1) - ) - c6_coefficients = ( - torch.sum((refc6ii * zetaii).view(refc6ii.size(0), -1), -1) - * self.convert2eVAngstrom6 - ) + zetaii = zeta.view(zeta.size(0), zeta.size(1), 1) * zeta.view(zeta.size(0), 1, zeta.size(1)) + c6_coefficients = torch.sum((refc6ii * zetaii).view(refc6ii.size(0), -1), -1) * self.convert2eVAngstrom6 else: polarizabilities = rij.new_zeros(N) c6_coefficients = rij.new_zeros(N) diff --git a/nff/nn/models/spooky_net_source/modules/electron_configurations.py b/nff/nn/models/spooky_net_source/modules/electron_configurations.py index 7cc32689..e1360d80 100644 --- a/nff/nn/models/spooky_net_source/modules/electron_configurations.py +++ b/nff/nn/models/spooky_net_source/modules/electron_configurations.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 import numpy as np # fmt: off diff --git a/nff/nn/models/spooky_net_source/modules/electronic_embedding.py b/nff/nn/models/spooky_net_source/modules/electronic_embedding.py index 82f047bb..92178b1f 100644 --- a/nff/nn/models/spooky_net_source/modules/electronic_embedding.py +++ b/nff/nn/models/spooky_net_source/modules/electronic_embedding.py @@ -1,10 +1,10 @@ +from typing import Optional + import torch import torch.nn as nn import torch.nn.functional as F + from .residual_mlp import ResidualMLP -from .shifted_softplus import ShiftedSoftplus -from .swish import Swish -from typing import Optional class ElectronicEmbedding(nn.Module): @@ -38,8 +38,8 @@ def __init__( activation: str = "swish", is_charge: bool = False, ) -> None: - """ Initializes the ElectronicEmbedding class. """ - super(ElectronicEmbedding, self).__init__() + """Initializes the ElectronicEmbedding class.""" + super().__init__() self.is_charge = is_charge self.linear_q = nn.Linear(num_features, num_features) if is_charge: # charges are duplicated to use separate weights for +/- @@ -58,7 +58,7 @@ def __init__( self.reset_parameters() def reset_parameters(self) -> None: - """ Initialize parameters. """ + """Initialize parameters.""" nn.init.orthogonal_(self.linear_k.weight) nn.init.orthogonal_(self.linear_v.weight) nn.init.orthogonal_(self.linear_q.weight) @@ -83,7 +83,7 @@ def forward( if batch_seg is None: # assume a single batch batch_seg = torch.zeros(x.size(0), dtype=torch.int64, device=x.device) q = self.linear_q(x) # queries - if self.is_charge: + if self.is_charge: # noqa e = F.relu(torch.stack([E, -E], dim=-1)) else: e = torch.abs(E).unsqueeze(-1) # +/- spin is the same => abs @@ -93,7 +93,7 @@ def forward( dot = torch.sum(k * q, dim=-1) / k.shape[-1] ** 0.5 # scaled dot product a = nn.functional.softplus(dot) # unnormalized attention weights anorm = a.new_zeros(num_batch).index_add_(0, batch_seg, a) - if a.device.type == "cpu": # indexing is faster on CPUs + if a.device.type == "cpu": # indexing is faster on CPUs # noqa anorm = anorm[batch_seg] else: # gathering is faster on GPUs anorm = torch.gather(anorm, 0, batch_seg) diff --git a/nff/nn/models/spooky_net_source/modules/electrostatic_energy.py b/nff/nn/models/spooky_net_source/modules/electrostatic_energy.py index 242859fe..0fae7e8d 100644 --- a/nff/nn/models/spooky_net_source/modules/electrostatic_energy.py +++ b/nff/nn/models/spooky_net_source/modules/electrostatic_energy.py @@ -1,9 +1,10 @@ import math +from typing import Optional + import torch import torch.nn as nn -import torch.nn.functional as F -from ..functional import switch_function -from typing import Optional + +from nff.nn.models.spooky_net_source.functional import switch_function """ computes electrostatic energy, switches between a constant value @@ -19,7 +20,7 @@ def __init__( cutoff: float = 1.0, lr_cutoff: Optional[float] = None, ) -> None: - super(ElectrostaticEnergy, self).__init__() + super().__init__() self.ke = ke self.kehalf = ke / 2 self.cuton = cuton @@ -32,25 +33,20 @@ def __init__( self.alpha2 = 0.0 self.two_pi = 2.0 * math.pi self.one_over_sqrtpi = 1 / math.sqrt(math.pi) - self.register_buffer( - "kmul", torch.Tensor(), persistent=False - ) + self.register_buffer("kmul", torch.Tensor(), persistent=False) self.reset_parameters() def reset_parameters(self) -> None: - """ For compatibility with other modules. """ - pass + """For compatibility with other modules.""" def set_lr_cutoff(self, lr_cutoff: Optional[float] = None) -> None: - """ Change the long range cutoff. """ + """Change the long range cutoff.""" self.lr_cutoff = lr_cutoff if self.lr_cutoff is not None: - self.lr_cutoff2 = lr_cutoff ** 2 + self.lr_cutoff2 = lr_cutoff**2 self.two_div_cut = 2.0 / lr_cutoff - self.rcutconstant = lr_cutoff / (lr_cutoff ** 2 + 1.0) ** (3.0 / 2.0) - self.cutconstant = (2 * lr_cutoff ** 2 + 1.0) / (lr_cutoff ** 2 + 1.0) ** ( - 3.0 / 2.0 - ) + self.rcutconstant = lr_cutoff / (lr_cutoff**2 + 1.0) ** (3.0 / 2.0) + self.cutconstant = (2 * lr_cutoff**2 + 1.0) / (lr_cutoff**2 + 1.0) ** (3.0 / 2.0) else: self.lr_cutoff2 = None self.two_div_cut = None @@ -58,7 +54,7 @@ def set_lr_cutoff(self, lr_cutoff: Optional[float] = None) -> None: self.cutconstant = None def set_kmax(self, Nxmax: int, Nymax: int, Nzmax: int) -> None: - """ Set integer reciprocal space cutoff for Ewald summation """ + """Set integer reciprocal space cutoff for Ewald summation""" kx = torch.arange(0, Nxmax + 1) kx = torch.cat([kx, -kx[1:]]) ky = torch.arange(0, Nymax + 1) @@ -66,17 +62,15 @@ def set_kmax(self, Nxmax: int, Nymax: int, Nzmax: int) -> None: kz = torch.arange(0, Nzmax + 1) kz = torch.cat([kz, -kz[1:]]) kmul = torch.cartesian_prod(kx, ky, kz)[1:] # 0th entry is 0 0 0 - kmax = max(max(Nxmax, Nymax), Nzmax) - self.register_buffer( - "kmul", kmul[torch.sum(kmul ** 2, dim=-1) <= kmax ** 2], persistent=False - ) + kmax = max(Nxmax, Nymax, Nzmax) + self.register_buffer("kmul", kmul[torch.sum(kmul**2, dim=-1) <= kmax**2], persistent=False) def set_alpha(self, alpha: Optional[float] = None) -> None: - """ Set real space damping parameter for Ewald summation """ + """Set real space damping parameter for Ewald summation""" if alpha is None: # automatically determine alpha alpha = 4.0 / self.cutoff + 1e-3 self.alpha = alpha - self.alpha2 = alpha ** 2 + self.alpha2 = alpha**2 self.two_pi = 2.0 * math.pi self.one_over_sqrtpi = 1 / math.sqrt(math.pi) # print a warning if alpha is so small that the reciprocal space sum @@ -103,7 +97,7 @@ def _real_space( fac = self.kehalf * torch.gather(q, 0, idx_i) * torch.gather(q, 0, idx_j) f = switch_function(rij, self.cuton, self.cutoff) coulomb = 1.0 / rij - damped = 1.0 / (rij ** 2 + 1.0) ** (1.0 / 2.0) + damped = 1.0 / (rij**2 + 1.0) ** (1.0 / 2.0) pairwise = fac * (f * damped + (1 - f) * coulomb) * torch.erfc(self.alpha * rij) return q.new_zeros(N).index_add_(0, idx_i, pairwise) @@ -128,17 +122,11 @@ def _reciprocal_space( else: # gathering is faster on GPUs b = batch_seg.view(-1, 1, 1).expand(-1, k.shape[-2], k.shape[-1]) dot = torch.sum(torch.gather(k, 0, b) * R.unsqueeze(-2), dim=-1) - q_real = q.new_zeros(num_batch, dot.shape[-1]).index_add_( - 0, batch_seg, q.unsqueeze(-1) * torch.cos(dot) - ) - q_imag = q.new_zeros(num_batch, dot.shape[-1]).index_add_( - 0, batch_seg, q.unsqueeze(-1) * torch.sin(dot) - ) - qf = q_real ** 2 + q_imag ** 2 + q_real = q.new_zeros(num_batch, dot.shape[-1]).index_add_(0, batch_seg, q.unsqueeze(-1) * torch.cos(dot)) + q_imag = q.new_zeros(num_batch, dot.shape[-1]).index_add_(0, batch_seg, q.unsqueeze(-1) * torch.sin(dot)) + qf = q_real**2 + q_imag**2 # reciprocal energy - e_reciprocal = ( - self.two_pi / torch.prod(box_length, dim=1) * torch.sum(qf * qg, dim=-1) - ) + e_reciprocal = self.two_pi / torch.prod(box_length, dim=1) * torch.sum(qf * qg, dim=-1) # self interaction correction q2 = q * q e_self = self.alpha * self.one_over_sqrtpi * q2 @@ -184,7 +172,7 @@ def _coulomb( f = switch_function(rij, self.cuton, self.cutoff) if self.lr_cutoff is None: coulomb = 1.0 / rij - damped = 1.0 / (rij ** 2 + 1.0) ** (1.0 / 2.0) + damped = 1.0 / (rij**2 + 1.0) ** (1.0 / 2.0) else: coulomb = torch.where( rij < self.lr_cutoff, @@ -193,9 +181,7 @@ def _coulomb( ) damped = torch.where( rij < self.lr_cutoff, - 1.0 / (rij ** 2 + 1.0) ** (1.0 / 2.0) - + rij * self.rcutconstant - - self.cutconstant, + 1.0 / (rij**2 + 1.0) ** (1.0 / 2.0) + rij * self.rcutconstant - self.cutconstant, torch.zeros_like(rij), ) pairwise = fac * (f * damped + (1 - f) * coulomb) @@ -218,5 +204,4 @@ def forward( assert cell is not None assert batch_seg is not None return self._ewald(N, q, R, rij, idx_i, idx_j, cell, num_batch, batch_seg) - else: - return self._coulomb(N, q, rij, idx_i, idx_j) + return self._coulomb(N, q, rij, idx_i, idx_j) diff --git a/nff/nn/models/spooky_net_source/modules/exponential_bernstein_polynomials.py b/nff/nn/models/spooky_net_source/modules/exponential_bernstein_polynomials.py index 7e4a3a9c..3a5d0bd4 100644 --- a/nff/nn/models/spooky_net_source/modules/exponential_bernstein_polynomials.py +++ b/nff/nn/models/spooky_net_source/modules/exponential_bernstein_polynomials.py @@ -1,8 +1,9 @@ +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -import numpy as np -from ..functional import softplus_inverse + +from nff.nn.models.spooky_net_source.functional import softplus_inverse class ExponentialBernsteinPolynomials(nn.Module): @@ -40,14 +41,14 @@ def __init__( ini_alpha: float = 0.9448630629184640, exp_weighting: bool = False, ) -> None: - """ Initializes the ExponentialBernsteinPolynomials class. """ - super(ExponentialBernsteinPolynomials, self).__init__() + """Initializes the ExponentialBernsteinPolynomials class.""" + super().__init__() self.ini_alpha = ini_alpha self.exp_weighting = exp_weighting if no_basis_function_at_infinity: # increase number of basis functions by one num_basis_functions += 1 # compute values to initialize buffers - logfactorial = np.zeros((num_basis_functions)) + logfactorial = np.zeros(num_basis_functions) for i in range(2, num_basis_functions): logfactorial[i] = logfactorial[i - 1] + np.log(i) v = np.arange(0, num_basis_functions) @@ -61,13 +62,11 @@ def __init__( self.register_buffer("logc", torch.tensor(logbinomial, dtype=torch.float64)) self.register_buffer("n", torch.tensor(n, dtype=torch.float64)) self.register_buffer("v", torch.tensor(v, dtype=torch.float64)) - self.register_parameter( - "_alpha", nn.Parameter(torch.tensor(1.0, dtype=torch.float64)) - ) + self.register_parameter("_alpha", nn.Parameter(torch.tensor(1.0, dtype=torch.float64))) self.reset_parameters() def reset_parameters(self) -> None: - """ Initialize exponential scaling parameter alpha. """ + """Initialize exponential scaling parameter alpha.""" nn.init.constant_(self._alpha, softplus_inverse(self.ini_alpha)) def forward(self, r: torch.Tensor, cutoff_values: torch.Tensor) -> torch.Tensor: @@ -93,5 +92,4 @@ def forward(self, r: torch.Tensor, cutoff_values: torch.Tensor) -> torch.Tensor: rbf = cutoff_values.view(-1, 1) * torch.exp(x) if self.exp_weighting: return rbf * torch.exp(alphar) - else: - return rbf + return rbf diff --git a/nff/nn/models/spooky_net_source/modules/exponential_gaussian_functions.py b/nff/nn/models/spooky_net_source/modules/exponential_gaussian_functions.py index 5c188740..7d0e1fb0 100644 --- a/nff/nn/models/spooky_net_source/modules/exponential_gaussian_functions.py +++ b/nff/nn/models/spooky_net_source/modules/exponential_gaussian_functions.py @@ -1,7 +1,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from ..functional import softplus_inverse + +from nff.nn.models.spooky_net_source.functional import softplus_inverse class ExponentialGaussianFunctions(nn.Module): @@ -32,8 +33,8 @@ def __init__( ini_alpha: float = 0.9448630629184640, exp_weighting: bool = False, ) -> None: - """ Initializes the ExponentialGaussianFunctions class. """ - super(ExponentialGaussianFunctions, self).__init__() + """Initializes the ExponentialGaussianFunctions class.""" + super().__init__() self.ini_alpha = ini_alpha self.exp_weighting = exp_weighting if no_basis_function_at_infinity: @@ -46,19 +47,13 @@ def __init__( torch.tensor(1.0 * (num_basis_functions + 1), dtype=torch.float64), ) else: - self.register_buffer( - "center", torch.linspace(1, 0, num_basis_functions, dtype=torch.float64) - ) - self.register_buffer( - "width", torch.tensor(1.0 * num_basis_functions, dtype=torch.float64) - ) - self.register_parameter( - "_alpha", nn.Parameter(torch.tensor(1.0, dtype=torch.float64)) - ) + self.register_buffer("center", torch.linspace(1, 0, num_basis_functions, dtype=torch.float64)) + self.register_buffer("width", torch.tensor(1.0 * num_basis_functions, dtype=torch.float64)) + self.register_parameter("_alpha", nn.Parameter(torch.tensor(1.0, dtype=torch.float64))) self.reset_parameters() def reset_parameters(self) -> None: - """ Initialize exponential scaling parameter alpha. """ + """Initialize exponential scaling parameter alpha.""" nn.init.constant_(self._alpha, softplus_inverse(self.ini_alpha)) def forward(self, r: torch.Tensor, cutoff_values: torch.Tensor) -> torch.Tensor: @@ -79,10 +74,7 @@ def forward(self, r: torch.Tensor, cutoff_values: torch.Tensor) -> torch.Tensor: Values of the radial basis functions for the distances r. """ expalphar = torch.exp(-F.softplus(self._alpha) * r.view(-1, 1)) - rbf = cutoff_values.view(-1, 1) * torch.exp( - -self.width * (expalphar - self.center) ** 2 - ) + rbf = cutoff_values.view(-1, 1) * torch.exp(-self.width * (expalphar - self.center) ** 2) if self.exp_weighting: return rbf * expalphar - else: - return rbf + return rbf diff --git a/nff/nn/models/spooky_net_source/modules/gaussian_functions.py b/nff/nn/models/spooky_net_source/modules/gaussian_functions.py index 527be8d2..66235119 100644 --- a/nff/nn/models/spooky_net_source/modules/gaussian_functions.py +++ b/nff/nn/models/spooky_net_source/modules/gaussian_functions.py @@ -1,7 +1,5 @@ import torch import torch.nn as nn -import torch.nn.functional as F -from ..functional import softplus_inverse class GaussianFunctions(nn.Module): @@ -20,21 +18,18 @@ class GaussianFunctions(nn.Module): """ def __init__(self, num_basis_functions: int, cutoff: float) -> None: - """ Initializes the GaussianFunctions class. """ - super(GaussianFunctions, self).__init__() + """Initializes the GaussianFunctions class.""" + super().__init__() self.register_buffer("cutoff", torch.tensor(cutoff, dtype=torch.float64)) self.register_buffer( "center", torch.linspace(0, cutoff, num_basis_functions, dtype=torch.float64), ) - self.register_buffer( - "width", torch.tensor(num_basis_functions / cutoff, dtype=torch.float64) - ) + self.register_buffer("width", torch.tensor(num_basis_functions / cutoff, dtype=torch.float64)) self.reset_parameters() def reset_parameters(self) -> None: - """ For compatibility with other modules. """ - pass + """For compatibility with other modules.""" def forward(self, r: torch.Tensor, cutoff_values: torch.Tensor) -> torch.Tensor: """ @@ -54,7 +49,5 @@ def forward(self, r: torch.Tensor, cutoff_values: torch.Tensor) -> torch.Tensor: rbf (FloatTensor [N, num_basis_functions]): Values of the radial basis functions for the distances r. """ - rbf = cutoff_values.view(-1, 1) * torch.exp( - -self.width * (r.view(-1, 1) - self.center) ** 2 - ) + rbf = cutoff_values.view(-1, 1) * torch.exp(-self.width * (r.view(-1, 1) - self.center) ** 2) return rbf diff --git a/nff/nn/models/spooky_net_source/modules/interaction_module.py b/nff/nn/models/spooky_net_source/modules/interaction_module.py index 0d23f018..0ad8247f 100644 --- a/nff/nn/models/spooky_net_source/modules/interaction_module.py +++ b/nff/nn/models/spooky_net_source/modules/interaction_module.py @@ -1,11 +1,12 @@ +from typing import Optional, Tuple + import torch import torch.nn as nn -import torch.nn.functional as F -from .residual_stack import ResidualStack + from .local_interaction import LocalInteraction from .nonlocal_interaction import NonlocalInteraction from .residual_mlp import ResidualMLP -from typing import Tuple, Optional +from .residual_stack import ResidualStack class InteractionModule(nn.Module): @@ -57,8 +58,8 @@ def __init__( num_residual_output: int, activation: str = "swish", ) -> None: - """ Initializes the InteractionModule class. """ - super(InteractionModule, self).__init__() + """Initializes the InteractionModule class.""" + super().__init__() # initialize modules self.local_interaction = LocalInteraction( num_features=num_features, @@ -79,14 +80,11 @@ def __init__( ) self.residual_pre = ResidualStack(num_features, num_residual_pre, activation) self.residual_post = ResidualStack(num_features, num_residual_post, activation) - self.resblock = ResidualMLP( - num_features, num_residual_output, activation=activation - ) + self.resblock = ResidualMLP(num_features, num_residual_output, activation=activation) self.reset_parameters() def reset_parameters(self) -> None: - """ For compatibility with other modules. """ - pass + """For compatibility with other modules.""" def forward( self, @@ -129,7 +127,7 @@ def forward( descriptors). """ x = self.residual_pre(x) - l = self.local_interaction(x, rbf, pij, dij, idx_i, idx_j) + local = self.local_interaction(x, rbf, pij, dij, idx_i, idx_j) n = self.nonlocal_interaction(x, num_batch, batch_seg, mask) - x = self.residual_post(x + l + n) + x = self.residual_post(x + local + n) return x, self.resblock(x) diff --git a/nff/nn/models/spooky_net_source/modules/local_interaction.py b/nff/nn/models/spooky_net_source/modules/local_interaction.py index 55115dcc..1855e967 100644 --- a/nff/nn/models/spooky_net_source/modules/local_interaction.py +++ b/nff/nn/models/spooky_net_source/modules/local_interaction.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F + from .residual_mlp import ResidualMLP @@ -39,8 +39,8 @@ def __init__( num_residual: int, activation: str = "swish", ) -> None: - """ Initializes the LocalInteraction class. """ - super(LocalInteraction, self).__init__() + """Initializes the LocalInteraction class.""" + super().__init__() self.radial_s = nn.Linear(num_basis_functions, num_features, bias=False) self.radial_p = nn.Linear(num_basis_functions, num_features, bias=False) self.radial_d = nn.Linear(num_basis_functions, num_features, bias=False) @@ -50,13 +50,11 @@ def __init__( self.resblock_d = ResidualMLP(num_features, num_residual_d, activation) self.projection_p = nn.Linear(num_features, 2 * num_features, bias=False) self.projection_d = nn.Linear(num_features, 2 * num_features, bias=False) - self.resblock = ResidualMLP( - num_features, num_residual, activation, zero_init=True - ) + self.resblock = ResidualMLP(num_features, num_residual, activation, zero_init=True) self.reset_parameters() def reset_parameters(self) -> None: - """ Initialize parameters. """ + """Initialize parameters.""" nn.init.orthogonal_(self.radial_s.weight) nn.init.orthogonal_(self.radial_p.weight) nn.init.orthogonal_(self.radial_d.weight) diff --git a/nff/nn/models/spooky_net_source/modules/nonlinear_electronic_embedding.py b/nff/nn/models/spooky_net_source/modules/nonlinear_electronic_embedding.py index 67306df7..7dc8efcd 100644 --- a/nff/nn/models/spooky_net_source/modules/nonlinear_electronic_embedding.py +++ b/nff/nn/models/spooky_net_source/modules/nonlinear_electronic_embedding.py @@ -1,11 +1,9 @@ +from typing import Optional + import torch import torch.nn as nn -import torch.nn.functional as F -from .attention import Attention + from .residual_mlp import ResidualMLP -from .shifted_softplus import ShiftedSoftplus -from .swish import Swish -from typing import Optional class NonlinearElectronicEmbedding(nn.Module): @@ -32,16 +30,12 @@ class NonlinearElectronicEmbedding(nn.Module): 'ssp': Shifted softplus activation function. """ - def __init__( - self, num_features: int, num_residual: int, activation: str = "swish" - ) -> None: - """ Initializes the NonlinearElectronicEmbedding class. """ - super(NonlinearElectronicEmbedding, self).__init__() + def __init__(self, num_features: int, num_residual: int, activation: str = "swish") -> None: + """Initializes the NonlinearElectronicEmbedding class.""" + super().__init__() self.linear_q = nn.Linear(num_features, num_features, bias=False) self.featurize_k = nn.Linear(1, num_features) - self.resblock_k = ResidualMLP( - num_features, num_residual, activation=activation, zero_init=True - ) + self.resblock_k = ResidualMLP(num_features, num_residual, activation=activation, zero_init=True) self.featurize_v = nn.Linear(1, num_features, bias=False) self.resblock_v = ResidualMLP( num_features, @@ -53,7 +47,7 @@ def __init__( self.reset_parameters() def reset_parameters(self) -> None: - """ Initialize parameters. """ + """Initialize parameters.""" nn.init.orthogonal_(self.linear_q.weight) nn.init.orthogonal_(self.featurize_k.weight) nn.init.zeros_(self.featurize_k.bias) @@ -84,14 +78,10 @@ def forward( # determine maximum dot product (for numerics) if num_batch > 1: if mask is None: - mask = ( - nn.functional.one_hot(batch_seg) - .to(dtype=x.dtype, device=x.device) - .transpose(-1, -2) - ) + mask = nn.functional.one_hot(batch_seg).to(dtype=x.dtype, device=x.device).transpose(-1, -2) tmp = dot.view(1, -1).expand(num_batch, -1) tmp, _ = torch.max(mask * tmp, dim=-1) - if tmp.device.type == "cpu": # indexing is faster on CPUs + if tmp.device.type == "cpu": # indexing is faster on CPUs # noqa maximum = tmp[batch_seg] else: # gathering is faster on GPUs maximum = torch.gather(tmp, 0, batch_seg) @@ -99,10 +89,10 @@ def forward( maximum = torch.max(dot) # attention d = k.shape[-1] - a = torch.exp((dot - maximum) / d ** 0.5) + a = torch.exp((dot - maximum) / d**0.5) anorm = a.new_zeros(num_batch).index_add_(0, batch_seg, a) - if a.device.type == "cpu": # indexing is faster on CPUs + if a.device.type == "cpu": # indexing is faster on CPUs # noqa anorm = anorm[batch_seg] else: # gathering is faster on GPUs anorm = torch.gather(anorm, 0, batch_seg) diff --git a/nff/nn/models/spooky_net_source/modules/nonlocal_interaction.py b/nff/nn/models/spooky_net_source/modules/nonlocal_interaction.py index 73c14d32..06afde86 100644 --- a/nff/nn/models/spooky_net_source/modules/nonlocal_interaction.py +++ b/nff/nn/models/spooky_net_source/modules/nonlocal_interaction.py @@ -1,9 +1,10 @@ +from typing import Optional + import torch import torch.nn as nn -import torch.nn.functional as F + from .attention import Attention from .residual_mlp import ResidualMLP -from typing import Optional class NonlocalInteraction(nn.Module): @@ -38,23 +39,16 @@ def __init__( num_residual_v: int, activation: str = "swish", ) -> None: - """ Initializes the NonlocalInteraction class. """ - super(NonlocalInteraction, self).__init__() - self.resblock_q = ResidualMLP( - num_features, num_residual_q, activation=activation, zero_init=True - ) - self.resblock_k = ResidualMLP( - num_features, num_residual_k, activation=activation, zero_init=True - ) - self.resblock_v = ResidualMLP( - num_features, num_residual_v, activation=activation, zero_init=True - ) + """Initializes the NonlocalInteraction class.""" + super().__init__() + self.resblock_q = ResidualMLP(num_features, num_residual_q, activation=activation, zero_init=True) + self.resblock_k = ResidualMLP(num_features, num_residual_k, activation=activation, zero_init=True) + self.resblock_v = ResidualMLP(num_features, num_residual_v, activation=activation, zero_init=True) self.attention = Attention(num_features, num_features, num_features) self.reset_parameters() def reset_parameters(self) -> None: - """ For compatibility with other modules. """ - pass + """For compatibility with other modules.""" def forward( self, diff --git a/nff/nn/models/spooky_net_source/modules/nuclear_embedding.py b/nff/nn/models/spooky_net_source/modules/nuclear_embedding.py index 89003e29..3d0a68f5 100644 --- a/nff/nn/models/spooky_net_source/modules/nuclear_embedding.py +++ b/nff/nn/models/spooky_net_source/modules/nuclear_embedding.py @@ -1,8 +1,8 @@ import math -import numpy as np + import torch import torch.nn as nn -import torch.nn.functional as F + from .electron_configurations import electron_config @@ -24,26 +24,18 @@ class NuclearEmbedding(nn.Module): value (has minimal memory impact). """ - def __init__( - self, num_features: int, Zmax: int = 87, zero_init: bool = True - ) -> None: - """ Initializes the NuclearEmbedding class. """ - super(NuclearEmbedding, self).__init__() + def __init__(self, num_features: int, Zmax: int = 87, zero_init: bool = True) -> None: + """Initializes the NuclearEmbedding class.""" + super().__init__() self.num_features = num_features self.register_buffer("electron_config", torch.tensor(electron_config)) - self.register_parameter( - "element_embedding", nn.Parameter(torch.Tensor(Zmax, self.num_features)) - ) - self.register_buffer( - "embedding", torch.Tensor(Zmax, self.num_features), persistent=False - ) - self.config_linear = nn.Linear( - self.electron_config.size(1), self.num_features, bias=False - ) + self.register_parameter("element_embedding", nn.Parameter(torch.Tensor(Zmax, self.num_features))) + self.register_buffer("embedding", torch.Tensor(Zmax, self.num_features), persistent=False) + self.config_linear = nn.Linear(self.electron_config.size(1), self.num_features, bias=False) self.reset_parameters(zero_init) def reset_parameters(self, zero_init: bool = True) -> None: - """ Initialize parameters. """ + """Initialize parameters.""" if zero_init: nn.init.zeros_(self.element_embedding) nn.init.zeros_(self.config_linear.weight) @@ -52,13 +44,11 @@ def reset_parameters(self, zero_init: bool = True) -> None: nn.init.orthogonal_(self.config_linear.weight) def train(self, mode: bool = True) -> None: - """ Switch between training and evaluation mode. """ - super(NuclearEmbedding, self).train(mode=mode) + """Switch between training and evaluation mode.""" + super().train(mode=mode) if not self.training: with torch.no_grad(): - self.embedding = self.element_embedding + self.config_linear( - self.electron_config - ) + self.embedding = self.element_embedding + self.config_linear(self.electron_config) def forward(self, Z: torch.Tensor) -> torch.Tensor: """ @@ -75,12 +65,8 @@ def forward(self, Z: torch.Tensor) -> torch.Tensor: Embeddings of all atoms. """ if self.training: # during training, the embedding needs to be recomputed - self.embedding = self.element_embedding + self.config_linear( - self.electron_config - ) + self.embedding = self.element_embedding + self.config_linear(self.electron_config) if self.embedding.device.type == "cpu": # indexing is faster on CPUs return self.embedding[Z] - else: # gathering is faster on GPUs - return torch.gather( - self.embedding, 0, Z.view(-1, 1).expand(-1, self.num_features) - ) + # gathering is faster on GPUs + return torch.gather(self.embedding, 0, Z.view(-1, 1).expand(-1, self.num_features)) diff --git a/nff/nn/models/spooky_net_source/modules/residual.py b/nff/nn/models/spooky_net_source/modules/residual.py index 6a0b93c2..237a319e 100644 --- a/nff/nn/models/spooky_net_source/modules/residual.py +++ b/nff/nn/models/spooky_net_source/modules/residual.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F + from .shifted_softplus import ShiftedSoftplus from .swish import Swish @@ -26,8 +26,8 @@ def __init__( bias: bool = True, zero_init: bool = True, ) -> None: - """ Initializes the Residual class. """ - super(Residual, self).__init__() + """Initializes the Residual class.""" + super().__init__() # initialize attributes if activation == "ssp": Activation = ShiftedSoftplus @@ -45,7 +45,7 @@ def __init__( self.reset_parameters(bias, zero_init) def reset_parameters(self, bias: bool = True, zero_init: bool = True) -> None: - """ Initialize parameters to compute an identity mapping. """ + """Initialize parameters to compute an identity mapping.""" nn.init.orthogonal_(self.linear1.weight) if zero_init: nn.init.zeros_(self.linear2.weight) diff --git a/nff/nn/models/spooky_net_source/modules/residual_mlp.py b/nff/nn/models/spooky_net_source/modules/residual_mlp.py index 2961195a..afb5f786 100644 --- a/nff/nn/models/spooky_net_source/modules/residual_mlp.py +++ b/nff/nn/models/spooky_net_source/modules/residual_mlp.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F + from .residual_stack import ResidualStack from .shifted_softplus import ShiftedSoftplus from .swish import Swish @@ -15,10 +15,8 @@ def __init__( bias: bool = True, zero_init: bool = False, ) -> None: - super(ResidualMLP, self).__init__() - self.residual = ResidualStack( - num_features, num_residual, activation=activation, bias=bias, zero_init=True - ) + super().__init__() + self.residual = ResidualStack(num_features, num_residual, activation=activation, bias=bias, zero_init=True) # initialize activation function if activation == "ssp": self.activation = ShiftedSoftplus(num_features) diff --git a/nff/nn/models/spooky_net_source/modules/residual_stack.py b/nff/nn/models/spooky_net_source/modules/residual_stack.py index 5c10838a..2deb80f7 100644 --- a/nff/nn/models/spooky_net_source/modules/residual_stack.py +++ b/nff/nn/models/spooky_net_source/modules/residual_stack.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F + from .residual import Residual @@ -27,14 +27,9 @@ def __init__( bias: bool = True, zero_init: bool = True, ) -> None: - """ Initializes the ResidualStack class. """ - super(ResidualStack, self).__init__() - self.stack = nn.ModuleList( - [ - Residual(num_features, activation, bias, zero_init) - for i in range(num_residual) - ] - ) + """Initializes the ResidualStack class.""" + super().__init__() + self.stack = nn.ModuleList([Residual(num_features, activation, bias, zero_init) for i in range(num_residual)]) def forward(self, x: torch.Tensor) -> torch.Tensor: """ diff --git a/nff/nn/models/spooky_net_source/modules/shifted_softplus.py b/nff/nn/models/spooky_net_source/modules/shifted_softplus.py index a8750cf8..9e9b1a7a 100644 --- a/nff/nn/models/spooky_net_source/modules/shifted_softplus.py +++ b/nff/nn/models/spooky_net_source/modules/shifted_softplus.py @@ -1,4 +1,5 @@ import math + import torch import torch.nn as nn import torch.nn.functional as F @@ -21,11 +22,9 @@ class ShiftedSoftplus(nn.Module): Initial "temperature" beta of the softplus function. """ - def __init__( - self, num_features: int, initial_alpha: float = 1.0, initial_beta: float = 1.0 - ) -> None: - """ Initializes the ShiftedSoftplus class. """ - super(ShiftedSoftplus, self).__init__() + def __init__(self, num_features: int, initial_alpha: float = 1.0, initial_beta: float = 1.0) -> None: + """Initializes the ShiftedSoftplus class.""" + super().__init__() self._log2 = math.log(2) self.initial_alpha = initial_alpha self.initial_beta = initial_beta @@ -34,7 +33,7 @@ def __init__( self.reset_parameters() def reset_parameters(self) -> None: - """ Initialize parameters alpha and beta. """ + """Initialize parameters alpha and beta.""" nn.init.constant_(self.alpha, self.initial_alpha) nn.init.constant_(self.beta, self.initial_beta) diff --git a/nff/nn/models/spooky_net_source/modules/sinc_functions.py b/nff/nn/models/spooky_net_source/modules/sinc_functions.py index 0f884904..a6134d9b 100644 --- a/nff/nn/models/spooky_net_source/modules/sinc_functions.py +++ b/nff/nn/models/spooky_net_source/modules/sinc_functions.py @@ -1,13 +1,12 @@ import math + import torch import torch.nn as nn -import torch.nn.functional as F -from ..functional import softplus_inverse # backwards compatibility with older versions of torch try: from torch import sinc -except: +except BaseException: def sinc(x): x = x * math.pi @@ -28,20 +27,16 @@ class SincFunctions(nn.Module): """ def __init__(self, num_basis_functions: int, cutoff: float) -> None: - """ Initializes the SincFunctions class. """ - super(SincFunctions, self).__init__() + """Initializes the SincFunctions class.""" + super().__init__() self.register_buffer( "factor", - torch.linspace( - 1, num_basis_functions, num_basis_functions, dtype=torch.float64 - ) - / cutoff, + torch.linspace(1, num_basis_functions, num_basis_functions, dtype=torch.float64) / cutoff, ) self.reset_parameters() def reset_parameters(self) -> None: - """ For compatibility with other modules. """ - pass + """For compatibility with other modules.""" def forward(self, r: torch.Tensor, cutoff_values: torch.Tensor) -> torch.Tensor: """ diff --git a/nff/nn/models/spooky_net_source/modules/swish.py b/nff/nn/models/spooky_net_source/modules/swish.py index 086796f8..5bb7096e 100644 --- a/nff/nn/models/spooky_net_source/modules/swish.py +++ b/nff/nn/models/spooky_net_source/modules/swish.py @@ -23,11 +23,9 @@ class Swish(nn.Module): (GELUs)." """ - def __init__( - self, num_features: int, initial_alpha: float = 1.0, initial_beta: float = 1.702 - ) -> None: - """ Initializes the Swish class. """ - super(Swish, self).__init__() + def __init__(self, num_features: int, initial_alpha: float = 1.0, initial_beta: float = 1.702) -> None: + """Initializes the Swish class.""" + super().__init__() self.initial_alpha = initial_alpha self.initial_beta = initial_beta self.register_parameter("alpha", nn.Parameter(torch.Tensor(num_features))) @@ -35,7 +33,7 @@ def __init__( self.reset_parameters() def reset_parameters(self) -> None: - """ Initialize parameters alpha and beta. """ + """Initialize parameters alpha and beta.""" nn.init.constant_(self.alpha, self.initial_alpha) nn.init.constant_(self.beta, self.initial_beta) diff --git a/nff/nn/models/spooky_net_source/modules/zbl_repulsion_energy.py b/nff/nn/models/spooky_net_source/modules/zbl_repulsion_energy.py index 396fb23a..1317ed1a 100644 --- a/nff/nn/models/spooky_net_source/modules/zbl_repulsion_energy.py +++ b/nff/nn/models/spooky_net_source/modules/zbl_repulsion_energy.py @@ -1,7 +1,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from ..functional import softplus_inverse + +from nff.nn.models.spooky_net_source.functional import softplus_inverse class ZBLRepulsionEnergy(nn.Module): @@ -20,11 +21,9 @@ class ZBLRepulsionEnergy(nn.Module): lengths in Angstrom and energy in electronvolt). """ - def __init__( - self, a0: float = 0.5291772105638411, ke: float = 14.399645351950548 - ) -> None: - """ Initializes the ZBLRepulsionEnergy class. """ - super(ZBLRepulsionEnergy, self).__init__() + def __init__(self, a0: float = 0.5291772105638411, ke: float = 14.399645351950548) -> None: + """Initializes the ZBLRepulsionEnergy class.""" + super().__init__() self.a0 = a0 self.ke = ke self.kehalf = ke / 2 @@ -41,7 +40,7 @@ def __init__( self.reset_parameters() def reset_parameters(self) -> None: - """ Initialize parameters to the default ZBL potential. """ + """Initialize parameters to the default ZBL potential.""" nn.init.constant_(self._adiv, softplus_inverse(1 / (0.8854 * self.a0))) nn.init.constant_(self._apow, softplus_inverse(0.23)) nn.init.constant_(self._c1, softplus_inverse(0.18180)) diff --git a/nff/nn/models/spooky_net_source/spookynet.py b/nff/nn/models/spooky_net_source/spookynet.py index 7f09e9cb..18330189 100644 --- a/nff/nn/models/spooky_net_source/spookynet.py +++ b/nff/nn/models/spooky_net_source/spookynet.py @@ -1,14 +1,28 @@ import math +from typing import Optional, Tuple + import torch import torch.nn as nn + from .functional import cutoff_function -from .modules import * -from typing import Tuple, Optional +from .modules import ( + D4DispersionEnergy, + ElectronicEmbedding, + ElectrostaticEnergy, + ExponentialBernsteinPolynomials, + ExponentialGaussianFunctions, + GaussianFunctions, + InteractionModule, + NonlinearElectronicEmbedding, + NuclearEmbedding, + SincFunctions, + ZBLRepulsionEnergy, +) # backwards compatibility with old versions of pytorch try: from torch.linalg import norm -except: +except BaseException: from torch import norm @@ -36,7 +50,7 @@ class SpookyNet(nn.Module): num_modules (int): Number of modules (iterations) for constructing atomic features. num_residual_electron (int): - Number of residual blocks applied to features encoding the electronic + Number of residual blocks applied to features encoding the electronic state. num_residual_pre (int): Number of residual blocks applied to atomic features in each module @@ -45,16 +59,16 @@ class SpookyNet(nn.Module): Number of residual blocks applied to atomic features after interaction with neighbouring atoms (per module). num_residual_pre_local_x (int): - Number of residual blocks (per module) applied to atomic features in + Number of residual blocks (per module) applied to atomic features in local interaction. num_residual_pre_local_s (int): - Number of residual blocks (per module) applied to s-type interaction features + Number of residual blocks (per module) applied to s-type interaction features in local interaction. num_residual_pre_local_p (int): - Number of residual blocks (per module) applied to p-type interaction features + Number of residual blocks (per module) applied to p-type interaction features in local interaction. num_residual_pre_local_d (int): - Number of residual blocks (per module) applied to d-type interaction features + Number of residual blocks (per module) applied to d-type interaction features in local interaction. num_residual_post (int): Number of residual blocks applied to atomic features in each module @@ -145,8 +159,8 @@ def __init__( zero_init=True, **kwargs, ) -> None: - """ Initializes the SpookyNet class. """ - super(SpookyNet, self).__init__() + """Initializes the SpookyNet class.""" + super().__init__() # load state from a file (if load_from is not None) and overwrite # the given arguments. @@ -179,14 +193,8 @@ def __init__( module_keep_prob = saved_state["module_keep_prob"] Zmax = saved_state["Zmax"] # compatibility with older code - if "use_irreps" in saved_state: - use_irreps = saved_state["use_irreps"] - else: - use_irreps = False - if "use_nonlinear_embedding" in saved_state: - use_nonlinear_embedding = saved_state["use_nonlinear_embedding"] - else: - use_nonlinear_embedding = True + use_irreps = saved_state.get("use_irreps", False) + use_nonlinear_embedding = saved_state.get("use_nonlinear_embedding", True) # store argument values as attributes self.activation = activation @@ -234,14 +242,10 @@ def __init__( # declare modules and parameters # element specific energy and charge bias - self.register_parameter( - "element_bias", nn.Parameter(torch.Tensor(self.Zmax, 2)) - ) + self.register_parameter("element_bias", nn.Parameter(torch.Tensor(self.Zmax, 2))) # embeddings - self.nuclear_embedding = NuclearEmbedding( - self.num_features, self.Zmax, zero_init=zero_init - ) + self.nuclear_embedding = NuclearEmbedding(self.num_features, self.Zmax, zero_init=zero_init) if self.use_nonlinear_embedding: self.charge_embedding = NonlinearElectronicEmbedding( self.num_features, self.num_residual_electron, activation @@ -273,17 +277,11 @@ def __init__( self.num_basis_functions, exp_weighting=self.exp_weighting ) elif self.basis_functions == "gaussian": - self.radial_basis_functions = GaussianFunctions( - self.num_basis_functions, self.cutoff - ) + self.radial_basis_functions = GaussianFunctions(self.num_basis_functions, self.cutoff) elif self.basis_functions == "bernstein": - self.radial_basis_functions = BernsteinPolynomials( - self.num_basis_functions, self.cutoff - ) + self.radial_basis_functions = BernsteinPolynomials(self.num_basis_functions, self.cutoff) elif self.basis_functions == "sinc": - self.radial_basis_functions = SincFunctions( - self.num_basis_functions, self.cutoff - ) + self.radial_basis_functions = SincFunctions(self.num_basis_functions, self.cutoff) else: raise ValueError( "Argument 'basis_functions' may only take the " @@ -347,9 +345,7 @@ def __init__( # runtime exception may happen if state_dict was saved with an older # version of the code, but it should be possible to convert it except RuntimeError: - self.load_state_dict( - self._convert_state_dict(saved_state["state_dict"]) - ) + self.load_state_dict(self._convert_state_dict(saved_state["state_dict"])) if use_d4_dispersion: self.d4_dispersion_energy._compute_refc6() @@ -357,7 +353,7 @@ def __init__( self.build_requires_grad_dict() def reset_parameters(self) -> None: - """ Initialize parameters randomly. """ + """Initialize parameters randomly.""" nn.init.orthogonal_(self.output.weight) nn.init.zeros_(self.element_bias) @@ -374,24 +370,24 @@ def set_lr_cutoff(self, lr_cutoff: Optional[float] = None) -> None: @property def dtype(self) -> torch.dtype: - """ Return torch.dtype of parameters (input tensors must match). """ + """Return torch.dtype of parameters (input tensors must match).""" return self.nuclear_embedding.element_embedding.dtype @property def device(self) -> torch.device: - """ Return torch.device of parameters (input tensors must match). """ + """Return torch.device of parameters (input tensors must match).""" return self.nuclear_embedding.element_embedding.device def train(self, mode: bool = True) -> None: - """ Turn on training mode. """ - super(SpookyNet, self).train(mode=mode) + """Turn on training mode.""" + super().train(mode=mode) for name, param in self.named_parameters(): param.requires_grad = self.requires_grad_dict[name] def eval(self) -> None: - """ Turn on evaluation mode (smaller memory footprint).""" - super(SpookyNet, self).eval() - for name, param in self.named_parameters(): + """Turn on evaluation mode (smaller memory footprint).""" + super().eval() + for _name, param in self.named_parameters(): param.requires_grad = False def build_requires_grad_dict(self) -> None: @@ -454,8 +450,9 @@ def _convert_state_dict(self, old_state_dict: dict) -> dict: Helper function to convert a state_dict saved with an old version of the code to the current version. """ + def prefix_postfix(string, pattern, prefix="resblock", sep=".", presep="_"): - """ Helper function for converting keys """ + """Helper function for converting keys""" parts = string.split(sep) for i, part in enumerate(parts): if pattern + presep in part: @@ -463,27 +460,16 @@ def prefix_postfix(string, pattern, prefix="resblock", sep=".", presep="_"): return sep.join(parts) new_state_dict = {} - for old_key in old_state_dict: + for old_key, old_value in old_state_dict.items(): if old_key == "idx" or old_key == "mul": continue - if ( - "local_interaction.residual_" in old_key - or "embedding.residual_" in old_key - ): + if "local_interaction.residual_" in old_key or "embedding.residual_" in old_key: new_key = prefix_postfix(old_key, "residual") - elif ( - "local_interaction.activation_" in old_key - or "embedding.activation_" in old_key - ): + elif "local_interaction.activation_" in old_key or "embedding.activation_" in old_key: new_key = prefix_postfix(old_key, "activation") - elif ( - "local_interaction.linear_" in old_key or "embedding.linear_" in old_key - ): - if "embedding.linear_q" in old_key: - new_key = old_key - else: - new_key = prefix_postfix(old_key, "linear") + elif "local_interaction.linear_" in old_key or "embedding.linear_" in old_key: + new_key = old_key if "embedding.linear_q" in old_key else prefix_postfix(old_key, "linear") elif ".local_interaction.residual." in old_key: new_key = old_key.replace(".residual.", ".resblock.residual.") elif ".local_interaction.activation." in old_key: @@ -507,11 +493,11 @@ def prefix_postfix(string, pattern, prefix="resblock", sep=".", presep="_"): new_key = new_key.replace("activation_pre", "activation1") if "activation_post" in new_key: new_key = new_key.replace("activation_post", "activation2") - new_state_dict[new_key] = old_state_dict[old_key] + new_state_dict[new_key] = old_value return new_state_dict def get_number_of_parameters(self) -> int: - """ Returns the total number of parameters. """ + """Returns the total number of parameters.""" num = 0 for param in self.parameters(): num += param.numel() @@ -567,9 +553,7 @@ def calculate_distances( else: # gathering is faster on GPUs Ri = torch.gather(R, 0, idx_i.view(-1, 1).expand(-1, 3)) Rj = torch.gather(R, 0, idx_j.view(-1, 1).expand(-1, 3)) - if ( - cell is not None and cell_offsets is not None and batch_seg is not None - ): # apply PBCs + if cell is not None and cell_offsets is not None and batch_seg is not None: # apply PBCs if cell.device.type == "cpu": # indexing is faster on CPUs cells = cell[batch_seg][idx_i] else: # gathering is faster on GPUs @@ -630,8 +614,7 @@ def _atomic_properties_static( self._sqrt3 * pij[:, 0] * pij[:, 2], # xz self._sqrt3 * pij[:, 1] * pij[:, 2], # yz 0.5 * (3 * pij[:, 2] * pij[:, 2] - 1.0), # z2 - self._sqrt3half - * (pij[:, 0] * pij[:, 0] - pij[:, 1] * pij[:, 1]), # x2-y2 + self._sqrt3half * (pij[:, 0] * pij[:, 0] - pij[:, 1] * pij[:, 1]), # x2-y2 ], dim=-1, ) @@ -653,9 +636,7 @@ def _atomic_properties_static( # mask for efficient attention if num_batch > 1 and batch_seg is not None: - one_hot = nn.functional.one_hot(batch_seg).to( - dtype=R.dtype, device=R.device - ) + one_hot = nn.functional.one_hot(batch_seg).to(dtype=R.dtype, device=R.device) mask = one_hot @ one_hot.transpose(-1, -2) else: mask = None @@ -714,11 +695,7 @@ def _atomic_properties_dynamic( # initialize feature vectors z = self.nuclear_embedding(Z) if num_batch > 1: - electronic_mask = ( - nn.functional.one_hot(batch_seg) - .to(dtype=rij.dtype, device=rij.device) - .transpose(-1, -2) - ) + electronic_mask = nn.functional.one_hot(batch_seg).to(dtype=rij.dtype, device=rij.device).transpose(-1, -2) else: electronic_mask = None q = self.charge_embedding(z, Q, num_batch, batch_seg, electronic_mask) @@ -731,9 +708,7 @@ def _atomic_properties_dynamic( # perform iterations over modules f = x.new_zeros(x.size()) # initialize output features to zero for module in self.module: - x, y = module( - x, rbf, pij, dij, sr_idx_i, sr_idx_j, num_batch, batch_seg, mask - ) + x, y = module(x, rbf, pij, dij, sr_idx_i, sr_idx_j, num_batch, batch_seg, mask) # apply dropout mask if self.training and self.module_keep_prob < 1.0: y = y * dropout_mask[batch_seg] @@ -773,16 +748,12 @@ def _atomic_properties_dynamic( # compute ZBL inspired short-range repulsive contributions if self.use_zbl_repulsion: - ea_rep = self.zbl_repulsion_energy( - N, Z.to(self.dtype), sr_rij, cutoff_values, sr_idx_i, sr_idx_j - ) + ea_rep = self.zbl_repulsion_energy(N, Z.to(self.dtype), sr_rij, cutoff_values, sr_idx_i, sr_idx_j) else: ea_rep = ea.new_zeros(N) # optimization when lr_cutoff is used - if self.lr_cutoff is not None and ( - self.use_electrostatics or self.use_d4_dispersion - ): + if self.lr_cutoff is not None and (self.use_electrostatics or self.use_d4_dispersion): mask = rij < self.lr_cutoff # select all entries below lr_cutoff rij = rij[mask] idx_i = idx_i[mask] @@ -790,16 +761,12 @@ def _atomic_properties_dynamic( # compute electrostatic contributions if self.use_electrostatics: - ea_ele = self.electrostatic_energy( - N, qa, rij, idx_i, idx_j, R, cell, num_batch, batch_seg - ) + ea_ele = self.electrostatic_energy(N, qa, rij, idx_i, idx_j, R, cell, num_batch, batch_seg) else: ea_ele = ea.new_zeros(N) # compute dispersion contributions if self.use_d4_dispersion: - ea_vdw, pa, c6 = self.d4_dispersion_energy( - N, Z, qa, rij, idx_i, idx_j, self.compute_d4_atomic - ) + ea_vdw, pa, c6 = self.d4_dispersion_energy(N, Z, qa, rij, idx_i, idx_j, self.compute_d4_atomic) else: ea_vdw, pa, c6 = ea.new_zeros(N), ea.new_zeros(N), ea.new_zeros(N) return (f, ea, qa, ea_rep, ea_ele, ea_vdw, pa, c6) @@ -969,7 +936,7 @@ def energy( Returns: energy (FloatTensor [B]): Potential energy of each molecule in the batch. - + (+ all return values of atomic_properties) """ if batch_seg is None: # assume a single batch @@ -986,9 +953,7 @@ def energy( num_batch=num_batch, batch_seg=batch_seg, ) - energy = ea.new_zeros(num_batch).index_add_( - 0, batch_seg, ea + ea_rep + ea_ele + ea_vdw - ) + energy = ea.new_zeros(num_batch).index_add_(0, batch_seg, ea + ea_rep + ea_ele + ea_vdw) return (energy, f, ea, qa, ea_rep, ea_ele, ea_vdw, pa, c6) @torch.jit.export @@ -1050,10 +1015,8 @@ def energy_and_forces( batch_seg=batch_seg, ) if idx_i.numel() > 0: # autograd will fail if there are no distances - grad = torch.autograd.grad( - [torch.sum(energy)], [R], create_graph=create_graph - )[0] - if grad is not None: # necessary for torch.jit compatibility + grad = torch.autograd.grad([torch.sum(energy)], [R], create_graph=create_graph)[0] + if grad is not None: # necessary for torch.jit compatibility # noqa forces = -grad else: forces = torch.zeros_like(R) @@ -1104,7 +1067,7 @@ def energy_and_forces_and_hessian( Hessian matrix. If more than one molecule is in the batch, the appropriate entries need to be collected from the matrix manually for each molecule. - + (+ all return values of atomic_properties) """ ( @@ -1244,9 +1207,7 @@ def forward( ) forces = torch.zeros_like(R) if use_dipole: - dipole = qa.new_zeros((num_batch, 3)).index_add_( - 0, batch_seg, qa.view(-1, 1) * R - ) + dipole = qa.new_zeros((num_batch, 3)).index_add_(0, batch_seg, qa.view(-1, 1) * R) else: dipole = qa.new_zeros((num_batch, 3)) return energy, forces, dipole, f, ea, qa, ea_rep, ea_ele, ea_vdw, pa, c6 diff --git a/nff/nn/models/spooky_net_source/spookynet_calculator.py b/nff/nn/models/spooky_net_source/spookynet_calculator.py index 0175bd0c..195306c7 100644 --- a/nff/nn/models/spooky_net_source/spookynet_calculator.py +++ b/nff/nn/models/spooky_net_source/spookynet_calculator.py @@ -1,16 +1,18 @@ -import torch import numpy as np -from ase import Atoms -from ase.neighborlist import neighbor_list +import torch from ase.calculators.calculator import Calculator, all_changes +from ase.neighborlist import neighbor_list from sklearn.neighbors import BallTree + from .spookynet import SpookyNet from .spookynet_ensemble import SpookyNetEnsemble + class SpookyNetCalculator(Calculator): """ This module defines an ASE interface for SpookyNet. """ + implemented_properties = ["energy", "forces", "hessian", "dipole", "charges"] default_parameters = dict( @@ -22,32 +24,18 @@ class SpookyNetCalculator(Calculator): skin=0.3, # skin-distance for building neighborlists ) - def __init__( - self, - restart=None, - ignore_bad_restart_file=False, - label=None, - atoms=None, - **kwargs - ): - Calculator.__init__( - self, restart, ignore_bad_restart_file, label, atoms, **kwargs - ) + def __init__(self, restart=None, ignore_bad_restart_file=False, label=None, atoms=None, **kwargs): + Calculator.__init__(self, restart, ignore_bad_restart_file, label, atoms, **kwargs) self.lr_cutoff = self.parameters.lr_cutoff if type(self.parameters.load_from) is list: self.ensemble = True self.spookynet = SpookyNetEnsemble(models=self.parameters.load_from) sr_cutoff = self.spookynet.models[0].cutoff self.cutoff = sr_cutoff - self.use_lr = ( - self.spookynet.models[0].use_d4_dispersion - or self.spookynet.models[0].use_electrostatics - ) + self.use_lr = self.spookynet.models[0].use_d4_dispersion or self.spookynet.models[0].use_electrostatics for model in self.spookynet.models: assert sr_cutoff == model.cutoff - assert self.use_lr == ( - model.use_d4_dispersion or model.use_electrostatics - ) + assert self.use_lr == (model.use_d4_dispersion or model.use_electrostatics) if self.lr_cutoff is not None: # overwrite lr_cutoff if one is given model.set_lr_cutoff(self.lr_cutoff) if model.lr_cutoff is not None: @@ -60,9 +48,7 @@ def __init__( self.ensemble = False self.spookynet = SpookyNet(load_from=self.parameters.load_from) self.cutoff = self.spookynet.cutoff - self.use_lr = ( - self.spookynet.use_d4_dispersion or self.spookynet.use_electrostatics - ) + self.use_lr = self.spookynet.use_d4_dispersion or self.spookynet.use_electrostatics if self.lr_cutoff is not None: # overwrite lr_cutoff if one is given self.spookynet.set_lr_cutoff(self.lr_cutoff) if self.spookynet.lr_cutoff is not None: @@ -77,11 +63,9 @@ def __init__( self.calc_hessian = False self.converged = True # for compatibility with other calculators # for the neighborlist - self.skin2 = self.parameters.skin ** 2 + self.skin2 = self.parameters.skin**2 assert self.parameters.skin >= 0 - self.cutoff += ( - 2 * self.parameters.skin - ) # cutoff needs to be larger when skin is used + self.cutoff += 2 * self.parameters.skin # cutoff needs to be larger when skin is used self.N = 0 self.positions = None self.pbc = np.array([False]) @@ -89,7 +73,7 @@ def __init__( self.cell_offsets = None def _nsquared_neighborlist(self, atoms): - if self.N != len(atoms): + if len(atoms) != self.N: self.N = len(atoms) self.positions = np.copy(atoms.positions) self.pbc = np.array([False]) @@ -104,7 +88,7 @@ def _nsquared_neighborlist(self, atoms): def _periodic_neighborlist(self, atoms): if ( - self.N != len(atoms) + len(atoms) != self.N or (self.pbc != atoms.pbc).any() or (self.cell != atoms.cell).any() or ((self.positions - atoms.positions) ** 2).sum(-1).max() > self.skin2 @@ -119,10 +103,7 @@ def _periodic_neighborlist(self, atoms): self.cell_offsets = torch.tensor(cell_offsets, dtype=self.dtype) def _non_periodic_neighborlist(self, atoms): - if ( - self.N != len(atoms) - or ((self.positions - atoms.positions) ** 2).sum(-1).max() >= self.skin2 - ): + if len(atoms) != self.N or ((self.positions - atoms.positions) ** 2).sum(-1).max() >= self.skin2: self.N = len(atoms) self.positions = np.copy(atoms.positions) self.pbc = np.array([False]) @@ -160,17 +141,15 @@ def calculate(self, atoms=None, properties=["energy"], system_changes=all_change "R": torch.tensor(atoms.positions, dtype=self.dtype, requires_grad=True), "idx_i": self.idx_i, "idx_j": self.idx_j, - "cell": None - if not atoms.pbc.any() - else torch.tensor([atoms.cell], dtype=self.dtype), + "cell": None if not atoms.pbc.any() else torch.tensor([atoms.cell], dtype=self.dtype), "cell_offsets": self.cell_offsets, } # send args to GPU if self.use_gpu: - for key in args.keys(): - if isinstance(args[key], torch.Tensor): - args[key] = args[key].cuda() + for key, value in args.items(): + if isinstance(value, torch.Tensor): + args[key] = value.cuda() if self.calc_hessian: ( @@ -216,29 +195,22 @@ def calculate(self, atoms=None, properties=["energy"], system_changes=all_change self.results["forces_std"] = forces[1].detach().cpu().numpy() self.results["charges"] = qa[0].detach().cpu().numpy() self.results["charges_std"] = qa[1].detach().cpu().numpy() - self.results["dipole"] = np.sum( - atoms.get_positions() * self.results["charges"][:, None], 0 - ) - self.results["dipole_std"] = np.sum( - atoms.get_positions() * self.results["charges_std"][:, None], 0 - ) + self.results["dipole"] = np.sum(atoms.get_positions() * self.results["charges"][:, None], 0) + self.results["dipole_std"] = np.sum(atoms.get_positions() * self.results["charges_std"][:, None], 0) else: self.results["features"] = f.detach().cpu().numpy() self.results["energy"] = energy.detach().cpu().item() self.results["forces"] = forces.detach().cpu().numpy() self.results["charges"] = qa.detach().cpu().numpy() - self.results["dipole"] = np.sum( - atoms.get_positions() * self.results["charges"][:, None], 0 - ) + self.results["dipole"] = np.sum(atoms.get_positions() * self.results["charges"][:, None], 0) def set_to_gradient_calculation(self): - """ For compatibility with other calculators. """ + """For compatibility with other calculators.""" self.calc_hessian = False def set_to_hessian_calculation(self): - """ For compatibility with other calculators. """ + """For compatibility with other calculators.""" self.calc_hessian = True def clear_restart_file(self): - """ For compatibility with scripts that use file i/o calculators. """ - pass + """For compatibility with scripts that use file i/o calculators.""" diff --git a/nff/nn/models/spooky_net_source/spookynet_ensemble.py b/nff/nn/models/spooky_net_source/spookynet_ensemble.py index 041b6b99..7693c30f 100644 --- a/nff/nn/models/spooky_net_source/spookynet_ensemble.py +++ b/nff/nn/models/spooky_net_source/spookynet_ensemble.py @@ -1,7 +1,9 @@ +from typing import List, Optional, Tuple + import torch import torch.nn as nn + from .spookynet import SpookyNet -from typing import List, Tuple, Optional class SpookyNetEnsemble(nn.Module): @@ -16,8 +18,8 @@ class SpookyNetEnsemble(nn.Module): """ def __init__(self, models: List[str] = []) -> None: - """ Initializes the SpookyNetEnsemble class. """ - super(SpookyNetEnsemble, self).__init__() + """Initializes the SpookyNetEnsemble class.""" + super().__init__() assert len(models) > 1 self.models = nn.ModuleList([SpookyNet(load_from=model) for model in models]) for model in self.models: @@ -25,12 +27,12 @@ def __init__(self, models: List[str] = []) -> None: @property def dtype(self) -> torch.dtype: - """ Return torch.dtype of parameters (input tensors must match). """ + """Return torch.dtype of parameters (input tensors must match).""" return self.models[0].dtype @property def device(self) -> torch.device: - """ Return torch.device of parameters (input tensors must match). """ + """Return torch.device of parameters (input tensors must match).""" return self.models[0].device def train(self, mode=True) -> None: @@ -38,13 +40,13 @@ def train(self, mode=True) -> None: Turn on training mode. This is just for compatibility, the models should be trained individually and only evaluated as ensemble. """ - super(SpookyNetEnsemble, self).train(mode=mode) + super().train(mode=mode) for model in self.models: model.train(mode) def eval(self) -> None: - """ Turn on evaluation mode (smaller memory footprint).""" - super(SpookyNetEnsemble, self).eval() + """Turn on evaluation mode (smaller memory footprint).""" + super().eval() for model in self.models: model.eval() @@ -137,9 +139,7 @@ def atomic_properties( c6.append(c6_) return (f, ea, qa, ea_rep, ea_ele, ea_vdw, pa, c6) - def _mean_std_from_list( - self, x: List[torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor]: + def _mean_std_from_list(self, x: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: """ Given a list of tensors, computes their mean and standard deviation. Only used internally. @@ -266,15 +266,13 @@ def energy_and_forces( retain_graph=True, create_graph=create_graph, )[0] - if grad_mean is not None: # necessary for torch.jit compatibility + if grad_mean is not None: # necessary for torch.jit compatibility # noqa forces_mean = -grad_mean else: forces_mean = torch.zeros_like(R) if calculate_forces_std: - grad_std = torch.autograd.grad( - [torch.sum(energy[1])], [R], create_graph=create_graph - )[0] - if grad_std is not None: # necessary for torch.jit compatibility + grad_std = torch.autograd.grad([torch.sum(energy[1])], [R], create_graph=create_graph)[0] + if grad_std is not None: # necessary for torch.jit compatibility # noqa forces_std = torch.abs(grad_std) else: forces_std = torch.zeros_like(R) @@ -357,15 +355,11 @@ def energy_and_forces_and_hessian( for idx in range(s): # loop through entries of the hessian # retain graph when the index is smaller than the max index, # else computation fails - tmp = torch.autograd.grad( - [grad_mean[idx]], [R], retain_graph=(idx < s) - )[0] + tmp = torch.autograd.grad([grad_mean[idx]], [R], retain_graph=(idx < s))[0] if tmp is not None: # necessary for torch.jit compatibility hessian_mean[idx] = tmp.view(-1) if calculate_hessian_std and calculate_forces_std: - tmp = torch.autograd.grad( - [grad_std[idx]], [R], retain_graph=(idx < s) - )[0] + tmp = torch.autograd.grad([grad_std[idx]], [R], retain_graph=(idx < s))[0] if tmp is not None: # necessary for torch.jit compatibility hessian_std[idx] = tmp.view(-1) hessian = (hessian_mean, hessian_std) @@ -452,16 +446,8 @@ def forward( ) forces = (torch.zeros_like(R), torch.zeros_like(R)) if use_dipole: - dipole_mean = ( - qa[0] - .new_zeros((num_batch, 3)) - .index_add_(0, batch_seg, qa[0].view(-1, 1) * R) - ) - dipole_std = ( - qa[1] - .new_zeros((num_batch, 3)) - .index_add_(0, batch_seg, qa[1].view(-1, 1) * R) - ) + dipole_mean = qa[0].new_zeros((num_batch, 3)).index_add_(0, batch_seg, qa[0].view(-1, 1) * R) + dipole_std = qa[1].new_zeros((num_batch, 3)).index_add_(0, batch_seg, qa[1].view(-1, 1) * R) dipole = (dipole_mean, torch.abs(dipole_std)) else: dipole = (qa[0].new_zeros((num_batch, 3)), qa[1].new_zeros((num_batch, 3))) diff --git a/nff/nn/models/spooky_painn.py b/nff/nn/models/spooky_painn.py index 60176459..8141f51b 100644 --- a/nff/nn/models/spooky_painn.py +++ b/nff/nn/models/spooky_painn.py @@ -184,7 +184,7 @@ def add_phys(self, results, s_i, v_i, xyz, z, charge, nbrs, num_atoms, offsets, if key in electrostatics: suffix = "_" + key.split("_")[-1] - if not any([i.isdigit() for i in suffix]): + if not any(i.isdigit() for i in suffix): suffix = "" results.update({f"dipole{suffix}": full_dip, f"q{suffix}": q, f"dip_atom{suffix}": dip_atom}) @@ -304,10 +304,7 @@ def get_off_diag_keys(self): return off_diag def get_diabat_charge(self, key, charge): - if key in self.off_diag_keys: - total_charge = torch.zeros_like(charge) - else: - total_charge = charge + total_charge = torch.zeros_like(charge) if key in self.off_diag_keys else charge return total_charge def add_phys(self, results, s_i, v_i, xyz, z, charge, nbrs, num_atoms, offsets, mol_offsets, mol_nbrs): @@ -325,7 +322,7 @@ def add_phys(self, results, s_i, v_i, xyz, z, charge, nbrs, num_atoms, offsets, # transition charges sum to 0 - total_charge = self.get_diabat_charge(key=key, charge=charge) + self.get_diabat_charge(key=key, charge=charge) mol_nbrs, _ = make_undirected(batch["mol_nbrs"]) elec_e, q, dip_atom, full_dip = elec_module( @@ -348,7 +345,7 @@ def add_phys(self, results, s_i, v_i, xyz, z, charge, nbrs, num_atoms, offsets, if key in electrostatics: suffix = "_" + key.split("_")[-1] - if not any([i.isdigit() for i in suffix]): + if not any(i.isdigit() for i in suffix): suffix = "" results.update({f"dipole{suffix}": full_dip, f"q{suffix}": q, f"dip_atom{suffix}": dip_atom}) diff --git a/nff/nn/modules/diabat.py b/nff/nn/modules/diabat.py index a59f1fdc..9ab85f5c 100644 --- a/nff/nn/modules/diabat.py +++ b/nff/nn/modules/diabat.py @@ -38,7 +38,7 @@ def __init__( def make_cross_talk(self, cross_talk_dic): if cross_talk_dic is None: - return + return None cross_talk = CrossTalk( diabat_keys=self.diabat_keys, @@ -237,13 +237,11 @@ def quants_to_eig(self, num_atoms, results, u): for j in range(num_states): if j < i: continue - if i == j: - key = f"{base_key}_{i}" - else: - key = f"trans_{base_key}_{i}{j}" + key = f"{base_key}_{i}" if i == j else f"trans_{base_key}_{i}{j}" results[key] = to_eig[..., i, j] return results + return None def add_adiabat_grads(self, xyz, results, inference, en_keys_for_grad): if en_keys_for_grad is None: @@ -284,7 +282,7 @@ def add_gap(self, results, add_grad): lower_grad_key = lower_key + "_grad" grad_keys = [upper_grad_key, lower_grad_key] - if not all([i in results for i in grad_keys]): + if not all(i in results for i in grad_keys): continue gap_grad = results[upper_grad_key] - results[lower_grad_key] @@ -348,10 +346,7 @@ def add_stochastic(self, results): return results def idx_to_grad_idx(self, num_atoms, nan_idx): - if isinstance(num_atoms, torch.Tensor): - atom_tens = num_atoms - else: - atom_tens = torch.LongTensor(num_atoms) + atom_tens = num_atoms if isinstance(num_atoms, torch.Tensor) else torch.LongTensor(num_atoms) end_idx = torch.cumsum(atom_tens, dim=0) start_idx = torch.cat([torch.tensor([0]).to(end_idx.device), end_idx[:-1]]) @@ -746,9 +741,7 @@ def forward(self, results): final_results[key] = pool_val - for key, val in results.items(): - if key not in combined_results: - final_results[key] = val + final_results.update(**{key: val for key, val in results.items() if key not in combined_results}) return final_results @@ -764,12 +757,11 @@ def __init__(self, output_keys, grad_keys, abs_name): def get_abs(self, abs_name): if abs_name == "abs": return abs - elif abs_name is None: + if abs_name is None: return lambda x: x - elif abs_name in layer_types: + if abs_name in layer_types: return layer_types[abs_name]() - else: - raise NotImplementedError + raise NotImplementedError def forward(self, results, xyz): ordered_keys = sorted(self.output_keys, key=lambda x: int(x.split("_")[-1])) diff --git a/nff/nn/modules/dimenet.py b/nff/nn/modules/dimenet.py index d33b09fb..b661677a 100644 --- a/nff/nn/modules/dimenet.py +++ b/nff/nn/modules/dimenet.py @@ -1,9 +1,9 @@ import torch from torch import nn -from nff.utils.scatter import scatter_add, compute_grad -from nff.utils.tools import layer_types from nff.nn.layers import Dense +from nff.utils.scatter import scatter_add +from nff.utils.tools import layer_types def get_dense(inp_dim, out_dim, activation, bias): @@ -23,7 +23,6 @@ def get_dense(inp_dim, out_dim, activation, bias): class EdgeEmbedding(nn.Module): - """ Class to create an edge embedding from edge features and node emebeddings. @@ -44,11 +43,7 @@ def __init__(self, embed_dim, n_rbf, activation): # 3 * embed_dim (one each for h_i, h_j, and e_ij) # and output dimension embed_dim. - self.dense = get_dense( - 3 * embed_dim, - embed_dim, - activation=activation, - bias=True) + self.dense = get_dense(3 * embed_dim, embed_dim, activation=activation, bias=True) def forward(self, h, e, nbr_list): """ @@ -113,15 +108,10 @@ def __init__(self, n_rbf, embed_dim, activation): # create a dense layer to convert the basis # representation of the distances into a vector # of size embed_dim - self.edge_dense = get_dense(n_rbf, - embed_dim, - activation=None, - bias=False) + self.edge_dense = get_dense(n_rbf, embed_dim, activation=None, bias=False) # make node and edge embedding layers self.node_embedding = NodeEmbedding(embed_dim) - self.edge_embedding = EdgeEmbedding(embed_dim, - n_rbf, - activation) + self.edge_embedding = EdgeEmbedding(embed_dim, n_rbf, activation) def forward(self, e_rbf, z, nbr_list): """ @@ -137,14 +127,12 @@ def forward(self, e_rbf, z, nbr_list): e = self.edge_dense(e_rbf) h = self.node_embedding(z) - m_ji = self.edge_embedding(h=h, - e=e, - nbr_list=nbr_list) + m_ji = self.edge_embedding(h=h, e=e, nbr_list=nbr_list) return m_ji class ResidualBlock(nn.Module): - """ Residual block """ + """Residual block""" def __init__(self, embed_dim, n_rbf, activation): """ @@ -159,11 +147,7 @@ def __init__(self, embed_dim, n_rbf, activation): super().__init__() # create dense layers self.dense_layers = nn.ModuleList( - [get_dense(embed_dim, - embed_dim, - activation=activation, - bias=True) - for _ in range(2)] + [get_dense(embed_dim, embed_dim, activation=activation, bias=True) for _ in range(2)] ) def forward(self, m_ji): @@ -184,19 +168,12 @@ def forward(self, m_ji): class DirectedMessage(nn.Module): - """ Module for passing directed messages based on distances and angles. """ - def __init__(self, - activation, - embed_dim, - n_rbf, - n_spher, - l_spher, - n_bilinear): + def __init__(self, activation, embed_dim, n_rbf, n_spher, l_spher, n_bilinear): """ Args: activation (str): name of activation layer @@ -217,38 +194,23 @@ def __init__(self, # dense layer to apply to m's that are in the # neighborhood of those in your neighborhood - self.m_kj_dense = get_dense(embed_dim, - embed_dim, - activation=activation, - bias=True) + self.m_kj_dense = get_dense(embed_dim, embed_dim, activation=activation, bias=True) # dense layer to apply to the rbf representation of # the distances - self.e_dense = get_dense(n_rbf, - embed_dim, - activation=None, - bias=False) + self.e_dense = get_dense(n_rbf, embed_dim, activation=None, bias=False) # dense layer to apply to the sbf representation of # the angles and distances - self.a_dense = get_dense(n_spher * l_spher, - n_bilinear, - activation=None, - bias=False) + self.a_dense = get_dense(n_spher * l_spher, n_bilinear, activation=None, bias=False) # matrix that is used to aggregate the distance # and angle information - self.w = nn.Parameter(torch.empty( - embed_dim, n_bilinear, embed_dim)) + self.w = nn.Parameter(torch.empty(embed_dim, n_bilinear, embed_dim)) nn.init.xavier_uniform_(self.w) - def forward(self, - m_ji, - e_rbf, - a_sbf, - kj_idx, - ji_idx): + def forward(self, m_ji, e_rbf, a_sbf, kj_idx, ji_idx): """ Args: m_ji (torch.Tensor): edge vector @@ -296,64 +258,31 @@ def forward(self, # use `scatter_add` with indices `ji_idx`, and give the resulting # vector the same dimension as m_ji. - out = scatter_add(aggr.transpose(0, 1), - ji_idx, - dim_size=m_ji.shape[0] - ).transpose(0, 1) + out = scatter_add(aggr.transpose(0, 1), ji_idx, dim_size=m_ji.shape[0]).transpose(0, 1) return out class DirectedMessagePP(nn.Module): - def __init__(self, - activation, - embed_dim, - n_rbf, - n_spher, - l_spher, - int_dim, - basis_emb_dim): - + def __init__(self, activation, embed_dim, n_rbf, n_spher, l_spher, int_dim, basis_emb_dim): super().__init__() - self.m_kj_dense = get_dense(embed_dim, - embed_dim, - activation=activation, - bias=True) - self.e_dense = nn.Sequential(get_dense(n_rbf, - basis_emb_dim, - activation=None, - bias=False), - get_dense(basis_emb_dim, - embed_dim, - activation=None, - bias=False)) - - self.a_dense = nn.Sequential(get_dense(n_spher * l_spher, - basis_emb_dim, - activation=None, - bias=False), - get_dense(basis_emb_dim, - int_dim, - activation=None, - bias=False)) - - self.down_conv = get_dense(embed_dim, - int_dim, - activation=activation, - bias=False) - - self.up_conv = get_dense(int_dim, - embed_dim, - activation=activation, - bias=False) - - def forward(self, - m_ji, - e_rbf, - a_sbf, - kj_idx, - ji_idx): + self.m_kj_dense = get_dense(embed_dim, embed_dim, activation=activation, bias=True) + self.e_dense = nn.Sequential( + get_dense(n_rbf, basis_emb_dim, activation=None, bias=False), + get_dense(basis_emb_dim, embed_dim, activation=None, bias=False), + ) + + self.a_dense = nn.Sequential( + get_dense(n_spher * l_spher, basis_emb_dim, activation=None, bias=False), + get_dense(basis_emb_dim, int_dim, activation=None, bias=False), + ) + + self.down_conv = get_dense(embed_dim, int_dim, activation=activation, bias=False) + + self.up_conv = get_dense(int_dim, embed_dim, activation=activation, bias=False) + + def forward(self, m_ji, e_rbf, a_sbf, kj_idx, ji_idx): """ Args: m_ji (torch.Tensor): edge vector @@ -376,10 +305,7 @@ def forward(self, edge_message = self.down_conv(m_kj * e_ji) aggr = edge_message * a - out = self.up_conv(scatter_add(aggr.transpose(0, 1), - ji_idx, - dim_size=m_ji.shape[0] - ).transpose(0, 1)) + out = self.up_conv(scatter_add(aggr.transpose(0, 1), ji_idx, dim_size=m_ji.shape[0]).transpose(0, 1)) return out @@ -389,16 +315,9 @@ class InteractionBlock(nn.Module): Block for aggregating distance and angle information """ - def __init__(self, - embed_dim, - n_rbf, - activation, - n_spher, - l_spher, - n_bilinear, - int_dim=None, - basis_emb_dim=None, - use_pp=False): + def __init__( + self, embed_dim, n_rbf, activation, n_spher, l_spher, n_bilinear, int_dim=None, basis_emb_dim=None, use_pp=False + ): """ Args: embed_dim (int): embedding size @@ -419,10 +338,7 @@ def __init__(self, # make the three residual blocks self.residual_blocks = nn.ModuleList( - [ResidualBlock( - embed_dim=embed_dim, - n_rbf=n_rbf, - activation=activation) for _ in range(3)] + [ResidualBlock(embed_dim=embed_dim, n_rbf=n_rbf, activation=activation) for _ in range(3)] ) # make a block for getting the directed messages @@ -435,7 +351,8 @@ def __init__(self, n_spher=n_spher, l_spher=l_spher, int_dim=int_dim, - basis_emb_dim=basis_emb_dim) + basis_emb_dim=basis_emb_dim, + ) else: self.directed_block = DirectedMessage( @@ -444,26 +361,16 @@ def __init__(self, n_rbf=n_rbf, n_spher=n_spher, l_spher=l_spher, - n_bilinear=n_bilinear) + n_bilinear=n_bilinear, + ) # dense layers for m_ji and for what comes after # the residual blocks - self.m_ji_dense = get_dense(embed_dim, - embed_dim, - activation=activation, - bias=True) - - self.post_res_dense = get_dense(embed_dim, - embed_dim, - activation=activation, - bias=True) - - def forward(self, - m_ji, - e_rbf, - a_sbf, - kj_idx, - ji_idx): + self.m_ji_dense = get_dense(embed_dim, embed_dim, activation=activation, bias=True) + + self.post_res_dense = get_dense(embed_dim, embed_dim, activation=activation, bias=True) + + def forward(self, m_ji, e_rbf, a_sbf, kj_idx, ji_idx): """ Args: m_ji (torch.Tensor): edge vector @@ -482,18 +389,13 @@ def forward(self, """ # get the directed message - directed_out = self.directed_block(m_ji=m_ji, - e_rbf=e_rbf, - a_sbf=a_sbf, - kj_idx=kj_idx, - ji_idx=ji_idx) + directed_out = self.directed_block(m_ji=m_ji, e_rbf=e_rbf, a_sbf=a_sbf, kj_idx=kj_idx, ji_idx=ji_idx) # put m_ji through dense layer and add to directed # message dense_m_ji = self.m_ji_dense(m_ji) output = directed_out + dense_m_ji # put through one dense layer and add back m_ji - output = self.post_res_dense( - self.residual_blocks[0](output)) + m_ji + output = self.post_res_dense(self.residual_blocks[0](output)) + m_ji # put through remaining dense layers for res_block in self.residual_blocks[1:]: output = res_block(output) @@ -506,12 +408,7 @@ class OutputBlock(nn.Module): Block to convert edge messages to atomic fingerprints """ - def __init__(self, - embed_dim, - n_rbf, - activation, - use_pp=False, - out_dim=None): + def __init__(self, embed_dim, n_rbf, activation, use_pp=False, out_dim=None): """ Args: embed_dim (int): embedding size @@ -524,32 +421,18 @@ def __init__(self, # dense layer to convert rbf edge representation # to dimension embed_dim - self.edge_dense = get_dense(n_rbf, - embed_dim, - activation=None, - bias=False) + self.edge_dense = get_dense(n_rbf, embed_dim, activation=None, bias=False) out_dense = [] if use_pp: - out_dense.append(get_dense(embed_dim, - out_dim, - activation=None, - bias=False)) + out_dense.append(get_dense(embed_dim, out_dim, activation=None, bias=False)) else: out_dim = embed_dim - out_dense += [get_dense(out_dim, - out_dim, - activation=activation, - bias=True) - for _ in range(3)] - out_dense.append(get_dense(out_dim, - out_dim, - activation=None, - bias=False)) + out_dense += [get_dense(out_dim, out_dim, activation=activation, bias=True) for _ in range(3)] + out_dense.append(get_dense(out_dim, out_dim, activation=None, bias=False)) self.out_dense = nn.Sequential(*out_dense) def forward(self, m_ji, e_rbf, nbr_list, num_atoms): - # product of e and m prod = self.edge_dense(e_rbf) * m_ji @@ -563,9 +446,7 @@ def forward(self, m_ji, e_rbf, nbr_list, num_atoms): # and the last two to index 0. This means we use # nbr_list[:, 1] in the scatter addition. - node_feats = scatter_add(prod.transpose(0, 1), - nbr_list[:, 1], - dim_size=num_atoms).transpose(0, 1) + node_feats = scatter_add(prod.transpose(0, 1), nbr_list[:, 1], dim_size=num_atoms).transpose(0, 1) # Apply the dense layers node_feats = self.out_dense(node_feats) diff --git a/nff/nn/modules/painn.py b/nff/nn/modules/painn.py index 8fa59783..d1298e47 100644 --- a/nff/nn/modules/painn.py +++ b/nff/nn/modules/painn.py @@ -1,19 +1,18 @@ import torch from torch import nn -from nff.utils.tools import layer_types -from nff.nn.layers import (PainnRadialBasis, CosineEnvelope, - ExpNormalBasis, Dense) -from nff.utils.scatter import scatter_add +from nff.nn.layers import CosineEnvelope, Dense, PainnRadialBasis from nff.nn.modules.schnet import ScaleShift -from nff.nn.modules.torchmd_net import MessageBlock as MDMessage from nff.nn.modules.torchmd_net import EmbeddingBlock as MDEmbedding +from nff.nn.modules.torchmd_net import MessageBlock as MDMessage +from nff.utils.scatter import scatter_add +from nff.utils.tools import layer_types EPS = 1e-15 def norm(vec): - result = ((vec ** 2 + EPS).sum(-1)) ** 0.5 + result = ((vec**2 + EPS).sum(-1)) ** 0.5 return result @@ -33,20 +32,12 @@ def to_module(activation): class InvariantDense(nn.Module): - def __init__(self, - dim, - dropout, - activation='swish'): + def __init__(self, dim, dropout, activation="swish"): super().__init__() - self.layers = nn.Sequential(Dense(in_features=dim, - out_features=dim, - bias=True, - dropout_rate=dropout, - activation=to_module(activation)), - Dense(in_features=dim, - out_features=3 * dim, - bias=True, - dropout_rate=dropout)) + self.layers = nn.Sequential( + Dense(in_features=dim, out_features=dim, bias=True, dropout_rate=dropout, activation=to_module(activation)), + Dense(in_features=dim, out_features=3 * dim, bias=True, dropout_rate=dropout), + ) def forward(self, s_j): output = self.layers(s_j) @@ -54,22 +45,11 @@ def forward(self, s_j): class DistanceEmbed(nn.Module): - def __init__(self, - n_rbf, - cutoff, - feat_dim, - learnable_k, - dropout): - + def __init__(self, n_rbf, cutoff, feat_dim, learnable_k, dropout): super().__init__() - rbf = PainnRadialBasis(n_rbf=n_rbf, - cutoff=cutoff, - learnable_k=learnable_k) - - dense = Dense(in_features=n_rbf, - out_features=3 * feat_dim, - bias=True, - dropout_rate=dropout) + rbf = PainnRadialBasis(n_rbf=n_rbf, cutoff=cutoff, learnable_k=learnable_k) + + dense = Dense(in_features=n_rbf, out_features=3 * feat_dim, bias=True, dropout_rate=dropout) self.block = nn.Sequential(rbf, dense) self.f_cut = CosineEnvelope(cutoff=cutoff) @@ -82,29 +62,15 @@ def forward(self, dist): class InvariantMessage(nn.Module): - def __init__(self, - feat_dim, - activation, - n_rbf, - cutoff, - learnable_k, - dropout): + def __init__(self, feat_dim, activation, n_rbf, cutoff, learnable_k, dropout): super().__init__() - self.inv_dense = InvariantDense(dim=feat_dim, - activation=activation, - dropout=dropout) - self.dist_embed = DistanceEmbed(n_rbf=n_rbf, - cutoff=cutoff, - feat_dim=feat_dim, - learnable_k=learnable_k, - dropout=dropout) - - def forward(self, - s_j, - dist, - nbrs): + self.inv_dense = InvariantDense(dim=feat_dim, activation=activation, dropout=dropout) + self.dist_embed = DistanceEmbed( + n_rbf=n_rbf, cutoff=cutoff, feat_dim=feat_dim, learnable_k=learnable_k, dropout=dropout + ) + def forward(self, s_j, dist, nbrs): phi = self.inv_dense(s_j)[nbrs[:, 1]] w_s = self.dist_embed(dist) output = phi * w_s @@ -119,17 +85,9 @@ def forward(self, class MessageBase(nn.Module): - - def forward(self, - s_j, - v_j, - r_ij, - nbrs): - + def forward(self, s_j, v_j, r_ij, nbrs): dist, unit = preprocess_r(r_ij) - inv_out = self.inv_message(s_j=s_j, - dist=dist, - nbrs=nbrs) + inv_out = self.inv_message(s_j=s_j, dist=dist, nbrs=nbrs) split_0 = inv_out[:, 0, :].unsqueeze(-1) split_1 = inv_out[:, 1, :] @@ -142,47 +100,28 @@ def forward(self, # add results from neighbors of each node graph_size = s_j.shape[0] - delta_v_i = scatter_add(src=delta_v_ij, - index=nbrs[:, 0], - dim=0, - dim_size=graph_size) + delta_v_i = scatter_add(src=delta_v_ij, index=nbrs[:, 0], dim=0, dim_size=graph_size) - delta_s_i = scatter_add(src=delta_s_ij, - index=nbrs[:, 0], - dim=0, - dim_size=graph_size) + delta_s_i = scatter_add(src=delta_s_ij, index=nbrs[:, 0], dim=0, dim_size=graph_size) return delta_s_i, delta_v_i class MessageBlock(MessageBase): - def __init__(self, - feat_dim, - activation, - n_rbf, - cutoff, - learnable_k, - dropout, - **kwargs): + def __init__(self, feat_dim, activation, n_rbf, cutoff, learnable_k, dropout, **kwargs): super().__init__() - self.inv_message = InvariantMessage(feat_dim=feat_dim, - activation=activation, - n_rbf=n_rbf, - cutoff=cutoff, - learnable_k=learnable_k, - dropout=dropout) - - def forward(self, - s_j, - v_j, - r_ij, - nbrs, - **kwargs): + self.inv_message = InvariantMessage( + feat_dim=feat_dim, + activation=activation, + n_rbf=n_rbf, + cutoff=cutoff, + learnable_k=learnable_k, + dropout=dropout, + ) + def forward(self, s_j, v_j, r_ij, nbrs, **kwargs): dist, unit = preprocess_r(r_ij) - inv_out = self.inv_message(s_j=s_j, - dist=dist, - nbrs=nbrs) + inv_out = self.inv_message(s_j=s_j, dist=dist, nbrs=nbrs) split_0 = inv_out[:, 0, :].unsqueeze(-1) split_1 = inv_out[:, 1, :] @@ -195,95 +134,54 @@ def forward(self, # add results from neighbors of each node graph_size = s_j.shape[0] - delta_v_i = scatter_add(src=delta_v_ij, - index=nbrs[:, 0], - dim=0, - dim_size=graph_size) + delta_v_i = scatter_add(src=delta_v_ij, index=nbrs[:, 0], dim=0, dim_size=graph_size) - delta_s_i = scatter_add(src=delta_s_ij, - index=nbrs[:, 0], - dim=0, - dim_size=graph_size) + delta_s_i = scatter_add(src=delta_s_ij, index=nbrs[:, 0], dim=0, dim_size=graph_size) return delta_s_i, delta_v_i class InvariantTransformerMessage(nn.Module): - def __init__(self, - rbf, - num_heads, - feat_dim, - activation, - layer_norm): - + def __init__(self, rbf, num_heads, feat_dim, activation, layer_norm): super().__init__() - self.msg_layer = MDMessage(feat_dim=feat_dim, - num_heads=num_heads, - activation=activation, - rbf=rbf) - - self.dense = Dense(in_features=(num_heads * feat_dim), - out_features=(3 * feat_dim), - bias=True, - activation=None) - self.layer_norm = nn.LayerNorm(feat_dim) if (layer_norm) else None + self.msg_layer = MDMessage(feat_dim=feat_dim, num_heads=num_heads, activation=activation, rbf=rbf) - def forward(self, - s_j, - dist, - nbrs): + self.dense = Dense(in_features=(num_heads * feat_dim), out_features=(3 * feat_dim), bias=True, activation=None) + self.layer_norm = nn.LayerNorm(feat_dim) if (layer_norm) else None + def forward(self, s_j, dist, nbrs): inp = self.layer_norm(s_j) if self.layer_norm else s_j - output = self.dense(self.msg_layer(dist=dist, - nbrs=nbrs, - x_i=inp)) + output = self.dense(self.msg_layer(dist=dist, nbrs=nbrs, x_i=inp)) out_reshape = output.reshape(output.shape[0], 3, -1) return out_reshape class TransformerMessageBlock(MessageBase): - def __init__(self, - rbf, - num_heads, - feat_dim, - activation, - layer_norm): + def __init__(self, rbf, num_heads, feat_dim, activation, layer_norm): super().__init__() self.inv_message = InvariantTransformerMessage( - rbf=rbf, - num_heads=num_heads, - feat_dim=feat_dim, - activation=activation, - layer_norm=layer_norm) + rbf=rbf, num_heads=num_heads, feat_dim=feat_dim, activation=activation, layer_norm=layer_norm + ) class UpdateBlock(nn.Module): - def __init__(self, - feat_dim, - activation, - dropout): + def __init__(self, feat_dim, activation, dropout): super().__init__() - self.u_mat = Dense(in_features=feat_dim, - out_features=feat_dim, - bias=False) - self.v_mat = Dense(in_features=feat_dim, - out_features=feat_dim, - bias=False) - self.s_dense = nn.Sequential(Dense(in_features=2*feat_dim, - out_features=feat_dim, - bias=True, - dropout_rate=dropout, - activation=to_module(activation)), - Dense(in_features=feat_dim, - out_features=3*feat_dim, - bias=True, - dropout_rate=dropout)) - - def forward(self, - s_i, - v_i): + self.u_mat = Dense(in_features=feat_dim, out_features=feat_dim, bias=False) + self.v_mat = Dense(in_features=feat_dim, out_features=feat_dim, bias=False) + self.s_dense = nn.Sequential( + Dense( + in_features=2 * feat_dim, + out_features=feat_dim, + bias=True, + dropout_rate=dropout, + activation=to_module(activation), + ), + Dense(in_features=feat_dim, out_features=3 * feat_dim, bias=True, dropout_rate=dropout), + ) + def forward(self, s_i, v_i): # v_i = (num_atoms, num_feats, 3) # v_i.transpose(1, 2).reshape(-1, v_i.shape[1]) # = (num_atoms, 3, num_feats).reshape(-1, num_feats) @@ -298,16 +196,13 @@ def forward(self, # to get (num_atoms, num_feats, 3) num_feats = v_i.shape[1] - u_v = (self.u_mat(v_tranpose).reshape(-1, 3, num_feats) - .transpose(1, 2)) - v_v = (self.v_mat(v_tranpose).reshape(-1, 3, num_feats) - .transpose(1, 2)) + u_v = self.u_mat(v_tranpose).reshape(-1, 3, num_feats).transpose(1, 2) + v_v = self.v_mat(v_tranpose).reshape(-1, 3, num_feats).transpose(1, 2) v_v_norm = norm(v_v) s_stack = torch.cat([s_i, v_v_norm], dim=-1) - split = (self.s_dense(s_stack) - .reshape(s_i.shape[0], 3, -1)) + split = self.s_dense(s_stack).reshape(s_i.shape[0], 3, -1) # delta v update a_vv = split[:, 0, :].unsqueeze(-1) @@ -324,96 +219,65 @@ def forward(self, class EmbeddingBlock(nn.Module): - def __init__(self, - feat_dim): - + def __init__(self, feat_dim): super().__init__() self.atom_embed = nn.Embedding(100, feat_dim, padding_idx=0) self.feat_dim = feat_dim - def forward(self, - z_number, - **kwargs): - + def forward(self, z_number, **kwargs): num_atoms = z_number.shape[0] s_i = self.atom_embed(z_number) - v_i = (torch.zeros(num_atoms, self.feat_dim, 3) - .to(s_i.device)) + v_i = torch.zeros(num_atoms, self.feat_dim, 3).to(s_i.device) return s_i, v_i class NbrEmbeddingBlock(nn.Module): - def __init__(self, - feat_dim, - dropout, - rbf): - + def __init__(self, feat_dim, dropout, rbf): super().__init__() - self.embedding = MDEmbedding(feat_dim=feat_dim, - dropout=dropout, - rbf=rbf) + self.embedding = MDEmbedding(feat_dim=feat_dim, dropout=dropout, rbf=rbf) self.feat_dim = feat_dim - def forward(self, - z_number, - nbrs, - r_ij): - + def forward(self, z_number, nbrs, r_ij): num_atoms = z_number.shape[0] dist, _ = preprocess_r(r_ij) - s_i = self.embedding(z_number=z_number, - nbrs=nbrs, - dist=dist) + s_i = self.embedding(z_number=z_number, nbrs=nbrs, dist=dist) - v_i = (torch.zeros(num_atoms, self.feat_dim, 3) - .to(s_i.device)) + v_i = torch.zeros(num_atoms, self.feat_dim, 3).to(s_i.device) return s_i, v_i class GatedEquivariantBlock(nn.Module): - def __init__(self, - feat_dim, - activation, - dropout_rate): - + def __init__(self, feat_dim, activation, dropout_rate): super().__init__() - self.W1 = Dense(in_features=feat_dim, - out_features=feat_dim, - bias=False) - self.W2 = Dense(in_features=feat_dim, - out_features=feat_dim, - bias=False) - self.s_dense = nn.Sequential(Dense(in_features=2*feat_dim, - out_features=feat_dim, - bias=True, - dropout_rate=dropout_rate, - activation=to_module(activation)), - Dense(in_features=feat_dim, - out_features=2*feat_dim, - bias=True, - dropout_rate=dropout_rate)) - - def forward(self, - sv_tuple): - + self.W1 = Dense(in_features=feat_dim, out_features=feat_dim, bias=False) + self.W2 = Dense(in_features=feat_dim, out_features=feat_dim, bias=False) + self.s_dense = nn.Sequential( + Dense( + in_features=2 * feat_dim, + out_features=feat_dim, + bias=True, + dropout_rate=dropout_rate, + activation=to_module(activation), + ), + Dense(in_features=feat_dim, out_features=2 * feat_dim, bias=True, dropout_rate=dropout_rate), + ) + + def forward(self, sv_tuple): s_i, v_i = sv_tuple v_tranpose = v_i.transpose(1, 2).reshape(-1, v_i.shape[1]) num_feats = v_i.shape[1] - - W1_v = (self.W1(v_tranpose).reshape(-1, 3, num_feats) - .transpose(1, 2)) - W2_v = (self.W2(v_tranpose).reshape(-1, 3, num_feats) - .transpose(1, 2)) + + W1_v = self.W1(v_tranpose).reshape(-1, 3, num_feats).transpose(1, 2) + W2_v = self.W2(v_tranpose).reshape(-1, 3, num_feats).transpose(1, 2) W2_v_norm = norm(W2_v) s_stack = torch.cat([s_i, W2_v_norm], dim=-1) - split = (self.s_dense(s_stack) - .reshape(s_i.shape[0], 2, -1)) + split = self.s_dense(s_stack).reshape(s_i.shape[0], 2, -1) # delta v update new_v = W1_v * split[:, 0, :].unsqueeze(-1) @@ -422,34 +286,29 @@ def forward(self, new_s = split[:, 1, :] return (new_s, new_v) - - + + class ReadoutBlock(nn.Module): - def __init__(self, - feat_dim, - output_keys, - activation, - dropout, - means=None, - stddevs=None): + def __init__(self, feat_dim, output_keys, activation, dropout, means=None, stddevs=None): super().__init__() self.readoutdict = nn.ModuleDict( - {key: nn.Sequential( - Dense(in_features=feat_dim, - out_features=feat_dim//2, - bias=True, - dropout_rate=dropout, - activation=to_module(activation)), - Dense(in_features=feat_dim//2, - out_features=1, - bias=True, - dropout_rate=dropout)) - for key in output_keys} + { + key: nn.Sequential( + Dense( + in_features=feat_dim, + out_features=feat_dim // 2, + bias=True, + dropout_rate=dropout, + activation=to_module(activation), + ), + Dense(in_features=feat_dim // 2, out_features=1, bias=True, dropout_rate=dropout), + ) + for key in output_keys + } ) - self.scale_shift = ScaleShift(means=means, - stddevs=stddevs) + self.scale_shift = ScaleShift(means=means, stddevs=stddevs) def forward(self, s_i): """ @@ -464,46 +323,50 @@ def forward(self, s_i): results[key] = output return results - + class ReadoutBlock_Tuple(nn.Module): - def __init__(self, - feat_dim, - output_keys, - activation, - dropout, - means=None, - stddevs=None): + def __init__(self, feat_dim, output_keys, activation, dropout, means=None, stddevs=None): super().__init__() - + self.output_keys = output_keys self.readoutdict = nn.ModuleDict( - {key_tuple: nn.Sequential( - Dense(in_features=feat_dim, - out_features=feat_dim//2, - bias=True, - dropout_rate=dropout, - activation=to_module(activation)), - Dense(in_features=feat_dim//2, - out_features=feat_dim//4, - bias=True, - dropout_rate=dropout, - activation=to_module(activation)), - Dense(in_features=feat_dim//4, - out_features=feat_dim//8, - bias=True, - dropout_rate=dropout, - activation=to_module(activation)), - Dense(in_features=feat_dim//8, - out_features=len(key_tuple.split("+")), - bias=True, - dropout_rate=dropout)) - for key_tuple in output_keys} + { + key_tuple: nn.Sequential( + Dense( + in_features=feat_dim, + out_features=feat_dim // 2, + bias=True, + dropout_rate=dropout, + activation=to_module(activation), + ), + Dense( + in_features=feat_dim // 2, + out_features=feat_dim // 4, + bias=True, + dropout_rate=dropout, + activation=to_module(activation), + ), + Dense( + in_features=feat_dim // 4, + out_features=feat_dim // 8, + bias=True, + dropout_rate=dropout, + activation=to_module(activation), + ), + Dense( + in_features=feat_dim // 8, + out_features=len(key_tuple.split("+")), + bias=True, + dropout_rate=dropout, + ), + ) + for key_tuple in output_keys + } ) - self.scale_shift = ScaleShift(means=means, - stddevs=stddevs) + self.scale_shift = ScaleShift(means=means, stddevs=stddevs) def forward(self, s_i): """ @@ -514,7 +377,7 @@ def forward(self, s_i): for keys, readoutdict in self.readoutdict.items(): outputs = readoutdict(s_i) - for ii, key in enumerate(keys.split('+')): + for ii, key in enumerate(keys.split("+")): output = self.scale_shift(outputs[..., ii], key) results[key] = output @@ -522,27 +385,23 @@ def forward(self, s_i): class ReadoutBlock_Complex(nn.Module): - def __init__(self, - feat_dim, - output_keys, - activation, - dropout, - means=None, - stddevs=None): + def __init__(self, feat_dim, output_keys, activation, dropout, means=None, stddevs=None): super().__init__() self.readoutdict = nn.ModuleDict( - {key: nn.Sequential( - Dense(in_features=feat_dim, - out_features=feat_dim//2, - bias=True, - dropout_rate=dropout, - activation=to_module(activation)), - Dense(in_features=feat_dim//2, - out_features=2, - bias=True, - dropout_rate=dropout)) - for key in output_keys} + { + key: nn.Sequential( + Dense( + in_features=feat_dim, + out_features=feat_dim // 2, + bias=True, + dropout_rate=dropout, + activation=to_module(activation), + ), + Dense(in_features=feat_dim // 2, out_features=2, bias=True, dropout_rate=dropout), + ) + for key in output_keys + } ) def forward(self, s_i): @@ -560,56 +419,56 @@ def forward(self, s_i): class ReadoutBlock_Vec(nn.Module): - def __init__(self, - feat_dim, - # out_dims, # right now we can only get Natomsx3 but what if we just want 1x3? - output_keys, - activation, - dropout, - means=None, - stddevs=None): + def __init__( + self, + feat_dim, + # out_dims, # right now we can only get Natomsx3 but what if we just want 1x3? + output_keys, + activation, + dropout, + means=None, + stddevs=None, + ): super().__init__() self.umat_dict = nn.ModuleDict( - {key: Dense(in_features=feat_dim, - out_features=feat_dim, - bias=False) - for key in output_keys} + {key: Dense(in_features=feat_dim, out_features=feat_dim, bias=False) for key in output_keys} ) self.vmat_dict = nn.ModuleDict( - {key: Dense(in_features=feat_dim, - out_features=feat_dim, - bias=False) - for key in output_keys} + {key: Dense(in_features=feat_dim, out_features=feat_dim, bias=False) for key in output_keys} ) self.sdense_dict = nn.ModuleDict( - {key: nn.Sequential(Dense(in_features=2*feat_dim, - out_features=feat_dim, - bias=True, - dropout_rate=dropout, - activation=to_module(activation)), - Dense(in_features=feat_dim, - out_features=feat_dim, - bias=True, - dropout_rate=dropout)) - for key in output_keys} + { + key: nn.Sequential( + Dense( + in_features=2 * feat_dim, + out_features=feat_dim, + bias=True, + dropout_rate=dropout, + activation=to_module(activation), + ), + Dense(in_features=feat_dim, out_features=feat_dim, bias=True, dropout_rate=dropout), + ) + for key in output_keys + } ) # figure out how to do the collapsing self.readoutdict = nn.ModuleDict( - {key: - Dense(in_features=feat_dim, - out_features=1, - bias=False, - dropout_rate=dropout,) - for key in output_keys} + { + key: Dense( + in_features=feat_dim, + out_features=1, + bias=False, + dropout_rate=dropout, + ) + for key in output_keys + } ) - def forward(self, - s_i, - v_i): + def forward(self, s_i, v_i): """ Note: no atomwise summation. That's done in the model itself """ @@ -629,11 +488,8 @@ def forward(self, v_tranpose = v_i.transpose(1, 2).reshape(-1, num_feats) for key, readoutdict in self.readoutdict.items(): - - u_v = (self.umat_dict[key](v_tranpose).reshape(-1, 3, num_feats) - .transpose(1, 2)) - v_v = (self.vmat_dict[key](v_tranpose).reshape(-1, 3, num_feats) - .transpose(1, 2)) + u_v = self.umat_dict[key](v_tranpose).reshape(-1, 3, num_feats).transpose(1, 2) + v_v = self.vmat_dict[key](v_tranpose).reshape(-1, 3, num_feats).transpose(1, 2) # now reshape it to (num_atoms, 3, num_feats) and transpose # to get (num_atoms, num_feats, 3) @@ -641,13 +497,12 @@ def forward(self, v_v_norm = norm(v_v) s_stack = torch.cat([s_i, v_v_norm], dim=-1) - a_vv = self.sdense_dict[key](s_stack).reshape( - s_i.shape[0], -1).unsqueeze(-1) + a_vv = self.sdense_dict[key](s_stack).reshape(s_i.shape[0], -1).unsqueeze(-1) new_v_i = u_v * a_vv # (num_atoms, num_feats, 3) new_v_i = new_v_i.transpose(1, 2) # (num_atoms, 3, num_feats) - output = readoutdict(new_v_i).sum(dim=2) # (num_atoms, 3, 1) -> (num_atoms, 3) + output = readoutdict(new_v_i).sum(dim=2) # (num_atoms, 3, 1) -> (num_atoms, 3) results[key] = output return results @@ -656,39 +511,42 @@ def forward(self, class ReadoutBlock_Vec2(nn.Module): # this does not use part of the update block but n gated equivariant blocks # as shown in the original PaiNN paper in Fig. 3 - def __init__(self, - feat_dim, - # out_dims, # right now we can only get Natomsx3 but what if we just want 1x3? - output_keys, - activation, - dropout, - means=None, - stddevs=None): + def __init__( + self, + feat_dim, + # out_dims, # right now we can only get Natomsx3 but what if we just want 1x3? + output_keys, + activation, + dropout, + means=None, + stddevs=None, + ): super().__init__() self.gated_dict = nn.ModuleDict( - {key: nn.Sequential(GatedEquivariantBlock(feat_dim=feat_dim, - dropout_rate=dropout, - activation=activation), - GatedEquivariantBlock(feat_dim=feat_dim, - dropout_rate=dropout, - activation=activation)) - for key in output_keys} + { + key: nn.Sequential( + GatedEquivariantBlock(feat_dim=feat_dim, dropout_rate=dropout, activation=activation), + GatedEquivariantBlock(feat_dim=feat_dim, dropout_rate=dropout, activation=activation), + ) + for key in output_keys + } ) # figure out how to do the collapsing self.readoutdict = nn.ModuleDict( - {key: - Dense(in_features=feat_dim, - out_features=1, - bias=False, - dropout_rate=dropout,) - for key in output_keys} + { + key: Dense( + in_features=feat_dim, + out_features=1, + bias=False, + dropout_rate=dropout, + ) + for key in output_keys + } ) - def forward(self, - s_i, - v_i): + def forward(self, s_i, v_i): """ Note: no atomwise summation. That's done in the model itself """ @@ -696,10 +554,9 @@ def forward(self, results = {} for key, readoutdict in self.readoutdict.items(): - new_s_i, new_v_i = self.gated_dict[key]((s_i, v_i)) new_v_i = new_v_i.transpose(1, 2) # (num_atoms, 3, num_feats) - output = readoutdict(new_v_i).sum(dim=2) # (num_atoms, 3, 1) -> (num_atoms, 3) + output = readoutdict(new_v_i).sum(dim=2) # (num_atoms, 3, 1) -> (num_atoms, 3) results[key] = output return results @@ -755,8 +612,8 @@ def forward(self, # new_s_i, new_v_i = self.gated_dict[key]((s_i, v_i)) # new_v_i = new_v_i.transpose(1, 2) # (num_atoms, 3, num_feats) # nu = readoutdict(new_v_i).sum(dim=2) # (num_atoms, 3, 1) -> (num_atoms, 3) -# output = (torch.outer(r_i.reshape(-1), nu.reshape(-1)) +# output = (torch.outer(r_i.reshape(-1), nu.reshape(-1)) # + torch.outer(nu.reshape(-1), r_i.reshape(-1))) # results[key] = output -# return results \ No newline at end of file +# return results diff --git a/nff/nn/modules/schnet.py b/nff/nn/modules/schnet.py index fc5c7127..71d5877c 100644 --- a/nff/nn/modules/schnet.py +++ b/nff/nn/modules/schnet.py @@ -14,7 +14,6 @@ from nff.nn.layers import Dense, GaussianSmearing # for backwards compatability -from nff.nn.modules.diabat import DiabaticReadout from nff.nn.utils import ( chemprop_msg_to_node, chemprop_msg_update, @@ -48,11 +47,9 @@ def get_rij(xyz, batch, nbrs, cutoff): # to catch atoms that become neighbors between nbr # list updates) dist = (r_ij.detach() ** 2).sum(-1) ** 0.5 - - if type(cutoff) == torch.Tensor: + if isinstance(cutoff, torch.Tensor): dist = dist.to(cutoff.device) use_nbrs = dist <= cutoff - r_ij = r_ij[use_nbrs] nbrs = nbrs[use_nbrs] @@ -77,10 +74,10 @@ def add_stress(batch, all_results, nbrs, r_ij): if batch["num_atoms"].shape[0] == 1: all_results["stress_volume"] = torch.matmul(Z.t(), r_ij) else: - allstress = [] - for j in range(batch["nxyz"].shape[0]): - allstress.append(torch.matmul(Z[torch.where(nbrs[:, 0] == j)].t(), r_ij[torch.where(nbrs[:, 0] == j)])) - allstress = torch.stack(allstress) + allstress = torch.stack([ + torch.matmul(Z[torch.where(nbrs[:, 0] == j)].t(), r_ij[torch.where(nbrs[:, 0] == j)]) + for j in range(batch["nxyz"].shape[0]) + ]) N = batch["num_atoms"].detach().cpu().tolist() split_val = torch.split(allstress, N) all_results["stress_volume"] = torch.stack([i.sum(0) for i in split_val]) @@ -96,7 +93,7 @@ class SchNetEdgeUpdate(EdgeUpdateModule): """ def __init__(self, n_atom_basis): - super(SchNetEdgeUpdate, self).__init__() + super().__init__() self.mlp = Sequential( Linear(2 * n_atom_basis, n_atom_basis), @@ -116,7 +113,7 @@ def update(self, e): class SchNetEdgeFilter(nn.Module): def __init__(self, cutoff, n_gaussians, trainable_gauss, n_filters, dropout_rate, activation="shifted_softplus"): - super(SchNetEdgeFilter, self).__init__() + super().__init__() self.filter = Sequential( GaussianSmearing( @@ -158,7 +155,7 @@ def __init__( trainable_gauss, dropout_rate, ): - super(SchNetConv, self).__init__() + super().__init__() self.moduledict = ModuleDict( { "message_edge_filter": Sequential( @@ -238,7 +235,7 @@ class GraphAttention(MessagePassingModule): """ def __init__(self, n_atom_basis): - super(GraphAttention, self).__init__() + super().__init__() self.weight = torch.nn.Parameter(torch.rand(1, 2 * n_atom_basis)) self.activation = LeakyReLU() @@ -268,7 +265,7 @@ def message(self, r, e, a): ) a_ij = weight_ij / normalization[a[:, 0]] # the importance of node j’s features to node i - a_ji = weight_ji / normalization[a[:, 1]] # the importance of node i’s features to node j + weight_ji / normalization[a[:, 1]] # the importance of node i’s features to node j a_ii = weight_ii / normalization # self-attention message = ( @@ -330,7 +327,7 @@ def __init__(self, multitaskdict, post_readout=None): Args: multitaskdict (dict): dictionary that contains model information """ - super(NodeMultiTaskReadOut, self).__init__() + super().__init__() # construct moduledict self.readout = construct_module_dict(multitaskdict) self.post_readout = post_readout @@ -409,7 +406,7 @@ def __init__(self, n_atom_hidden, n_filters, dropout_rate, n_bond_hidden, activa Returns: None """ - super(MixedSchNetConv, self).__init__() + super().__init__() self.moduledict = ModuleDict( { # convert the atom features to the dimension @@ -490,7 +487,7 @@ def __init__(self, mol_basis, boltz_basis, final_act, equal_weights=False, prob_ None """ - super(ConfAttention, self).__init__() + super().__init__() """ Xavier initializations from @@ -625,7 +622,7 @@ def __init__(self, mol_basis, boltz_basis, final_act, equal_weights=False, prob_ None """ - super(LinearConfAttention, self).__init__(mol_basis, boltz_basis, final_act, equal_weights) + super().__init__(mol_basis, boltz_basis, final_act, equal_weights) # has dimension mol_basis instead of 2 * mol_basis because we're not # comparing fingerprint pairs @@ -1017,7 +1014,7 @@ def sum_and_grad(batch, xyz, r_ij, nbrs, atomwise_output, grad_keys, out_keys=No use_val = val.sum(-1) else: - raise Exception(("Don't know how to handle val shape " "{} for key {}".format(val.shape, key))) + raise Exception("Don't know how to handle val shape " f"{val.shape} for key {key}") pooled_result = scatter_add(use_val, mol_idx, dim_size=dim_size) if mean: @@ -1033,17 +1030,15 @@ def sum_and_grad(batch, xyz, r_ij, nbrs, atomwise_output, grad_keys, out_keys=No if key == "stress": output = results["energy"] grad_ = compute_grad(output=output, inputs=r_ij) - allstress = [] - for i in range(batch["nxyz"].shape[0]): - allstress.append( - torch.matmul(grad_[torch.where(nbrs[:, 0] == i)].t(), r_ij[torch.where(nbrs[:, 0] == i)]) - ) - allstress = torch.stack(allstress) + allstress = torch.stack([ + torch.matmul(grad_[torch.where(nbrs[:, 0] == i)].t(), r_ij[torch.where(nbrs[:, 0] == i)]) + for i in range(batch["nxyz"].shape[0]) + ]) split_val = torch.split(allstress, N) grad_ = torch.stack([i.sum(0) for i in split_val]) - if "cell" in batch.keys(): + if "cell" in batch: cell = torch.stack(torch.split(batch["cell"], 3, dim=0)) - elif "lattice" in batch.keys(): + elif "lattice" in batch: cell = torch.stack(torch.split(batch["lattice"], 3, dim=0)) volume = torch.Tensor(np.abs(np.linalg.det(cell.cpu().numpy()))).to(grad_.get_device()) grad = grad_ * (1 / volume[:, None, None]) @@ -1253,7 +1248,7 @@ class ScaleShift(nn.Module): """ def __init__(self, means=None, stddevs=None): - super(ScaleShift, self).__init__() + super().__init__() means = means if (means is not None) else {} stddevs = stddevs if (stddevs is not None) else {} @@ -1288,11 +1283,7 @@ def testBaseEdgeUpdate(self): r_in = torch.rand(6, 10) model = MessagePassingModule() r_out = model(r_in, e, a) - self.assertEqual( - r_in.shape, - r_out.shape, - "The node feature dimensions should be same for the base case", - ) + assert r_in.shape == r_out.shape, "The node feature dimensions should be same for the base case" def testBaseMessagePassing(self): # initialize basic graphs @@ -1301,11 +1292,7 @@ def testBaseMessagePassing(self): r = torch.rand(6, 10) model = EdgeUpdateModule() e_out = model(r, e_in, a) - self.assertEqual( - e_in.shape, - e_out.shape, - "The edge feature dimensions should be same for the base case", - ) + assert e_in.shape == e_out.shape, "The edge feature dimensions should be same for the base case" def testSchNetMPNN(self): # contruct a graph @@ -1315,7 +1302,6 @@ def testSchNetMPNN(self): n_filters = 10 n_gaussians = 10 num_nodes = 6 - cutoff = 0.5 e = torch.rand(5, n_atom_basis) r_in = torch.rand(num_nodes, n_atom_basis) @@ -1331,7 +1317,7 @@ def testSchNetMPNN(self): r_out = model(r_in, e, a) - self.assertEqual(r_in.shape, r_out.shape, "The node feature dimensions should be same.") + assert r_in.shape == r_out.shape, "The node feature dimensions should be same." """ Deprecated @@ -1409,11 +1395,7 @@ def testSchNetEdgeUpdate(self): model = SchNetEdgeUpdate(n_atom_basis=n_atom_basis) e_out = model(r, e_in, a) - self.assertEqual( - e_in.shape, - e_out.shape, - ("The edge feature dimensions should be same for the SchNet " "Edge Update case"), - ) + assert e_in.shape == e_out.shape, "The edge feature dimensions should be same for the SchNet Edge Update case" def testGAT(self): n_atom_basis = 10 @@ -1426,7 +1408,7 @@ def testGAT(self): r_out = attention(r_in, e, a) - self.assertEqual(r_out.shape, r_in.shape) + assert r_out.shape == r_in.shape def testmultitask(self): n_atom = 10 @@ -1449,7 +1431,7 @@ def testmultitask(self): } model = NodeMultiTaskReadOut(multitaskdict) - output = model(r) + model(r) if __name__ == "__main__": diff --git a/nff/nn/modules/spooky.py b/nff/nn/modules/spooky.py index 5889ed04..c86e21bd 100644 --- a/nff/nn/modules/spooky.py +++ b/nff/nn/modules/spooky.py @@ -1,3 +1,5 @@ + +# ruff: noqa: E741 import time import torch @@ -55,7 +57,7 @@ def get_elec_config(max_z): # in ELEC_CONFIG elec_config = torch.ones(max_z + 1, 20) * float("nan") for z, val in ELEC_CONFIG.items(): - elec_config[z] = torch.Tensor([z] + val) / max_z_config + elec_config[z] = torch.Tensor([z, *val]) / max_z_config return elec_config @@ -101,14 +103,8 @@ def scatter_mol(atomwise, num_atoms): because it takes a very long time to make the indices that map atom index to molecule. """ - - out = [] atom_split = torch.split(atomwise, num_atoms.tolist()) - for split in atom_split: - out.append(split.sum(0)) - out = torch.stack(out) - - return out + return torch.stack([split.sum(0) for split in atom_split]) def scatter_pairwise(pairwise, num_atoms, nbrs): diff --git a/nff/nn/modules/spooky_fast.py b/nff/nn/modules/spooky_fast.py index dfa160d8..ac9634fd 100644 --- a/nff/nn/modules/spooky_fast.py +++ b/nff/nn/modules/spooky_fast.py @@ -1,16 +1,18 @@ + +# ruff: noqa: E741 +import time + import torch from torch import nn -from nff.nn.layers import (PreActivation, Dense, zeros_initializer, - BatchedPreActivation) -from nff.utils.tools import layer_types, make_undirected -from nff.utils.scatter import scatter_add -from nff.utils.constants import ELEC_CONFIG, KE_KCAL, BOHR_RADIUS -from nff.utils import spooky_f_cut, make_y_lm, rho_k -import time +from nff.nn.layers import BatchedPreActivation, Dense, PreActivation, zeros_initializer +from nff.utils import make_y_lm, rho_k, spooky_f_cut +from nff.utils.constants import BOHR_RADIUS, ELEC_CONFIG, KE_KCAL +from nff.utils.scatter import scatter_add +from nff.utils.tools import layer_types, make_undirected EPS = 1e-15 -DEFAULT_ACTIVATION = 'learnable_swish' +DEFAULT_ACTIVATION = "learnable_swish" DEFAULT_MAX_Z = 86 DEFAULT_DROPOUT = 0 DEFAULT_RES_LAYERS = 2 @@ -48,7 +50,7 @@ def norm(vec): For stable norm calculation. PyTorch's implementation can be unstable """ - result = ((vec ** 2 + EPS).sum(-1)) ** 0.5 + result = ((vec**2 + EPS).sum(-1)) ** 0.5 return result @@ -56,9 +58,9 @@ def get_elec_config(max_z): max_z_config = torch.Tensor([max_z] + ELEC_CONFIG[max_z]) # nan ensures we get nan results for any elements not # in ELEC_CONFIG - elec_config = torch.ones(max_z + 1, 20) * float('nan') + elec_config = torch.ones(max_z + 1, 20) * float("nan") for z, val in ELEC_CONFIG.items(): - elec_config[z] = torch.Tensor([z] + val) / max_z_config + elec_config[z] = torch.Tensor([z, *val]) / max_z_config return elec_config @@ -96,30 +98,22 @@ def get_elec_config(max_z): # return out -def scatter_mol(atomwise, - num_atoms): + +def scatter_mol(atomwise, num_atoms): """ Add atomic contributions in a batch to their respective geometries. A simple sum is much faster than doing scatter_add - because it takes a very long time to make the indices that + because it takes a very long time to make the indices that map atom index to molecule. """ - - out = [] atom_split = torch.split(atomwise, num_atoms.tolist()) - for split in atom_split: - out.append(split.sum(0)) - out = torch.stack(out) - - return out + return torch.stack([split.sum(0) for split in atom_split]) -def scatter_pairwise(pairwise, - num_atoms, - nbrs): +def scatter_pairwise(pairwise, num_atoms, nbrs): """ Add pair-wise contributions in a batch to their respective - geometries + geometries """ # mol_idx = [] @@ -136,41 +130,35 @@ def scatter_pairwise(pairwise, for i, num in enumerate(num_atoms): mol_idx += [i] * int(num) - mol_idx = (torch.LongTensor(mol_idx) - .to(pairwise.device)) + mol_idx = torch.LongTensor(mol_idx).to(pairwise.device) nbr_to_mol = mol_idx[nbrs[:, 0]] - out = scatter_add(src=pairwise, - index=nbr_to_mol, - dim=0, - dim_size=len(num_atoms)) + out = scatter_add(src=pairwise, index=nbr_to_mol, dim=0, dim_size=len(num_atoms)) return out class Residual(nn.Module): - def __init__(self, - feat_dim, - activation=DEFAULT_ACTIVATION, - dropout=DEFAULT_DROPOUT, - num_layers=DEFAULT_RES_LAYERS, - bias=True): - + def __init__( + self, feat_dim, activation=DEFAULT_ACTIVATION, dropout=DEFAULT_DROPOUT, num_layers=DEFAULT_RES_LAYERS, bias=True + ): super().__init__() block = [ - PreActivation(in_features=feat_dim, - out_features=feat_dim, - activation=activation, - dropout_rate=dropout, - bias=bias) + PreActivation( + in_features=feat_dim, out_features=feat_dim, activation=activation, dropout_rate=dropout, bias=bias + ) for _ in range(num_layers - 1) ] - block.append(PreActivation(in_features=feat_dim, - out_features=feat_dim, - activation=activation, - dropout_rate=dropout, - bias=bias, - weight_init=zeros_initializer)) + block.append( + PreActivation( + in_features=feat_dim, + out_features=feat_dim, + activation=activation, + dropout_rate=dropout, + bias=bias, + weight_init=zeros_initializer, + ) + ) self.block = nn.Sequential(*block) def forward(self, x): @@ -179,24 +167,21 @@ def forward(self, x): class ResidualMLP(nn.Module): - def __init__(self, - feat_dim, - activation=DEFAULT_ACTIVATION, - dropout=DEFAULT_DROPOUT, - bias=True, - residual_layers=DEFAULT_RES_LAYERS): - + def __init__( + self, + feat_dim, + activation=DEFAULT_ACTIVATION, + dropout=DEFAULT_DROPOUT, + bias=True, + residual_layers=DEFAULT_RES_LAYERS, + ): super().__init__() - residual = Residual(feat_dim=feat_dim, - activation=activation, - dropout=dropout, - bias=bias, - num_layers=residual_layers) - self.block = nn.Sequential(residual, - layer_types[activation](), - Dense(in_features=feat_dim, - out_features=feat_dim, - bias=bias)) + residual = Residual( + feat_dim=feat_dim, activation=activation, dropout=dropout, bias=bias, num_layers=residual_layers + ) + self.block = nn.Sequential( + residual, layer_types[activation](), Dense(in_features=feat_dim, out_features=feat_dim, bias=bias) + ) def forward(self, x): output = self.block(x) @@ -204,31 +189,38 @@ def forward(self, x): class BatchedResidual(nn.Module): - def __init__(self, - feat_dim, - num_out, - activation=DEFAULT_ACTIVATION, - dropout=DEFAULT_DROPOUT, - num_layers=DEFAULT_RES_LAYERS, - bias=True): - + def __init__( + self, + feat_dim, + num_out, + activation=DEFAULT_ACTIVATION, + dropout=DEFAULT_DROPOUT, + num_layers=DEFAULT_RES_LAYERS, + bias=True, + ): super().__init__() block = [ - BatchedPreActivation(in_features=feat_dim, - out_features=feat_dim, - num_out=num_out, - activation=activation, - dropout_rate=dropout, - bias=bias) + BatchedPreActivation( + in_features=feat_dim, + out_features=feat_dim, + num_out=num_out, + activation=activation, + dropout_rate=dropout, + bias=bias, + ) for _ in range(num_layers - 1) ] - block.append(BatchedPreActivation(in_features=feat_dim, - out_features=feat_dim, - num_out=num_out, - activation=activation, - dropout_rate=dropout, - bias=bias, - weight_init=zeros_initializer)) + block.append( + BatchedPreActivation( + in_features=feat_dim, + out_features=feat_dim, + num_out=num_out, + activation=activation, + dropout_rate=dropout, + bias=bias, + weight_init=zeros_initializer, + ) + ) self.block = nn.Sequential(*block) self.num_out = num_out @@ -238,32 +230,35 @@ def forward(self, x): class BatchedResidualMLP(nn.Module): - def __init__(self, - feat_dim, - num_out, - activation=DEFAULT_ACTIVATION, - dropout=DEFAULT_DROPOUT, - bias=True, - residual_layers=DEFAULT_RES_LAYERS): - + def __init__( + self, + feat_dim, + num_out, + activation=DEFAULT_ACTIVATION, + dropout=DEFAULT_DROPOUT, + bias=True, + residual_layers=DEFAULT_RES_LAYERS, + ): super().__init__() - residual = BatchedResidual(feat_dim=feat_dim, - num_out=num_out, - activation=activation, - dropout=dropout, - bias=bias, - num_layers=residual_layers) + residual = BatchedResidual( + feat_dim=feat_dim, + num_out=num_out, + activation=activation, + dropout=dropout, + bias=bias, + num_layers=residual_layers, + ) nonlinear = layer_types[activation]() - linear = BatchedPreActivation(in_features=feat_dim, - out_features=feat_dim, - num_out=num_out, - activation=None, - dropout_rate=dropout, - bias=bias) - - self.block = nn.Sequential(residual, - nonlinear, - linear) + linear = BatchedPreActivation( + in_features=feat_dim, + out_features=feat_dim, + num_out=num_out, + activation=None, + dropout_rate=dropout, + bias=bias, + ) + + self.block = nn.Sequential(residual, nonlinear, linear) self.num_out = num_out def forward(self, x): @@ -272,17 +267,12 @@ def forward(self, x): class NuclearEmbedding(nn.Module): - def __init__(self, - feat_dim, - max_z=DEFAULT_MAX_Z): - + def __init__(self, feat_dim, max_z=DEFAULT_MAX_Z): super().__init__() self.elec_config = get_elec_config(max_z) - self.m_mat = Dense(in_features=20, - out_features=feat_dim, - bias=False, - activation=None, - weight_init=zeros_initializer) + self.m_mat = Dense( + in_features=20, out_features=feat_dim, bias=False, activation=None, weight_init=zeros_initializer + ) self.z_embed = nn.Embedding(max_z, feat_dim, padding_idx=0) def forward(self, z): @@ -294,33 +284,21 @@ def forward(self, z): class ElectronicEmbedding(nn.Module): - def __init__(self, - feat_dim, - activation=DEFAULT_ACTIVATION, - dropout=DEFAULT_DROPOUT, - residual_layers=DEFAULT_RES_LAYERS): + def __init__( + self, feat_dim, activation=DEFAULT_ACTIVATION, dropout=DEFAULT_DROPOUT, residual_layers=DEFAULT_RES_LAYERS + ): super().__init__() - self.linear = Dense(in_features=feat_dim, - out_features=feat_dim, - bias=True, - activation=None) + self.linear = Dense(in_features=feat_dim, out_features=feat_dim, bias=True, activation=None) self.feat_dim = feat_dim - self.resmlp = ResidualMLP(feat_dim=feat_dim, - activation=activation, - dropout=dropout, - bias=False, - residual_layers=residual_layers) - names = ['k_plus', 'k_minus', 'v_plus', 'v_minus'] + self.resmlp = ResidualMLP( + feat_dim=feat_dim, activation=activation, dropout=dropout, bias=False, residual_layers=residual_layers + ) + names = ["k_plus", "k_minus", "v_plus", "v_minus"] for name in names: - val = nn.Parameter(torch.zeros(feat_dim, 1, - dtype=torch.float32)) + val = nn.Parameter(torch.zeros(feat_dim, 1, dtype=torch.float32)) setattr(self, name, val) - def forward(self, - psi, - e_z, - num_atoms): - + def forward(self, psi, e_z, num_atoms): q = self.linear(e_z) split_qs = torch.split(q, num_atoms.tolist()) e_psi = torch.zeros_like(e_z) @@ -332,8 +310,7 @@ def forward(self, k = self.k_plus if (mol_psi >= 0) else self.k_minus # mol_q has dimension atoms_in_mol x F # k has dimension F x 1 - arg = (torch.einsum('ij, jk -> i', mol_q, k) - / self.feat_dim ** 0.5) + arg = torch.einsum("ij, jk -> i", mol_q, k) / self.feat_dim**0.5 num = torch.log(1 + torch.exp(arg)) denom = num.sum() @@ -345,45 +322,29 @@ def forward(self, # dimension atoms_in_mol x F av = a_i.reshape(-1, 1) * v.reshape(1, -1) this_e_psi = self.resmlp(av) - e_psi[counter: counter + num_atoms[j]] = this_e_psi + e_psi[counter : counter + num_atoms[j]] = this_e_psi counter += num_atoms[j] return e_psi class CombinedEmbedding(nn.Module): - def __init__(self, - feat_dim, - activation=DEFAULT_ACTIVATION, - max_z=DEFAULT_MAX_Z, - residual_layers=DEFAULT_RES_LAYERS): - + def __init__( + self, feat_dim, activation=DEFAULT_ACTIVATION, max_z=DEFAULT_MAX_Z, residual_layers=DEFAULT_RES_LAYERS + ): super().__init__() - self.nuc_embedding = NuclearEmbedding( - feat_dim=feat_dim, - max_z=max_z) + self.nuc_embedding = NuclearEmbedding(feat_dim=feat_dim, max_z=max_z) self.charge_embedding = ElectronicEmbedding( - feat_dim=feat_dim, - activation=activation, - residual_layers=residual_layers) + feat_dim=feat_dim, activation=activation, residual_layers=residual_layers + ) self.spin_embedding = ElectronicEmbedding( - feat_dim=feat_dim, - activation=activation, - residual_layers=residual_layers) - - def forward(self, - charge, - spin, - z, - num_atoms): + feat_dim=feat_dim, activation=activation, residual_layers=residual_layers + ) + def forward(self, charge, spin, z, num_atoms): e_z = self.nuc_embedding(z) - e_q = self.charge_embedding(psi=charge, - e_z=e_z, - num_atoms=num_atoms) - e_s = self.charge_embedding(psi=spin, - e_z=e_z, - num_atoms=num_atoms) + e_q = self.charge_embedding(psi=charge, e_z=e_z, num_atoms=num_atoms) + e_s = self.charge_embedding(psi=spin, e_z=e_z, num_atoms=num_atoms) x_0 = e_z + e_q + e_s @@ -391,11 +352,7 @@ def forward(self, class GBlock(nn.Module): - def __init__(self, - l, - r_cut, - bern_k, - gamma): + def __init__(self, l, r_cut, bern_k, gamma): super().__init__() self.l = l self.r_cut = r_cut @@ -404,21 +361,16 @@ def __init__(self, self.y_lm_fn = make_y_lm(l) def forward(self, r_ij): - r = norm(r_ij).reshape(-1, 1) n_pairs = r_ij.shape[0] device = r_ij.device m_vals = list(range(-self.l, self.l + 1)) # is this for-loop slow? - y = torch.stack([self.y_lm_fn(r_ij, r, self.l, m) for m in - m_vals]).transpose(0, 1) + y = torch.stack([self.y_lm_fn(r_ij, r, self.l, m) for m in m_vals]).transpose(0, 1) gamma = self.gamma.clamp(0) rho = rho_k(r, self.r_cut, self.bern_k, gamma) - g = torch.ones(n_pairs, - self.bern_k, - len(m_vals), - device=device) + g = torch.ones(n_pairs, self.bern_k, len(m_vals), device=device) g = g * rho.reshape(n_pairs, -1, 1) g = g * y.reshape(n_pairs, 1, -1) @@ -426,16 +378,17 @@ def forward(self, r_ij): class LocalInteraction(nn.Module): - def __init__(self, - feat_dim, - bern_k, - gamma, - r_cut, - activation=DEFAULT_ACTIVATION, - max_z=DEFAULT_MAX_Z, - dropout=DEFAULT_DROPOUT, - residual_layers=DEFAULT_RES_LAYERS): - + def __init__( + self, + feat_dim, + bern_k, + gamma, + r_cut, + activation=DEFAULT_ACTIVATION, + max_z=DEFAULT_MAX_Z, + dropout=DEFAULT_DROPOUT, + residual_layers=DEFAULT_RES_LAYERS, + ): super().__init__() # for letter in ["c", "s", "p", "d", "l"]: @@ -447,42 +400,35 @@ def __init__(self, # setattr(self, key, val) num_orbs = 3 - self.num_channels = (num_orbs + 1) - self.resmlp = BatchedResidualMLP(feat_dim=feat_dim, - num_out=self.num_channels, - activation=activation, - dropout=dropout, - residual_layers=residual_layers) - self.resmlp_l = ResidualMLP(feat_dim=feat_dim, - activation=activation, - dropout=dropout, - residual_layers=residual_layers) + self.num_channels = num_orbs + 1 + self.resmlp = BatchedResidualMLP( + feat_dim=feat_dim, + num_out=self.num_channels, + activation=activation, + dropout=dropout, + residual_layers=residual_layers, + ) + self.resmlp_l = ResidualMLP( + feat_dim=feat_dim, activation=activation, dropout=dropout, residual_layers=residual_layers + ) for key in ["G_s", "G_p", "G_d"]: - val = nn.Parameter(torch.ones(feat_dim, - bern_k)) + val = nn.Parameter(torch.ones(feat_dim, bern_k)) nn.init.xavier_uniform_(val) setattr(self, key, val) for key in ["P_1", "P_2", "D_1", "D_2"]: - val = nn.Parameter(torch.ones(feat_dim, - feat_dim)) + val = nn.Parameter(torch.ones(feat_dim, feat_dim)) nn.init.xavier_uniform_(val) setattr(self, key, val) letters = {0: "s", 1: "p", 2: "d"} for l, letter in letters.items(): key = f"g_{letter}" - g_block = GBlock(l=l, - r_cut=r_cut, - bern_k=bern_k, - gamma=gamma) + g_block = GBlock(l=l, r_cut=r_cut, bern_k=bern_k, gamma=gamma) setattr(self, key, g_block) - def g_matmul(self, - r_ij, - orbital): - + def g_matmul(self, r_ij, orbital): g_func = getattr(self, f"g_{orbital}") g = g_func(r_ij) G = getattr(self, f"G_{orbital}") @@ -490,17 +436,18 @@ def g_matmul(self, # G: F x K # output: N_nbrs x F x (1, 3, or 5) - out = torch.einsum('ik, jkl -> jil', G, g) + out = torch.einsum("ik, jkl -> jil", G, g) return out - def make_quant(self, - r_ij, - # x_j, - nbrs, - graph_size, - orbital, - res_out): - + def make_quant( + self, + r_ij, + # x_j, + nbrs, + graph_size, + orbital, + res_out, + ): # res_block = getattr(self, f"resmlp_{orbital}") n_nbrs = nbrs.shape[0] @@ -509,20 +456,13 @@ def make_quant(self, # per_nbr = (resmlp.reshape(n_nbrs, -1, 1) # * matmul) - per_nbr = (res_out.reshape(n_nbrs, -1, 1) - * matmul) + per_nbr = res_out.reshape(n_nbrs, -1, 1) * matmul - out = scatter_add(src=per_nbr, - index=nbrs[:, 0], - dim=0, - dim_size=graph_size) + out = scatter_add(src=per_nbr, index=nbrs[:, 0], dim=0, dim_size=graph_size) return out - def take_inner(self, - quant, - orbital): - + def take_inner(self, quant, orbital): name = orbital.upper() # dimensions F x F @@ -532,10 +472,8 @@ def take_inner(self, # quant has dimension n_atoms x F x (3 or 5) # term has dimension n_atoms x F x (3 or 5) - term_1 = torch.einsum("ij, kjm->kim", - mat_1, quant) - term_2 = torch.einsum("ij, kjm->kim", - mat_2, quant) + term_1 = torch.einsum("ij, kjm->kim", mat_1, quant) + term_2 = torch.einsum("ij, kjm->kim", mat_2, quant) # inner product multiplies elementwise and # sums over the last dimension @@ -544,11 +482,7 @@ def take_inner(self, return inner - def forward(self, - xyz, - x_tilde, - nbrs): - + def forward(self, xyz, x_tilde, nbrs): inp = repeat_inp(x_tilde, self.num_channels) # inp = torch.stack([x_tilde for _ in range(self.num_channels)]) res_outs = self.resmlp(inp) @@ -556,26 +490,26 @@ def forward(self, # import pdb # pdb.set_trace() - orbitals = ['s', 'p', 'd'] + orbitals = ["s", "p", "d"] r_ij = xyz[nbrs[:, 1]] - xyz[nbrs[:, 0]] graph_size = xyz.shape[0] quants = [] for orbital, res_out in zip(orbitals, res_outs[1:]): - quant = self.make_quant(r_ij=r_ij, - # x_j=x_j, - nbrs=nbrs, - graph_size=graph_size, - orbital=orbital, - res_out=res_out[nbrs[:, 1]]) + quant = self.make_quant( + r_ij=r_ij, + # x_j=x_j, + nbrs=nbrs, + graph_size=graph_size, + orbital=orbital, + res_out=res_out[nbrs[:, 1]], + ) quants.append(quant) invariants = [] - for quant, orbital in zip(quants[1:], - orbitals[1:]): - invariant = self.take_inner(quant, - orbital) + for quant, orbital in zip(quants[1:], orbitals[1:]): + invariant = self.take_inner(quant, orbital) invariants.append(invariant) s_i = quants[0].reshape(graph_size, -1) @@ -588,12 +522,9 @@ def forward(self, class NonLocalInteraction(nn.Module): - def __init__(self, - feat_dim, - activation=DEFAULT_ACTIVATION, - dropout=DEFAULT_DROPOUT, - residual_layers=DEFAULT_RES_LAYERS): - + def __init__( + self, feat_dim, activation=DEFAULT_ACTIVATION, dropout=DEFAULT_DROPOUT, residual_layers=DEFAULT_RES_LAYERS + ): from performer_pytorch import FastAttention super().__init__() @@ -601,15 +532,11 @@ def __init__(self, # no redraw should happen here - only if # you call self attention of cross attention # as wrappers - self.attn = FastAttention(dim_heads=feat_dim, - nb_features=feat_dim, - causal=False) + self.attn = FastAttention(dim_heads=feat_dim, nb_features=feat_dim, causal=False) self.feat_dim = feat_dim - self.resmlp = BatchedResidualMLP(feat_dim=feat_dim, - num_out=3, - activation=activation, - dropout=dropout, - residual_layers=residual_layers) + self.resmlp = BatchedResidualMLP( + feat_dim=feat_dim, num_out=3, activation=activation, dropout=dropout, residual_layers=residual_layers + ) # for letter in ['q', 'k', 'v']: # key = f'resmlp_{letter}' @@ -619,10 +546,7 @@ def __init__(self, # residual_layers=residual_layers) # setattr(self, key, val) - def forward(self, - x_tilde, - num_atoms): - + def forward(self, x_tilde, num_atoms): # x_tilde has dimension N x F # N = number of nodes, F = feature dimension @@ -655,15 +579,12 @@ def forward(self, # k_stack = torch.stack(k_stack) # v_stack = torch.stack(v_stack) - out = torch.zeros(num_nodes, - self.feat_dim, - device=x_tilde.device) + out = torch.zeros(num_nodes, self.feat_dim, device=x_tilde.device) for i, num in enumerate(num_atoms): q = q_split[i].reshape(1, 1, -1, self.feat_dim) k = k_split[i].reshape(1, 1, -1, self.feat_dim) v = v_split[i].reshape(1, 1, -1, self.feat_dim) - att = (self.attn(q, k, v) - .reshape(-1, self.feat_dim)) + att = self.attn(q, k, v).reshape(-1, self.feat_dim) ### # real_q = q.reshape(-1, self.feat_dim) @@ -691,61 +612,51 @@ def forward(self, ### counter = sum(num_atoms[:i]) - out[counter: counter + num] = att + out[counter : counter + num] = att - import pdb # pdb.set_trace() return out class InteractionBlock(nn.Module): - def __init__(self, - feat_dim, - r_cut, - gamma, - bern_k, - activation=DEFAULT_ACTIVATION, - dropout=DEFAULT_DROPOUT, - max_z=DEFAULT_MAX_Z, - residual_layers=DEFAULT_RES_LAYERS): + def __init__( + self, + feat_dim, + r_cut, + gamma, + bern_k, + activation=DEFAULT_ACTIVATION, + dropout=DEFAULT_DROPOUT, + max_z=DEFAULT_MAX_Z, + residual_layers=DEFAULT_RES_LAYERS, + ): super().__init__() - self.residual_1 = Residual(feat_dim=feat_dim, - activation=activation, - dropout=dropout, - num_layers=residual_layers) - self.residual_2 = Residual(feat_dim=feat_dim, - activation=activation, - dropout=dropout, - num_layers=residual_layers) - self.resmlp = ResidualMLP(feat_dim=feat_dim, - activation=activation, - dropout=dropout, - residual_layers=residual_layers) - self.local = LocalInteraction(feat_dim=feat_dim, - bern_k=bern_k, - gamma=gamma, - r_cut=r_cut, - activation=activation, - max_z=max_z, - dropout=dropout) - self.non_local = NonLocalInteraction(feat_dim=feat_dim, - activation=activation, - dropout=dropout) - - def forward(self, - x, - xyz, - nbrs, - num_atoms): + self.residual_1 = Residual( + feat_dim=feat_dim, activation=activation, dropout=dropout, num_layers=residual_layers + ) + self.residual_2 = Residual( + feat_dim=feat_dim, activation=activation, dropout=dropout, num_layers=residual_layers + ) + self.resmlp = ResidualMLP( + feat_dim=feat_dim, activation=activation, dropout=dropout, residual_layers=residual_layers + ) + self.local = LocalInteraction( + feat_dim=feat_dim, + bern_k=bern_k, + gamma=gamma, + r_cut=r_cut, + activation=activation, + max_z=max_z, + dropout=dropout, + ) + self.non_local = NonLocalInteraction(feat_dim=feat_dim, activation=activation, dropout=dropout) + def forward(self, x, xyz, nbrs, num_atoms): x_tilde = self.residual_1(x) - l = self.local(xyz=xyz, - x_tilde=x_tilde, - nbrs=nbrs) - n = self.non_local(x_tilde=x_tilde, - num_atoms=num_atoms) + l = self.local(xyz=xyz, x_tilde=x_tilde, nbrs=nbrs) + n = self.non_local(x_tilde=x_tilde, num_atoms=num_atoms) x_t = self.residual_2(x_tilde + l + n) # x_t = self.residual_2(x_tilde + l ) @@ -755,7 +666,6 @@ def forward(self, def get_f_switch(r, r_on, r_off): - arg = (r - r_on) / (r_off - r_on) x = arg y = 1 - arg @@ -782,43 +692,29 @@ def get_f_switch(r, r_on, r_off): class Electrostatics(nn.Module): - def __init__(self, - feat_dim, - r_cut, - max_z=DEFAULT_MAX_Z): + def __init__(self, feat_dim, r_cut, max_z=DEFAULT_MAX_Z): super().__init__() - self.w = Dense(in_features=feat_dim, - out_features=1, - bias=False, - activation=None) + self.w = Dense(in_features=feat_dim, out_features=1, bias=False, activation=None) self.z_embed = nn.Embedding(max_z, 1, padding_idx=0) self.r_on = r_cut / 4 self.r_off = 3 * r_cut / 4 def f_switch(self, r): - out = get_f_switch(r=r, - r_on=self.r_on, - r_off=self.r_off) + out = get_f_switch(r=r, r_on=self.r_on, r_off=self.r_off) return out - def get_charge(self, - f, - z, - total_charge, - num_atoms): - + def get_charge(self, f, z, total_charge, num_atoms): w_f = self.w(f) q_z = self.z_embed(z) charge = w_f + q_z - mol_sum = scatter_mol(atomwise=charge, - num_atoms=num_atoms).reshape(-1) + mol_sum = scatter_mol(atomwise=charge, num_atoms=num_atoms).reshape(-1) correction = 1 / num_atoms * (total_charge - mol_sum) new_charges = [] for i, n in enumerate(num_atoms): counter = num_atoms[:i].sum() - old_val = charge[counter: counter + n] + old_val = charge[counter : counter + n] new_val = old_val + correction[i] new_charges.append(new_val) @@ -826,44 +722,22 @@ def get_charge(self, return new_charges - def get_en(self, - q, - xyz, - num_atoms, - mol_nbrs): - + def get_en(self, q, xyz, num_atoms, mol_nbrs): r_ij = norm(xyz[mol_nbrs[:, 0]] - xyz[mol_nbrs[:, 1]]) q_i = q[mol_nbrs[:, 0]].reshape(-1) q_j = q[mol_nbrs[:, 1]].reshape(-1) - arg_0 = (self.f_switch(r_ij) - / (r_ij ** 2 + BOHR_RADIUS ** 2) ** 0.5) + arg_0 = self.f_switch(r_ij) / (r_ij**2 + BOHR_RADIUS**2) ** 0.5 arg_1 = (1 - self.f_switch(r_ij)) / r_ij - pairwise = (KE_KCAL * q_i * q_j * (arg_0 + arg_1)) + pairwise = KE_KCAL * q_i * q_j * (arg_0 + arg_1) - energy = (scatter_pairwise(pairwise=pairwise, - num_atoms=num_atoms, - nbrs=mol_nbrs) - .reshape(-1, 1)) + energy = scatter_pairwise(pairwise=pairwise, num_atoms=num_atoms, nbrs=mol_nbrs).reshape(-1, 1) return energy - def forward(self, - f, - z, - xyz, - total_charge, - num_atoms, - mol_nbrs): - - q = self.get_charge(f=f, - z=z, - total_charge=total_charge, - num_atoms=num_atoms) - energy = self.get_en(q=q, - xyz=xyz, - num_atoms=num_atoms, - mol_nbrs=mol_nbrs) + def forward(self, f, z, xyz, total_charge, num_atoms, mol_nbrs): + q = self.get_charge(f=f, z=z, total_charge=total_charge, num_atoms=num_atoms) + energy = self.get_en(q=q, xyz=xyz, num_atoms=num_atoms, mol_nbrs=mol_nbrs) return energy, q @@ -882,45 +756,28 @@ def __init__(self, r_cut): new_val = nn.Parameter(val.reshape(-1, 1)) setattr(self, key, new_val) - def zbl_phi(self, - r_ij, - z_i, - z_j): - + def zbl_phi(self, r_ij, z_i, z_j): d = self.d.clamp(0) z_exp = self.z_exp.clamp(0) c = self.c.clamp(0) c = c / c.sum() exponents = self.exponents.clamp(0) - a = ((d / (z_i ** z_exp + z_j ** z_exp)) - .reshape(-1)) + a = (d / (z_i**z_exp + z_j**z_exp)).reshape(-1) - out = (c * torch.exp(-exponents * r_ij.reshape(-1) / a) - ).sum(0) + out = (c * torch.exp(-exponents * r_ij.reshape(-1) / a)).sum(0) return out - def forward(self, - xyz, - z, - nbrs, - num_atoms): - + def forward(self, xyz, z, nbrs, num_atoms): undirec = make_undirected(nbrs) z_i = z[undirec[:, 0]].to(torch.float32) z_j = z[undirec[:, 1]].to(torch.float32) r_ij = norm(xyz[undirec[:, 0]] - xyz[undirec[:, 1]]) - phi = self.zbl_phi(r_ij=r_ij, - z_i=z_i, - z_j=z_j) - pairwise = (KE_KCAL * z_i * z_j / r_ij - * phi - * spooky_f_cut(r_ij, self.r_cut)) - energy = scatter_pairwise(pairwise=pairwise, - num_atoms=num_atoms, - nbrs=undirec).reshape(-1, 1) + phi = self.zbl_phi(r_ij=r_ij, z_i=z_i, z_j=z_j) + pairwise = KE_KCAL * z_i * z_j / r_ij * phi * spooky_f_cut(r_ij, self.r_cut) + energy = scatter_pairwise(pairwise=pairwise, num_atoms=num_atoms, nbrs=undirec).reshape(-1, 1) # this_z_i = z[undirec[:, 0][:36]] # this_z_j = z[undirec[:, 1][:36]] @@ -961,34 +818,20 @@ def forward(self, class AtomwiseReadout(nn.Module): - def __init__(self, - feat_dim, - max_z=DEFAULT_MAX_Z): + def __init__(self, feat_dim, max_z=DEFAULT_MAX_Z): super().__init__() - self.w_e = Dense(in_features=feat_dim, - out_features=1, - bias=False, - activation=None) + self.w_e = Dense(in_features=feat_dim, out_features=1, bias=False, activation=None) self.z_bias = nn.Embedding(max_z, 1, padding_idx=0) - def forward(self, - z, - f, - num_atoms): - + def forward(self, z, f, num_atoms): atomwise = self.w_e(f) + self.z_bias(z) - e_total = scatter_mol(atomwise=atomwise, - num_atoms=num_atoms) + e_total = scatter_mol(atomwise=atomwise, num_atoms=num_atoms) return e_total -def get_dipole(xyz, - q, - num_atoms): - +def get_dipole(xyz, q, num_atoms): qr = q * xyz - dipole = scatter_mol(atomwise=qr, - num_atoms=num_atoms) + dipole = scatter_mol(atomwise=qr, num_atoms=num_atoms) return dipole diff --git a/nff/nn/modules/torchmd_net.py b/nff/nn/modules/torchmd_net.py index 624d936a..23a49af4 100644 --- a/nff/nn/modules/torchmd_net.py +++ b/nff/nn/modules/torchmd_net.py @@ -1,26 +1,18 @@ import torch from torch import nn -from nff.nn.layers import (CosineEnvelope, Dense) + +from nff.nn.layers import CosineEnvelope, Dense from nff.utils.scatter import scatter_add from nff.utils.tools import layer_types class DistanceEmbeding(nn.Module): - - def __init__(self, - feat_dim, - dropout, - rbf, - bias=False): - + def __init__(self, feat_dim, dropout, rbf, bias=False): super().__init__() n_rbf = rbf.mu.shape[0] cutoff = rbf.cutoff - dense = Dense(in_features=n_rbf, - out_features=feat_dim, - bias=bias, - dropout_rate=dropout) + dense = Dense(in_features=n_rbf, out_features=feat_dim, bias=bias, dropout_rate=dropout) self.block = nn.Sequential(rbf, dense) self.f_cut = CosineEnvelope(cutoff=cutoff) @@ -33,53 +25,31 @@ def forward(self, dist): class EmbeddingBlock(nn.Module): - def __init__(self, - feat_dim, - dropout, - rbf): - + def __init__(self, feat_dim, dropout, rbf): super().__init__() self.atom_embed = nn.Embedding(100, feat_dim, padding_idx=0) self.feat_dim = feat_dim - self.distance_embed = DistanceEmbeding(feat_dim=feat_dim, - dropout=dropout, - rbf=rbf, - bias=False) + self.distance_embed = DistanceEmbeding(feat_dim=feat_dim, dropout=dropout, rbf=rbf, bias=False) - self.concat_dense = Dense(in_features=2*feat_dim, - out_features=feat_dim, - dropout_rate=dropout, - activation=None) - - def forward(self, - z_number, - nbrs, - dist): + self.concat_dense = Dense( + in_features=2 * feat_dim, out_features=feat_dim, dropout_rate=dropout, activation=None + ) + def forward(self, z_number, nbrs, dist): num_atoms = z_number.shape[0] node_embeddings = self.atom_embed(z_number) nbr_embeddings = self.atom_embed(z_number[nbrs[:, 1]]) edge_feats = self.distance_embed(dist) * nbr_embeddings - aggr_embeddings = scatter_add(src=edge_feats, - index=nbrs[:, 0], - dim=0, - dim_size=num_atoms) + aggr_embeddings = scatter_add(src=edge_feats, index=nbrs[:, 0], dim=0, dim_size=num_atoms) - final_embeddings = self.concat_dense(torch.cat([node_embeddings, - aggr_embeddings], - dim=-1)) + final_embeddings = self.concat_dense(torch.cat([node_embeddings, aggr_embeddings], dim=-1)) return final_embeddings class AttentionHeads(nn.Module): - def __init__(self, - feat_dim, - activation, - num_heads, - rbf): - + def __init__(self, feat_dim, activation, num_heads, rbf): super().__init__() self.rbf = rbf @@ -89,35 +59,26 @@ def __init__(self, self.dk_layer = nn.Sequential( nn.Conv1d( - in_channels=n_rbf * num_heads, - out_channels=feat_dim * num_heads, - kernel_size=1, - groups=num_heads), - layer_types[activation]() + in_channels=n_rbf * num_heads, out_channels=feat_dim * num_heads, kernel_size=1, groups=num_heads + ), + layer_types[activation](), ) - self.query_layer = nn.Conv1d(in_channels=feat_dim * num_heads, - out_channels=feat_dim * num_heads, - kernel_size=1, - groups=num_heads) + self.query_layer = nn.Conv1d( + in_channels=feat_dim * num_heads, out_channels=feat_dim * num_heads, kernel_size=1, groups=num_heads + ) - self.key_layer = nn.Conv1d(in_channels=feat_dim * num_heads, - out_channels=feat_dim * num_heads, - kernel_size=1, - groups=num_heads) + self.key_layer = nn.Conv1d( + in_channels=feat_dim * num_heads, out_channels=feat_dim * num_heads, kernel_size=1, groups=num_heads + ) self.activation = layer_types[activation]() self.num_heads = num_heads - def forward(self, - dist, - nbrs, - x_i): - + def forward(self, dist, nbrs, x_i): x_i_nbrs = x_i[nbrs[:, 0]] x_j = x_i[nbrs[:, 1]] - edge_feats = (self.rbf(dist) - * self.f_cut(dist).reshape(-1, 1)) + edge_feats = self.rbf(dist) * self.f_cut(dist).reshape(-1, 1) x_i_inp = x_i_nbrs.repeat(1, self.num_heads).unsqueeze(-1) x_j_inp = x_j.repeat(1, self.num_heads).unsqueeze(-1) @@ -136,46 +97,30 @@ def forward(self, class MessageBlock(nn.Module): - def __init__(self, - feat_dim, - num_heads, - activation, - rbf): - + def __init__(self, feat_dim, num_heads, activation, rbf): super().__init__() n_rbf = rbf.mu.shape[0] - self.attention_heads = AttentionHeads(feat_dim=feat_dim, - activation=activation, - num_heads=num_heads, - rbf=rbf) + self.attention_heads = AttentionHeads(feat_dim=feat_dim, activation=activation, num_heads=num_heads, rbf=rbf) - self.v_layer = nn.Conv1d(in_channels=feat_dim * num_heads, - out_channels=feat_dim * num_heads, - kernel_size=1, - groups=num_heads) + self.v_layer = nn.Conv1d( + in_channels=feat_dim * num_heads, out_channels=feat_dim * num_heads, kernel_size=1, groups=num_heads + ) self.dv_layer = nn.Sequential( nn.Conv1d( - in_channels=n_rbf * num_heads, - out_channels=feat_dim * num_heads, - kernel_size=1, - groups=num_heads), - layer_types[activation]() + in_channels=n_rbf * num_heads, out_channels=feat_dim * num_heads, kernel_size=1, groups=num_heads + ), + layer_types[activation](), ) self.rbf = self.attention_heads.rbf self.f_cut = self.attention_heads.f_cut self.num_heads = num_heads - def forward(self, - dist, - nbrs, - x_i): - + def forward(self, dist, nbrs, x_i): x_j = x_i[nbrs[:, 1]] - edge_feats = (self.rbf(dist) - * self.f_cut(dist).reshape(-1, 1)) + edge_feats = self.rbf(dist) * self.f_cut(dist).reshape(-1, 1) x_j_inp = x_j.repeat(1, self.num_heads).unsqueeze(-1) edge_inp = edge_feats.repeat(1, self.num_heads).unsqueeze(-1) @@ -185,14 +130,10 @@ def forward(self, prod_v = v_feats * d_v - weights = self.attention_heads(dist=dist, - nbrs=nbrs, - x_i=x_i) + weights = self.attention_heads(dist=dist, nbrs=nbrs, x_i=x_i) # dimension num_edges x num_heads x num_features - scaled_v = (prod_v.reshape(prod_v.shape[0], - self.num_heads, -1) * - weights.unsqueeze(-1)) + scaled_v = prod_v.reshape(prod_v.shape[0], self.num_heads, -1) * weights.unsqueeze(-1) # reshape it into a concatenation, i.e. dimension # num_edges x (num_heads * num_features) @@ -203,28 +144,16 @@ def forward(self, class UpdateBlock(nn.Module): - def __init__(self, - num_heads, - feat_dim, - dropout): + def __init__(self, num_heads, feat_dim, dropout): super().__init__() - self.concat_dense = Dense(in_features=(num_heads * feat_dim), - out_features=feat_dim, - bias=True, - dropout_rate=dropout, - activation=None) - - def forward(self, - nbrs, - x_i, - scaled_v): + self.concat_dense = Dense( + in_features=(num_heads * feat_dim), out_features=feat_dim, bias=True, dropout_rate=dropout, activation=None + ) + def forward(self, nbrs, x_i, scaled_v): # dimension num_nodes x num_heads x num_features - x_i_prime = scatter_add(src=scaled_v, - index=nbrs[:, 0], - dim=0, - dim_size=x_i.shape[0]) + x_i_prime = scatter_add(src=scaled_v, index=nbrs[:, 0], dim=0, dim_size=x_i.shape[0]) x_i = x_i + self.concat_dense(x_i_prime) return x_i diff --git a/nff/nn/tensorgrad.py b/nff/nn/tensorgrad.py index c9a5f6b1..191c5f49 100644 --- a/nff/nn/tensorgrad.py +++ b/nff/nn/tensorgrad.py @@ -1,18 +1,18 @@ -"""Summary -""" -import numpy as np +"""Summary""" + import copy import inspect +import numpy as np import torch -from torch.autograd import grad import torch.nn.functional as F +from torch.autograd import grad from torch.utils.data import DataLoader def compute_jacobian(inputs, output, device): """ - Compute Jacobians + Compute Jacobians Args: inputs (torch.Tensor): size (N_in, ) @@ -46,31 +46,31 @@ def compute_jacobian(inputs, output, device): return torch.transpose(jacobian, dim0=0, dim1=1) -def compute_grad(inputs, - output, - allow_unused=False): - ''' +def compute_grad(inputs, output, allow_unused=False): + """ Args: inputs (torch.Tensor): size (N_in, ) output (torch.Tensor): size (..., -1) Returns: torch.Tensor: size (N_in, ) - ''' + """ assert inputs.requires_grad - gradspred, = grad(output, - inputs, - grad_outputs=output.data.new(output.shape).fill_(1), - create_graph=True, - retain_graph=True, - allow_unused=allow_unused) + (gradspred,) = grad( + output, + inputs, + grad_outputs=output.data.new(output.shape).fill_(1), + create_graph=True, + retain_graph=True, + allow_unused=allow_unused, + ) return gradspred def compute_hess(inputs, output, device): - ''' + """ Compute Hessians for arbitary model Args: @@ -80,7 +80,7 @@ def compute_hess(inputs, output, device): Returns: torch.Tensor: N_in, N_in, N_out - ''' + """ gradient = compute_grad(inputs, output) hess = compute_jacobian(inputs, gradient, device=device) @@ -88,14 +88,14 @@ def compute_hess(inputs, output, device): def get_schnet_hessians(batch, model, device=0): - """Get Hessians from schnet models + """Get Hessians from schnet models Args: batch (dict): batch of data model (TYPE): Description device (int, optional): Description """ - N_atom = batch['nxyz'].shape[0] + N_atom = batch["nxyz"].shape[0] xyz_reshape = batch["nxyz"][:, 1:].reshape(1, N_atom * 3) xyz_reshape.requires_grad = True xyz_input = xyz_reshape.reshape(N_atom, 3) @@ -106,9 +106,10 @@ def get_schnet_hessians(batch, model, device=0): return hess + def get_painn_hessians(batch, model, device=0): """Get Hessians from painn models. Hessian is returned in kcal/mol/A**2. - Use this method for painn models instead of hess from atoms. Tested both with + Use this method for painn models instead of hess from atoms. Tested both with molecular data (water) and periodic structures (quartz). Args: @@ -116,26 +117,24 @@ def get_painn_hessians(batch, model, device=0): model (TYPE): Description device (int, optional): Description """ - N_atom = batch['nxyz'].shape[0] + N_atom = batch["nxyz"].shape[0] xyz_reshape = batch["nxyz"][:, 1:].reshape(1, N_atom * 3) xyz_reshape.requires_grad = True xyz_input = xyz_reshape.reshape(N_atom, 3) - results = model(batch,xyz=xyz_input) + results = model(batch, xyz=xyz_input) energy = results["energy"] - hess=compute_hess(xyz_reshape, energy, device=device) + hess = compute_hess(xyz_reshape, energy, device=device) return hess def adj_nbrs_and_z(batch, xyz, max_dim, stacked): - nan_dims = [i for i, row in enumerate(xyz) if torch.isnan(row).all()] new_nbrs = copy.deepcopy(batch["nbr_list"]) new_z = copy.deepcopy(batch["nxyz"][:, 0]) for dim in nan_dims: - # adjust the neighbor list to account for the increased length # of the nxyz @@ -143,10 +142,14 @@ def adj_nbrs_and_z(batch, xyz, max_dim, stacked): new_nbrs[mask] += 1 # add dummy atomic numbers for these new nan's - new_z = torch.cat([new_z[:dim], - torch.Tensor([float("nan")]).to(new_z.device), - # torch.Tensor([float("1")]).to(new_z.device), - new_z[dim:]]) + new_z = torch.cat( + [ + new_z[:dim], + torch.Tensor([float("nan")]).to(new_z.device), + # torch.Tensor([float("1")]).to(new_z.device), + new_z[dim:], + ] + ) # change the neighbor list in the batch batch["real_nbrs"] = copy.deepcopy(batch["nbr_list"]) @@ -154,8 +157,7 @@ def adj_nbrs_and_z(batch, xyz, max_dim, stacked): # change the nxyz in the batch batch["real_nxyz"] = copy.deepcopy(batch["nxyz"]) - batch["nxyz"] = torch.cat([new_z.reshape(-1, 1), xyz], - dim=-1) + batch["nxyz"] = torch.cat([new_z.reshape(-1, 1), xyz], dim=-1) # change the number of atoms in the batch batch["real_num_atoms"] = copy.deepcopy(batch["num_atoms"]) @@ -169,7 +171,6 @@ def adj_nbrs_and_z(batch, xyz, max_dim, stacked): def pad(batch): - nxyz = batch["nxyz"] N = batch["num_atoms"].tolist() @@ -183,10 +184,7 @@ def pad(batch): num_pads = [max_dim - i.shape[0] for i in reshaped] # pad each geometry and stack the resulting nxyz's - stacked = torch.stack([F.pad(i, [0, num_pad], - value=nan) - for i, num_pad in - zip(reshaped, num_pads)]) + stacked = torch.stack([F.pad(i, [0, num_pad], value=nan) for i, num_pad in zip(reshaped, num_pads)]) # Get the stacked `xyz` by applying a mask to # remove the atomic numbers in the nxyz. We need @@ -218,7 +216,6 @@ def pad(batch): def hess_from_pad(stacked, output, device, N): - gradient = compute_grad(stacked, output) pad_hess = compute_jacobian(stacked, gradient, device=device) hess_list = [] @@ -230,11 +227,7 @@ def hess_from_pad(stacked, output, device, N): return hess_list -def schnet_batched_hessians(batch, - model, - device=0, - energy_keys=["energy"]): - +def schnet_batched_hessians(batch, model, device=0, energy_keys=["energy"]): from nff.nn.graphop import batch_and_sum stack_xyz, xyz, batch = pad(batch) @@ -246,10 +239,7 @@ def schnet_batched_hessians(batch, for key in energy_keys: output = results[key] - hess = hess_from_pad(stacked=stack_xyz, - output=output, - device=device, - N=N) + hess = hess_from_pad(stacked=stack_xyz, output=output, device=device, N=N) hess_dic[key + "_hess"] = hess # change these keys back to their original values @@ -265,12 +255,8 @@ def schnet_batched_hessians(batch, return hess_dic -def results_from_stack(batch, - model=None, - forward=None, - **kwargs): - - batch['nxyz'] = batch['nxyz'].detach() +def results_from_stack(batch, model=None, forward=None, **kwargs): + batch["nxyz"] = batch["nxyz"].detach() stack_xyz, xyz, batch = pad(batch) # Make sure the model takes `xyz` as an input @@ -283,33 +269,22 @@ def results_from_stack(batch, forward = model.forward info = inspect.getargspec(forward) - if 'xyz' not in info.args: - raise Exception(("Model does not take xyz as input. " - "Please modify the model so that it can take " - "an external xyz.")) - results = forward(batch=batch, - xyz=xyz, - **kwargs) + if "xyz" not in info.args: + raise Exception( + "Model does not take xyz as input. " "Please modify the model so that it can take " "an external xyz." + ) + results = forward(batch=batch, xyz=xyz, **kwargs) return xyz, stack_xyz, results -def hess_from_results(results, - xyz, - stack_xyz, - keys, - batch, - device): - +def hess_from_results(results, xyz, stack_xyz, keys, batch, device): hess_dic = {} - N = batch['real_num_atoms'] + N = batch["real_num_atoms"] for key in keys: output = results[key] - hess = hess_from_pad(stacked=stack_xyz, - output=output, - device=device, - N=N) + hess = hess_from_pad(stacked=stack_xyz, output=output, device=device, N=N) hess_dic[key + "_hess"] = hess # change these keys back to their original values @@ -327,27 +302,13 @@ def hess_from_results(results, return results -def general_batched_hessian(batch, - keys, - device, - model=None, - forward=None, - **kwargs): - +def general_batched_hessian(batch, keys, device, model=None, forward=None, **kwargs): # doesn't seem to work for painn, at least with non-locality - assert any([i is not None for i in [model, forward]]) - xyz, stack_xyz, results = results_from_stack(batch=batch, - model=model, - forward=forward, - **kwargs) + assert any(i is not None for i in [model, forward]) + xyz, stack_xyz, results = results_from_stack(batch=batch, model=model, forward=forward, **kwargs) - results = hess_from_results(results=results, - xyz=xyz, - stack_xyz=stack_xyz, - keys=keys, - batch=batch, - device=device) + results = hess_from_results(results=results, xyz=xyz, stack_xyz=stack_xyz, keys=keys, batch=batch, device=device) return results @@ -363,8 +324,7 @@ def hess_from_atoms(atoms): """ - from nff.data import Dataset - from nff.data import collate_dicts + from nff.data import Dataset, collate_dicts from nff.train import batch_to from nff.utils import constants as const @@ -377,8 +337,7 @@ def hess_from_atoms(atoms): n = atoms.get_atomic_numbers().reshape(-1, 1) nxyz = np.concatenate([n, xyz], axis=-1) dset = Dataset(props={"nxyz": [nxyz]}) - dset.generate_neighbor_list(cutoff, - undirected=(not directed)) + dset.generate_neighbor_list(cutoff, undirected=(not directed)) loader = DataLoader(dset, collate_fn=collate_dicts) batch = next(iter(loader)) @@ -388,17 +347,12 @@ def hess_from_atoms(atoms): # get the results key = getattr(atoms.calc, "en_key", "energy") - results = general_batched_hessian(batch=batch, - keys=[key], - device=device, - model=model) + results = general_batched_hessian(batch=batch, keys=[key], device=device, model=model) hess_key = key + "_hess" hessian = torch.stack(results[hess_key]) hessian = hessian.reshape(*hessian.shape[1:]) - hessian = (hessian.detach().cpu().numpy() * - const.KCAL_TO_AU['energy'] * - const.BOHR_RADIUS ** 2) + hessian = hessian.detach().cpu().numpy() * const.KCAL_TO_AU["energy"] * const.BOHR_RADIUS**2 return hessian diff --git a/nff/nn/utils.py b/nff/nn/utils.py index c5e18b5b..212e6301 100644 --- a/nff/nn/utils.py +++ b/nff/nn/utils.py @@ -1,15 +1,16 @@ """Tools to build layers""" + import collections -import numpy as np -import torch import copy +import numpy as np +import torch from torch.nn import ModuleDict, Sequential + from nff.nn.activations import shifted_softplus from nff.nn.layers import Dense, Diagonalize from nff.utils.scatter import scatter_add - layer_types = { "linear": torch.nn.Linear, "Tanh": torch.nn.Tanh, @@ -40,8 +41,7 @@ def construct_sequential(layers): """ return Sequential( collections.OrderedDict( - [layer["name"] + str(i), layer_types[layer["name"]](**layer["param"])] - for i, layer in enumerate(layers) + [layer["name"] + str(i), layer_types[layer["name"]](**layer["param"])] for i, layer in enumerate(layers) ) ) @@ -148,9 +148,7 @@ def clean_matrix(matrix, eps=1e-12): return matrix -def torch_nbr_list( - atomsobject, cutoff, device="cuda:0", directed=True, requires_large_offsets=True -): +def torch_nbr_list(atomsobject, cutoff, device="cuda:0", directed=True, requires_large_offsets=True): """Pytorch implementations of nbr_list for minimum image convention, the offsets are only limited to 0, 1, -1: it means that no pair interactions is allowed for more than 1 periodic box length. It is so much faster than neighbor_list algorithm in ase. @@ -170,10 +168,7 @@ def torch_nbr_list( # otherwise, default to the "robust" nbr_list function below for small cells if ( np.all(2 * cutoff < atomsobject.cell.cellpar()[:3]) - and not np.count_nonzero( - atomsobject.cell.T - np.diag(np.diagonal(atomsobject.cell.T)) - ) - != 0 + and np.count_nonzero(atomsobject.cell.T - np.diag(np.diagonal(atomsobject.cell.T))) == 0 ): # "fast" nbr_list function for large cells (pbc) xyz = torch.Tensor(atomsobject.get_positions(wrap=False)).to(device) @@ -183,17 +178,13 @@ def torch_nbr_list( shift = torch.round(torch.divide(dis_mat, cell_dim)) offsets = -shift else: - offsets = -dis_mat.ge(0.5 * cell_dim).to(torch.float) + dis_mat.lt( - -0.5 * cell_dim - ).to(torch.float) + offsets = -dis_mat.ge(0.5 * cell_dim).to(torch.float) + dis_mat.lt(-0.5 * cell_dim).to(torch.float) dis_mat = dis_mat + offsets * cell_dim dis_sq = dis_mat.pow(2).sum(-1) mask = (dis_sq < cutoff**2) & (dis_sq != 0) nbr_list = mask.nonzero(as_tuple=False) - offsets = ( - offsets[nbr_list[:, 0], nbr_list[:, 1], :].detach().to("cpu").numpy() - ) + offsets = offsets[nbr_list[:, 0], nbr_list[:, 1], :].detach().to("cpu").numpy() else: # "robust" nbr_list function for all cells (pbc) @@ -205,9 +196,7 @@ def torch_nbr_list( unwrapped_positions = atomsobject.get_positions(wrap=False) shift = positions - unwrapped_positions cell = atomsobject.cell - cell = np.broadcast_to( - cell.T, (shift.shape[0], cell.shape[0], cell.shape[1]) - ) + cell = np.broadcast_to(cell.T, (shift.shape[0], cell.shape[0], cell.shape[1])) shift = np.linalg.solve(cell, shift).round().astype(int) # estimate getting close to the cutoff with supercell expansion @@ -222,12 +211,8 @@ def torch_nbr_list( lattice_points_frac = lattice_points_in_supercell(supercell_matrix) lattice_points = np.dot(lattice_points_frac, supercell) # need to get all negative lattice translation vectors but remove duplicate 0 vector - zero_idx = np.where( - np.all(lattice_points.__eq__(np.array([0, 0, 0])), axis=1) - )[0][0] - lattice_points = np.concatenate( - [lattice_points[zero_idx:, :], lattice_points[:zero_idx, :]] - ) + zero_idx = np.where(np.all(lattice_points.__eq__(np.array([0, 0, 0])), axis=1))[0][0] + lattice_points = np.concatenate([lattice_points[zero_idx:, :], lattice_points[:zero_idx, :]]) N = len(lattice_points) # perform lattice translation vectors on positions @@ -247,9 +232,7 @@ def torch_nbr_list( ] # get offsets as original integer multiples of lattice vectors - cell = np.broadcast_to( - cell.T, (offsets.shape[0], cell.shape[0], cell.shape[1]) - ) + cell = np.broadcast_to(cell.T, (offsets.shape[0], cell.shape[0], cell.shape[1])) offsets = offsets.detach().to("cpu").numpy() offsets = np.linalg.solve(cell, offsets).round().astype(int) @@ -279,10 +262,7 @@ def torch_nbr_list( nbr_list[:, 1].detach().to("cpu").numpy(), ) - if any(atomsobject.pbc): - offsets = offsets - else: - offsets = np.zeros((nbr_list.shape[0], 3)) + offsets = offsets if any(atomsobject.pbc) else np.zeros((nbr_list.shape[0], 3)) return i, j, offsets @@ -488,9 +468,7 @@ def chemprop_msg_to_node(h, nbrs, num_nodes): h_to_add = h[good_idx] # add together - node_features = scatter_add( - src=h_to_add, index=match_idx, dim=0, dim_size=num_nodes - ) + node_features = scatter_add(src=h_to_add, index=match_idx, dim=0, dim_size=num_nodes) return node_features diff --git a/nff/opt/algos.py b/nff/opt/algos.py index ba736b20..fbfced06 100644 --- a/nff/opt/algos.py +++ b/nff/opt/algos.py @@ -1,50 +1,45 @@ -from ase.optimize.sciopt import SciPyFminCG, SciPyFminBFGS -from ase.optimize import BFGS -import scipy.optimize as opt import numpy as np - -from warnings import warn - +import scipy.optimize as opt +from ase.optimize import BFGS +from ase.optimize.sciopt import SciPyFminBFGS, SciPyFminCG class Converged(Exception): pass + class OptimizerConvergenceError(Exception): pass - class NeuralCG(SciPyFminCG): - def call_fmin(self, fmax, steps): - output = opt.fmin_cg(self.f, - self.x0(), - fprime=self.fprime, - # args=(), - gtol=fmax * 0.1, # Should never be reached - norm=np.inf, - # epsilon= - maxiter=steps, - full_output=1, - disp=0, - # retall=0, - callback=self.callback) + output = opt.fmin_cg( + self.f, + self.x0(), + fprime=self.fprime, + # args=(), + gtol=fmax * 0.1, # Should never be reached + norm=np.inf, + # epsilon= + maxiter=steps, + full_output=1, + disp=0, + # retall=0, + callback=self.callback, + ) warnflag = output[-1] if warnflag == 2: - raise OptimizerConvergenceError( - 'Warning: Desired error not necessarily achieved ' - 'due to precision loss') + raise OptimizerConvergenceError("Warning: Desired error not necessarily achieved " "due to precision loss") def run(self, fmax=0.05, steps=100000000): - if self.force_consistent is None: self.set_force_consistent() self.fmax = fmax try: # want to update the neighbor list every step self.atoms.update_nbr_list() - + # As SciPy does not log the zeroth iteration, we do that manually self.callback(None) @@ -54,11 +49,8 @@ def run(self, fmax=0.05, steps=100000000): pass - class NeuralBFGS(SciPyFminBFGS): - def run(self, fmax=0.05, steps=100000000): - if self.force_consistent is None: self.set_force_consistent() self.fmax = fmax @@ -74,9 +66,7 @@ def run(self, fmax=0.05, steps=100000000): pass - class NeuralAseBFGS(BFGS): - def step(self, f=None): atoms = self.atoms @@ -90,6 +80,7 @@ def step(self, f=None): self.update(r.flat, f, self.r0, self.f0) from numpy.linalg import eigh + omega, V = eigh(self.H) # FUTURE: Log this properly @@ -105,9 +96,9 @@ def step(self, f=None): # self.logfile.flush() dr = np.dot(V, np.dot(f, V) / np.fabs(omega)).reshape((-1, 3)) - steplengths = (dr**2).sum(1)**0.5 + steplengths = (dr**2).sum(1) ** 0.5 dr = self.determine_step(dr, steplengths) atoms.set_positions(r + dr) self.r0 = r.flat.copy() self.f0 = f.copy() - self.dump((self.H, self.r0, self.f0, self.maxstep)) \ No newline at end of file + self.dump((self.H, self.r0, self.f0, self.maxstep)) diff --git a/nff/qm/integrals/overlap.py b/nff/qm/integrals/overlap.py index dfedf5c3..6ab300bb 100644 --- a/nff/qm/integrals/overlap.py +++ b/nff/qm/integrals/overlap.py @@ -5,21 +5,18 @@ """ # import numba as nb +import time + import numpy as np import torch -import time # from typing import Union from nff.utils.scatter import compute_grad - # @torch.jit.script -def horizontal(s: torch.Tensor, - p: float, - r_pa: torch.Tensor): - +def horizontal(s: torch.Tensor, p: float, r_pa: torch.Tensor): num = s.shape[0] for i in range(1, num): @@ -33,11 +30,7 @@ def horizontal(s: torch.Tensor, # @torch.jit.script -def vertical(s: torch.Tensor, - r_pb: torch.Tensor, - p: float, - device: int): - +def vertical(s: torch.Tensor, r_pb: torch.Tensor, p: float, device: int): l_1 = s.shape[2] i_range = torch.arange(s.shape[1]).to(device) @@ -46,12 +39,11 @@ def vertical(s: torch.Tensor, zeros = torch.zeros(3, 1).to(device) for j in range(1, l_1): - s_term = r_pb_shape * s[:, :, j - 1].clone() new_s = torch.cat((zeros, s[:, :-1, j - 1]), dim=1) i_term = i_range / (2 * p) * new_s - s[:, :, j] = (s_term + i_term) + s[:, :, j] = s_term + i_term if j > 1: j_term = (j - 1) / (2 * p) * s[:, :, j - 2] @@ -61,11 +53,7 @@ def vertical(s: torch.Tensor, # @torch.jit.script -def get_prelims(r_a: torch.Tensor, - r_b: torch.Tensor, - a: float, - b: float): - +def get_prelims(r_a: torch.Tensor, r_b: torch.Tensor, a: float, b: float): p = a + b mu = a * b / (a + b) @@ -75,20 +63,12 @@ def get_prelims(r_a: torch.Tensor, r_pa = big_p - r_a r_pb = big_p - r_b - s_0 = (torch.sqrt(torch.tensor(np.pi / p)) - * torch.exp(-mu * r_ab ** 2)) + s_0 = torch.sqrt(torch.tensor(np.pi / p)) * torch.exp(-mu * r_ab**2) return r_pa, r_pb, s_0, p -def compute_overlaps(l_0, - l_1, - p, - r_pa, - r_pb, - s_0, - device): - +def compute_overlaps(l_0, l_1, p, r_pa, r_pb, s_0, device): r_pa = r_pa.to(device) r_pb = r_pb.to(device) @@ -96,24 +76,16 @@ def compute_overlaps(l_0, s[0, :] = s_0.to(device) s = horizontal(s, p, r_pa) - s_t = (s.transpose(0, 1).reshape(3, 1, -1) - .transpose(1, 2)) + s_t = s.transpose(0, 1).reshape(3, 1, -1).transpose(1, 2) - zeros = (torch.zeros((3, l_0, l_1 - 1)) - .to(device)) + zeros = torch.zeros((3, l_0, l_1 - 1)).to(device) s = torch.cat([s_t, zeros], dim=2) s = vertical(s, r_pb, p, device) return s -def pos_to_overlaps(r_a, - r_b, - a, - b, - l_0, - l_1, - device): +def pos_to_overlaps(r_a, r_b, a, b, l_0, l_1, device): """ Overlaps between the Cartesian Gaussian orbitals of two atoms at r_a and r_b, respectively, @@ -121,18 +93,9 @@ def pos_to_overlaps(r_a, and maximum angular momenta l_0 and l_1. """ - r_pa, r_pb, s_0, p = get_prelims(r_a=r_a, - r_b=r_b, - a=a, - b=b) + r_pa, r_pb, s_0, p = get_prelims(r_a=r_a, r_b=r_b, a=a, b=b) - s = compute_overlaps(l_0=l_0, - l_1=l_1, - p=p, - r_pa=r_pa, - r_pb=r_pb, - s_0=s_0, - device=device) + s = compute_overlaps(l_0=l_0, l_1=l_1, p=p, r_pa=r_pa, r_pb=r_pb, s_0=s_0, device=device) return s @@ -150,13 +113,7 @@ def test(): l_1 = 5 start = time.time() - s = pos_to_overlaps(r_a, - r_b, - a, - b, - l_0, - l_1, - device='cpu') + s = pos_to_overlaps(r_a, r_b, a, b, l_0, l_1, device="cpu") end = time.time() delta = end - start print("%.5e seconds" % delta) @@ -173,8 +130,7 @@ def test(): targs = [0.167162, 7.52983e-5, -0.0320324] for i, idx in enumerate(idx_pairs): print(idx) - print("Predicted value: %.5e" % - (s[idx[0], idx[1], idx[2]])) + print("Predicted value: %.5e" % (s[idx[0], idx[1], idx[2]])) print("Target value: %.5e" % targs[i]) diff --git a/nff/reactive_tools/ev_following.py b/nff/reactive_tools/ev_following.py index 1ed80585..6d1828bd 100644 --- a/nff/reactive_tools/ev_following.py +++ b/nff/reactive_tools/ev_following.py @@ -1,12 +1,12 @@ import torch +from neuralnet.vib import hessian_and_modes from nff.io.ase_calcs import NeuralFF from nff.reactive_tools.utils import ( neural_energy_ase, neural_force_ase, ) -from nff.utils.constants import EV_TO_AU, BOHR_RADIUS -from neuralnet.vib import hessian_and_modes +from nff.utils.constants import BOHR_RADIUS, EV_TO_AU CONVG_LINE = "Optimization converged!" @@ -30,11 +30,7 @@ def powell_update(hessian_old, h, gradient_old, gradient_new): update = ( torch.mm(V.reshape(-1, 1), h.reshape(1, -1)) + torch.mm(h.reshape(-1, 1), V.reshape(1, -1)) - - ( - torch.dot(V, h) - / torch.dot(h, h) - * torch.mm(h.reshape(-1, 1), h.reshape(1, -1)) - ) + - (torch.dot(V, h) / torch.dot(h, h) * torch.mm(h.reshape(-1, 1), h.reshape(1, -1))) ) / torch.dot(h, h) powell_hessian = hessian_old + update @@ -92,9 +88,7 @@ def eigvec_following( lambda_n = torch.linalg.eigvalsh(matrix_n, UPLO="U")[0] - lambda_n = lambda_n.new_full((Ndim * len(old_xyz[0]) - 1,), lambda_n.item()).to( - device - ) + lambda_n = lambda_n.new_full((Ndim * len(old_xyz[0]) - 1,), lambda_n.item()).to(device) h_p = -1.0 * F[0] * eigvecs_t[0] / (eigenvalues[0] - lambda_p) h_n = -1.0 * F[1:] * eigvecs_t[1:] / ((eigenvalues[1:] - lambda_n).reshape(-1, 1)) @@ -102,10 +96,7 @@ def eigvec_following( h = torch.add(h_p, torch.sum(h_n, dim=0)).reshape(-1, len(old_xyz[0]), Ndim) step_size = h.norm() - if step_size <= maxstepsize: - new_xyz = old_xyz + h - else: - new_xyz = old_xyz + (h / (step_size / maxstepsize)) + new_xyz = old_xyz + h if step_size <= maxstepsize else old_xyz + h / (step_size / maxstepsize) output = (new_xyz.detach(), grad.detach(), hessian.detach(), h.reshape(-1).detach()) print(f"STEP {step}:", output) @@ -138,9 +129,7 @@ def ev_run( rmslist = [] maxlist = [] - calc_kwargs = get_calc_kwargs( - calc_kwargs=calc_kwargs, device=device, nff_dir=nff_dir - ) + calc_kwargs = get_calc_kwargs(calc_kwargs=calc_kwargs, device=device, nff_dir=nff_dir) nff = NeuralFF.from_file(**calc_kwargs) ev_atoms.set_calculator(nff) @@ -150,27 +139,15 @@ def ev_run( if step % nbr_update_period == 0: ev_atoms.update_nbr_list() - if step == 0: - args = [] - else: - args = [hessian, grad, h] + args = [] if step == 0 else [hessian, grad, h] - xyz, grad, hessian, h = eigvec_following( - ev_atoms, step, maxstepsize, device, method, *args - ) + xyz, grad, hessian, h = eigvec_following(ev_atoms, step, maxstepsize, device, method, *args) - if step == 0: - xyz_all = xyz - else: - xyz_all = torch.cat((xyz_all, xyz), dim=0) + xyz_all = xyz if step == 0 else torch.cat((xyz_all, xyz), dim=0) rmslist.append(grad.pow(2).sqrt().mean()) maxlist.append(grad.pow(2).sqrt().max()) - print( - "RMS: {}, MAX: {}".format( - grad.pow(2).sqrt().mean(), grad.pow(2).sqrt().max() - ) - ) + print(f"RMS: {grad.pow(2).sqrt().mean()}, MAX: {grad.pow(2).sqrt().max()}") if grad.pow(2).sqrt().max() < convergence: print(CONVG_LINE) diff --git a/nff/reactive_tools/kabsch.py b/nff/reactive_tools/kabsch.py index 0c4571a1..44bd2995 100644 --- a/nff/reactive_tools/kabsch.py +++ b/nff/reactive_tools/kabsch.py @@ -1,15 +1,12 @@ -import numpy as np import io +import numpy as np + try: - from alog import Logger - from acore import settings import acore as ac + from alog import Logger except ModuleNotFoundError: - print( - "You need to install the group's fork of aRMSD and put it in your path " - "https://github.mit.edu/MLMat/aRMSD" - ) + print("You need to install the group's fork of aRMSD and put it in your path " "https://github.mit.edu/MLMat/aRMSD") VERSION, YEAR = "0.9.4", "2017" @@ -17,7 +14,7 @@ def write_coord(coord): """Adjusts whitespace for coordinates""" - return "{:06.8f}".format(coord) if coord < 0.0 else " " + "{:06.8f}".format(coord) + return f"{coord:06.8f}" if coord < 0.0 else " " + f"{coord:06.8f}" def data_to_xyz(sym, cor): @@ -77,9 +74,7 @@ def kabsch(rxn, indexedproductgeom_raw, reactantgeom_raw, rid, pid): idp_element_symbol, idp_element_xyz = ac.read_xyz_file(logger, idp_data) idp_element_xyz_std = None # Create a molecule object - molecule1 = ac.Molecule( - idp_mol_name, idp_element_symbol, idp_element_xyz, idp_element_xyz_std - ) + molecule1 = ac.Molecule(idp_mol_name, idp_element_symbol, idp_element_xyz, idp_element_xyz_std) molecule1.get_charge() molecule1.get_mass() molecule1.calc_com(calc_for="molecule") @@ -95,9 +90,7 @@ def kabsch(rxn, indexedproductgeom_raw, reactantgeom_raw, rid, pid): rxt_element_symbol, rxt_element_xyz = ac.read_xyz_file(logger, rxt_data) rxt_element_xyz_std = None # Create a molecule object - molecule2 = ac.Molecule( - rxt_mol_name, rxt_element_symbol, rxt_element_xyz, rxt_element_xyz_std - ) + molecule2 = ac.Molecule(rxt_mol_name, rxt_element_symbol, rxt_element_xyz, rxt_element_xyz_std) molecule2.get_charge() molecule2.get_mass() molecule2.calc_com(calc_for="molecule") diff --git a/nff/reactive_tools/neb.py b/nff/reactive_tools/neb.py index 67b3c08e..18499914 100644 --- a/nff/reactive_tools/neb.py +++ b/nff/reactive_tools/neb.py @@ -1,14 +1,13 @@ import copy -from nff.io.ase import AtomsBatch -from nff.io.ase_calcs import NeuralFF -from nff.reactive_tools.utils import xyz_to_ase_atoms - from ase.io import read - from ase.neb import NEB from ase.optimize import BFGS +from nff.io.ase import AtomsBatch +from nff.io.ase_calcs import NeuralFF +from nff.reactive_tools.utils import xyz_to_ase_atoms + def neural_neb_ase( reactantxyzfile, @@ -36,17 +35,17 @@ def neural_neb_ase( neb.interpolate() neb.idpp_interpolate(optimizer=BFGS, steps=steps) - images = read("idpp.traj@-{}:".format(str(n_images + 2))) + images = read(f"idpp.traj@-{n_images + 2!s}:") # # Set calculators: nff_ase = NeuralFF.from_file(nff_dir, device="cuda:0") neb.set_calculators(nff_ase) # # Optimize: - optimizer = BFGS(neb, trajectory="{}/{}.traj".format(nff_dir, rxn_name)) + optimizer = BFGS(neb, trajectory=f"{nff_dir}/{rxn_name}.traj") optimizer.run(fmax=fmax, steps=steps) # Read NEB images from File - images = read("{}/{}.traj@-{}:".format(nff_dir, rxn_name, str(n_images + 2))) + images = read(f"{nff_dir}/{rxn_name}.traj@-{n_images + 2!s}:") return images diff --git a/nff/reactive_tools/nms.py b/nff/reactive_tools/nms.py index 4a6fc76f..6f75a5c9 100644 --- a/nff/reactive_tools/nms.py +++ b/nff/reactive_tools/nms.py @@ -1,97 +1,104 @@ -import scipy import numpy as np +from ase.units import Bohr from scipy.stats import rv_discrete -from ase.units import Bohr,Rydberg,kJ,kB,fs,Hartree,mol,kcal,second CM_2_AU = 4.5564e-6 ANGS_2_AU = 1.8897259886 AMU_2_AU = 1822.88985136 k_B = 1.38064852e-23 -PLANCKS_CONS = 6.62607015e-34 -HA2J = 4.359744E-18 +PLANCKS_CONS = 6.62607015e-34 +HA2J = 4.359744e-18 BOHRS2ANG = 0.529177 -SPEEDOFLIGHT = 2.99792458E8 -AMU2KG = 1.660538782E-27 +SPEEDOFLIGHT = 2.99792458e8 +AMU2KG = 1.660538782e-27 + class Boltzmann_gen(rv_discrete): "Boltzmann distribution" + def _pmf(self, k, nu, temperature): - return ((np.exp(-(k * PLANCKS_CONS * nu)/(k_B * temperature))) * - (1 - np.exp(-(PLANCKS_CONS * nu)/(k_B * temperature)))) - -def reactive_normal_mode_sampling(xyz, force_constants_J_m_2, - proj_vib_freq_cm_1, proj_hessian_eigvec, - temperature, - kick=1): - + return (np.exp(-(k * PLANCKS_CONS * nu) / (k_B * temperature))) * ( + 1 - np.exp(-(PLANCKS_CONS * nu) / (k_B * temperature)) + ) + + +def reactive_normal_mode_sampling( + xyz, force_constants_J_m_2, proj_vib_freq_cm_1, proj_hessian_eigvec, temperature, kick=1 +): """Normal Mode Sampling for Transition States. Takes in xyz(1,N,3), force_constants(3N-6) in J/m^2, projected vibrational frequencies(3N-6) in cm^-1,mass-weighted projected hessian eigenvectors(3N-6,3N) - ,temperature in K, and scaling factor of initial velocity of the lowest imaginary mode. + ,temperature in K, and scaling factor of initial velocity of the lowest imaginary mode. Returns displaces xyz and a pair of velocities(forward and backwards)""" - - #Determine the highest level occupany of each mode + + # Determine the highest level occupany of each mode occ_vib_modes = [] boltzmann = Boltzmann_gen(a=0, b=1000000, name="boltzmann") for i, nu in enumerate(proj_vib_freq_cm_1): if nu > 50: - occ_vib_modes.append(boltzmann.rvs(nu * SPEEDOFLIGHT * 100, - temperature)) + occ_vib_modes.append(boltzmann.rvs(nu * SPEEDOFLIGHT * 100, temperature)) elif i == 0: - occ_vib_modes.append(boltzmann.rvs(-1 * nu * SPEEDOFLIGHT * 100, - temperature)) + occ_vib_modes.append(boltzmann.rvs(-1 * nu * SPEEDOFLIGHT * 100, temperature)) else: occ_vib_modes.append(-1) - - #Determine maximum displacement (amplitude) of each mode - + + # Determine maximum displacement (amplitude) of each mode + amplitudes = [] - freqs = [] for i, occ in enumerate(occ_vib_modes): if occ >= 0: - energy = proj_vib_freq_cm_1[i] * SPEEDOFLIGHT * 100 * PLANCKS_CONS # cm-1 to Joules - amplitudes.append(np.sqrt((0.5 * (occ + 1) * energy) / force_constants_J_m_2[i]) * 1e9) #Angstom + energy = proj_vib_freq_cm_1[i] * SPEEDOFLIGHT * 100 * PLANCKS_CONS # cm-1 to Joules + amplitudes.append(np.sqrt((0.5 * (occ + 1) * energy) / force_constants_J_m_2[i]) * 1e9) # Angstom else: amplitudes.append(0) - #Determine the actual displacements and velocities + # Determine the actual displacements and velocities displacements = [] velocities = [] - random_0_1 = [np.random.normal(0,1) for i in range(len(amplitudes))] + random_0_1 = [np.random.normal(0, 1) for i in range(len(amplitudes))] for i, amplitude in enumerate(amplitudes): - if force_constants_J_m_2[i] > 0: - - displacements.append(amplitude - * np.cos(2 * np.pi * random_0_1[i]) - * proj_hessian_eigvec[i]) - - velocities.append(-1 * proj_vib_freq_cm_1[i] * SPEEDOFLIGHT * 100 * 2 * np.pi - * amplitude - * np.sin(2 * np.pi * random_0_1[i]) - * proj_hessian_eigvec[i] / Bohr**2) - + displacements.append(amplitude * np.cos(2 * np.pi * random_0_1[i]) * proj_hessian_eigvec[i]) + + velocities.append( + -1 + * proj_vib_freq_cm_1[i] + * SPEEDOFLIGHT + * 100 + * 2 + * np.pi + * amplitude + * np.sin(2 * np.pi * random_0_1[i]) + * proj_hessian_eigvec[i] + / Bohr**2 + ) + elif i == 0: - displacements.append(0) velocities.append(0) - + # Extra kick for lowest imagninary mode(s) - velocities.append(-1 * proj_vib_freq_cm_1[i] * SPEEDOFLIGHT * 100 * 2 * np.pi - * amplitude - * np.sin(2 * np.pi * random_0_1[i]) - * proj_hessian_eigvec[i] / Bohr**2) - - + velocities.append( + -1 + * proj_vib_freq_cm_1[i] + * SPEEDOFLIGHT + * 100 + * 2 + * np.pi + * amplitude + * np.sin(2 * np.pi * random_0_1[i]) + * proj_hessian_eigvec[i] + / Bohr**2 + ) # todo: properly import / name / document units and unit conversions + else: - displacements.append(0) velocities.append(0) - - tot_disp = np.sum(np.array(displacements),axis=0) - #In angstroms - disp_xyz = xyz + tot_disp.reshape(1,-1,3) - #In angstroms per second - tot_vel_plus = np.sum(np.array(velocities),axis=0).reshape(1,-1,3) + + tot_disp = np.sum(np.array(displacements), axis=0) + # In angstroms + disp_xyz = xyz + tot_disp.reshape(1, -1, 3) + # In angstroms per second + tot_vel_plus = np.sum(np.array(velocities), axis=0).reshape(1, -1, 3) tot_vel_minus = -1 * tot_vel_plus - + return disp_xyz, tot_vel_plus, tot_vel_minus diff --git a/nff/reactive_tools/reactive_langevin.py b/nff/reactive_tools/reactive_langevin.py index ffecd1c7..b6ab0ddf 100644 --- a/nff/reactive_tools/reactive_langevin.py +++ b/nff/reactive_tools/reactive_langevin.py @@ -1,75 +1,72 @@ +import numpy as np from ase.io import Trajectory -from ase.md.langevin import * -from ase import Atoms -from ase.units import Bohr,Rydberg,kJ,kB,fs,Hartree,mol,kcal,second,Ang +# todo check if this necessary, probably better to properly typehint and check the input arguments +from ase.md.langevin import * # noqa +from ase.units import Ang, fs, kB, second + from nff.md.utils import NeuralMDLogger, write_traj + class Reactive_Dynamics: - - def __init__(self, - atomsbatch, - nms_vel, - mdparam, - ): - - # initialize the atoms batch system + def __init__( + self, + atomsbatch, + nms_vel, + mdparam, + ): + # initialize the atoms batch system self.atomsbatch = atomsbatch self.mdparam = mdparam - - #initialize velocity from nms + + # initialize velocity from nms self.vel = nms_vel - - self.temperature = self.mdparam['T_init'] - - self.friction = self.mdparam['friction'] - + + self.temperature = self.mdparam["T_init"] + + self.friction = self.mdparam["friction"] + # todo: structure optimization before starting - - # intialize system momentum by normal mode sampling - self.atomsbatch.set_velocities(self.vel.reshape(-1,3) * Ang / second) - - # set thermostats - integrator = self.mdparam['thermostat'] - - self.integrator = integrator(self.atomsbatch, - self.mdparam['time_step'] * fs, - self.temperature * kB, - self.friction) - - # attach trajectory dump - self.traj = Trajectory(self.mdparam['traj_filename'], 'w', self.atomsbatch) - self.integrator.attach(self.traj.write, interval=mdparam['save_frequency']) - + + # intialize system momentum by normal mode sampling + self.atomsbatch.set_velocities(self.vel.reshape(-1, 3) * Ang / second) + + # set thermostats + integrator = self.mdparam["thermostat"] + + self.integrator = integrator( + self.atomsbatch, self.mdparam["time_step"] * fs, self.temperature * kB, self.friction + ) + + # attach trajectory dump + self.traj = Trajectory(self.mdparam["traj_filename"], "w", self.atomsbatch) + self.integrator.attach(self.traj.write, interval=mdparam["save_frequency"]) + # attach log file - self.integrator.attach(NeuralMDLogger(self.integrator, - self.atomsbatch, - self.mdparam['thermo_filename'], - mode='a'), interval=mdparam['save_frequency']) + self.integrator.attach( + NeuralMDLogger(self.integrator, self.atomsbatch, self.mdparam["thermo_filename"], mode="a"), + interval=mdparam["save_frequency"], + ) def run(self): - - self.integrator.run(self.mdparam['steps']) + self.integrator.run(self.mdparam["steps"]) + + # self.traj.close() - #self.traj.close() - - def save_as_xyz(self, filename): - - ''' - TODO: save time information + """ + TODO: save time information TODO: subclass TrajectoryReader/TrajectoryReader to digest AtomsBatch instead of Atoms? - TODO: other system variables in .xyz formats - ''' - traj = Trajectory(self.mdparam['traj_filename'], mode='r') - + TODO: other system variables in .xyz formats + """ + traj = Trajectory(self.mdparam["traj_filename"], mode="r") + xyz = [] for snapshot in traj: - frames = np.concatenate([ - snapshot.get_atomic_numbers().reshape(-1, 1), - snapshot.get_positions().reshape(-1, 3) - ], axis=1) - + frames = np.concatenate( + [snapshot.get_atomic_numbers().reshape(-1, 1), snapshot.get_positions().reshape(-1, 3)], axis=1 + ) + xyz.append(frames) - + write_traj(filename, np.array(xyz)) diff --git a/nff/reactive_tools/utils.py b/nff/reactive_tools/utils.py index 6d4891e8..669dfead 100644 --- a/nff/reactive_tools/utils.py +++ b/nff/reactive_tools/utils.py @@ -1,12 +1,9 @@ -from ase.vibrations import Vibrations -from ase.units import Bohr, mol, kcal -from ase import Atoms - import numpy as np - +from ase import Atoms +from ase.units import Bohr, kcal, mol +from ase.vibrations import Vibrations from rdkit import Chem - PT = Chem.GetPeriodicTable() HA2J = 4.359744e-18 @@ -38,8 +35,8 @@ def xyz_to_ase_atoms(xyz_file): sym = [] pos = [] - f = open(xyz_file, "r") - lines = f.readlines() + with open(xyz_file, "r") as f: + lines = f.readlines() for i, line in enumerate(lines): if i > 1: @@ -56,9 +53,7 @@ def xyz_to_ase_atoms(xyz_file): def moi_tensor(massvec, expmassvec, xyz): # Center of Mass - com = np.sum(expmassvec.reshape(-1, 3) * xyz.reshape(-1, 3), axis=0) / np.sum( - massvec - ) + com = np.sum(expmassvec.reshape(-1, 3) * xyz.reshape(-1, 3), axis=0) / np.sum(massvec) # xyz shifted to COM xyz_com = xyz.reshape(-1, 3) - com @@ -111,24 +106,18 @@ def trans_rot_vec(massvec, xyz_com, moi_eigvec): big_p = np.matmul(xyz_com, moi_eigvec) d4 = ( - np.repeat(big_p[:, 1], 3).reshape(-1) - * np.tile(moi_eigvec[:, 2], len(massvec)).reshape(-1) - - np.repeat(big_p[:, 2], 3).reshape(-1) - * np.tile(moi_eigvec[:, 1], len(massvec)).reshape(-1) + np.repeat(big_p[:, 1], 3).reshape(-1) * np.tile(moi_eigvec[:, 2], len(massvec)).reshape(-1) + - np.repeat(big_p[:, 2], 3).reshape(-1) * np.tile(moi_eigvec[:, 1], len(massvec)).reshape(-1) ) * expsqrtmassvec d5 = ( - np.repeat(big_p[:, 2], 3).reshape(-1) - * np.tile(moi_eigvec[:, 0], len(massvec)).reshape(-1) - - np.repeat(big_p[:, 0], 3).reshape(-1) - * np.tile(moi_eigvec[:, 2], len(massvec)).reshape(-1) + np.repeat(big_p[:, 2], 3).reshape(-1) * np.tile(moi_eigvec[:, 0], len(massvec)).reshape(-1) + - np.repeat(big_p[:, 0], 3).reshape(-1) * np.tile(moi_eigvec[:, 2], len(massvec)).reshape(-1) ) * expsqrtmassvec d6 = ( - np.repeat(big_p[:, 0], 3).reshape(-1) - * np.tile(moi_eigvec[:, 1], len(massvec)).reshape(-1) - - np.repeat(big_p[:, 1], 3).reshape(-1) - * np.tile(moi_eigvec[:, 0], len(massvec)).reshape(-1) + np.repeat(big_p[:, 0], 3).reshape(-1) * np.tile(moi_eigvec[:, 1], len(massvec)).reshape(-1) + - np.repeat(big_p[:, 1], 3).reshape(-1) * np.tile(moi_eigvec[:, 0], len(massvec)).reshape(-1) ) * expsqrtmassvec d1_norm = d1 / np.linalg.norm(d1) @@ -148,12 +137,7 @@ def vib_analy(r, xyz, hessian): # xyz is the cartesian coordinates in Angstrom # Hessian elements in atomic units (Ha/bohr^2) - massvec = np.array( - [ - PT.GetAtomicWeight(i.item()) * AMU2KG - for i in list(np.array(r.reshape(-1)).astype(int)) - ] - ) + massvec = np.array([PT.GetAtomicWeight(i.item()) * AMU2KG for i in list(np.array(r.reshape(-1)).astype(int))]) expmassvec = np.repeat(massvec, 3) sqrtinvmassvec = np.divide(1.0, np.sqrt(expmassvec)) hessian_mwc = np.einsum("i,ij,j->ij", sqrtinvmassvec, hessian, sqrtinvmassvec) @@ -178,9 +162,7 @@ def vib_analy(r, xyz, hessian): hessian_eigval_abs = np.abs(hessian_eigval) - pre_vib_freq_cm_1 = np.sqrt(hessian_eigval_abs * HA2J * 10e19) / ( - SPEEDOFLIGHT * 2 * np.pi * BOHRS2ANG * 100 - ) + pre_vib_freq_cm_1 = np.sqrt(hessian_eigval_abs * HA2J * 10e19) / (SPEEDOFLIGHT * 2 * np.pi * BOHRS2ANG * 100) vib_freq_cm_1 = pre_vib_freq_cm_1.copy() @@ -193,9 +175,7 @@ def vib_analy(r, xyz, hessian): if np.abs(freq) < 1.0: trans_rot_elms.append(i) - force_constants_J_m_2 = np.delete( - hessian_eigval * HA2J * 1e20 / (BOHRS2ANG**2) * AMU2KG, trans_rot_elms - ) + force_constants_J_m_2 = np.delete(hessian_eigval * HA2J * 1e20 / (BOHRS2ANG**2) * AMU2KG, trans_rot_elms) proj_vib_freq_cm_1 = np.delete(vib_freq_cm_1, trans_rot_elms) proj_hessian_eigvec = np.delete(hessian_eigvec.T, trans_rot_elms, 0) diff --git a/nff/data/tests/__init__.py b/nff/tests/__init__.py similarity index 100% rename from nff/data/tests/__init__.py rename to nff/tests/__init__.py diff --git a/nff/tests/conftest.py b/nff/tests/conftest.py new file mode 100644 index 00000000..87cbfc4e --- /dev/null +++ b/nff/tests/conftest.py @@ -0,0 +1,15 @@ +import os + +import pytest +import torch + +torch.set_num_threads(int(os.getenv("OMP_NUM_THREADS", "1"))) + + +def pytest_addoption(parser): + parser.addoption("--device", action="store", default="cpu", help="Whether to use the CPU or GPU for the tests") + + +@pytest.fixture +def device(request): + return request.config.getoption("--device") diff --git a/nff/tests/data/azo_diabat.pth.tar b/nff/tests/data/azo_diabat.pth.tar new file mode 100644 index 00000000..1065bcae Binary files /dev/null and b/nff/tests/data/azo_diabat.pth.tar differ diff --git a/nff/tests/data/dataset.pth.tar b/nff/tests/data/dataset.pth.tar new file mode 100644 index 00000000..51dff90c Binary files /dev/null and b/nff/tests/data/dataset.pth.tar differ diff --git a/nff/md/zhu_nakamura/dynamics_test.py b/nff/tests/dynamics_test.py similarity index 75% rename from nff/md/zhu_nakamura/dynamics_test.py rename to nff/tests/dynamics_test.py index bf391d24..f0a09b24 100644 --- a/nff/md/zhu_nakamura/dynamics_test.py +++ b/nff/tests/dynamics_test.py @@ -1,46 +1,122 @@ +import copy import os -import numpy as np +import pickle import random -import json -import pdb -import logging +import unittest as ut from datetime import datetime -from pytz import timezone -import torch -import copy -import csv -import pickle +from pathlib import Path -from ase.md.md import MolecularDynamics +import numpy as np +import pytest +import torch from ase.io.trajectory import Trajectory -from ase import Atoms +from torch.utils.data import DataLoader -from nff.md.utils import mol_dot, mol_norm, ZhuNakamuraLogger, atoms_to_nxyz -from nff.md.nvt_test import NoseHoover, NoseHooverChain -from nff.utils.constants import BOHR_RADIUS, FS_TO_AU, AMU_TO_AU, FS_TO_ASE, ASE_TO_FS, EV_TO_AU from nff.data import Dataset, collate_dicts -from nff.utils.cuda import batch_to -from nff.utils.constants import KCAL_TO_AU, KB_EV +from nff.io.ase import AtomsBatch +from nff.io.ase_calcs import NeuralFF +from nff.md.nvt import Langevin +from nff.md.nvt_ax import NoseHoover, NoseHooverChain +from nff.md.utils_ax import ZhuNakamuraLogger, atoms_to_nxyz, mol_dot, mol_norm from nff.train import load_model - -from torch.utils.data import DataLoader - - +from nff.utils.constants import AMU_TO_AU, ASE_TO_FS, BOHR_RADIUS, EV_TO_AU, FS_TO_AU, KCAL_TO_AU +from nff.utils.cuda import batch_to HBAR = 1 OUT_FILE = "trj.csv" LOG_FILE = "trj.log" +this_file = Path(__file__).resolve() +ETHANOL_MODEL_PATH = ( + this_file.parent.parent.parent / "tutorials" / "models" / "cco_1" / "best_model" +) # Simon's SchNet model -METHOD_DIC = { - "nosehoover": NoseHoover, - "nosehooverchain": NoseHooverChain - } - +METHOD_DIC = {"nosehoover": NoseHoover, "nosehooverchain": NoseHooverChain} -class ZhuNakamuraDynamics(ZhuNakamuraLogger): +def get_directed_ethanol(): + """Returns an ethanol molecule. + Returns: + ethanol (Atoms) + """ + props = { + "nxyz": torch.Tensor( + [ + [6.0000e00, 5.5206e-03, 5.9149e-01, -8.1382e-04], + [6.0000e00, -1.2536e00, -2.5536e-01, -2.9801e-02], + [8.0000e00, 1.0878e00, -3.0755e-01, 4.8230e-02], + [1.0000e00, 6.2821e-02, 1.2838e00, -8.4279e-01], + [1.0000e00, 6.0567e-03, 1.2303e00, 8.8535e-01], + [1.0000e00, -2.2182e00, 1.8981e-01, -5.8160e-02], + [1.0000e00, -9.1097e-01, -1.0539e00, -7.8160e-01], + [1.0000e00, -1.1920e00, -7.4248e-01, 9.2197e-01], + [1.0000e00, 1.8488e00, -2.8632e-02, -5.2569e-01], + ] + ), + "energy": torch.tensor(-4.3701), + "energy_grad": torch.Tensor( + [ + [10.2030, -33.6563, 1.9132], + [-59.5878, 42.4086, 10.0746], + [-36.9785, 2.0060, 18.7998], + [-1.8185, 5.6604, 4.6715], + [-1.8685, 0.9660, -1.9927], + [11.0286, -11.6878, 18.4956], + [38.0142, -24.5804, -16.6240], + [5.8505, 15.7041, -12.9981], + [35.1569, 3.1794, -22.3399], + ] + ), + "smiles": "CCO", + "num_atoms": torch.tensor(9), + "nbr_list": torch.Tensor( + [ + [0, 1], + [0, 2], + [0, 3], + [0, 4], + [0, 5], + [0, 6], + [0, 7], + [0, 8], + [1, 2], + [1, 3], + [1, 4], + [1, 5], + [1, 6], + [1, 7], + [1, 8], + [2, 3], + [2, 4], + [2, 5], + [2, 6], + [2, 7], + [2, 8], + [3, 4], + [3, 5], + [3, 6], + [3, 7], + [3, 8], + [4, 5], + [4, 6], + [4, 7], + [4, 8], + [5, 6], + [5, 7], + [5, 8], + [6, 7], + [6, 8], + [7, 8], + ] + ), + "charge": torch.tensor(0.0), + "spin": torch.tensor(0.0), + } + return AtomsBatch(positions=props["nxyz"][:, 1:], directed=True, numbers=props["nxyz"][:, 0], props=props) + + +class ZhuNakamuraDynamics(ZhuNakamuraLogger): """ Class for running Zhu-Nakamura surface-hopping dynamics. This method follows the description in Yu et. al, "Trajectory based nonadiabatic molecular dynamics without calculating nonadiabatic @@ -104,12 +180,13 @@ class ZhuNakamuraDynamics(ZhuNakamuraLogger): Properties: - positions: returns self._positions. Updating positions updates self._positions, self.positions_list, and positions of - self.atoms. - velocities: returns self._velocities. Updating positions updates self._velocities, self.velocities_list and velocities - of self.atoms. + positions: returns self._positions. Updating positions updates self._positions, self.positions_list, + and positions of self.atoms. + velocities: returns self._velocities. Updating positions updates self._velocities, self.velocities_list + and velocities of self.atoms. forces: returns self._forces. Updating forces updates self._forces, self.forces_list and forces of self.atoms. - energies: returns self._energies. Updating energies updates self._energies, self.energy_list and energies of self.atoms. + energies: returns self._energies. Updating energies updates self._energies, self.energy_list + and energies of self.atoms. surf: returns self._surf. Updating surf updates self._surf and self.surf_list. in_trj: returns self._in_trj. Updating in_trj updates self._in_trj and self. time: returns self._time. Updating time updates self.time_list. @@ -117,16 +194,18 @@ class ZhuNakamuraDynamics(ZhuNakamuraLogger): self.hopping_probability_list """ - def __init__(self, - atoms, - timestep, - max_time, - initial_time=0.0, - initial_surf=1, - num_states=2, - out_file=OUT_FILE, - log_file=LOG_FILE, - **kwargs): + def __init__( + self, + atoms, + timestep, + max_time, + initial_time=0.0, + initial_surf=1, + num_states=2, + out_file=OUT_FILE, + log_file=LOG_FILE, + **kwargs, + ): """ Initializes a ZhuNakamura instance. @@ -149,23 +228,22 @@ def __init__(self, """ self.atoms = atoms - self.dt = timestep*FS_TO_AU - self.max_time = max_time*FS_TO_AU + self.dt = timestep * FS_TO_AU + self.max_time = max_time * FS_TO_AU self.num_states = num_states self.Natom = atoms.get_number_of_atoms() self.out_file = out_file self.log_file = log_file self.setup_logging() - # everything in a.u. other than positions (which are in angstrom) self._positions = atoms.get_positions() - self._velocities = atoms.get_velocities()*EV_TO_AU/(ASE_TO_FS*FS_TO_AU) + self._velocities = atoms.get_velocities() * EV_TO_AU / (ASE_TO_FS * FS_TO_AU) self._forces = None self._energies = None self._surf = initial_surf self._in_trj = True - self._time = initial_time*FS_TO_AU + self._time = initial_time * FS_TO_AU self._hopping_probabilities = [] self.position_list = [self._positions] @@ -189,15 +267,19 @@ def __init__(self, self.ke_parallel = 0.0 self.ke = 0.0 - - save_keys = ["position_list", "velocity_list", "force_list", "energy_list", "surf_list", "in_trj_list", - "hopping_probability_list", "time_list"] + save_keys = [ + "position_list", + "velocity_list", + "force_list", + "energy_list", + "surf_list", + "in_trj_list", + "hopping_probability_list", + "time_list", + ] super().__init__(save_keys=save_keys, **self.__dict__) - - - @property def positions(self): return self._positions @@ -289,7 +371,6 @@ def energies(self, value): else: self.energy_list = [value] - @surf.setter def surf(self, value): """ @@ -362,7 +443,7 @@ def get_masses(self): Returns: self.atoms.get_masses() (numpy.ndarray): masses """ - return self.atoms.get_masses()*AMU_TO_AU + return self.atoms.get_masses() * AMU_TO_AU def get_accel(self): """ @@ -374,11 +455,10 @@ def get_accel(self): # the force is force acting on the current state force = self.forces[self.surf] - accel = ( force / self.get_masses().reshape(-1, 1) ) + accel = force / self.get_masses().reshape(-1, 1) return accel def position_step(self): - # get current acceleration and velocity accel = self.get_accel() self.old_accel = accel @@ -388,11 +468,9 @@ def position_step(self): # Note also that we don't use += here, because that causes problems with # setters. - self.positions = self.positions + (self.velocities * self.dt + 1 / - 2 * accel * self.dt ** 2) * BOHR_RADIUS + self.positions = self.positions + (self.velocities * self.dt + 1 / 2 * accel * self.dt**2) * BOHR_RADIUS def velocity_step(self): - new_accel = self.get_accel() self.velocities = self.velocities + 1 / 2 * (new_accel + self.old_accel) * self.dt # assume the current frame is in the trajectory until finding out otherwise @@ -400,15 +478,14 @@ def velocity_step(self): # update surf (which also appends to surf_list) self.surf = self.surf self.time = self.time + self.dt - self.log("Completed step {}. Currently in state {}.".format( - int(self.time/self.dt), self.surf)) - self.log("Relative energies are {} eV".format(", ".join( - ((self.energies - self.energies[0])*27.2).reshape(-1).astype("str").tolist()))) - + self.log(f"Completed step {int(self.time / self.dt)}. Currently in state {self.surf}.") + self.log( + "Relative energies are {} eV".format( + ", ".join(((self.energies - self.energies[0]) * 27.2).reshape(-1).astype("str").tolist()) + ) + ) def md_step(self): - - """ Take a regular molecular dynamics step on the current surface. """ @@ -419,7 +496,6 @@ def md_step(self): self.update_energies() self.velocity_step() - def check_crossing(self): """Check if we're at an avoided crossing by seeing if the energy gap was at a minimum in the last step. Args: @@ -429,7 +505,6 @@ def check_crossing(self): new_surfs (list): list of surfaces that are at an avoided crossing with the current surface. """ - new_surfs = [] at_crossing = False @@ -448,8 +523,7 @@ def check_crossing(self): if i == self.surf: continue # list of energy gaps - gaps = [abs(energies[i] - energies[self.surf]) - for energies in self.energy_list[-3:]] + gaps = [abs(energies[i] - energies[self.surf]) for energies in self.energy_list[-3:]] # whether or not the middle gap is the smallest of the three gap_min = gaps[0] > gaps[1] and gaps[2] > gaps[1] if gap_min: @@ -477,23 +551,25 @@ def update_diabatic_quants(self, lower_state, upper_state): r_12 = self.position_list[-2] - self.position_list[-1] # diabatic forecs on the lower state - lower_diabatic_forces = -(-self.force_list[-1][lower_state] * r_10 + - self.force_list[-3][upper_state] * r_12) / r_20 + lower_diabatic_forces = ( + -(-self.force_list[-1][lower_state] * r_10 + self.force_list[-3][upper_state] * r_12) / r_20 + ) # diabatic forces on the upper state - upper_diabatic_forces = -(-self.force_list[-1][upper_state] * r_10 + - self.force_list[-3][lower_state] * r_12) / r_20 + upper_diabatic_forces = ( + -(-self.force_list[-1][upper_state] * r_10 + self.force_list[-3][lower_state] * r_12) / r_20 + ) # array of forces on the lower and upper diabatic states - self.diabatic_forces = np.append([lower_diabatic_forces], [ - upper_diabatic_forces], axis=0) + self.diabatic_forces = np.append([lower_diabatic_forces], [upper_diabatic_forces], axis=0) # update diabatic coupling self.diabatic_coupling = ( - self.energy_list[-2][upper_state].item() - self.energy_list[-2][lower_state].item()) / 2 + self.energy_list[-2][upper_state].item() - self.energy_list[-2][lower_state].item() + ) / 2 # update Zhu difference parameter norm_vec = mol_norm(self.diabatic_forces[1] - self.diabatic_forces[0]) - self.zhu_difference = np.sum(norm_vec ** 2 / self.get_masses()) ** 0.5 + self.zhu_difference = np.sum(norm_vec**2 / self.get_masses()) ** 0.5 # update Zhu product parameter and the Zhu sign parameter prods = self.diabatic_forces[0] * self.diabatic_forces[1] @@ -503,19 +579,15 @@ def update_diabatic_quants(self, lower_state, upper_state): # get parallel component of velocity and the associated KE # First normalize s-vector to give n-vector - s = (self.diabatic_forces[1] - self.diabatic_forces[0] - ) / self.get_masses().reshape(-1, 1) ** 0.5 + s = (self.diabatic_forces[1] - self.diabatic_forces[0]) / self.get_masses().reshape(-1, 1) ** 0.5 self.n_vector = s / mol_norm(s).reshape(-1, 1) # Then get ke's self.v_parallel = mol_dot(self.velocity_list[-2], self.n_vector) - self.ke_parallel = np.sum( - self.get_masses() * (self.v_parallel ** 2) / 2) - self.ke = np.sum(self.get_masses() * - mol_norm(self.velocity_list[-2]) ** 2 / 2) + self.ke_parallel = np.sum(self.get_masses() * (self.v_parallel**2) / 2) + self.ke = np.sum(self.get_masses() * mol_norm(self.velocity_list[-2]) ** 2 / 2) def rescale_v(self, old_surf, new_surf): - """ Re-scale the velocity after a hop. Args: @@ -532,20 +604,16 @@ def rescale_v(self, old_surf, new_surf): v_par_vec = self.n_vector * (self.v_parallel).reshape(-1, 1) # the scaling factor for the velocities - scale_arg = (((energy[old_surf] + (self.ke_parallel)) - - energy[new_surf]) / (self.ke_parallel)) + scale_arg = ((energy[old_surf] + (self.ke_parallel)) - energy[new_surf]) / (self.ke_parallel) if scale_arg < 0: return "err" - scale = (((energy[old_surf] + (self.ke_parallel)) - - energy[new_surf]) / (self.ke_parallel)) ** 0.5 - self.velocities = scale * v_par_vec + \ - (self.velocity_list[-2] - v_par_vec) + scale = (((energy[old_surf] + (self.ke_parallel)) - energy[new_surf]) / (self.ke_parallel)) ** 0.5 + self.velocities = scale * v_par_vec + (self.velocity_list[-2] - v_par_vec) + return None def update_probabilities(self): - - """ Update the Zhu a, b and p probabilities. """ @@ -558,50 +626,44 @@ def update_probabilities(self): return # if the molecule's exploded then move on - if 'nan' in self.positions.astype("str") or 'nan' in self.forces.astype("str"): + if "nan" in self.positions.astype("str") or "nan" in self.forces.astype("str"): self.hopping_probabilities = hopping_probabilities return for new_surf in new_surfs: - # get the upper and lower state by sorting the current surface and the new one lower_state, upper_state = sorted((self.surf, new_surf)) self.update_diabatic_quants(lower_state, upper_state) # use context manager to ignore any divide by 0's - with np.errstate(divide='ignore', invalid='ignore'): - + with np.errstate(divide="ignore", invalid="ignore"): # calculate the zhu a parameter - a_numerator = HBAR ** 2 / 2 * self.zhu_product * self.zhu_difference + a_numerator = HBAR**2 / 2 * self.zhu_product * self.zhu_difference a_denominator = (2 * self.diabatic_coupling) ** 3 - zhu_a = np.nan_to_num( - np.divide(a_numerator, a_denominator) ** 0.5) + zhu_a = np.nan_to_num(np.divide(a_numerator, a_denominator) ** 0.5) # calculate the zhu b parameter, starting with Et and Ex et = self.ke_parallel + self.energy_list[-2][self.surf].item() - ex = (self.energy_list[-2][upper_state].item() + - self.energy_list[-2][lower_state].item()) / 2 - b_numerator = (et - ex) * self.zhu_difference / \ - self.zhu_product + ex = (self.energy_list[-2][upper_state].item() + self.energy_list[-2][lower_state].item()) / 2 + b_numerator = (et - ex) * self.zhu_difference / self.zhu_product b_denominator = 2 * self.diabatic_coupling - zhu_b = np.nan_to_num( - np.divide(b_numerator, b_denominator) ** 0.5) + zhu_b = np.nan_to_num(np.divide(b_numerator, b_denominator) ** 0.5) # calculating the hopping probability - zhu_p = np.nan_to_num(np.exp(-np.pi / 4 / zhu_a * (2 / (zhu_b ** 2 + - (abs((zhu_b ** 4) + ( - self.zhu_sign) * 1.0)) ** 0.5)) ** 0.5)) + zhu_p = np.nan_to_num( + np.exp( + -np.pi / 4 / zhu_a * (2 / (zhu_b**2 + (abs((zhu_b**4) + (self.zhu_sign) * 1.0)) ** 0.5)) ** 0.5 + ) + ) # add this info to the list of hopping probabilities - hopping_probabilities.append( - {"zhu_a": zhu_a, "zhu_b": zhu_b, "zhu_p": zhu_p, "new_surf": new_surf}) + hopping_probabilities.append({"zhu_a": zhu_a, "zhu_b": zhu_b, "zhu_p": zhu_p, "new_surf": new_surf}) self.hopping_probabilities = hopping_probabilities def should_hop(self, zhu_a, zhu_b, zhu_p): - - """ + """ Decide whether or not to hop based on the zhu a, b and p parameters. Args: zhu_a (float): Zhu a parameter @@ -626,7 +688,6 @@ def should_hop(self, zhu_a, zhu_b, zhu_p): return will_hop def hop(self, new_surf): - """ Hop from the current surface to a new surface at an avoided crossing. Args: @@ -659,6 +720,7 @@ def hop(self, new_surf): self.time = self.time - self.dt self.modify_save() + return None def full_step(self, compute_internal_forces=True): """ @@ -683,14 +745,16 @@ def full_step(self, compute_internal_forces=True): # loop through sets of states to hop between for probability_dic in self.hopping_probabilities: - zhu_a = probability_dic["zhu_a"] zhu_b = probability_dic["zhu_b"] zhu_p = probability_dic["zhu_p"] new_surf = probability_dic["new_surf"] - self.log("Attempting hop from state {} to state {}. Probability is {}.".format( - self.surf, probability_dic["new_surf"], zhu_p)) + self.log( + "Attempting hop from state {} to state {}. Probability is {}.".format( + self.surf, probability_dic["new_surf"], zhu_p + ) + ) # decide whether or not to hop based on Zhu a, b, and p will_hop = self.should_hop(zhu_a, zhu_b, zhu_p) @@ -699,15 +763,12 @@ def full_step(self, compute_internal_forces=True): if will_hop: out = self.hop(new_surf) if out != "err": - self.log("Hopped from from state {} to state {}.".format( - self.surf, probability_dic["new_surf"])) + self.log("Hopped from from state {} to state {}.".format(self.surf, probability_dic["new_surf"])) return else: - self.log("Did not hop from from state {} to state {}.".format( - self.surf, probability_dic["new_surf"])) + self.log("Did not hop from from state {} to state {}.".format(self.surf, probability_dic["new_surf"])) def run(self): - # save intitial conditions self.update_energies() @@ -715,9 +776,11 @@ def run(self): self.save() self.log("Beginning surface hopping at {}.".format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))) - self.log("Relative energies are {} eV".format(", ".join( - ((self.energies - self.energies[0])*27.2).reshape(-1).astype("str").tolist()))) - + self.log( + "Relative energies are {} eV".format( + ", ".join(((self.energies - self.energies[0]) * 27.2).reshape(-1).astype("str").tolist()) + ) + ) while self.time < self.max_time: self.step() @@ -727,15 +790,11 @@ def run(self): self.log("Surface hopping completed normally at {}.".format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))) - - - class BatchedZhuNakamura: - """ - A class for running several Zhu Nakamura trajectories at once. This is done by taking a half step for each trajectory, - combining all the xyz's into a dataset and batching it for the network, and then de-batching to put the forces and energies - back in the trajectories. + A class for running several Zhu Nakamura trajectories at once. This is done by taking a half step + for each trajectory, combining all the xyz's into a dataset and batching it for the network, and then de-batching + to put the forces and energies back in the trajectories. Attributes: num_trj (int): number of concurrent trajectories @@ -753,7 +812,6 @@ class BatchedZhuNakamura: """ def __init__(self, atoms_list, props, batched_params, zhu_params): - """ Initialize. Args: @@ -763,13 +821,11 @@ def __init__(self, atoms_list, props, batched_params, zhu_params): zhu_params (dict): parameters related to Zhu Nakamura """ - self.num_trj = batched_params["num_trj"] self.zhu_trjs = self.make_zhu_trjs(props, atoms_list, zhu_params) self.max_time = self.zhu_trjs[0].max_time - self.energy_keys = ["energy_{}".format(i) for i in range(self.zhu_trjs[0].num_states)] - self.grad_keys = ["{}_grad".format(key) for key in self.energy_keys] - + self.energy_keys = [f"energy_{i}" for i in range(self.zhu_trjs[0].num_states)] + self.grad_keys = [f"{key}_grad" for key in self.energy_keys] self.props = self.duplicate_props(props) self.nbr_update_period = batched_params["nbr_update_period"] @@ -779,9 +835,7 @@ def __init__(self, atoms_list, props, batched_params, zhu_params): self.batch_size = batched_params["batch_size"] self.cutoff = batched_params["cutoff"] - def make_zhu_trjs(self, props, atoms_list, zhu_params): - """ Instantiate the Zhu Nakamura objects. Args: @@ -800,17 +854,15 @@ def make_zhu_trjs(self, props, atoms_list, zhu_params): zhu_trjs = [] for i, atoms in enumerate(atoms_list): - these_params = copy.deepcopy(zhu_params) - these_params["out_file"] = "{}_{}.csv".format(base_out_name, i) - these_params["log_file"] = "{}_{}.log".format(base_log_name, i) + these_params["out_file"] = f"{base_out_name}_{i}.csv" + these_params["log_file"] = f"{base_log_name}_{i}.log" zhu_trjs.append(ZhuNakamuraDynamics(atoms=atoms, **these_params)) return zhu_trjs def duplicate_props(self, props): - """ Duplicate properties, once for each trajectory. Args: @@ -822,10 +874,10 @@ def duplicate_props(self, props): new_props = dict() for key, val in props.items(): if type(val) is list: - new_props[key] = val*self.num_trj + new_props[key] = val * self.num_trj elif hasattr(val, "tolist"): typ = type(val) - new_props[key] = typ((val.tolist())*self.num_trj) + new_props[key] = typ((val.tolist()) * self.num_trj) else: raise Exception @@ -834,9 +886,7 @@ def duplicate_props(self, props): return new_props - def update_energies_forces(self, trjs, get_new_neighbors): - """ Update the energies and forces for the molecules of each trajectory. Args: @@ -845,11 +895,10 @@ def update_energies_forces(self, trjs, get_new_neighbors): Returns: None """ - + nxyz_data = [atoms_to_nxyz(trj.atoms) for trj in trjs] self.props.update({"nxyz": nxyz_data}) - dataset = Dataset(props=self.props.copy(), units='kcal/mol') - + dataset = Dataset(props=self.props.copy(), units="kcal/mol") if get_new_neighbors: dataset.generate_neighbor_list(cutoff=self.cutoff) @@ -859,22 +908,25 @@ def update_energies_forces(self, trjs, get_new_neighbors): loader = DataLoader(dataset, batch_size=self.batch_size, collate_fn=collate_dicts) for i, batch in enumerate(loader): - batch = batch_to(batch, self.device) results = self.model(batch) for key in self.grad_keys: - N = batch["num_atoms"].cpu().detach().numpy().tolist() + N = batch["num_atoms"].cpu().detach().numpy().tolist() results[key] = torch.split(results[key], N) - current_trj = i*self.batch_size + current_trj = i * self.batch_size - for j, trj in enumerate(trjs[current_trj:current_trj+self.batch_size]): + for j, trj in enumerate(trjs[current_trj : current_trj + self.batch_size]): energies = [] forces = [] for key in self.energy_keys: - energy = (results[key][j].item())*KCAL_TO_AU["energy"] - force = ((-results[key + "_grad"][j]).detach().cpu().numpy())*KCAL_TO_AU["energy"]*KCAL_TO_AU["_grad"] + energy = (results[key][j].item()) * KCAL_TO_AU["energy"] + force = ( + ((-results[key + "_grad"][j]).detach().cpu().numpy()) + * KCAL_TO_AU["energy"] + * KCAL_TO_AU["_grad"] + ) energies.append(energy) forces.append(force) @@ -882,7 +934,6 @@ def update_energies_forces(self, trjs, get_new_neighbors): trj.forces = np.array(forces) def step(self, get_new_neighbors): - """ Take a step for each trajectory Args: @@ -902,7 +953,7 @@ def step(self, get_new_neighbors): # take a velocity step trj.velocity_step() # take a "full_step" with compute_internal_forces=False, - # which just amounts to checking if you're at a crossing and + # which just amounts to checking if you're at a crossing and # potentially hopping trj.full_step(compute_internal_forces=False) @@ -911,7 +962,6 @@ def step(self, get_new_neighbors): trj.save() def run(self): - """ Run all the trajectories """ @@ -924,26 +974,19 @@ def run(self): num_steps = 0 while not complete: - num_steps += 1 - if np.mod(num_steps, self.nbr_update_period) == 0: - get_new_neighbors = True - else: - get_new_neighbors = False + get_new_neighbors = np.mod(num_steps, self.nbr_update_period) == 0 self.step(get_new_neighbors=get_new_neighbors) - print("Completed step {}".format(num_steps)) - - complete = all([trj.time >= self.max_time for trj in self.zhu_trjs]) + print(f"Completed step {num_steps}") + complete = all(trj.time >= self.max_time for trj in self.zhu_trjs) for trj in self.zhu_trjs: trj.output_to_json() - class CombinedZhuNakamura: - """ Class for combining an initial ground state MD simulation with BatchedZhuNakamura. Attributes: @@ -959,7 +1002,6 @@ class CombinedZhuNakamura: """ def __init__(self, atoms, zhu_params, batched_params, ground_params, props): - """ Initialize: atoms: ase Atoms objects @@ -976,7 +1018,6 @@ def __init__(self, atoms, zhu_params, batched_params, ground_params, props): ase_ground_params["trajectory"] = ground_params["savefile"] # ase_ground_params["temperature"] = ground_params["temperature"]*KB_EV - method = METHOD_DIC[ase_ground_params["thermostat"]] self.ground_dynamics = method(atoms, **ase_ground_params) self.ground_savefile = ground_params["savefile"] @@ -989,13 +1030,12 @@ def __init__(self, atoms, zhu_params, batched_params, ground_params, props): self.ground_params = ground_params def sample_ground_geoms(self): - - with open('atoms.pickle', 'rb') as f: + with open("atoms.pickle", "rb") as f: atoms = pickle.load(f) return [atoms] * self.num_trj """ - Run a ground state trajectory and extract starting geometries and velocities for each + Run a ground state trajectory and extract starting geometries and velocities for each Zhu Nakamura trajectory. Args: None @@ -1003,32 +1043,68 @@ def sample_ground_geoms(self): actual_states (list): list of atoms objects extracted from the trajectories. """ - steps = int(self.ground_params["max_time"]/self.ground_params["timestep"]) - equil_steps = int(self.ground_params["equil_time"]/self.ground_params["timestep"]) + steps = int(self.ground_params["max_time"] / self.ground_params["timestep"]) + equil_steps = int(self.ground_params["equil_time"] / self.ground_params["timestep"]) self.ground_dynamics.run(steps=steps) trj = Trajectory(self.ground_savefile) possible_states = [trj[index] for index in range(equil_steps, len(trj))] - random_indices = random.sample(range(len(possible_states)), self.num_trj) + random_indices = random.sample(range(len(possible_states)), self.num_trj) actual_states = [possible_states[index] for index in random_indices] return actual_states def run(self): - """ Run a ground state trajectory followed by a set of parallel Zhu Nakamura trajectories. """ atoms_list = self.sample_ground_geoms() - batched_zn = BatchedZhuNakamura(atoms_list=atoms_list, props=self.props, batched_params=self.batched_params, - zhu_params=self.zhu_params) + batched_zn = BatchedZhuNakamura( + atoms_list=atoms_list, props=self.props, batched_params=self.batched_params, zhu_params=self.zhu_params + ) batched_zn.run() - - - +# @pytest.mark.usefixtures("device") +@pytest.mark.skip("Works locally but need to update to work on remote CI") +class TestLangevin(ut.TestCase): + def setUp(self): + self.ethanol = get_directed_ethanol() + self.device = self._test_fixture_device + self.model = NeuralFF.from_file(ETHANOL_MODEL_PATH, device=self.device) + self.ethanol.set_calculator(self.model) + if os.path.exists("langevin.traj"): + os.remove("langevin.traj") + if os.path.exists("langevin.log"): + os.remove("langevin.log") + + @pytest.mark.timeout(30) + def test_langevin(self): + # Set up Langevin dynamics + my_dt = 1.0 # fs + my_temp = 100 # K + my_friction = 1.0 + logfile = "langevin.log" + + dyn = Langevin( + self.ethanol, + timestep=my_dt, + temperature=my_temp, + friction=my_friction, + maxwell_temp=my_temp, + logfile=logfile, + trajectory="langevin.traj", + ) + dyn.run(steps=40) + + # Check that the trajectory file was created + assert os.path.exists("langevin.traj") + assert os.path.exists("langevin.log") + + +if __name__ == "__main__": + ut.main() diff --git a/nff/io/tests/test_ase.py b/nff/tests/test_ase.py similarity index 91% rename from nff/io/tests/test_ase.py rename to nff/tests/test_ase.py index 14ee02d7..d3f99642 100644 --- a/nff/io/tests/test_ase.py +++ b/nff/tests/test_ase.py @@ -3,6 +3,7 @@ import networkx as nx import numpy as np +import pytest from ase import Atoms from nff.io.ase import AtomsBatch @@ -19,6 +20,8 @@ def compare_dicts(d1: dict, d2: dict): for key, value in d1.items(): if isinstance(value, dict): compare_dicts(value, d2[key]) + elif isinstance(value, str): + assert value == d2[key] elif isinstance(value, Iterable): assert np.allclose(value, d2[key]) else: @@ -47,10 +50,17 @@ def get_ethanol(): return Atoms(nxyz[:, 0].astype(int), positions=nxyz[:, 1:]) -# @ut.skip("skip this for now") +@pytest.mark.usefixtures("device") # Ensure the fixture is accessible class TestAtomsBatch(ut.TestCase): def setUp(self): self.ethanol = get_ethanol() + # Access the device value from the pytest fixture + self.device = self._test_fixture_device + + @pytest.fixture(autouse=True) + def inject_device(self, device): + # Automatically set the fixture value to an attribute + self._test_fixture_device = device @ut.skip("skip this for now") def test_AtomsBatch(self): @@ -111,7 +121,7 @@ def test_AtomsBatch(self): ] ) - atoms_batch = AtomsBatch(self.ethanol, cutoff=2.5) + atoms_batch = AtomsBatch(self.ethanol, cutoff=2.5, device=self.device) atoms_batch.update_nbr_list() G1 = nx.from_edgelist(expected_nbrlist_cutoff_2dot5) @@ -120,13 +130,13 @@ def test_AtomsBatch(self): assert nx.is_isomorphic(G1, G2) def test_get_batch(self): - atoms_batch = AtomsBatch(self.ethanol, cutoff=5) + atoms_batch = AtomsBatch(self.ethanol, cutoff=5, device=self.device) batch = atoms_batch.get_batch() assert "nxyz" in batch def test_from_atoms(self): - atoms_batch = AtomsBatch.from_atoms(self.ethanol, cutoff=2.5) + atoms_batch = AtomsBatch.from_atoms(self.ethanol, cutoff=2.5, device=self.device) # ensure atomic numbers, positions, and cell are the same assert np.allclose(atoms_batch.get_atomic_numbers(), self.ethanol.get_atomic_numbers()) @@ -134,7 +144,7 @@ def test_from_atoms(self): assert np.allclose(atoms_batch.get_cell(), self.ethanol.get_cell()) def test_copy(self): - atoms_batch = AtomsBatch(self.ethanol, cutoff=2.5) + atoms_batch = AtomsBatch(self.ethanol, cutoff=2.5, device=self.device) atoms_batch.get_batch() # update props atoms_batch_copy = atoms_batch.copy() @@ -154,7 +164,7 @@ def test_copy(self): assert atoms_batch.requires_large_offsets == atoms_batch_copy.requires_large_offsets def test_fromdict(self): - atoms_batch = AtomsBatch(self.ethanol, cutoff=2.5) + atoms_batch = AtomsBatch(self.ethanol, cutoff=2.5, device=self.device) ab_dict = atoms_batch.todict(update_props=True) ab_from_dict = AtomsBatch.fromdict(ab_dict) @@ -183,6 +193,7 @@ def test_fromdict(self): compare_dicts(ab_dict_props, ab_dict_again_props) +@pytest.mark.usefixtures("device") # Ensure the fixture is loaded class TestPeriodic(ut.TestCase): def setUp(self): nxyz = np.array( @@ -205,9 +216,16 @@ def setUp(self): [0.0, 0.0, 5.51891759], ] ) - self.quartz = AtomsBatch(nxyz[:, 0].astype(int), positions=nxyz[:, 1:], cell=lattice, pbc=True) + self.quartz = AtomsBatch( + nxyz[:, 0].astype(int), positions=nxyz[:, 1:], cell=lattice, pbc=True, device=self._test_fixture_device + ) + + @pytest.fixture(autouse=True) + def inject_device(self, device): + # Automatically set the fixture value to an attribute + self._test_fixture_device = device - def test_ase(self): + def test_print(self): print(self.quartz) def test_nbrlist(self): @@ -469,7 +487,6 @@ def test_nbrlist(self): ] ) assert np.allclose(nbrlist, expected_nbrlist) - print(offsets) if __name__ == "__main__": diff --git a/nff/io/tests/__init__.py b/nff/tests/test_data/__init__.py similarity index 100% rename from nff/io/tests/__init__.py rename to nff/tests/test_data/__init__.py diff --git a/nff/data/tests/data/SrIrO3_bulk_55_nff_all_dataset.pth.tar b/nff/tests/test_data/data/SrIrO3_bulk_55_nff_all_dataset.pth.tar similarity index 100% rename from nff/data/tests/data/SrIrO3_bulk_55_nff_all_dataset.pth.tar rename to nff/tests/test_data/data/SrIrO3_bulk_55_nff_all_dataset.pth.tar diff --git a/nff/data/tests/test_dataset.py b/nff/tests/test_data/test_dataset.py similarity index 82% rename from nff/data/tests/test_dataset.py rename to nff/tests/test_data/test_dataset.py index cdbf73bf..df167f26 100644 --- a/nff/data/tests/test_dataset.py +++ b/nff/tests/test_data/test_dataset.py @@ -4,6 +4,7 @@ from pathlib import Path import numpy as np +import pytest import torch from nff.data.dataset import ( @@ -14,8 +15,8 @@ ) current_path = Path(__file__).parent -DATASET_PATH = current_path / "../../../tutorials/data/dataset.pth.tar" -PEROVSKITE_DATA_PATH = current_path / "./data/SrIrO3_bulk_55_nff_all_dataset.pth.tar" +DATASET_PATH = os.path.join(current_path, "..", "..", "..", "tutorials", "data", "dataset.pth.tar") +PEROVSKITE_DATA_PATH = os.path.join(current_path, "data", "SrIrO3_bulk_55_nff_all_dataset.pth.tar") TARG_NAME = "formula" VAL_SIZE = 0.1 TEST_SIZE = 0.1 @@ -114,13 +115,13 @@ def test_split_train_validation_test(self): min_count=MIN_COUNT, ) - self.assertEqual(len(train_dset), 43) - self.assertEqual(len(val_dset), 6) - self.assertEqual(len(test_dset), 6) + assert len(train_dset) == 43 + assert len(val_dset) == 6 + assert len(test_dset) == 6 - self.assertEqual(Counter(train_dset.props[TARG_NAME]), self.train_formula_count) - self.assertEqual(Counter(val_dset.props[TARG_NAME]), self.val_formula_count) - self.assertEqual(Counter(test_dset.props[TARG_NAME]), self.test_formula_count) + assert Counter(train_dset.props[TARG_NAME]) == self.train_formula_count + assert Counter(val_dset.props[TARG_NAME]) == self.val_formula_count + assert Counter(test_dset.props[TARG_NAME]) == self.test_formula_count def test_stratified_split(self): idx_train, idx_test = stratified_split( @@ -131,8 +132,8 @@ def test_stratified_split(self): min_count=MIN_COUNT, ) - self.assertEqual(idx_train, self.idx_train) - self.assertEqual(idx_test, self.idx_test) + assert idx_train == self.idx_train + assert idx_test == self.idx_test class TestConcatenate(unittest.TestCase): @@ -167,19 +168,19 @@ def setUp(self): def test_concat_1(self): ab = concatenate_dict(self.dict_a, self.dict_b) - self.assertEqual(ab, self.dict_ab) + assert ab == self.dict_ab def test_concat_2(self): ac = concatenate_dict(self.dict_a, self.dict_c) - self.assertEqual(ac, self.dict_ac) + assert ac == self.dict_ac def test_concat_single_dict(self): a = concatenate_dict(self.dict_a) - self.assertEqual(a, self.dict_a_list) + assert a == self.dict_a_list def test_concat_single_dict_lists(self): a = concatenate_dict(self.dict_a_list) - self.assertEqual(a, self.dict_a_list) + assert a == self.dict_a_list def test_tensors(self): d1 = {"a": torch.tensor([1.0])} @@ -192,11 +193,11 @@ def test_tensors(self): torch.tensor(3.0), ] } - self.assertEqual(dcat, expected) + assert dcat == expected def test_concat_list_lists(self): dd = concatenate_dict(self.dict_d, self.dict_d) - self.assertEqual(dd, self.dict_dd) + assert dd == self.dict_dd def test_concat_tensors(self): t = { @@ -212,7 +213,7 @@ def test_concat_tensors(self): concat = concatenate_dict(t, t) for key, val in concat.items(): for i, j in zip(val, tt[key]): - self.assertTrue((i == j).all().item()) + assert (i == j).all().item() def test_inexistent_list_lists(self): a = {"a": [[[1, 2]], [[3, 4]]], "b": [5, 6]} @@ -220,9 +221,10 @@ def test_inexistent_list_lists(self): b = {"b": [7, 8]} ab = concatenate_dict(a, b) expected = {"a": [[[1, 2]], [[3, 4]], None, None], "b": [5, 6, 7, 8]} - self.assertEqual(ab, expected) + assert ab == expected +@pytest.mark.usefixtures("device") # Ensure the fixture is accessible class TestPeriodicDataset(unittest.TestCase): def setUp(self): self.quartz = { @@ -248,10 +250,15 @@ def setUp(self): ), } - self.qtz_dataset = Dataset(concatenate_dict(*[self.quartz] * 3)) + self.qtz_dataset = Dataset(concatenate_dict(*[self.quartz] * 3), device=self._test_fixture_device) + + @pytest.fixture(autouse=True) + def inject_device(self, device): + # Automatically set the fixture value to an attribute + self._test_fixture_device = device def test_neighbor_list(self): - nbrs, offs = self.qtz_dataset.generate_neighbor_list(cutoff=5) + self.qtz_dataset.generate_neighbor_list(cutoff=5) if __name__ == "__main__": diff --git a/nff/data/tests/test_stats.py b/nff/tests/test_data/test_stats.py similarity index 86% rename from nff/data/tests/test_stats.py rename to nff/tests/test_data/test_stats.py index 682d4310..84a74c44 100644 --- a/nff/data/tests/test_stats.py +++ b/nff/tests/test_data/test_stats.py @@ -1,4 +1,3 @@ -import os import unittest from pathlib import Path @@ -18,17 +17,17 @@ def test_get_atom_count(self): # Test case 1: Single atom formula formula = "H" expected_result = {"H": 1} - self.assertEqual(get_atom_count(formula), expected_result) + assert get_atom_count(formula) == expected_result # Test case 2: Formula with multiple atoms formula = "H2O" expected_result = {"H": 2, "O": 1} - self.assertEqual(get_atom_count(formula), expected_result) + assert get_atom_count(formula) == expected_result # Test case 3: Formula with repeated atoms formula = "CH3CH2CH3" expected_result = {"C": 3, "H": 8} - self.assertEqual(get_atom_count(formula), expected_result) + assert get_atom_count(formula) == expected_result def test_all_atoms(self): unique_formulas = ["H2O", "CH4", "CO2"] @@ -36,7 +35,7 @@ def test_all_atoms(self): result = all_atoms(unique_formulas) - self.assertEqual(result, expected_result, "Incorrect atom set") + assert result == expected_result, "Incorrect atom set" class TestStats(unittest.TestCase): @@ -53,7 +52,7 @@ def test_remove_outliers_scalar(self): new_array = new_dset.props[TEST_KEY].cpu().numpy() ref_std = np.std(array) - ref_mean = np.mean(array) + np.mean(array) assert np.max(new_array) - np.min(new_array) <= 2 * STD_AWAY * ref_std, "range is not working" @@ -68,7 +67,7 @@ def test_remove_outliers_tensor(self): stats_array = torch.cat(array, dim=0).flatten().cpu().numpy() ref_std = np.std(stats_array) - ref_mean = np.mean(stats_array) + np.mean(stats_array) new_stats_array = torch.cat(new_array, dim=0).flatten().cpu().numpy() diff --git a/nff/tests/test_excited_states_training.py b/nff/tests/test_excited_states_training.py new file mode 100644 index 00000000..481a5425 --- /dev/null +++ b/nff/tests/test_excited_states_training.py @@ -0,0 +1,133 @@ +import os +import pathlib + +import pytest +import torch +from torch.optim import Adam +from torch.utils.data import DataLoader +from torch.utils.data.sampler import RandomSampler + +from nff.data import Dataset, collate_dicts, split_train_validation_test +from nff.train import Trainer, evaluate, get_model, hooks, loss, metrics + + +@pytest.mark.skip("still taking too long, disable for now") +def test_excited_training(device, tmpdir): + # define loss + loss_dict = { + "mse": [ + {"coef": 0.01, "params": {"key": "d_00"}}, + {"coef": 0.01, "params": {"key": "d_11"}}, + {"coef": 0.01, "params": {"key": "d_22"}}, + {"coef": 0.2, "params": {"key": "energy_0"}}, + {"coef": 1, "params": {"key": "energy_0_grad"}}, + {"coef": 0.1, "params": {"key": "energy_1"}}, + {"coef": 1, "params": {"key": "energy_1_grad"}}, + {"coef": 0.5, "params": {"key": "energy_1_energy_0_delta"}}, + ], + "nacv": [{"coef": 1, "params": {"abs": False, "key": "force_nacv_10", "max": False}}], + } + loss_fn = loss.build_multi_loss(loss_dict) + + # define model + diabat_keys = [["d_00", "d_01", "d_02"], ["d_01", "d_11", "d_12"], ["d_02", "d_12", "d_22"]] + modelparams = { + "feat_dim": 128, + "activation": "swish", + "n_rbf": 20, + "cutoff": 5.0, + "num_conv": 3, + "output_keys": ["energy_0", "energy_1"], + "grad_keys": ["energy_0_grad", "energy_1_grad"], + "diabat_keys": diabat_keys, + "add_nacv": True, + } + model = get_model(modelparams, model_type="PainnDiabat") + + # define training + trainable_params = filter(lambda p: p.requires_grad, model.parameters()) + optimizer = Adam(trainable_params, lr=1e-4) + train_metrics = [ + metrics.MeanAbsoluteError("energy_0"), + metrics.MeanAbsoluteError("energy_1"), + metrics.MeanAbsoluteError("energy_0_grad"), + metrics.MeanAbsoluteError("energy_1_grad"), + metrics.MeanAbsoluteError("energy_1_energy_0_delta"), + ] + + # output + outdir = tmpdir + train_hooks = [ + hooks.CSVHook( + outdir, + metrics=train_metrics, + ), + hooks.PrintingHook(outdir, metrics=train_metrics, separator=" | ", time_strf="%M:%S"), + hooks.ReduceLROnPlateauHook( + optimizer=optimizer, + # patience in the original paper + patience=50, + factor=0.5, + min_lr=1e-7, + window_length=1, + stop_after_min=True, + ), + ] + + # data set + dset = Dataset.from_file(os.path.join(pathlib.Path(__file__).parent.absolute(), "data/azo_diabat.pth.tar")) + train, val, test = split_train_validation_test(dset, val_size=0.1, test_size=0.1) + batch_size = 20 + train_loader = DataLoader(train, batch_size=batch_size, collate_fn=collate_dicts, sampler=RandomSampler(train)) + val_loader = DataLoader(val, batch_size=batch_size, collate_fn=collate_dicts) + test_loader = DataLoader(test, batch_size=batch_size, collate_fn=collate_dicts) + + # train + T = Trainer( + model_path=outdir, + model=model, + loss_fn=loss_fn, + optimizer=optimizer, + train_loader=train_loader, + validation_loader=val_loader, + checkpoint_interval=1, + hooks=train_hooks, + mini_batches=1, + ) + T.train(device=device, n_epochs=10) + + # evaluation + def correct_nacv(results, targets, key): + num_atoms = targets["num_atoms"] + if not isinstance(num_atoms, list): + num_atoms = num_atoms.tolist() + pred = torch.split(torch.cat(results[key]), num_atoms) + targ = torch.split(torch.cat(targets[key]), num_atoms) + + real_pred = [] + + for p, t in zip(pred, targ): + sub_err = (p - t).abs().mean() + add_err = (p + t).abs().mean() + sign = 1 if sub_err < add_err else -1 + real_pred.append(sign * p) + + return real_pred + + results, targets, test_loss = evaluate( + T.get_best_model(), test_loader, loss_fn=lambda x, y: torch.Tensor([0]), device=device + ) + real_nacv = correct_nacv(results, targets, "force_nacv_10") + results["force_nacv_10"] = real_nacv + + en_keys = ["energy_0", "energy_1", "energy_1_energy_0_delta"] + grad_keys = ["energy_0_grad", "energy_1_grad"] + + for key in [*en_keys, *grad_keys, "force_nacv_10"]: + pred = results[key] + targ = targets[key] + targ_dim = len(targets["energy_0"][0].shape) + fn = torch.stack if targ_dim == 0 else torch.cat + pred = torch.cat(pred).reshape(-1) + targ = fn(targ).reshape(-1) + assert abs(pred - targ).mean() < 12.0 diff --git a/nff/tests/test_training.py b/nff/tests/test_training.py new file mode 100644 index 00000000..f7940686 --- /dev/null +++ b/nff/tests/test_training.py @@ -0,0 +1,71 @@ +import os +import pathlib + +import torch +from torch.optim import Adam +from torch.utils.data import DataLoader + +from nff.data import Dataset, collate_dicts, split_train_validation_test +from nff.train import Trainer, evaluate, get_model, hooks, loss, metrics + + +def test_training(device, tmpdir): + # data set + OUTDIR = tmpdir + dataset = Dataset.from_file(os.path.join(pathlib.Path(__file__).parent.absolute(), "data", "dataset.pth.tar")) + train, val, test = split_train_validation_test(dataset, val_size=0.2, test_size=0.2) + train_loader = DataLoader(train, batch_size=50, collate_fn=collate_dicts) + val_loader = DataLoader(val, batch_size=50, collate_fn=collate_dicts) + test_loader = DataLoader(test, batch_size=50, collate_fn=collate_dicts) + + # define model + params = { + "n_atom_basis": 256, + "n_filters": 256, + "n_gaussians": 32, + "n_convolutions": 4, + "cutoff": 5.0, + "trainable_gauss": True, + "dropout_rate": 0.2, + } + model = get_model(params) + + # define training + loss_fn = loss.build_mse_loss(loss_coef={"energy": 0.01, "energy_grad": 1}) + trainable_params = filter(lambda p: p.requires_grad, model.parameters()) + optimizer = Adam(trainable_params, lr=3e-4) + train_metrics = [metrics.MeanAbsoluteError("energy"), metrics.MeanAbsoluteError("energy_grad")] + + # output + train_hooks = [ + hooks.MaxEpochHook(7), + hooks.CSVHook( + OUTDIR, + metrics=train_metrics, + ), + hooks.PrintingHook(OUTDIR, metrics=train_metrics, separator=" | ", time_strf="%M:%S"), + hooks.ReduceLROnPlateauHook( + optimizer=optimizer, patience=30, factor=0.5, min_lr=1e-7, window_length=1, stop_after_min=True + ), + ] + + # train + T = Trainer( + model_path=OUTDIR, + model=model, + loss_fn=loss_fn, + optimizer=optimizer, + train_loader=train_loader, + validation_loader=val_loader, + checkpoint_interval=1, + hooks=train_hooks, + ) + T.train(device=device, n_epochs=7) + + # evaluation + results, targets, val_loss = evaluate(T.get_best_model(), test_loader, loss_fn, device=device) + for key in ["energy_grad", "energy"]: + pred = torch.stack(results[key], dim=0).view(-1).detach().cpu().numpy() + targ = torch.stack(targets[key], dim=0).view(-1).detach().cpu().numpy() + mae = abs(pred - targ).mean() + assert mae < 10.0 diff --git a/nff/train/builders/model.py b/nff/train/builders/model.py index db2ed57b..f6b8b70f 100644 --- a/nff/train/builders/model.py +++ b/nff/train/builders/model.py @@ -498,9 +498,9 @@ def check_parameters(params_type, params): if key in params_type and not isinstance(val, params_type[key]): raise ParameterError(f"{key} is not {params_type[key]}") - for model in PARAMS_TYPE: + for model, value in PARAMS_TYPE.items(): if key == f"{model.lower()}_params": - check_parameters(PARAMS_TYPE[model], val) + check_parameters(value, val) def get_model(params: dict, model_type: str = "SchNet", **kwargs): diff --git a/nff/train/builders/trainer.py b/nff/train/builders/trainer.py index b12fab74..dc437e76 100644 --- a/nff/train/builders/trainer.py +++ b/nff/train/builders/trainer.py @@ -1,14 +1,16 @@ """Helper function to create a trainer for a given model. -Adapted from https://github.com/atomistic-machine-learning/schnetpack/blob/dev/src/schnetpack/utils/script_utils/training.py +Adapted from: +https://github.com/atomistic-machine-learning/schnetpack/blob/dev/src/schnetpack/utils/script_utils/training.py """ -import os + import json +import os -import nff -import torch from torch.optim import Adam +import nff + def get_trainer(args, model, train_loader, val_loader, metrics, loss_fn=None): # setup hook and logging diff --git a/nff/train/chgnet.py b/nff/train/chgnet.py index f24696f8..08ffe396 100644 --- a/nff/train/chgnet.py +++ b/nff/train/chgnet.py @@ -1,4 +1,5 @@ -from typing import Dict, Iterable, Union +from collections.abc import Iterable +from typing import Dict, Union import torch from chgnet.trainer.trainer import CombinedLoss @@ -9,6 +10,7 @@ class CombinedLossNFF(CombinedLoss): """Wrapper for the combined loss function that maps keys from NFF to CHGNet keys.""" + def __init__(self, *args, key_mappings=None, **kwargs): super().__init__(*args, **kwargs) if not key_mappings: @@ -37,11 +39,13 @@ def forward(self, targets: Dict[str, Tensor], predictions: Dict[str, Tensor], ke raise ValueError("key_style must be either 'nff' or 'chgnet'") targets = {k: self.split_props(k, v, detach(targets["num_atoms"]).tolist()) for k, v in targets.items()} - predictions = {k: self.split_props(k, v, detach(predictions["num_atoms"]).tolist()) for k, v in predictions.items()} + predictions = { + k: self.split_props(k, v, detach(predictions["num_atoms"]).tolist()) for k, v in predictions.items() + } if key_style == "nff": - targets = {self.key_mappings.get(k, k): self.negate_value(k, v) for k, v in targets.items()} - predictions = {self.key_mappings.get(k, k): self.negate_value(k, v) for k, v in predictions.items()} + targets = {self.key_mappings.get(k, k): self.negate_value(k, v) for k, v in targets.items()} + predictions = {self.key_mappings.get(k, k): self.negate_value(k, v) for k, v in predictions.items()} out = super().forward(targets, predictions) loss = out["loss"] @@ -64,9 +68,7 @@ def negate_value(self, key: str, value: Iterable) -> Union[list, Tensor]: return -value return value - def split_props( - self, key: str, value: Union[list, Tensor], num_atoms: Union[list, Tensor] - ) -> Union[list, Tensor]: + def split_props(self, key: str, value: Union[list, Tensor], num_atoms: Union[list, Tensor]) -> Union[list, Tensor]: """Split the properties if the key is in the split_keys list. Args: diff --git a/nff/train/evaluate.py b/nff/train/evaluate.py index 334738b6..71eecb6e 100644 --- a/nff/train/evaluate.py +++ b/nff/train/evaluate.py @@ -114,9 +114,8 @@ def evaluate( if not return_results: return {}, {}, eval_loss - else: - # this step can be slow, - all_results = concatenate_dict(*all_results) - all_batches = concatenate_dict(*all_batches) + # this step can be slow, + all_results = concatenate_dict(*all_results) + all_batches = concatenate_dict(*all_batches) - return all_results, all_batches, eval_loss + return all_results, all_batches, eval_loss diff --git a/nff/train/hooks/base_hook.py b/nff/train/hooks/base_hook.py index 2e9093f4..bcb108f3 100644 --- a/nff/train/hooks/base_hook.py +++ b/nff/train/hooks/base_hook.py @@ -3,6 +3,7 @@ Retrieved from https://github.com/atomistic-machine-learning/schnetpack/tree/dev/src/schnetpack/train/hooks """ + class Hook: """Base class for hooks.""" @@ -30,7 +31,6 @@ def on_epoch_begin(self, trainer): trainer (Trainer): instance of schnetpack.train.trainer.Trainer class. """ - pass def on_batch_begin(self, trainer, train_batch): """Log at the beginning of train batch. @@ -40,7 +40,6 @@ def on_batch_begin(self, trainer, train_batch): train_batch (dict of torch.Tensor): SchNetPack dictionary of input tensors. """ - pass def on_batch_end(self, trainer, train_batch, result, loss): pass diff --git a/nff/train/hooks/logging.py b/nff/train/hooks/logging.py index b65ee761..5e7da171 100644 --- a/nff/train/hooks/logging.py +++ b/nff/train/hooks/logging.py @@ -3,15 +3,16 @@ Retrieved from https://github.com/atomistic-machine-learning/schnetpack/tree/dev/src/schnetpack/train/hooks """ +import json import os +import sys import time + import numpy as np import torch -import json -import sys from nff.train.hooks import Hook -from nff.train.metrics import (RootMeanSquaredError, PrAuc, RocAuc) +from nff.train.metrics import PrAuc, RocAuc, RootMeanSquaredError class LoggingHook(Hook): @@ -37,7 +38,7 @@ def __init__( log_learning_rate=True, mini_batches=1, global_rank=0, - world_size=1 + world_size=1, ): self.log_train_loss = log_train_loss self.log_validation_loss = log_validation_loss @@ -70,16 +71,14 @@ def on_epoch_begin(self, trainer): self._train_loss = None def on_batch_end(self, trainer, train_batch, result, loss): - if self.log_train_loss: n_samples = self._batch_size(result) self._train_loss += float(loss.data) * n_samples self._counter += n_samples def _batch_size(self, result): - if type(result) is dict: - n_samples = list(result.values())[0].size(0) + n_samples = next(iter(result.values())).size(0) elif type(result) in [list, tuple]: n_samples = result[0].size(0) else: @@ -127,8 +126,7 @@ def get_par_folders(self): """ base_folder = self.get_base_folder() - par_folders = [os.path.join(base_folder, str(i)) - for i in range(self.world_size)] + par_folders = [os.path.join(base_folder, str(i)) for i in range(self.world_size)] return par_folders def save_metrics(self, epoch, test): @@ -143,10 +141,9 @@ def save_metrics(self, epoch, test): # save metrics to json file par_folder = self.par_folders[self.global_rank] if test: - json_file = os.path.join( - par_folder, "epoch_{}_test.json".format(epoch)) + json_file = os.path.join(par_folder, f"epoch_{epoch}_test.json") else: - json_file = os.path.join(par_folder, "epoch_{}.json".format(epoch)) + json_file = os.path.join(par_folder, f"epoch_{epoch}.json") # if the json file you're saving to already exists, # then load its contents @@ -159,8 +156,7 @@ def save_metrics(self, epoch, test): # update with metrics for metric in self.metrics: if type(metric) in [RocAuc, PrAuc]: - m = {"y_true": metric.actual, - "y_pred": metric.pred} + m = {"y_true": metric.actual, "y_pred": metric.pred} else: m = metric.aggregate() dic[metric.name] = m @@ -194,11 +190,9 @@ def avg_parallel_metrics(self, epoch, test): while None in par_dic.values(): for folder in self.par_folders: if test: - path = os.path.join( - folder, "epoch_{}_test.json".format(epoch)) + path = os.path.join(folder, f"epoch_{epoch}_test.json") else: - path = os.path.join( - folder, "epoch_{}.json".format(epoch)) + path = os.path.join(folder, f"epoch_{epoch}.json") try: with open(path, "r") as f: path_dic = json.load(f) @@ -209,8 +203,7 @@ def avg_parallel_metrics(self, epoch, test): # average appropriately if isinstance(metric, RootMeanSquaredError): - metric_val = np.mean( - np.array(list(par_dic.values)) ** 2) ** 0.5 + metric_val = np.mean(np.array(list(par_dic.values)) ** 2) ** 0.5 elif type(metric) in [RocAuc, PrAuc]: y_true = [] y_pred = [] @@ -239,8 +232,7 @@ def aggregate(self, trainer, test=False): # if parallel, average over parallel metrics if self.parallel: - metric_dic = self.avg_parallel_metrics(epoch=trainer.epoch, - test=test) + metric_dic = self.avg_parallel_metrics(epoch=trainer.epoch, test=test) # otherwise aggregate as usual else: @@ -276,19 +268,24 @@ def __init__( every_n_epochs=1, mini_batches=1, global_rank=0, - world_size=1 + world_size=1, ): log_path = os.path.join(log_path, "log.csv") super().__init__( - log_path, metrics, log_train_loss, log_validation_loss, - log_learning_rate, mini_batches, global_rank, world_size + log_path, + metrics, + log_train_loss, + log_validation_loss, + log_learning_rate, + mini_batches, + global_rank, + world_size, ) self._offset = 0 self._restart = False self.every_n_epochs = every_n_epochs def on_train_begin(self, trainer): - if os.path.exists(self.log_path): remove_file = False with open(self.log_path, "r") as f: @@ -393,13 +390,19 @@ def __init__( log_histogram=False, mini_batches=1, global_rank=0, - world_size=1 + world_size=1, ): from tensorboardX import SummaryWriter super().__init__( - log_path, metrics, log_train_loss, log_validation_loss, - log_learning_rate, mini_batches, global_rank, world_size + log_path, + metrics, + log_train_loss, + log_validation_loss, + log_learning_rate, + mini_batches, + global_rank, + world_size, ) self.writer = SummaryWriter(self.log_path) self.every_n_epochs = every_n_epochs @@ -409,10 +412,7 @@ def __init__( def on_epoch_end(self, trainer): if trainer.epoch % self.every_n_epochs == 0: if self.log_train_loss: - self.writer.add_scalar( - "train/loss", - self._train_loss / self._counter, trainer.epoch - ) + self.writer.add_scalar("train/loss", self._train_loss / self._counter, trainer.epoch) if self.log_learning_rate: self.writer.add_scalar( "train/learning_rate", @@ -427,43 +427,31 @@ def on_validation_end(self, trainer, val_loss): m = metric_dic[metric.name] if np.isscalar(m): - self.writer.add_scalar( - "metrics/%s" % metric.name, float(m), trainer.epoch - ) - elif m.ndim == 2: - if trainer.epoch % self.img_every_n_epochs == 0: - import matplotlib.pyplot as plt - - # tensorboardX only accepts images as numpy arrays. - # we therefore convert plots in numpy array - # see https://github.com/lanpa/tensorboard- - # pytorch/blob/master/examples/matplotlib_demo.py - fig = plt.figure() - plt.colorbar(plt.pcolor(m)) - fig.canvas.draw() - - np_image = np.fromstring( - fig.canvas.tostring_rgb(), dtype="uint8" - ) - np_image = np_image.reshape( - fig.canvas.get_width_height()[::-1] + (3,) - ) - - plt.close(fig) - - self.writer.add_image( - "metrics/%s" % metric.name, np_image, trainer.epoch - ) + self.writer.add_scalar("metrics/%s" % metric.name, float(m), trainer.epoch) + elif m.ndim == 2 and trainer.epoch % self.img_every_n_epochs == 0: + import matplotlib.pyplot as plt + + # tensorboardX only accepts images as numpy arrays. + # we therefore convert plots in numpy array + # see https://github.com/lanpa/tensorboard- + # pytorch/blob/master/examples/matplotlib_demo.py + fig = plt.figure() + plt.colorbar(plt.pcolor(m)) + fig.canvas.draw() + + np_image = np.fromstring(fig.canvas.tostring_rgb(), dtype="uint8") + np_image = np_image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + + plt.close(fig) + + self.writer.add_image("metrics/%s" % metric.name, np_image, trainer.epoch) if self.log_validation_loss: - self.writer.add_scalar( - "train/val_loss", float(val_loss), trainer.step) + self.writer.add_scalar("train/val_loss", float(val_loss), trainer.step) if self.log_histogram: for name, param in trainer._model.named_parameters(): - self.writer.add_histogram( - name, param.detach().cpu().numpy(), trainer.epoch - ) + self.writer.add_histogram(name, param.detach().cpu().numpy(), trainer.epoch) def on_train_ends(self, trainer): self.writer.close() @@ -496,18 +484,23 @@ def __init__( log_learning_rate=True, log_memory=True, every_n_epochs=1, - separator=' ', - time_strf=r'%Y-%m-%d %H:%M:%S', - str_format=r'{1:>{0}}', + separator=" ", + time_strf=r"%Y-%m-%d %H:%M:%S", + str_format=r"{1:>{0}}", mini_batches=1, global_rank=0, - world_size=1 + world_size=1, ): - log_path = os.path.join(log_path, "log_human_read.csv") super().__init__( - log_path, metrics, log_train_loss, log_validation_loss, - log_learning_rate, mini_batches, global_rank, world_size + log_path, + metrics, + log_train_loss, + log_validation_loss, + log_learning_rate, + mini_batches, + global_rank, + world_size, ) self.every_n_epochs = every_n_epochs @@ -516,12 +509,12 @@ def __init__( self._separator = separator self.time_strf = time_strf self._headers = { - 'time': 'Time', - 'epoch': 'Epoch', - 'lr': 'Learning rate', - 'train_loss': 'Train loss', - 'val_loss': 'Validation loss', - 'memory': 'GPU Memory (MB)' + "time": "Time", + "epoch": "Epoch", + "lr": "Learning rate", + "train_loss": "Train loss", + "val_loss": "Validation loss", + "memory": "GPU Memory (MB)", } self.str_format = str_format self.log_memory = log_memory @@ -533,112 +526,83 @@ def print(self, log): sys.stdout.flush() def on_train_begin(self, trainer): - log_dir = os.path.dirname(self.log_path) if not os.path.exists(log_dir): os.makedirs(log_dir) - log = self.str_format.format( - len(time.strftime(self.time_strf)), - self._headers['time'] - ) + log = self.str_format.format(len(time.strftime(self.time_strf)), self._headers["time"]) if self.log_epoch: log += self._separator - log += self.str_format.format( - len(self._headers['epoch']), self._headers['epoch'] - ) + log += self.str_format.format(len(self._headers["epoch"]), self._headers["epoch"]) if self.log_learning_rate: log += self._separator - log += self.str_format.format( - len(self._headers['lr']), self._headers['lr'] - ) + log += self.str_format.format(len(self._headers["lr"]), self._headers["lr"]) if self.log_train_loss: log += self._separator - log += self.str_format.format( - len(self._headers['train_loss']), self._headers['train_loss'] - ) + log += self.str_format.format(len(self._headers["train_loss"]), self._headers["train_loss"]) if self.log_validation_loss: log += self._separator - log += self.str_format.format( - len(self._headers['val_loss']), self._headers['val_loss'] - ) + log += self.str_format.format(len(self._headers["val_loss"]), self._headers["val_loss"]) if len(self.metrics) > 0: log += self._separator - for i, metric in enumerate(self.metrics): + for metric in self.metrics: header = str(metric.name) log += self.str_format.format(len(header), header) log += self._separator if self.log_memory: - log += self.str_format.format( - len(self._headers['memory']), self._headers['memory'] - ) + log += self.str_format.format(len(self._headers["memory"]), self._headers["memory"]) self.print(log) def on_validation_end(self, trainer, val_loss): if trainer.epoch % self.every_n_epochs == 0: - log = time.strftime(self.time_strf) if self.log_epoch: log += self._separator - log += self.str_format.format( - len(self._headers['epoch']), - '%d' % trainer.epoch - ) + log += self.str_format.format(len(self._headers["epoch"]), "%d" % trainer.epoch) if self.log_learning_rate: log += self._separator log += self.str_format.format( - len(self._headers['lr']), - '%.3e' % trainer.optimizer.param_groups[0]['lr'] + len(self._headers["lr"]), "%.3e" % trainer.optimizer.param_groups[0]["lr"] ) if self.log_train_loss: log += self._separator log += self.str_format.format( - len(self._headers['train_loss']), - '%.4f' % (self._train_loss / self._counter) + len(self._headers["train_loss"]), "%.4f" % (self._train_loss / self._counter) ) if self.log_validation_loss: log += self._separator - log += self.str_format.format( - len(self._headers['val_loss']), - '%.4f' % val_loss - ) + log += self.str_format.format(len(self._headers["val_loss"]), "%.4f" % val_loss) if len(self.metrics) > 0: log += self._separator metric_dic = self.aggregate(trainer) - for i, metric in enumerate(self.metrics): + for metric in self.metrics: m = metric_dic[metric.name] - if hasattr(m, '__iter__'): + if hasattr(m, "__iter__"): log += self._separator.join([str(j) for j in m]) else: - log += self.str_format.format( - len(metric.name), - '%.4f' % m - ) + log += self.str_format.format(len(metric.name), "%.4f" % m) log += self._separator if self.log_memory: memory = torch.cuda.max_memory_allocated(device=None) * 1e-6 - log += self.str_format.format( - len(self._headers['memory']), - '%d' % memory - ) + log += self.str_format.format(len(self._headers["memory"]), "%d" % memory) self.print(log) def on_train_failed(self, trainer): - self.print('the training has failed') + self.print("the training has failed") diff --git a/nff/train/hooks/scheduling.py b/nff/train/hooks/scheduling.py index dedfe492..0656a76d 100644 --- a/nff/train/hooks/scheduling.py +++ b/nff/train/hooks/scheduling.py @@ -4,7 +4,6 @@ """ import numpy as np -import torch from torch.optim.lr_scheduler import ( CosineAnnealingLR, ReduceLROnPlateau, @@ -52,7 +51,6 @@ def on_validation_end(self, trainer, val_loss): class WarmRestartHook(Hook): - def __init__( self, optimizer, @@ -108,9 +106,7 @@ def on_validation_end(self, trainer, val_loss): self.Tmax *= self.Tmult self.scheduler.last_epoch = -1 self.scheduler.T_max = self.Tmax - self.scheduler.base_lrs = [ - base_lr * self.lr_factor for base_lr in self.scheduler.base_lrs - ] + self.scheduler.base_lrs = [base_lr * self.lr_factor for base_lr in self.scheduler.base_lrs] trainer.optimizer.load_state_dict(self.init_opt_state) if self.best_current > self.best_previous: @@ -318,12 +314,10 @@ def on_batch_end(self, trainer, train_batch, result, loss): class WarmUpLR(_LRScheduler): - def __init__(self, optimizer, n_steps, max_lr, last_epoch=-1, verbose=False): - self.n_steps = n_steps self.max_lr = max_lr - super(WarmUpLR, self).__init__(optimizer, last_epoch, verbose) + super().__init__(optimizer, last_epoch, verbose) for param_group in self.optimizer.param_groups: param_group["lr"] = 0 @@ -345,7 +339,6 @@ def __init__(self, optimizer, n_steps, max_lr): self.scheduler = WarmUpLR(optimizer=optimizer, n_steps=n_steps, max_lr=max_lr) def on_batch_end(self, trainer, train_batch, result, loss): - self.scheduler.step() if self.scheduler._step_count >= self.scheduler.n_steps: trainer._stop = True diff --git a/nff/train/loss.py b/nff/train/loss.py index 0125b61f..38324250 100644 --- a/nff/train/loss.py +++ b/nff/train/loss.py @@ -1,6 +1,5 @@ import numpy as np import torch -from torch.nn import CrossEntropyLoss from nff.utils import constants as const @@ -56,10 +55,7 @@ def loss_fn(ground_truth, results): loss = 0.0 for key, coef in loss_coef.items(): - if key not in ground_truth.keys(): - ground_key = correspondence_keys[key] - else: - ground_key = key + ground_key = correspondence_keys[key] if key not in ground_truth else key targ = ground_truth[ground_key] pred = results[key].view(targ.shape) diff --git a/nff/train/metrics.py b/nff/train/metrics.py index e29f00b7..4ce7495b 100644 --- a/nff/train/metrics.py +++ b/nff/train/metrics.py @@ -1,6 +1,6 @@ import numpy as np import torch -from sklearn.metrics import roc_auc_score, auc, precision_recall_curve +from sklearn.metrics import auc, precision_recall_curve, roc_auc_score class Metric: @@ -33,7 +33,7 @@ def reset(self): self.n_entries = 0.0 def add_batch(self, batch, results): - """ Add a batch to calculate the metric on """ + """Add a batch to calculate the metric on""" y = batch[self.target] yp = results[self.target] @@ -96,9 +96,7 @@ def __init__( name=None, ): name = "RMSE_" + target if name is None else name - super().__init__( - target, name - ) + super().__init__(target, name) def aggregate(self): """Aggregate metric over all previously added batches.""" @@ -129,7 +127,6 @@ def __init__( @staticmethod def loss_fn(y, yp): - # select only properties which are given yp = yp.reshape(*y.shape) @@ -167,7 +164,6 @@ def __init__( @staticmethod def loss_fn(y, yp): - # select only properties which are given yp = yp.reshape(*y.shape) @@ -179,9 +175,7 @@ def loss_fn(y, yp): pos_delta = (abs(y - yp)).mean(-1) neg_delta = (abs(y + yp)).mean(-1) - signs = (torch.ones(pos_delta.shape[0], - dtype=torch.long) - .to(pos_delta.device)) + signs = torch.ones(pos_delta.shape[0], dtype=torch.long).to(pos_delta.device) signs[neg_delta < pos_delta] = -1 y = y * signs.reshape(-1, 1) @@ -191,7 +185,7 @@ def loss_fn(y, yp): class Classifier(Metric): - """" Metric for binary classification.""" + """ " Metric for binary classification.""" def __init__( self, @@ -205,7 +199,7 @@ def __init__( ) def add_batch(self, batch, results): - """ Add a batch to calculate the metric on """ + """Add a batch to calculate the metric on""" y = batch[self.target] yp = results[self.target] @@ -216,7 +210,6 @@ def add_batch(self, batch, results): self.loss += loss def non_nan(self): - actual = torch.Tensor(self.actual) pred = torch.Tensor(self.pred) @@ -228,16 +221,13 @@ def non_nan(self): def aggregate(self): """Aggregate metric over all previously added batches.""" - if self.n_entries == 0: - result = float('nan') - else: - result = self.loss / self.n_entries + result = float("nan") if self.n_entries == 0 else self.loss / self.n_entries return result class FalsePositives(Classifier): """ - Percentage of claimed positives that are actually wrong for a + Percentage of claimed positives that are actually wrong for a binary classifier. """ @@ -254,13 +244,11 @@ def __init__( @staticmethod def loss_fn(y, yp): - actual = y.detach().cpu().numpy().round().reshape(-1) pred = yp.detach().cpu().numpy().round().reshape(-1) all_positives = [i for i, item in enumerate(pred) if item == 1] - false_positives = [i for i in range(len(pred)) if pred[i] - == 1 and pred[i] != actual[i]] + false_positives = [i for i in range(len(pred)) if pred[i] == 1 and pred[i] != actual[i]] # number of predicted negatives num_pred = len(all_positives) @@ -270,9 +258,8 @@ def loss_fn(y, yp): class FalseNegatives(Classifier): - """ - Percentage of claimed negatives that are actually wrong for a + Percentage of claimed negatives that are actually wrong for a binary classifier. """ @@ -289,13 +276,11 @@ def __init__( @staticmethod def loss_fn(y, yp): - actual = y.detach().cpu().numpy().round().reshape(-1) pred = yp.detach().cpu().numpy().round().reshape(-1) all_negatives = [i for i, item in enumerate(pred) if item == 0] - false_negatives = [i for i in range(len(pred)) if pred[i] - == 0 and pred[i] != actual[i]] + false_negatives = [i for i in range(len(pred)) if pred[i] == 0 and pred[i] != actual[i]] # number of predicted negatives num_pred = len(all_negatives) num_pred_correct = len(false_negatives) @@ -304,9 +289,8 @@ def loss_fn(y, yp): class TruePositives(Classifier): - """ - Percentage of claimed positives that are actually right for a + Percentage of claimed positives that are actually right for a binary classifier. """ @@ -323,13 +307,11 @@ def __init__( @staticmethod def loss_fn(y, yp): - actual = y.detach().cpu().numpy().round().reshape(-1) pred = yp.detach().cpu().numpy().round().reshape(-1) all_positives = [i for i, item in enumerate(pred) if item == 1] - true_positives = [i for i in range(len(pred)) if pred[i] - == 1 and pred[i] == actual[i]] + true_positives = [i for i in range(len(pred)) if pred[i] == 1 and pred[i] == actual[i]] # number of predicted negatives num_pred = len(all_positives) @@ -339,9 +321,8 @@ def loss_fn(y, yp): class TrueNegatives(Classifier): - """ - Percentage of claimed negatives that are actually right for a + Percentage of claimed negatives that are actually right for a binary classifier. """ @@ -358,13 +339,11 @@ def __init__( @staticmethod def loss_fn(y, yp): - actual = y.detach().cpu().numpy().round().reshape(-1) pred = yp.detach().cpu().numpy().round().reshape(-1) all_negatives = [i for i, item in enumerate(pred) if item == 0] - true_negatives = [i for i in range(len(pred)) if pred[i] - == 0 and pred[i] == actual[i]] + true_negatives = [i for i in range(len(pred)) if pred[i] == 0 and pred[i] == actual[i]] # number of predicted negatives num_pred = len(all_negatives) @@ -374,7 +353,6 @@ def loss_fn(y, yp): class RocAuc(Classifier): - """ AUC metric (area under true-positive vs. false-positive curve). """ @@ -411,7 +389,7 @@ def loss_fn(self, y, yp): return actual, pred def add_batch(self, batch, results): - """ Add a batch to calculate the metric on """ + """Add a batch to calculate the metric on""" y = batch[self.target] yp = results[self.target] @@ -434,7 +412,6 @@ def aggregate(self): class PrAuc(Classifier): - """ AUC metric (area under true-positive vs. false-positive curve). """ @@ -471,7 +448,7 @@ def loss_fn(self, y, yp): return actual, pred def add_batch(self, batch, results): - """ Add a batch to calculate the metric on """ + """Add a batch to calculate the metric on""" y = batch[self.target] yp = results[self.target] @@ -487,8 +464,7 @@ def aggregate(self): pred, actual = self.non_nan() try: - precision, recall, thresholds = precision_recall_curve( - y_true=actual, probas_pred=pred) + precision, recall, thresholds = precision_recall_curve(y_true=actual, probas_pred=pred) pr_auc = auc(recall, precision) except ValueError: @@ -498,7 +474,6 @@ def aggregate(self): class Accuracy(Classifier): - """ Overall accuracy of classifier. """ @@ -516,7 +491,6 @@ def __init__( @staticmethod def loss_fn(y, yp): - actual = y.detach().cpu().numpy().round().reshape(-1) pred = yp.detach().cpu().numpy().round().reshape(-1) diff --git a/nff/train/parallel.py b/nff/train/parallel.py index 06ff7cf5..e678ac7a 100644 --- a/nff/train/parallel.py +++ b/nff/train/parallel.py @@ -3,16 +3,15 @@ between processes. """ -import pickle import os +import pickle def get_grad(optimizer): - grad_list = [] for group in optimizer.param_groups: grad_list.append([]) - for param in group['params']: + for param in group["params"]: if param.grad is None: grad_list[-1].append(param.grad) else: @@ -20,31 +19,15 @@ def get_grad(optimizer): return grad_list -def save_grad(optimizer, - loss_size, - rank, - weight_path, - batch_num, - epoch): - +def save_grad(optimizer, loss_size, rank, weight_path, batch_num, epoch): grad_list = get_grad(optimizer) save_dic = {"grad": grad_list, "loss_size": loss_size} - save_path = os.path.join(weight_path, str(rank), - "grad_{}_{}.pickle".format( - epoch, batch_num)) + save_path = os.path.join(weight_path, str(rank), f"grad_{epoch}_{batch_num}.pickle") with open(save_path, "wb") as f: pickle.dump(save_dic, f) -def add_grads(optimizer, - loss_size, - weight_path, - rank, - world_size, - batch_num, - epoch, - device): - +def add_grads(optimizer, loss_size, weight_path, rank, world_size, batch_num, epoch, device): # Set the optimizer to have zero gradient and then load in all # the grads. This ensures no differences in the gradients between the # different processes, which would occur due to loss of precision @@ -54,30 +37,25 @@ def add_grads(optimizer, # paths to all pickle files - paths = [os.path.join(weight_path, str(index), - "grad_{}_{}.pickle".format(epoch, batch_num)) - for index in range(world_size)] + paths = [os.path.join(weight_path, str(index), f"grad_{epoch}_{batch_num}.pickle") for index in range(world_size)] loaded_grads = {path: None for path in paths} while None in loaded_grads.values(): - missing_paths = [key for key, val - in loaded_grads.items() if val is None] + missing_paths = [key for key, val in loaded_grads.items() if val is None] for path in missing_paths: try: with open(path, "rb") as f: loaded_grads[path] = pickle.load(f) - except (EOFError, FileNotFoundError, pickle.UnpicklingError): + except (EOFError, FileNotFoundError, pickle.UnpicklingError): # noqa continue # total size is the sum of all sizes from each process - total_size = sum([grad_dic["loss_size"] for - grad_dic in loaded_grads.values()] - ) + total_size = sum([grad_dic["loss_size"] for grad_dic in loaded_grads.values()]) for k, grad_dic in enumerate(loaded_grads.values()): for i, group in enumerate(optimizer.param_groups): - for j, param in enumerate(group['params']): + for j, param in enumerate(group["params"]): if param.grad is None: continue param.grad += grad_dic["grad"][i][j].to(device) @@ -91,13 +69,7 @@ def add_grads(optimizer, return optimizer -def del_grad(rank, - epoch, - batch_num, - weight_path, - del_interval, - max_batch_iters): - +def del_grad(rank, epoch, batch_num, weight_path, del_interval, max_batch_iters): # epoch starts counting from 1 and batch_num starts # counting from 0 @@ -113,37 +85,30 @@ def del_grad(rank, os.remove(file_path) -def update_optim(optimizer, - loss_size, - rank, - world_size, - weight_path, - batch_num, - epoch, - del_interval, - device, - max_batch_iters): - - save_grad(optimizer=optimizer, - loss_size=loss_size, - rank=rank, - weight_path=weight_path, - batch_num=batch_num, - epoch=epoch) - - optimizer = add_grads(optimizer=optimizer, - loss_size=loss_size, - weight_path=weight_path, - rank=rank, - world_size=world_size, - batch_num=batch_num, - epoch=epoch, - device=device) - del_grad(rank=rank, - epoch=epoch, - batch_num=batch_num, - weight_path=weight_path, - del_interval=del_interval, - max_batch_iters=max_batch_iters) +def update_optim( + optimizer, loss_size, rank, world_size, weight_path, batch_num, epoch, del_interval, device, max_batch_iters +): + save_grad( + optimizer=optimizer, loss_size=loss_size, rank=rank, weight_path=weight_path, batch_num=batch_num, epoch=epoch + ) + + optimizer = add_grads( + optimizer=optimizer, + loss_size=loss_size, + weight_path=weight_path, + rank=rank, + world_size=world_size, + batch_num=batch_num, + epoch=epoch, + device=device, + ) + del_grad( + rank=rank, + epoch=epoch, + batch_num=batch_num, + weight_path=weight_path, + del_interval=del_interval, + max_batch_iters=max_batch_iters, + ) return optimizer diff --git a/nff/train/trainer.py b/nff/train/trainer.py index 20bb9d80..d59a4411 100644 --- a/nff/train/trainer.py +++ b/nff/train/trainer.py @@ -118,9 +118,7 @@ def __init__( # how many times you've called loss.backward() self.back_count = 0 # maximum number of batches to iterate through - self.max_batch_iters = ( - max_batch_iters if (max_batch_iters is not None) else len(self.train_loader) - ) + self.max_batch_iters = max_batch_iters if (max_batch_iters is not None) else len(self.train_loader) self.model_kwargs = model_kwargs if (model_kwargs is not None) else {} self.batch_stop = False self.nloss = 0 @@ -165,9 +163,7 @@ def to(self, device): def _check_is_parallel(self): data_par = isinstance(self._model, torch.nn.DataParallel) - dist_dat_par = isinstance( - self._model, torch.nn.parallel.DistributedDataParallel - ) + dist_dat_par = isinstance(self._model, torch.nn.parallel.DistributedDataParallel) return any((data_par, dist_dat_par)) def _load_model_state_dict(self, state_dict): @@ -191,10 +187,7 @@ def get_best_model(self): return model def call_model(self, batch, train): - if (self.torch_parallel and self.parallel) and not train: - model = self._model.module - else: - model = self._model + model = self._model.module if (self.torch_parallel and self.parallel) and not train else self._model return model(batch, **self.model_kwargs) @@ -227,9 +220,7 @@ def state_dict(self, state_dict): hook.scheduler.optimizer = self.optimizer def store_checkpoint(self): - chkpt = os.path.join( - self.checkpoint_path, "checkpoint-" + str(self.epoch) + ".pth.tar" - ) + chkpt = os.path.join(self.checkpoint_path, "checkpoint-" + str(self.epoch) + ".pth.tar") torch.save(self.state_dict, chkpt) chpts = [f for f in os.listdir(self.checkpoint_path) if f.endswith(".pth.tar")] @@ -249,9 +240,7 @@ def restore_checkpoint(self, epoch=None): ] ) - chkpt = os.path.join( - self.checkpoint_path, "checkpoint-" + str(epoch) + ".pth.tar" - ) + chkpt = os.path.join(self.checkpoint_path, "checkpoint-" + str(epoch) + ".pth.tar") self.state_dict = torch.load(chkpt, map_location="cpu") def loss_backward(self, loss): @@ -341,7 +330,7 @@ def call_and_loss(self, batch, device, calc_loss): except RuntimeError as err: if "CUDA out of memory" in str(err): - print(("CUDA out of memory. Doing this batch " "on cpu.")) + print("CUDA out of memory. Doing this batch " "on cpu.") use_device = "cpu" torch.cuda.empty_cache() else: @@ -387,9 +376,7 @@ def train(self, device, n_epochs=MAX_EPOCHS): for hook in self.hooks: hook.on_batch_begin(self, batch) - batch, results, mini_loss, _ = self.call_and_loss( - batch=batch, device=device, calc_loss=True - ) + batch, results, mini_loss, _ = self.call_and_loss(batch=batch, device=device, calc_loss=True) if not torch.isnan(mini_loss): loss += mini_loss.cpu().detach().to(device) @@ -468,9 +455,7 @@ def get_par_folders(self): # each parallel folder just has the name of its global rank - par_folders = [ - os.path.join(self.model_path, str(i)) for i in range(self.world_size) - ] + par_folders = [os.path.join(self.model_path, str(i)) for i in range(self.world_size)] self_folder = par_folders[self.global_rank] # if the folder of this global rank doesn't exist yet then @@ -496,7 +481,7 @@ def save_val_loss(self, val_loss, n_val): # write the loss as a number to a file called "val_epoch_i" # for epoch i. - info_file = os.path.join(self_folder, "val_epoch_{}".format(self.epoch)) + info_file = os.path.join(self_folder, f"val_epoch_{self.epoch}") with open(info_file, "w") as f_open: loss_float = val_loss.item() string = f"{loss_float},{n_val}" @@ -526,7 +511,7 @@ def load_val_loss(self): # then no need to load anything if loaded_vals[folder] is not None: continue - val_file = os.path.join(folder, "val_epoch_{}".format(self.epoch)) + val_file = os.path.join(folder, f"val_epoch_{self.epoch}") # try opening the file and getting the value try: with open(val_file, "r") as f_open: @@ -543,9 +528,7 @@ def load_val_loss(self): # average the losses according to number of atoms # or molecules in each denom = sum(list(n_vals.values())) - avg_loss = ( - sum([n_vals[key] * loaded_vals[key] for key in n_vals.keys()]) / denom - ) + avg_loss = sum([n_vals[key] * loaded_vals[key] for key in n_vals]) / denom else: # add the losses avg_loss = np.sum(list(loaded_vals.values())) @@ -613,9 +596,7 @@ def validate(self, device, test=False): vsize = val_batch["nxyz"].size(0) n_val += vsize - val_batch, results, _, use_device = self.call_and_loss( - batch=val_batch, device=device, calc_loss=False - ) + val_batch, results, _, use_device = self.call_and_loss(batch=val_batch, device=device, calc_loss=False) # detach from the graph results = batch_to(batch_detach(results), use_device) diff --git a/nff/train/transfer.py b/nff/train/transfer.py index 54c250f2..ea3db699 100644 --- a/nff/train/transfer.py +++ b/nff/train/transfer.py @@ -71,7 +71,6 @@ def model_tl( """ Function to transfer learn a model. Defined in the subclasses. """ - pass class PainnLayerFreezer(LayerFreezer): @@ -88,9 +87,7 @@ def unfreeze_painn_readout(self, model: torch.nn.Module, freeze_skip: bool) -> N unfreeze_skip = not freeze_skip for i, block in enumerate(model.readout_blocks): - if unfreeze_skip: - self.unfreeze_parameters(block) - elif i == num_readouts - 1: + if unfreeze_skip or i == num_readouts - 1: self.unfreeze_parameters(block) def unfreeze_painn_pooling(self, model: torch.nn.Module) -> None: @@ -251,9 +248,7 @@ def unfreeze_mace_readout(self, model: torch.nn.Module, freeze_skip: bool = Fals unfreeze_skip = not freeze_skip for i, block in enumerate(model.readouts): - if unfreeze_skip: - self.unfreeze_parameters(block) - elif i == num_readouts - 1: + if unfreeze_skip or i == num_readouts - 1: self.unfreeze_parameters(block) print(f"Unfreezing {block.__class__.__name__}") @@ -406,9 +401,7 @@ def unfreeze_chgnet_readout(self, model: torch.nn.Module, freeze_skip: bool = Fa unfreeze_skip = not freeze_skip for i, block in enumerate(model.mlp.layers): - if unfreeze_skip: - self.unfreeze_parameters(block) - elif i == num_readouts - 1: + if unfreeze_skip or i == num_readouts - 1: self.unfreeze_parameters(block) def model_tl( diff --git a/nff/train/uncertainty.py b/nff/train/uncertainty.py index 6940aa88..8d6de4eb 100644 --- a/nff/train/uncertainty.py +++ b/nff/train/uncertainty.py @@ -1,19 +1,19 @@ import os import warnings -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch -from ..io.gmm import GaussianMixture +from nff.io.gmm import GaussianMixture __all__ = [ - "Uncertainty", + "ConformalPrediction", "EnsembleUncertainty", "EvidentialUncertainty", - "MVEUncertainty", "GMMUncertainty", - "ConformalPrediction", + "MVEUncertainty", + "Uncertainty", ] @@ -23,7 +23,7 @@ def __init__( order: str, calibrate: bool, cp_alpha: Union[None, float] = None, - min_uncertainty: float = None, + min_uncertainty: Optional[float] = None, *args, **kwargs, ): @@ -52,23 +52,19 @@ def set_min_uncertainty(self, min_uncertainty: float, force: bool = False) -> No """ Set the minimum uncertainty value to be used for scaling the uncertainty. """ - if getattr(self, "umin") is None: + if self.umin is None: self.umin = min_uncertainty elif force: - warnings.warn( - f"Uncertainty: min_uncertainty already set to {self.umin}. Overwriting." - ) + warnings.warn(f"Uncertainty: min_uncertainty already set to {self.umin}. Overwriting.", stacklevel=2) self.umin = min_uncertainty else: raise Exception(f"Uncertainty: min_uncertainty already set to {self.umin}") - def scale_to_min_uncertainty( - self, uncertainty: Union[np.ndarray, torch.Tensor] - ) -> Union[np.ndarray, torch.Tensor]: + def scale_to_min_uncertainty(self, uncertainty: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: """ Scale the uncertainty to the minimum value. """ - if getattr(self, "umin") is not None: + if self.umin is not None: if self.order not in ["system_mean_squared"]: uncertainty = uncertainty - self.umin else: @@ -92,16 +88,14 @@ def calibrate_uncertainty( """ Calibrate the uncertainty using Conformal Prediction. """ - if getattr(self.CP, "qhat") is None: + if self.CP.qhat is None: raise Exception("Uncertainty: ConformalPrediction not fitted.") cp_uncertainty, qhat = self.CP.predict(uncertainty) return cp_uncertainty - def get_system_uncertainty( - self, uncertainty: torch.Tensor, num_atoms: List[int] - ) -> torch.Tensor: + def get_system_uncertainty(self, uncertainty: torch.Tensor, num_atoms: List[int]) -> torch.Tensor: """ Get the uncertainty for the entire system. """ @@ -109,9 +103,7 @@ def get_system_uncertainty( assert len(uncertainty) == len(num_atoms), "Number of systems do not match" - assert all( - [len(u) == n for u, n in zip(uncertainty, num_atoms)] - ), "Number of atoms in each system do not match" + assert all(len(u) == n for u, n in zip(uncertainty, num_atoms)), "Number of atoms in each system do not match" if self.order == "system_sum": uncertainty = uncertainty.sum(dim=-1) @@ -155,9 +147,7 @@ def fit( scores = np.array(scores) n = len(residuals_calib) - qhat = torch.quantile( - torch.from_numpy(scores), np.ceil((n + 1) * (1 - self.alpha)) / n - ) + qhat = torch.quantile(torch.from_numpy(scores), np.ceil((n + 1) * (1 - self.alpha)) / n) qhat_value = np.float64(qhat.numpy()).item() self.qhat = qhat_value @@ -198,9 +188,7 @@ def __init__( self.targ_unit = targ_unit self.std_or_var = std_or_var - def convert_units( - self, value: Union[float, np.ndarray], orig_unit: str, targ_unit: str - ): + def convert_units(self, value: Union[float, np.ndarray], orig_unit: str, targ_unit: str): """ Convert the energy/forces units of the value from orig_unit to targ_unit. """ @@ -224,9 +212,7 @@ def get_energy_uncertainty( Get the uncertainty for the energy. """ if self.orig_unit is not None and self.targ_unit is not None: - results[self.q] = self.convert_units( - results[self.q], orig_unit=self.orig_unit, targ_unit=self.targ_unit - ) + results[self.q] = self.convert_units(results[self.q], orig_unit=self.orig_unit, targ_unit=self.targ_unit) if self.std_or_var == "std": val = results[self.q].std(-1) @@ -244,9 +230,7 @@ def get_forces_uncertainty( Get the uncertainty for the forces. """ if self.orig_unit is not None and self.targ_unit is not None: - results[self.q] = self.convert_units( - results[self.q], orig_unit=self.orig_unit, targ_unit=self.targ_unit - ) + results[self.q] = self.convert_units(results[self.q], orig_unit=self.orig_unit, targ_unit=self.targ_unit) splits = torch.split(results[self.q], list(num_atoms)) stack_split = torch.stack(splits, dim=0) @@ -263,9 +247,7 @@ def get_forces_uncertainty( return val - def get_uncertainty( - self, results: dict, num_atoms: Union[List[int], None] = None, *args, **kwargs - ): + def get_uncertainty(self, results: dict, num_atoms: Union[List[int], None] = None, *args, **kwargs): if self.q == "energy": val = self.get_energy_uncertainty(results=results) elif self.q in ["energy_grad", "forces"]: @@ -293,7 +275,7 @@ def __init__( source: str = "epistemic", calibrate: bool = False, cp_alpha: Union[float, None] = None, - min_uncertainty: float = None, + min_uncertainty: Optional[float] = None, *args, **kwargs, ): @@ -308,9 +290,7 @@ def __init__( self.shared_v = shared_v self.source = source - def check_params( - self, results: dict, num_atoms=None - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def check_params(self, results: dict, num_atoms=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Check if the parameters are present in the results, if the shapes are correct. If the order is "atomic" and shared_v is True, then the v @@ -324,7 +304,8 @@ def check_params( num_systems = len(num_atoms) total_atoms = torch.sum(num_atoms) if self.order == "atomic" and self.shared_v: - assert v.shape[0] == num_systems and alpha.shape[0] == total_atoms + assert v.shape[0] == num_systems + assert alpha.shape[0] == total_atoms v = torch.split(v, list(num_atoms)) v = torch.stack(v, dim=0) v = v.mean(-1, keepdims=True) @@ -332,9 +313,7 @@ def check_params( return v, alpha, beta - def get_uncertainty( - self, results: dict, num_atoms: Union[List[int], None] = None, *args, **kwargs - ) -> torch.Tensor: + def get_uncertainty(self, results: dict, num_atoms: Union[List[int], None] = None, *args, **kwargs) -> torch.Tensor: v, alpha, beta = self.check_params(results=results, num_atoms=num_atoms) if self.source == "aleatoric": @@ -348,9 +327,7 @@ def get_uncertainty( splits = torch.split(uncertainty, list(num_atoms)) stack_split = torch.stack(splits, dim=0) - uncertainty = self.get_system_uncertainty( - uncertainty=stack_split, num_atoms=num_atoms - ) + uncertainty = self.get_system_uncertainty(uncertainty=stack_split, num_atoms=num_atoms) uncertainty = self.scale_to_min_uncertainty(uncertainty) @@ -370,7 +347,7 @@ def __init__( variance_key: str = "var", quantity: str = "forces", order: str = "atomic", - min_uncertainty: float = None, + min_uncertainty: Optional[float] = None, *args, **kwargs, ): @@ -378,9 +355,7 @@ def __init__( self.vkey = variance_key self.q = quantity - def get_uncertainty( - self, results: dict, num_atoms: Union[List[int], None] = None, *args, **kwargs - ) -> torch.Tensor: + def get_uncertainty(self, results: dict, num_atoms: Union[List[int], None] = None, *args, **kwargs) -> torch.Tensor: var = results[self.vkey].squeeze() assert results[self.q].shape[0] == var.shape[0] @@ -388,9 +363,7 @@ def get_uncertainty( splits = torch.split(var, list(num_atoms)) stack_split = torch.stack(splits, dim=0) - var = self.get_system_uncertainty( - uncertainty=stack_split, num_atoms=num_atoms - ) + var = self.get_system_uncertainty(uncertainty=stack_split, num_atoms=num_atoms) var = self.scale_to_min_uncertainty(var) @@ -468,7 +441,7 @@ def fit_gmm(self, Xtrain: torch.Tensor) -> None: self.gm_model.fit(self.Xtrain.squeeze().cpu().numpy()) # Save the fitted GMM model if gmm_path is specified - if hasattr(self, "gmm_path") and not os.path.exists(getattr(self, "gmm_path")): + if hasattr(self, "gmm_path") and not os.path.exists(self.gmm_path): self.gm_model.save(self.gmm_path) print(f"Saved fitted GMM model to {self.gmm_path}") @@ -503,9 +476,7 @@ def _set_gmm_params(self) -> None: raise Exception("GMMUncertainty: GMM does not exist/is not fitted") self.means = self._check_tensor(self.gm_model.means_) - self.precisions_cholesky = self._check_tensor( - self.gm_model.precisions_cholesky_ - ) + self.precisions_cholesky = self._check_tensor(self.gm_model.precisions_cholesky_) self.weights = self._check_tensor(self.gm_model.weights_) def estimate_log_prob(self, X: torch.Tensor) -> torch.Tensor: @@ -518,9 +489,7 @@ def estimate_log_prob(self, X: torch.Tensor) -> torch.Tensor: n_clusters, _ = self.means.shape log_det = torch.sum( - torch.log( - self.precisions_cholesky.reshape(n_clusters, -1)[:, :: n_features + 1] - ), + torch.log(self.precisions_cholesky.reshape(n_clusters, -1)[:, :: n_features + 1]), dim=1, ) @@ -552,9 +521,7 @@ def log_likelihood(self, X: torch.Tensor) -> torch.Tensor: # below, the calculation below makes it stable # log(sum_i(a_i)) = log(exp(a_max) * sum_i(exp(a_i - a_max))) = a_max + log(sum_i(exp(a_i - a_max))) wlp_stable = weighted_log_prob - weighted_log_prob_max.reshape(-1, 1) - logsumexp = weighted_log_prob_max + torch.log( - torch.sum(torch.exp(wlp_stable), dim=1) - ) + logsumexp = weighted_log_prob_max + torch.log(torch.sum(torch.exp(wlp_stable), dim=1)) return logsumexp @@ -599,9 +566,7 @@ def get_uncertainty( splits = torch.split(uncertainty, list(num_atoms)) stack_split = torch.stack(splits, dim=0) - uncertainty = self.get_system_uncertainty( - uncertainty=stack_split, num_atoms=num_atoms - ).squeeze() + uncertainty = self.get_system_uncertainty(uncertainty=stack_split, num_atoms=num_atoms).squeeze() uncertainty = self.scale_to_min_uncertainty(uncertainty) diff --git a/nff/utils/cellfilters.py b/nff/utils/cellfilters.py index d963ebe0..1f72db02 100644 --- a/nff/utils/cellfilters.py +++ b/nff/utils/cellfilters.py @@ -1,6 +1,5 @@ -from ase.constraints import Filter, UnitCellFilter, ExpCellFilter import numpy as np - +from ase.constraints import UnitCellFilter from ase.stress import full_3x3_to_voigt_6_stress, voigt_6_to_full_3x3_stress @@ -39,9 +38,7 @@ def set_velocities(self, velocities): return self.atoms.set_velocities(velocities=velocities) def set_cell(self, cell, scale_atoms=False, apply_constraint=True): - return self.atoms.set_cell( - cell, scale_atoms=scale_atoms, apply_constraint=apply_constraint - ) + return self.atoms.set_cell(cell, scale_atoms=scale_atoms, apply_constraint=apply_constraint) def get_cell(self, complete=False): return self.atoms.get_cell(complete=complete) @@ -60,9 +57,7 @@ def get_stress(self, voigt=True, apply_constraint=True, include_ideal_gas=False) ) volume = self.atoms.get_volume() - virial = -volume * ( - voigt_6_to_full_3x3_stress(stress) + np.diag([self.scalar_pressure] * 3) - ) + virial = -volume * (voigt_6_to_full_3x3_stress(stress) + np.diag([self.scalar_pressure] * 3)) cur_deform_grad = self.deform_grad() virial = np.linalg.solve(cur_deform_grad, virial.T).T @@ -97,9 +92,7 @@ def get_forces(self, **kwargs): return forces def get_potential_energy(self, force_consistent=False, apply_constraint=True): - return self.atoms.get_potential_energy( - force_consistent=force_consistent, apply_constraint=apply_constraint - ) + return self.atoms.get_potential_energy(force_consistent=force_consistent, apply_constraint=apply_constraint) def get_global_number_of_atoms(self): return self.atoms.get_global_number_of_atoms() diff --git a/nff/utils/confgen.py b/nff/utils/confgen.py index d35dc26d..47d0e650 100644 --- a/nff/utils/confgen.py +++ b/nff/utils/confgen.py @@ -1,30 +1,29 @@ +import copy +import json import os +import pickle import random -import subprocess import re import socket +import subprocess import time -import numpy as np -import json -import pickle -import copy -import math -from rdkit.Chem import (AddHs, MolFromSmiles, inchi, GetPeriodicTable, - Conformer, MolToSmiles) -from rdkit.Chem.AllChem import (EmbedMultipleConfs, - UFFGetMoleculeForceField, - MMFFGetMoleculeForceField, - MMFFGetMoleculeProperties, - GetConformerRMS) -from rdkit.Chem.rdmolops import RemoveHs, GetFormalCharge +import numpy as np +from rdkit.Chem import AddHs, Conformer, GetPeriodicTable, MolFromSmiles, MolToSmiles, inchi +from rdkit.Chem.AllChem import ( + EmbedMultipleConfs, + GetConformerRMS, + MMFFGetMoleculeForceField, + MMFFGetMoleculeProperties, + UFFGetMoleculeForceField, +) +from rdkit.Chem.rdmolops import GetFormalCharge, RemoveHs from nff.utils.misc import read_csv, tqdm_enum -from nff.data.parallel import gen_parallel PERIODICTABLE = GetPeriodicTable() -UFF_ELEMENTS = ['B', 'Al'] +UFF_ELEMENTS = ["B", "Al"] DEFAULT_GEOM_COMPARE_TIMEOUT = 300 XYZ_NAME = "{0}_Conf_{1}.xyz" MAX_CONFS = 10000 @@ -35,60 +34,64 @@ def write_xyz(coords, filename, comment): - ''' + """ Write an xyz file from coords - ''' + """ with open(filename, "w") as f_p: f_p.write(str(len(coords)) + "\n") f_p.write(str(comment) + "\n") for atom in coords: - f_p.write("%s %.4f %.4f %.4f\n" % - (atom[0], atom[1][0], atom[1][1], atom[1][2])) + f_p.write("%s %.4f %.4f %.4f\n" % (atom[0], atom[1][0], atom[1][1], atom[1][2])) def obfit_rmsd(file1, file2, smarts, path): - cmd = ["obfit", "'" + str(smarts) + "'", - os.path.join(path, file1 + '.xyz'), - os.path.join(path, file2 + '.xyz')] - ret = subprocess.check_output(" ".join(cmd), - stdin=None, - stderr=subprocess.STDOUT, - shell=True, - universal_newlines=False, - timeout=DEFAULT_GEOM_COMPARE_TIMEOUT) - rmsd = float(ret.decode('utf-8')[5:13]) + cmd = ["obfit", "'" + str(smarts) + "'", os.path.join(path, file1 + ".xyz"), os.path.join(path, file2 + ".xyz")] + ret = subprocess.check_output( + " ".join(cmd), + stdin=None, + stderr=subprocess.STDOUT, + shell=True, + universal_newlines=False, + timeout=DEFAULT_GEOM_COMPARE_TIMEOUT, + ) + rmsd = float(ret.decode("utf-8")[5:13]) return rmsd def align_rmsd(file1, file2, path, smarts=None): - cmd = ["obabel", - os.path.join(path, file1 + '.xyz'), - os.path.join(path, file2 + '.xyz'), - '-o', 'smi', - '--align', - '--append', - 'rmsd'] + cmd = [ + "obabel", + os.path.join(path, file1 + ".xyz"), + os.path.join(path, file2 + ".xyz"), + "-o", + "smi", + "--align", + "--append", + "rmsd", + ] if smarts: - cmd += ['-s', str(smarts)] - ret = subprocess.check_output(cmd, - stdin=None, - stderr=subprocess.STDOUT, - shell=False, - universal_newlines=False, - timeout=DEFAULT_GEOM_COMPARE_TIMEOUT) - rmsd = ret.decode('utf-8').split()[-1] + cmd += ["-s", str(smarts)] + ret = subprocess.check_output( + cmd, + stdin=None, + stderr=subprocess.STDOUT, + shell=False, + universal_newlines=False, + timeout=DEFAULT_GEOM_COMPARE_TIMEOUT, + ) + rmsd = ret.decode("utf-8").split()[-1] return float(rmsd) -class ConformerGenerator(object): - ''' +class ConformerGenerator: + """ Generates conformations of molecules from 2D representation. - ''' + """ def __init__(self, smiles, forcefield="mmff"): - ''' + """ Initialises the class - ''' + """ self.mol = MolFromSmiles(smiles) self.full_clusters = [] self.forcefield = forcefield @@ -96,18 +99,15 @@ def __init__(self, smiles, forcefield="mmff"): self.initial_confs = None self.smiles = smiles - def generate(self, - max_generated_conformers=50, - prune_thresh=0.01, - maxattempts_per_conformer=5, - output=None, - threads=1): - ''' + def generate( + self, max_generated_conformers=50, prune_thresh=0.01, maxattempts_per_conformer=5, output=None, threads=1 + ): + """ Generates conformers Note the number max_generated _conformers required is related to the number of rotatable bonds - ''' + """ self.mol = AddHs(self.mol, addCoords=True) self.initial_confs = EmbedMultipleConfs( self.mol, @@ -118,14 +118,11 @@ def generate(self, # Despite what the documentation says -1 is a seed!! # It doesn't mean random generation numThreads=threads, - randomSeed=random.randint( - 1, 10000000) + randomSeed=random.randint(1, 10000000), ) if len(self.initial_confs) == 0: - output.write((f"Generated {len(self.initial_confs)} " - "initial confs\n")) - output.write((f"Trying again with {max_generated_conformers * 10} " - "attempts and random coords\n")) + output.write(f"Generated {len(self.initial_confs)} " "initial confs\n") + output.write(f"Trying again with {max_generated_conformers * 10} " "attempts and random coords\n") self.initial_confs = EmbedMultipleConfs( self.mol, @@ -137,35 +134,28 @@ def generate(self, # It doesn't mean random # generatrion numThreads=threads, - randomSeed=random.randint( - 1, 10000000) + randomSeed=random.randint(1, 10000000), ) - output.write("Generated " + - str(len(self.initial_confs)) + " initial confs\n") + output.write("Generated " + str(len(self.initial_confs)) + " initial confs\n") return self.initial_confs - def minimise(self, - output=None, - minimize=True): - ''' + def minimise(self, output=None, minimize=True): + """ Minimises conformers using a force field - ''' + """ if "\\" in self.smiles or "/" in self.smiles: - output.write(("WARNING: Smiles string contains slashes, " - "which specify cis/trans stereochemistry.\n")) - output.write(("Bypassing force-field minimization to avoid generating " - "incorrect isomer.\n")) + output.write("WARNING: Smiles string contains slashes, " "which specify cis/trans stereochemistry.\n") + output.write("Bypassing force-field minimization to avoid generating " "incorrect isomer.\n") minimize = False if self.forcefield != "mmff" and self.forcefield != "uff": raise ValueError("Unrecognised force field") if self.forcefield == "mmff": props = MMFFGetMoleculeProperties(self.mol) - for i in range(0, len(self.initial_confs)): - potential = MMFFGetMoleculeForceField( - self.mol, props, confId=i) + for i in range(len(self.initial_confs)): + potential = MMFFGetMoleculeForceField(self.mol, props, confId=i) if potential is None: output.write("MMFF not available, using UFF\n") potential = UFFGetMoleculeForceField(self.mol, confId=i) @@ -178,7 +168,7 @@ def minimise(self, self.conf_energies.append((i, mmff_energy)) elif self.forcefield == "uff": - for i in range(0, len(self.initial_confs)): + for i in range(len(self.initial_confs)): potential = UFFGetMoleculeForceField(self.mol, confId=i) assert potential is not None if minimize: @@ -188,15 +178,10 @@ def minimise(self, self.conf_energies = sorted(self.conf_energies, key=lambda tup: tup[1]) return self.mol - def cluster(self, - rms_tolerance=0.1, - max_ranked_conformers=10, - energy_window=5, - Report_e_tol=10, - output=None): - ''' + def cluster(self, rms_tolerance=0.1, max_ranked_conformers=10, energy_window=5, Report_e_tol=10, output=None): + """ Removes duplicates after minimization - ''' + """ self.counter = 0 self.factormax = 3 self.mol_no_h = RemoveHs(self.mol) @@ -209,13 +194,14 @@ def cluster(self, for i, pair_1 in enumerate(confs): if i == 0: index_0, energy_0 = pair_1 - output.write((f"clustering cluster {i} of " - f"{len(self.conf_energies)}\n")) + output.write(f"clustering cluster {i} of " f"{len(self.conf_energies)}\n") index_1, energy_1 = pair_1 if abs(energy_1 - energy_0) > Report_e_tol: - output.write(("Breaking because hit Report Energy Window, " - f"E was {energy_1} kcal/mol " - f"and minimum was {energy_0} \n")) + output.write( + "Breaking because hit Report Energy Window, " + f"E was {energy_1} kcal/mol " + f"and minimum was {energy_0} \n" + ) break if i in ignore: @@ -223,7 +209,7 @@ def cluster(self, continue self.counter += 1 if self.counter == self.factormax * max_ranked_conformers: - output.write('Breaking because hit MaxNConfs \n') + output.write("Breaking because hit MaxNConfs \n") break clustered = [[self.mol.GetConformer(id=index_1), energy_1, 0.00]] ignore.append(i) @@ -236,22 +222,16 @@ def cluster(self, if abs(energy_1 - energy_2) > energy_window: break if abs(energy_1 - energy_2) <= 1e-3: - clustered.append([self.mol.GetConformer(id=index_2), - energy_2, 0.00]) + clustered.append([self.mol.GetConformer(id=index_2), energy_2, 0.00]) ignore.append(j) - rms = GetConformerRMS(self.mol_no_h, - index_1, - index_2) + rms = GetConformerRMS(self.mol_no_h, index_1, index_2) calcs_performed += 1 if rms <= rms_tolerance: - clustered.append( - [self.mol.GetConformer(id=index_2), - energy_2, rms]) + clustered.append([self.mol.GetConformer(id=index_2), energy_2, rms]) ignore.append(j) self.full_clusters.append(clustered) output.write(f"{ignored} ignore passes made\n") - output.write((f"{calcs_performed} overlays needed out " - f"of a possible {len(self.conf_energies) ** 2}\n")) + output.write(f"{calcs_performed} overlays needed out " f"of a possible {len(self.conf_energies) ** 2}\n") ranked_clusters = [] for i, cluster in enumerate(self.full_clusters): @@ -260,16 +240,18 @@ def cluster(self, return ranked_clusters - def recluster(self, - path, - rms_tolerance=0.1, - max_ranked_conformers=10, - energy_window=5, - output=None, - clustered_confs=[], - molecule=None, - key=None, - fallback_to_align=False): + def recluster( + self, + path, + rms_tolerance=0.1, + max_ranked_conformers=10, + energy_window=5, + output=None, + clustered_confs=[], + molecule=None, + key=None, + fallback_to_align=False, + ): self.removed = [] self.counter = 0 i = -1 @@ -280,63 +262,50 @@ def recluster(self, for k in range(i, len(clustered_confs)): if os.path.isfile(key + "_Conf_" + str(k + 1) + ".xyz"): os.remove(key + "_Conf_" + str(k + 1) + ".xyz") - output.write("Removed " + key + - "_Conf_" + str(k + 1) + ".xyz\n") + output.write("Removed " + key + "_Conf_" + str(k + 1) + ".xyz\n") break if i in self.removed: continue self.counter += 1 - for conf_b in clustered_confs[i + 1:]: + for conf_b in clustered_confs[i + 1 :]: j += 1 if conf_b[1] - conf_a[1] > energy_window: break if j in self.removed: continue try: - rms = obfit_rmsd(key + "_Conf_" + str(i + 1), - key + "_Conf_" + str(j + 1), - str(molecule), - path=path) - except (subprocess.CalledProcessError, ValueError, - subprocess.TimeoutExpired) as e: + rms = obfit_rmsd(key + "_Conf_" + str(i + 1), key + "_Conf_" + str(j + 1), str(molecule), path=path) + except (subprocess.CalledProcessError, ValueError, subprocess.TimeoutExpired) as e: if fallback_to_align: - output.write( - 'obfit failed, falling back to obabel --align') - output.write(f'Exception {e}\n') + output.write("obfit failed, falling back to obabel --align") + output.write(f"Exception {e}\n") try: - rms = align_rmsd(f"{key}_Conf_{str(i + 1)}", - f"{key}_Conf_{str(j + 1)}", - path) + rms = align_rmsd(f"{key}_Conf_{i + 1!s}", f"{key}_Conf_{j + 1!s}", path) except (ValueError, subprocess.TimeoutExpired): continue else: continue - output.write("Comparing " + str(i + 1) + " " + - str(j + 1) + ' RMSD ' + str(rms) + "\n") + output.write("Comparing " + str(i + 1) + " " + str(j + 1) + " RMSD " + str(rms) + "\n") if rms > rms_tolerance: pos = _atomic_pos_from_conformer(conf_b[0]) elements = _extract_atomic_type(conf_b[0]) pos = [[-float(coor[k]) for k in range(3)] for coor in pos] coords = list(zip(elements, pos)) - filename = os.path.join(path, key + "_Conf_" + - str(j + 1) + "_inv.xyz") - write_xyz(coords=coords, filename=filename, - comment=conf_b[1]) + filename = os.path.join(path, key + "_Conf_" + str(j + 1) + "_inv.xyz") + write_xyz(coords=coords, filename=filename, comment=conf_b[1]) try: file1 = key + "_Conf_" + str(i + 1) file2 = key + "_Conf_" + str(j + 1) + "_inv" rmsinv = obfit_rmsd(file1, file2, str(molecule)) - except (subprocess.CalledProcessError, ValueError, - subprocess.TimeoutExpired) as e: + except (subprocess.CalledProcessError, ValueError, subprocess.TimeoutExpired) as e: if fallback_to_align: - output.write( - 'obfit failed, falling back to obabel --align') - output.write(f'Exception {e}\n') + output.write("obfit failed, falling back to obabel --align") + output.write(f"Exception {e}\n") try: - i_key = f"{key}_Conf_{str(i + 1)}" - inv_key = f"{key}_Conf_{str(j + 1)}_inv" + i_key = f"{key}_Conf_{i + 1!s}" + inv_key = f"{key}_Conf_{j + 1!s}_inv" rmsinv = align_rmsd(i_key, inv_key) except (ValueError, subprocess.TimeoutExpired): continue @@ -345,8 +314,7 @@ def recluster(self, rms = min([rms, rmsinv]) os.remove(key + "_Conf_" + str(j + 1) + "_inv.xyz") - output.write((f"Comparing {i + 1} {j + 1} " - f"RMSD after checking inversion {rms}\n")) + output.write(f"Comparing {i + 1} {j + 1} " f"RMSD after checking inversion {rms}\n") if rms <= rms_tolerance: self.removed.append(j) output.write("Removed Conf_" + str(j + 1) + "\n") @@ -354,26 +322,23 @@ def recluster(self, def _extract_atomic_type(confomer): - ''' + """ Extracts the elements associated with a conformer, in order that prune_threshy are read in - ''' - elements = [] + """ mol = confomer.GetOwningMol() - for atom in mol.GetAtoms(): - elements.append(atom.GetSymbol()) - return elements + return [atom.GetSymbol() for atom in mol.GeAtoms()] def _atomic_pos_from_conformer(conformer): - ''' + """ Extracts the atomic positions for an RDKit conformer object, to allow writing of input files, uploading to databases, etc. Returns a list of lists - ''' + """ atom_positions = [] natoms = conformer.GetNumAtoms() - for atom_num in range(0, natoms): + for atom_num in range(natoms): pos = conformer.GetAtomPosition(atom_num) atom_positions.append([pos.x, pos.y, pos.z]) return atom_positions @@ -384,13 +349,12 @@ def rename_xyz_files(path): flist = os.listdir(path) for filename in flist: if filename.endswith(".xyz"): - num = int(filename.split('_')[-1][:-4]) + num = int(filename.split("_")[-1][:-4]) namedict[num] = filename keys = namedict.keys() for i, num in enumerate(sorted(keys)): oldfilename = namedict[num] - newfilename = '_'.join(oldfilename.split( - '_')[:-1]) + '_' + str(i + 1) + ".xyz" + newfilename = "_".join(oldfilename.split("_")[:-1]) + "_" + str(i + 1) + ".xyz" oldfilepath = os.path.join(path, oldfilename) newfilepath = os.path.join(path, newfilename) os.rename(oldfilepath, newfilepath) @@ -398,111 +362,80 @@ def rename_xyz_files(path): def clean(molecule): molecule = str(molecule.split()[0]) - molecule = re.sub('Cl', '[#17]', molecule) - molecule = re.sub('C', '[#6]', molecule) - molecule = re.sub('c', '[#6]', molecule) - molecule = re.sub('\[N-\]', '[#7-]', molecule) - molecule = re.sub('N', '[#7]', molecule) - molecule = re.sub('n', '[#7]', molecule) - molecule = re.sub('\[\[', '[', molecule) - molecule = re.sub('\]\]', ']', molecule) - molecule = re.sub('\]H\]', 'H]', molecule) - molecule = re.sub('=', '~', molecule) + molecule = re.sub("Cl", "[#17]", molecule) + molecule = re.sub("C", "[#6]", molecule) + molecule = re.sub("c", "[#6]", molecule) + molecule = re.sub("\\[N-\\]", "[#7-]", molecule) + molecule = re.sub("N", "[#7]", molecule) + molecule = re.sub("n", "[#7]", molecule) + molecule = re.sub("\\[\\[", "[", molecule) + molecule = re.sub("\\]\\]", "]", molecule) + molecule = re.sub("\\]H\\]", "H]", molecule) + molecule = re.sub("=", "~", molecule) return molecule -def minimize(output, - molecule, - forcefield, - nconf_gen, - prun_tol, - e_window, - rms_tol, - rep_e_window): - +def minimize(output, molecule, forcefield, nconf_gen, prun_tol, e_window, rms_tol, rep_e_window): output.write(f"Analysing smiles string {molecule}\n") MolFromSmiles(molecule) # print "There are", NumRotatableBonds(mol) output.write("Generating initial conformations\n") - confgen = ConformerGenerator( - smiles=molecule, forcefield=forcefield) - output.write((f"Minimising conformations using the {forcefield} " - "force field\n")) - confgen.generate(max_generated_conformers=int(nconf_gen), - prune_thresh=float(prun_tol), - output=output) + confgen = ConformerGenerator(smiles=molecule, forcefield=forcefield) + output.write(f"Minimising conformations using the {forcefield} " "force field\n") + confgen.generate(max_generated_conformers=int(nconf_gen), prune_thresh=float(prun_tol), output=output) gen_time = time.time() confgen.minimise(output=output) min_time = time.time() - output.write(("Minimisation complete, generated conformations " - "with the following energies:\n")) - output.write("\n".join([str(energy[1]) - for energy in confgen.conf_energies])+"\n") - msg = (f"Clustering structures using an energy window of " - f"{e_window} and an rms tolerance of {rms_tol} and a " - f"Report Energy Window of {rep_e_window}\n") + output.write("Minimisation complete, generated conformations " "with the following energies:\n") + output.write("\n".join([str(energy[1]) for energy in confgen.conf_energies]) + "\n") + msg = ( + f"Clustering structures using an energy window of " + f"{e_window} and an rms tolerance of {rms_tol} and a " + f"Report Energy Window of {rep_e_window}\n" + ) output.write(msg) return confgen, gen_time, min_time -def write_clusters(output, - idx, - conformer, - inchikey, - path): +def write_clusters(output, idx, conformer, inchikey, path): output.write(f"Cluster {idx} has energy {conformer[1]}\n") pos = _atomic_pos_from_conformer(conformer[0]) elements = _extract_atomic_type(conformer[0]) coords = list(zip(elements, pos)) xyz_file = os.path.join(path, f"{inchikey}_Conf_{(idx + 1)}.xyz") - write_xyz(coords=coords, filename=xyz_file, - comment=conformer[1]) + write_xyz(coords=coords, filename=xyz_file, comment=conformer[1]) -def run_obabel(inchikey, - idx): +def run_obabel(inchikey, idx): try: cmd = ["obabel", f"{inchikey}_Conf_{(idx+1)}.xyz", "-osmi"] except UnboundLocalError as err: print(f"Did not produce any geometries for {inchikey} {err}") raise - molecule = subprocess.check_output(cmd, - stdin=None, - stderr=subprocess.STDOUT, - shell=False, - universal_newlines=False - ).decode('utf-8') + molecule = subprocess.check_output( + cmd, stdin=None, stderr=subprocess.STDOUT, shell=False, universal_newlines=False + ).decode("utf-8") molecule = clean(molecule) return molecule -def summarize(output, - gen_time, - start_time, - min_time, - cluster_time): - +def summarize(output, gen_time, start_time, min_time, cluster_time): recluster_time = time.time() output.write(socket.gethostname() + "\n") - output.write('gen time {0:1f} sec\n'.format( - gen_time - start_time)) - output.write('min time {0:1f} sec\n'.format(min_time - gen_time)) - output.write('cluster time {0:1f} sec\n'.format( - cluster_time - min_time)) - output.write('recluster time {0:1f} sec\n'.format( - recluster_time - cluster_time)) - output.write('total time {0:1f} sec\n'.format( - time.time() - start_time)) - output.write('Terminated successfully\n') + output.write(f"gen time {gen_time - start_time:1f} sec\n") + output.write(f"min time {min_time - gen_time:1f} sec\n") + output.write(f"cluster time {cluster_time - min_time:1f} sec\n") + output.write(f"recluster time {recluster_time - cluster_time:1f} sec\n") + output.write(f"total time {time.time() - start_time:1f} sec\n") + output.write("Terminated successfully\n") def get_mol(smiles): - mol = MolFromSmiles(MolToSmiles( - MolFromSmiles(smiles))) + mol = MolFromSmiles(MolToSmiles(MolFromSmiles(smiles))) return mol @@ -519,9 +452,7 @@ def xyz_to_rdmol(nxyz, smiles): return mol -def update_with_boltz(geom_list, - temp): - +def update_with_boltz(geom_list, temp): rel_ens = np.array([i["relativeenergy"] for i in geom_list]) degens = np.array([i["degeneracy"] for i in geom_list]) k_t = temp * KB_KCAL @@ -547,11 +478,7 @@ def parse_nxyz(lines): return nxyz -def make_geom_dic(lines, - smiles, - geom_list, - idx): - +def make_geom_dic(lines, smiles, geom_list, idx): nxyz = parse_nxyz(lines) energy_kcal = float(lines[1]) # total energy in au @@ -561,18 +488,18 @@ def make_geom_dic(lines, if idx == 0: rel_energy = 0 else: - ref_energy_kcal = (geom_list[0]["totalenergy"] - * AU_TO_KCAL) + ref_energy_kcal = geom_list[0]["totalenergy"] * AU_TO_KCAL rel_energy = energy_kcal - ref_energy_kcal - rd_mol = xyz_to_rdmol(nxyz=nxyz, - smiles=smiles) + rd_mol = xyz_to_rdmol(nxyz=nxyz, smiles=smiles) - geom = {"confnum": idx + 1, - "totalenergy": energy_au, - "relativeenergy": rel_energy, - "degeneracy": 1, - "rd_mol": rd_mol} + geom = { + "confnum": idx + 1, + "totalenergy": energy_au, + "relativeenergy": rel_energy, + "degeneracy": 1, + "rd_mol": rd_mol, + } return geom @@ -583,42 +510,30 @@ def get_charge(smiles): return charge -def combine_geom_dics(geom_list, - temp, - other_props, - smiles): - - totalconfs = sum([i["degeneracy"] - for i in geom_list]) +def combine_geom_dics(geom_list, temp, other_props, smiles): + totalconfs = sum([i["degeneracy"] for i in geom_list]) uniqueconfs = len(geom_list) lowestenergy = geom_list[0]["totalenergy"] - poplowestpct = (geom_list[0]["boltzmannweight"] - * 100) + poplowestpct = geom_list[0]["boltzmannweight"] * 100 charge = get_charge(smiles) - combination = {"totalconfs": totalconfs, - "uniqueconfs": uniqueconfs, - "temperature": temp, - "lowestenergy": lowestenergy, - "poplowestpct": poplowestpct, - "charge": charge, - "conformers": geom_list, - "smiles": smiles} + combination = { + "totalconfs": totalconfs, + "uniqueconfs": uniqueconfs, + "temperature": temp, + "lowestenergy": lowestenergy, + "poplowestpct": poplowestpct, + "charge": charge, + "conformers": geom_list, + "smiles": smiles, + } if other_props is not None: combination.update(other_props) return combination -def parse_results(job_dir, - log_file, - inchikey, - smiles, - max_confs, - other_props, - temp, - clean_up): - +def parse_results(job_dir, log_file, inchikey, smiles, max_confs, other_props, temp, clean_up): # import pdb # pdb.set_trace() @@ -628,8 +543,7 @@ def parse_results(job_dir, loglines = f_p.readlines() if loglines[-1].strip() != "Terminated successfully": - msg = ("'Terminated successfully' not found " - "at end of conformer output") + msg = "'Terminated successfully' not found " "at end of conformer output" raise Exception(msg) geom_list = [] @@ -637,21 +551,14 @@ def parse_results(job_dir, path = os.path.join(job_dir, XYZ_NAME.format(inchikey, i + 1)) if not os.path.isfile(path): continue - with open(path, 'r') as f_p: + with open(path, "r") as f_p: lines = f_p.readlines() - geom = make_geom_dic(lines=lines, - smiles=smiles, - geom_list=geom_list, - idx=i) + geom = make_geom_dic(lines=lines, smiles=smiles, geom_list=geom_list, idx=i) geom_list.append(geom) - geom_list = update_with_boltz(geom_list=geom_list, - temp=temp) - summary_dic = combine_geom_dics(geom_list=geom_list, - temp=temp, - other_props=other_props, - smiles=smiles) + geom_list = update_with_boltz(geom_list=geom_list, temp=temp) + summary_dic = combine_geom_dics(geom_list=geom_list, temp=temp, other_props=other_props, smiles=smiles) if clean_up: for file in os.listdir(job_dir): @@ -663,103 +570,102 @@ def parse_results(job_dir, return summary_dic -def one_species_confs(molecule, - log, - other_props, - max_confs, - forcefield, - nconf_gen, - e_window, - rms_tol, - prun_tol, - job_dir, - log_file, - rep_e_window, - fallback_to_align, - temp, - clean_up, - start_time): - +def one_species_confs( + molecule, + log, + other_props, + max_confs, + forcefield, + nconf_gen, + e_window, + rms_tol, + prun_tol, + job_dir, + log_file, + rep_e_window, + fallback_to_align, + temp, + clean_up, + start_time, +): smiles = copy.deepcopy(molecule) with open(log, "w") as output: output.write("The smiles strings that will be run are:\n") - output.write("\n".join([molecule])+"\n") - - if any([element in molecule for element in UFF_ELEMENTS]): - output.write(("Switching to UFF, since MMFF94 does " - "not have boron and/or aluminum\n")) - forcefield = 'uff' - - confgen, gen_time, min_time = minimize(output=output, - molecule=molecule, - forcefield=forcefield, - nconf_gen=nconf_gen, - prun_tol=prun_tol, - e_window=e_window, - rms_tol=rms_tol, - rep_e_window=rep_e_window) - clustered_confs = confgen.cluster(rms_tolerance=float(rms_tol), - max_ranked_conformers=int( - max_confs), - energy_window=float(e_window), - Report_e_tol=float(rep_e_window), - output=output) + output.write(f"{molecule}" + "\n") + + if any(element in molecule for element in UFF_ELEMENTS): + output.write("Switching to UFF, since MMFF94 does " "not have boron and/or aluminum\n") + forcefield = "uff" + + confgen, gen_time, min_time = minimize( + output=output, + molecule=molecule, + forcefield=forcefield, + nconf_gen=nconf_gen, + prun_tol=prun_tol, + e_window=e_window, + rms_tol=rms_tol, + rep_e_window=rep_e_window, + ) + clustered_confs = confgen.cluster( + rms_tolerance=float(rms_tol), + max_ranked_conformers=int(max_confs), + energy_window=float(e_window), + Report_e_tol=float(rep_e_window), + output=output, + ) cluster_time = time.time() - inchikey = inchi.MolToInchiKey(get_mol(molecule), - options=INCHI_OPTIONS) + inchikey = inchi.MolToInchiKey(get_mol(molecule), options=INCHI_OPTIONS) for i, conformer in enumerate(clustered_confs): - write_clusters(output=output, - idx=i, - conformer=conformer, - inchikey=inchikey, - path=job_dir) - - molecule = run_obabel(inchikey=inchikey, - idx=i) - confgen.recluster(path=job_dir, - rms_tolerance=float(rms_tol), - max_ranked_conformers=int(max_confs), - energy_window=float(e_window), - output=output, - clustered_confs=clustered_confs, - molecule=molecule, - key=inchikey, - fallback_to_align=fallback_to_align) + write_clusters(output=output, idx=i, conformer=conformer, inchikey=inchikey, path=job_dir) + + molecule = run_obabel(inchikey=inchikey, idx=i) + confgen.recluster( + path=job_dir, + rms_tolerance=float(rms_tol), + max_ranked_conformers=int(max_confs), + energy_window=float(e_window), + output=output, + clustered_confs=clustered_confs, + molecule=molecule, + key=inchikey, + fallback_to_align=fallback_to_align, + ) rename_xyz_files(path=job_dir) - summarize(output=output, - gen_time=gen_time, - start_time=start_time, - min_time=min_time, - cluster_time=cluster_time) - - conf_dic = parse_results(job_dir=job_dir, - log_file=log_file, - inchikey=inchikey, - max_confs=max_confs, - other_props=other_props, - temp=temp, - smiles=smiles, - clean_up=clean_up) + summarize(output=output, gen_time=gen_time, start_time=start_time, min_time=min_time, cluster_time=cluster_time) + + conf_dic = parse_results( + job_dir=job_dir, + log_file=log_file, + inchikey=inchikey, + max_confs=max_confs, + other_props=other_props, + temp=temp, + smiles=smiles, + clean_up=clean_up, + ) return conf_dic -def run_generator(smiles_list, - other_props=None, - max_confs=MAX_CONFS, - forcefield="mmff", - nconf_gen=(10 * MAX_CONFS), - e_window=5.0, - rms_tol=0.1, - prun_tol=0.01, - job_dir="confs", - log_file="confgen.log", - rep_e_window=5.0, - fallback_to_align=False, - temp=298.15, - clean_up=True, - **kwargs): +def run_generator( + smiles_list, + other_props=None, + max_confs=MAX_CONFS, + forcefield="mmff", + nconf_gen=(10 * MAX_CONFS), + e_window=5.0, + rms_tol=0.1, + prun_tol=0.01, + job_dir="confs", + log_file="confgen.log", + rep_e_window=5.0, + fallback_to_align=False, + temp=298.15, + clean_up=True, + **kwargs, +): """ Args: smiles_list (list[str]): list of SMILES strings @@ -786,36 +692,33 @@ def run_generator(smiles_list, conf_dics = [] for molecule in smiles_list: - conf_dic = one_species_confs(molecule=molecule, - log=log, - other_props=other_props, - max_confs=max_confs, - forcefield=forcefield, - nconf_gen=nconf_gen, - e_window=e_window, - rms_tol=rms_tol, - prun_tol=prun_tol, - job_dir=job_dir, - log_file=log_file, - rep_e_window=rep_e_window, - fallback_to_align=fallback_to_align, - temp=temp, - clean_up=clean_up, - start_time=start_time) + conf_dic = one_species_confs( + molecule=molecule, + log=log, + other_props=other_props, + max_confs=max_confs, + forcefield=forcefield, + nconf_gen=nconf_gen, + e_window=e_window, + rms_tol=rms_tol, + prun_tol=prun_tol, + job_dir=job_dir, + log_file=log_file, + rep_e_window=rep_e_window, + fallback_to_align=fallback_to_align, + temp=temp, + clean_up=clean_up, + start_time=start_time, + ) conf_dics.append(conf_dic) return conf_dics -def add_to_summary(summary_dic, - conf_dic, - smiles, - save_dir): - inchikey = inchi.MolToInchiKey(get_mol(smiles), - options=INCHI_OPTIONS) +def add_to_summary(summary_dic, conf_dic, smiles, save_dir): + inchikey = inchi.MolToInchiKey(get_mol(smiles), options=INCHI_OPTIONS) pickle_path = os.path.join(os.path.abspath(save_dir), f"{inchikey}.pickle") - summary_dic[smiles] = {key: val for key, val in - conf_dic.items() if key != "conformers"} + summary_dic[smiles] = {key: val for key, val in conf_dic.items() if key != "conformers"} summary_dic[smiles].update({"pickle_path": pickle_path}) return summary_dic, pickle_path @@ -837,16 +740,12 @@ def confs_and_save(config_path): print(f"Saving pickle files to directory {save_dir}") for i, smiles in tqdm_enum(smiles_dic["smiles"]): - other_props = {key: val[i] for key, val in smiles_dic.items() - if key != 'smiles'} + other_props = {key: val[i] for key, val in smiles_dic.items() if key != "smiles"} smiles_list = [smiles] - conf_dic = run_generator(smiles_list=smiles_list, - other_props=other_props, - **info)[0] - summary_dic, pickle_path = add_to_summary(summary_dic=summary_dic, - conf_dic=conf_dic, - smiles=smiles, - save_dir=save_dir) + conf_dic = run_generator(smiles_list=smiles_list, other_props=other_props, **info)[0] + summary_dic, pickle_path = add_to_summary( + summary_dic=summary_dic, conf_dic=conf_dic, smiles=smiles, save_dir=save_dir + ) with open(pickle_path, "wb") as f_open: pickle.dump(conf_dic, f_open) diff --git a/nff/utils/confs.py b/nff/utils/confs.py index cabc6a72..200b2366 100644 --- a/nff/utils/confs.py +++ b/nff/utils/confs.py @@ -2,10 +2,11 @@ Tools for manipulating conformer numbers in a dataset. """ -import torch +import copy import math + import numpy as np -import copy +import torch from nff.utils.misc import tqdm_enum @@ -25,14 +26,11 @@ def assert_ordered(batch): """ weights = batch["weights"].reshape(-1).tolist() - sort_weights = sorted(weights, - key=lambda x: -x) + sort_weights = sorted(weights, key=lambda x: -x) assert weights == sort_weights -def get_batch_dic(batch, - idx_dic, - num_confs): +def get_batch_dic(batch, idx_dic, num_confs): """ Get some conformer information about the batch. Args: @@ -44,7 +42,7 @@ def get_batch_dic(batch, statistical weight will be used. num_confs (int): Number of conformers to keep Returns: - info_dic (dict): Dictionary with extra conformer + info_dic (dict): Dictionary with extra conformer information about the batch """ @@ -56,11 +54,9 @@ def get_batch_dic(batch, confs_in_batch = old_num_atoms // mol_size # new number of atoms after trimming - new_num_atoms = int(mol_size * min( - confs_in_batch, num_confs)) + new_num_atoms = int(mol_size * min(confs_in_batch, num_confs)) if idx_dic is None: - assert_ordered(batch) # new number of conformers after trimming real_num_confs = min(confs_in_batch, num_confs) @@ -71,12 +67,14 @@ def get_batch_dic(batch, conf_idx = idx_dic[smiles] real_num_confs = len(conf_idx) - info_dic = {"conf_idx": conf_idx, - "real_num_confs": real_num_confs, - "old_num_atoms": old_num_atoms, - "new_num_atoms": new_num_atoms, - "confs_in_batch": confs_in_batch, - "mol_size": mol_size} + info_dic = { + "conf_idx": conf_idx, + "real_num_confs": real_num_confs, + "old_num_atoms": old_num_atoms, + "new_num_atoms": new_num_atoms, + "confs_in_batch": confs_in_batch, + "mol_size": mol_size, + } return info_dic @@ -86,7 +84,7 @@ def to_xyz_idx(batch_dic): Get the indices of the nxyz corresponding to atoms in conformers we want to keep. Args: - batch_dic (dict): Dictionary with extra conformer + batch_dic (dict): Dictionary with extra conformer information about the batch Returns: xyz_conf_all_idx (torch.LongTensor): nxyz indices of atoms @@ -108,7 +106,6 @@ def to_xyz_idx(batch_dic): # and append them to xyz_conf_all_idx for conf_num in conf_idx: - start_idx = xyz_conf_start_idx[conf_num] end_idx = xyz_conf_start_idx[conf_num + 1] full_idx = torch.arange(start_idx, end_idx) @@ -121,18 +118,15 @@ def to_xyz_idx(batch_dic): return xyz_conf_all_idx -def split_nbrs(nbrs, - mol_size, - confs_in_batch, - conf_idx): +def split_nbrs(nbrs, mol_size, confs_in_batch, conf_idx): """ - Get the indices of the neighbor list that correspond to conformers + Get the indices of the neighbor list that correspond to conformers we're keeping. Args: nbrs (torch.LongTensor): neighbor list mol_size (int): Number of atoms in each conformer confs_in_batch (int): Total number of conformers in the batch - conf_idx (list[int]): Indices of the conformers we're keeping + conf_idx (list[int]): Indices of the conformers we're keeping Returns: tens_idx (torch.LongTensor): nbr indices of conformers we're keeping. @@ -143,7 +137,6 @@ def split_nbrs(nbrs, cutoffs = [i * mol_size - 1 for i in range(1, confs_in_batch + 1)] for i in conf_idx: - # start index of the conformer start = cutoffs[i] - mol_size + 1 # end index of the conformer @@ -166,22 +159,19 @@ def to_nbr_idx(batch_dic, nbrs): """ Apply `split_nbrs` given `batch_dic` Args: - batch_dic (dict): Dictionary with extra conformer + batch_dic (dict): Dictionary with extra conformer information about the batch nbrs (torch.LongTensor): neighbor list Returns: split_nbr_idx (torch.LongTensor): nbr indices of conformers we're - keeping. + keeping. """ mol_size = batch_dic["mol_size"] confs_in_batch = batch_dic["confs_in_batch"] conf_idx = batch_dic["conf_idx"] - split_nbr_idx = split_nbrs(nbrs=nbrs, - mol_size=mol_size, - confs_in_batch=confs_in_batch, - conf_idx=conf_idx) + split_nbr_idx = split_nbrs(nbrs=nbrs, mol_size=mol_size, confs_in_batch=confs_in_batch, conf_idx=conf_idx) return split_nbr_idx @@ -190,7 +180,7 @@ def update_weights(batch, batch_dic): """ Readjust weights so they sum to 1. Args: - batch_dic (dict): Dictionary with extra conformer + batch_dic (dict): Dictionary with extra conformer information about the batch batch (dict): Batch dictionary Returns: @@ -227,10 +217,9 @@ def update_nbr_idx_keys(dset, batch, i, old_nbrs, num_confs): # make a mask for the neighbor list indices that are being kept - mol_size = batch['mol_size'] + mol_size = batch["mol_size"] for j in range(num_confs): - mask = (old_nbrs[:, 0] < (j + 1) * mol_size - ) * (old_nbrs[:, 0] >= j * mol_size) + mask = (old_nbrs[:, 0] < (j + 1) * mol_size) * (old_nbrs[:, 0] >= j * mol_size) if j == 0: total_mask = copy.deepcopy(mask) else: @@ -249,7 +238,7 @@ def update_nbr_idx_keys(dset, batch, i, old_nbrs, num_confs): def update_per_conf(dataset, i, old_num_atoms, new_n_confs): - mol_size = dataset.props["mol_size"][i] + dataset.props["mol_size"][i] for key in PER_CONF_KEYS: if key not in dataset.props: continue @@ -263,7 +252,7 @@ def update_dset(batch, batch_dic, dataset, i): number of conformers, for species at index i. Args: batch (dict): Batch dictionary - batch_dic (dict): Dictionary with extra conformer + batch_dic (dict): Dictionary with extra conformer information about the batch dataset (nff.data.dataset): NFF dataset i (int): index of the species whose info we're updating @@ -300,26 +289,18 @@ def update_dset(batch, batch_dic, dataset, i): dataset.props["atom_features"][i] = atom_feats[conf_xyz_idx] # renormalize weights - dataset.props["weights"][i] = update_weights(batch, - batch_dic) + dataset.props["weights"][i] = update_weights(batch, batch_dic) # update anything else that's a per-conformer quantity update_per_conf(dataset, i, old_num_atoms, batch_dic["real_num_confs"]) # update anything that depends on the indexing of the nbr list - update_nbr_idx_keys(dset=dataset, - batch=batch, - i=i, - old_nbrs=nbr_list, - num_confs=batch_dic["real_num_confs"]) + update_nbr_idx_keys(dset=dataset, batch=batch, i=i, old_nbrs=nbr_list, num_confs=batch_dic["real_num_confs"]) return dataset -def trim_confs(dataset, - num_confs, - idx_dic, - enum_func=None): +def trim_confs(dataset, num_confs, idx_dic, enum_func=None): """ Trim conformers for the entire dataset. Args: @@ -330,9 +311,9 @@ def trim_confs(dataset, of the conformers you want to keep. If not specified, then the top `num_confs` conformers with the highest statistical weight will be used. - enum_func (callable, optional): a function with which to + enum_func (callable, optional): a function with which to enumerate the dataset. If not given, we use tqdm - to track progress. + to track progress. Returns: dataset (nff.data.dataset): updated NFF dataset """ @@ -341,23 +322,14 @@ def trim_confs(dataset, enum_func = tqdm_enum for i, batch in tqdm_enum(dataset): + batch_dic = get_batch_dic(batch=batch, idx_dic=idx_dic, num_confs=num_confs) - batch_dic = get_batch_dic(batch=batch, - idx_dic=idx_dic, - num_confs=num_confs) - - dataset = update_dset(batch=batch, - batch_dic=batch_dic, - dataset=dataset, - i=i) + dataset = update_dset(batch=batch, batch_dic=batch_dic, dataset=dataset, i=i) return dataset -def make_split_nbrs(nbr_list, - mol_size, - num_confs, - confs_per_split): +def make_split_nbrs(nbr_list, mol_size, num_confs, confs_per_split): """ Split neighbor list of a species into chunks for each sub-batch. Args: @@ -367,10 +339,10 @@ def make_split_nbrs(nbr_list, confs_per_split (list[int]): number of conformers in each sub-batch. Returns: - all_grouped_nbrs (list[torch.LongTensor]): list of + all_grouped_nbrs (list[torch.LongTensor]): list of neighbor lists for each sub-batch. - nbr_masks (list(torch.BoolTensor))): masks that tell you which - indices of the combined neighbor list are being used for the + nbr_masks (list(torch.BoolTensor))): masks that tell you which + indices of the combined neighbor list are being used for the neighbor list of each sub-batch. """ @@ -382,8 +354,7 @@ def make_split_nbrs(nbr_list, # mask = (nbr_list[:, 0] <= (i + 1) * mol_size # ) * (nbr_list[:, 1] <= (i + 1) * mol_size) - mask = (nbr_list[:, 0] < (i + 1) * mol_size - ) * (nbr_list[:, 0] >= i * mol_size) + mask = (nbr_list[:, 0] < (i + 1) * mol_size) * (nbr_list[:, 0] >= i * mol_size) new_nbrs.append(nbr_list[mask]) masks.append(mask) @@ -394,7 +365,6 @@ def make_split_nbrs(nbr_list, sub_batch_masks = [] for i, num in enumerate(confs_per_split): - # neighbor first prev_idx = sum(confs_per_split[:i]) nbr_idx = list(range(prev_idx, prev_idx + num)) @@ -405,18 +375,13 @@ def make_split_nbrs(nbr_list, all_grouped_nbrs.append(grouped_nbrs) # then add together all the masks - mask = sum(masks[i] for i in range(prev_idx, - prev_idx + num)).to(torch.bool) + mask = sum(masks[i] for i in range(prev_idx, prev_idx + num)).to(torch.bool) sub_batch_masks.append(mask) return all_grouped_nbrs, sub_batch_masks -def add_split_nbrs(batch, - mol_size, - num_confs, - confs_per_split, - sub_batches): +def add_split_nbrs(batch, mol_size, num_confs, confs_per_split, sub_batches): """ Add split-up neighbor lists to each sub-batch. Args: @@ -429,8 +394,8 @@ def add_split_nbrs(batch, Returns: sub_batches (list[dict]): list of sub_batches updated with their neighbor lists. - nbr_masks (list(torch.BoolTensor))): masks that tell you which - indices of the combined neighbor list are being used for the + nbr_masks (list(torch.BoolTensor))): masks that tell you which + indices of the combined neighbor list are being used for the neighbor list of each sub-batch. """ @@ -443,10 +408,9 @@ def add_split_nbrs(batch, if key not in batch: continue nbr_list = batch[key] - split_nbrs, masks = make_split_nbrs(nbr_list=nbr_list, - mol_size=mol_size, - num_confs=num_confs, - confs_per_split=confs_per_split) + split_nbrs, masks = make_split_nbrs( + nbr_list=nbr_list, mol_size=mol_size, num_confs=num_confs, confs_per_split=confs_per_split + ) if key == "nbr_list": nbr_masks = masks @@ -457,9 +421,7 @@ def add_split_nbrs(batch, return sub_batches, nbr_masks -def get_confs_per_split(batch, - num_confs, - sub_batch_size): +def get_confs_per_split(batch, num_confs, sub_batch_size): """ Get the number of conformers per sub-batch. Args: @@ -474,8 +436,7 @@ def get_confs_per_split(batch, val_len = len(batch["nxyz"]) inherent_val_len = val_len // num_confs - split_list = [sub_batch_size * inherent_val_len] * math.floor( - num_confs / sub_batch_size) + split_list = [sub_batch_size * inherent_val_len] * math.floor(num_confs / sub_batch_size) # if there's a remainder @@ -487,24 +448,22 @@ def get_confs_per_split(batch, return confs_per_split -def fix_nbr_idx(batch, - masks, - sub_batches): +def fix_nbr_idx(batch, masks, sub_batches): """ Fix anything that is defined with respect to positions of pairs in a neighbor list (e.g. `bond_idx`, `kj_idx`, and `ji_idx`). Args: batch (dict): batched sample of species - masks (list(torch.BoolTensor))): masks that tell you which - indices of the combined neighbor list are being used for the + masks (list(torch.BoolTensor))): masks that tell you which + indices of the combined neighbor list are being used for the neighbor list of each sub-batch. sub_batches (list[dict]): sub batches of the batch Returns: sub_batches (list[dict]): corrected sub batches of the batch """ - old_nbr_list = batch['nbr_list'] + old_nbr_list = batch["nbr_list"] new_idx_list = [] for mask in masks: @@ -512,8 +471,7 @@ def fix_nbr_idx(batch, # make everything not in this batch equal to -1 so we # know what's actually not in this batch new_idx = -torch.ones_like(old_nbr_list)[:, 0] - new_idx[mask] = (torch.arange(num_new_nbrs) - .to(mask.device)) + new_idx[mask] = torch.arange(num_new_nbrs).to(mask.device) new_idx_list.append(new_idx) for new_idx, sub_batch in zip(new_idx_list, sub_batches): @@ -529,8 +487,7 @@ def fix_nbr_idx(batch, return sub_batches -def split_batch(batch, - sub_batch_size): +def split_batch(batch, sub_batch_size): """ Split a batch into sub-batches. Args: @@ -545,10 +502,7 @@ def split_batch(batch, num_confs = len(batch["nxyz"]) // mol_size sub_batch_dic = {} - confs_per_split = get_confs_per_split( - batch=batch, - num_confs=num_confs, - sub_batch_size=sub_batch_size) + confs_per_split = get_confs_per_split(batch=batch, num_confs=num_confs, sub_batch_size=sub_batch_size) num_splits = len(confs_per_split) @@ -559,10 +513,9 @@ def split_batch(batch, # get rid of `bond_idx` because it's wrong if key in [*NBR_IDX_KEYS, *REINDEX_KEYS]: continue - elif np.mod(val_len, num_confs) != 0 or val_len == 1: + if np.mod(val_len, num_confs) != 0 or val_len == 1: if key == "num_atoms": - sub_batch_dic[key] = [int(val * num / num_confs) - for num in confs_per_split] + sub_batch_dic[key] = [int(val * num / num_confs) for num in confs_per_split] else: sub_batch_dic[key] = [val] * num_splits continue @@ -575,26 +528,20 @@ def split_batch(batch, # use this to determine the number of items in each # section of the split list - split_list = [inherent_val_len * num - for num in confs_per_split] + split_list = [inherent_val_len * num for num in confs_per_split] # split the value accordingly split_val = torch.split(val, split_list) sub_batch_dic[key] = split_val - sub_batches = [{key: sub_batch_dic[key][i] for key in - sub_batch_dic.keys()} for i in range(num_splits)] + sub_batches = [{key: sub_batch_dic[key][i] for key in sub_batch_dic} for i in range(num_splits)] # fix neighbor list indexing - sub_batches, masks = add_split_nbrs(batch=batch, - mol_size=mol_size, - num_confs=num_confs, - confs_per_split=confs_per_split, - sub_batches=sub_batches) + sub_batches, masks = add_split_nbrs( + batch=batch, mol_size=mol_size, num_confs=num_confs, confs_per_split=confs_per_split, sub_batches=sub_batches + ) # fix anything that relies on the position of a neighbor list pair - sub_batches = fix_nbr_idx(batch=batch, - masks=masks, - sub_batches=sub_batches) + sub_batches = fix_nbr_idx(batch=batch, masks=masks, sub_batches=sub_batches) return sub_batches diff --git a/nff/utils/constants.py b/nff/utils/constants.py index 12402f9c..4238a5d9 100644 --- a/nff/utils/constants.py +++ b/nff/utils/constants.py @@ -31,36 +31,55 @@ 16: 32.06, } + AU_TO_KCAL = { "energy": HARTREE_TO_KCAL_MOL, "_grad": 1.0 / BOHR_RADIUS, + "stress": HARTREE_TO_KCAL_MOL * ((1.0 / BOHR_RADIUS) ** 3), + "_volume": 1 / ((1.0 / BOHR_RADIUS) ** 3), } + AU_TO_EV = { "energy": HARTREE_TO_EV, "_grad": 1.0 / BOHR_RADIUS, + "stress": HARTREE_TO_EV * ((1.0 / BOHR_RADIUS) ** 3), + "_volume": 1 / ((1.0 / BOHR_RADIUS) ** 3), } + EV_TO_AU = { "energy": 1.0 / HARTREE_TO_EV, "_grad": BOHR_RADIUS, + "stress": (1.0 / HARTREE_TO_EV) * (BOHR_RADIUS**3), + "_volume": 1 / (BOHR_RADIUS**3), } + EV_TO_KCAL = { "energy": EV_TO_KCAL_MOL, "_grad": 1.0, + "stress": EV_TO_KCAL_MOL * (1.0**3), + "_volume": 1 / (1.0**3), } + KCAL_TO_AU = { "energy": 1.0 / HARTREE_TO_KCAL_MOL, "_grad": BOHR_RADIUS, + "stress": (1.0 / HARTREE_TO_KCAL_MOL) * (BOHR_RADIUS**3), + "_volume": 1 / (BOHR_RADIUS**3), } + KCAL_TO_EV = { "energy": 1.0 / EV_TO_KCAL_MOL, "_grad": 1.0, + "stress": (1.0 / EV_TO_KCAL_MOL) * (1.0**3), + "_volume": 1 / (1.0**3), } + DEFAULT = { "energy": 1.0, "_grad": 1.0, @@ -225,11 +244,7 @@ def exc_ev_to_hartree(props: dict, add_ground_energy: bool = False) -> dict: new_props (dict): dictionary with properties converted. """ assert "energy_0" in props - exc_keys = [ - key - for key in props - if key.startswith("energy") and "grad" not in key and key != "energy_0" - ] + exc_keys = [key for key in props if key.startswith("energy") and "grad" not in key and key != "energy_0"] energy_0 = props["energy_0"] new_props = copy.deepcopy(props) diff --git a/nff/utils/cuda.py b/nff/utils/cuda.py index 219fdc21..093aeec3 100644 --- a/nff/utils/cuda.py +++ b/nff/utils/cuda.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Dict, List, Union +from typing import Dict, List import numpy as np import nvidia_smi @@ -25,7 +25,7 @@ def batch_to(batch: Dict[str, list | torch.Tensor], device: str) -> Dict[str, Li return gpu_batch -def detach(val: torch.Tensor, to_numpy: bool = False) -> Union[torch.Tensor, np.ndarray]: +def detach(val: torch.Tensor, to_numpy: bool = False) -> torch.Tensor | np.ndarray: """Detach GPU tensor Args: @@ -40,9 +40,7 @@ def detach(val: torch.Tensor, to_numpy: bool = False) -> Union[torch.Tensor, np. return val.detach().cpu() if hasattr(val, "detach") else val -def batch_detach( - batch: Dict[str, Union[List, torch.Tensor]], to_numpy: bool = False -) -> Dict[str, Union[List, torch.Tensor]]: +def batch_detach(batch: Dict[str, List | torch.Tensor], to_numpy: bool = False) -> Dict[str, List | torch.Tensor]: """Detach batch of GPU tensors Args: @@ -66,8 +64,8 @@ def batch_detach( def to_cpu( - batch: Dict[str, Union[List, torch.Tensor]], -) -> Dict[str, Union[List, torch.Tensor]]: + batch: Dict[str, List | torch.Tensor], +) -> Dict[str, List | torch.Tensor]: """Send batch to CPU Args: @@ -118,5 +116,4 @@ def get_final_device(device: str) -> str: return f"cuda:{cuda_devices_sorted_by_free_mem()[-1]}" except nvidia_smi.NVMLError: return "cuda:0" - return f"cuda:{cuda_devices_sorted_by_free_mem()[-1]}" return "cpu" diff --git a/nff/utils/dispersion.py b/nff/utils/dispersion.py index 403e107b..75ad4b1b 100644 --- a/nff/utils/dispersion.py +++ b/nff/utils/dispersion.py @@ -7,23 +7,22 @@ https://github.com/MMunibas/PhysNet/tree/master/neural_network/grimme_d3/tables """ +import json import os + import numpy as np import torch -import json +from ase import Atoms +from ase.calculators.dftd3 import DFTD3 +from nff.nn.utils import clean_matrix, lattice_points_in_supercell from nff.utils import constants as const -from nff.nn.utils import lattice_points_in_supercell, clean_matrix from nff.utils.scatter import scatter_add -from ase import Atoms -from ase.calculators.dftd3 import DFTD3 - -base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), - 'table_data') -c6_ref_path = os.path.join(base_dir, 'c6ab.npy') -r_cov_path = os.path.join(base_dir, 'rcov.npy') -r2r4_path = os.path.join(base_dir, 'r2r4.npy') +base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "table_data") +c6_ref_path = os.path.join(base_dir, "c6ab.npy") +r_cov_path = os.path.join(base_dir, "rcov.npy") +r2r4_path = os.path.join(base_dir, "r2r4.npy") func_path = os.path.join(base_dir, "functional_params.json") # reference C6 data for pairs of atom types AB in different reference systems @@ -43,30 +42,24 @@ FUNC_PARAMS = json.load(f) -def get_periodic_nbrs(batch, - xyz, - r_cut=95, - nbrs_info=None, - mol_idx=None): +def get_periodic_nbrs(batch, xyz, r_cut=95, nbrs_info=None, mol_idx=None): """ Get the neighbor list connecting every atom to its neighbor within a given geometry, but not to itself or to atoms in other geometries. - Since this is for perodic systems it also requires getting all possible + Since this is for periodic systems it also requires getting all possible lattice translation vectors. """ device = xyz.device - num_atoms = batch['num_atoms'] + num_atoms = batch["num_atoms"] if not isinstance(num_atoms, list): num_atoms = num_atoms.tolist() if nbrs_info is None: - - nxyz_list = torch.split(batch['nxyz'], num_atoms) + nxyz_list = torch.split(batch["nxyz"], num_atoms) xyzs = torch.split(xyz, num_atoms) - nbrs = [] nbrs_T = [] nbrs = [] z = [] @@ -78,22 +71,14 @@ def get_periodic_nbrs(batch, num_atoms = [] for _xyz, nxyz in zip(xyzs, nxyz_list): # only works if the cell for all crystals in batch are the same - cell = batch['cell'].cpu().numpy() + cell = batch["cell"].cpu().numpy() # cutoff specified by r_cut in Bohr (a.u.) # estimate getting close to the cutoff with supercell expansion - a_mul = int(np.ceil( - (r_cut*const.BOHR_RADIUS) / np.linalg.norm(cell[0]) - )) - b_mul = int(np.ceil( - (r_cut*const.BOHR_RADIUS) / np.linalg.norm(cell[1]) - )) - c_mul = int(np.ceil( - (r_cut*const.BOHR_RADIUS) / np.linalg.norm(cell[2]) - )) - supercell_matrix = np.array([[a_mul, 0, 0], - [0, b_mul, 0], - [0, 0, c_mul]]) + a_mul = int(np.ceil((r_cut * const.BOHR_RADIUS) / np.linalg.norm(cell[0]))) + b_mul = int(np.ceil((r_cut * const.BOHR_RADIUS) / np.linalg.norm(cell[1]))) + c_mul = int(np.ceil((r_cut * const.BOHR_RADIUS) / np.linalg.norm(cell[2]))) + supercell_matrix = np.array([[a_mul, 0, 0], [0, b_mul, 0], [0, 0, c_mul]]) supercell = clean_matrix(supercell_matrix @ cell) # cartesian lattice points @@ -102,36 +87,27 @@ def get_periodic_nbrs(batch, # need to get all negative lattice translation vectors # but remove duplicate 0 vector - zero_idx = np.where( - np.all(_lattice_points.__eq__(np.array([0,0,0])), - axis=1) - )[0][0] - _lattice_points = np.concatenate([_lattice_points[zero_idx:, :], - _lattice_points[:zero_idx, :]]) - - _z = nxyz[:,0].long().to(device) + zero_idx = np.where(np.all(_lattice_points.__eq__(np.array([0, 0, 0])), axis=1))[0][0] + _lattice_points = np.concatenate([_lattice_points[zero_idx:, :], _lattice_points[:zero_idx, :]]) + + _z = nxyz[:, 0].long().to(device) _N = len(_lattice_points) # perform lattice translations on positions - lattice_points_T = (torch.tile( - torch.from_numpy(_lattice_points), - ( (len(_xyz),) + - (1,)*(len(_lattice_points.shape)-1) ) - ) / const.BOHR_RADIUS).to(device) - _xyz_T = ((torch.repeat_interleave(_xyz, _N, dim=0) - / const.BOHR_RADIUS).to(device)) + lattice_points_T = ( + torch.tile(torch.from_numpy(_lattice_points), ((len(_xyz),) + (1,) * (len(_lattice_points.shape) - 1))) + / const.BOHR_RADIUS + ).to(device) + _xyz_T = (torch.repeat_interleave(_xyz, _N, dim=0) / const.BOHR_RADIUS).to(device) _xyz_T = _xyz_T + lattice_points_T # get valid indices within the cutoff num = _xyz.shape[0] idx = torch.arange(num) x, y = torch.meshgrid(idx, idx) - _nbrs = torch.cat([x.reshape(-1, 1), y.reshape(-1, 1)], - dim=1).to(device) - _lattice_points = (torch.tile( - torch.from_numpy(_lattice_points).to(device), - ( (len(_nbrs),) + - (1,)*(len(_lattice_points.shape)-1) ) - ) ) + _nbrs = torch.cat([x.reshape(-1, 1), y.reshape(-1, 1)], dim=1).to(device) + _lattice_points = torch.tile( + torch.from_numpy(_lattice_points).to(device), ((len(_nbrs),) + (1,) * (len(_lattice_points.shape) - 1)) + ) # convert everything from Angstroms to Bohr _xyz = _xyz / const.BOHR_RADIUS @@ -141,10 +117,10 @@ def get_periodic_nbrs(batch, # ensure that A != B when T=0 # since first index in _lattice_points corresponds to T=0 # get the idxs on which to apply the mask - idxs_to_apply = torch.tensor([True]*len(_nbrs_T)).to(device) + idxs_to_apply = torch.tensor([True] * len(_nbrs_T)).to(device) idxs_to_apply[::_N] = False # get the mask that we want to apply - mask = _nbrs_T[:,0] != _nbrs_T[:,1] + mask = _nbrs_T[:, 0] != _nbrs_T[:, 1] # do a joint boolean operation to get the mask _mask_applied = torch.logical_or(idxs_to_apply, mask) _nbrs_T = _nbrs_T[_mask_applied] @@ -162,42 +138,31 @@ def get_periodic_nbrs(batch, num_atoms.append(len(_xyz)) else: - nxyz_list = torch.split(batch['nxyz'], num_atoms) + nxyz_list = torch.split(batch["nxyz"], num_atoms) xyzs = torch.split(xyz, num_atoms) nbrs_T, nbrs, z, N, lattice_points, mask_applied = nbrs_info _xyzs = [] num_atoms = [] - for _xyz, nxyz in zip(xyzs, nxyz_list): - _xyz = _xyz / const.BOHR_RADIUS # convert to Bohr + for _xyz, nxyz in zip(xyzs, nxyz_list): # noqa + _xyz = _xyz / const.BOHR_RADIUS # convert to Bohr _xyzs.append(_xyz) num_atoms.append(len(_xyz)) if mol_idx is None: - mol_idx = torch.cat([torch.zeros(num) + i - for i, num in enumerate(num_atoms)] - ).long().to(_xyz.device) - - return nbrs_T, nbrs, z, _xyzs, N, lattice_points, mask_applied, \ - r_cut, mol_idx - - -def get_periodic_coordination(xyz, - z, - nbrs_T, - lattice_points, - r_cov, - k1, - k2, - cn_cut=40): + mol_idx = torch.cat([torch.zeros(num) + i for i, num in enumerate(num_atoms)]).long().to(_xyz.device) + + return nbrs_T, nbrs, z, _xyzs, N, lattice_points, mask_applied, r_cut, mol_idx + + +def get_periodic_coordination(xyz, z, nbrs_T, lattice_points, r_cov, k1, k2, cn_cut=40): """ Get coordination numbers for each atom in periodic system """ # r_ij with all lattice translation vectors # vector btwn pairs of atoms - r_ij_T = ( (xyz[nbrs_T[:, 0]] - xyz[nbrs_T[:, 1]]) - + lattice_points ).to(xyz.device) + r_ij_T = ((xyz[nbrs_T[:, 0]] - xyz[nbrs_T[:, 1]]) + lattice_points).to(xyz.device) # r_ab with all lattice translations # distance (scalar) btwn pairs of atoms @@ -209,36 +174,27 @@ def get_periodic_coordination(xyz, nbrs_T_cn = nbrs_T[r_ab_T < cn_cut] r_ab_T_cn = r_ab_T[r_ab_T < cn_cut] - # calculate covalent radii (for coordination number calculation) ra_cov_T = r_cov[z[nbrs_T_cn[:, 0]]].to(r_ab_T.device) rb_cov_T = r_cov[z[nbrs_T_cn[:, 1]]].to(r_ab_T.device) - cn_ab_T = ((1 / (1 + torch.exp( - -k1 * (k2 * (ra_cov_T + rb_cov_T) / r_ab_T_cn - 1)))) - .to(r_ab_T.device)) - cn = scatter_add(cn_ab_T, - nbrs_T_cn[:, 0], - dim_size=xyz.shape[0]) + cn_ab_T = (1 / (1 + torch.exp(-k1 * (k2 * (ra_cov_T + rb_cov_T) / r_ab_T_cn - 1)))).to(r_ab_T.device) + cn = scatter_add(cn_ab_T, nbrs_T_cn[:, 0], dim_size=xyz.shape[0]) return r_ab_T, r_ij_T, cn -def get_nbrs(batch, - xyz, - nbrs=None, - mol_idx=None): +def get_nbrs(batch, xyz, nbrs=None, mol_idx=None): """ Get the directed neighbor list connecting every atom to its neighbor within a given geometry, but not to itself or to atoms in other geometries. """ - num_atoms = batch['num_atoms'] + num_atoms = batch["num_atoms"] if not isinstance(num_atoms, list): num_atoms = num_atoms.tolist() if nbrs is None: - - nxyz_list = torch.split(batch['nxyz'], num_atoms) + nxyz_list = torch.split(batch["nxyz"], num_atoms) counter = 0 nbrs = [] @@ -257,45 +213,31 @@ def get_nbrs(batch, nbrs = torch.cat(nbrs).to(xyz.device) - z = batch['nxyz'][:, 0].long().to(xyz.device) + z = batch["nxyz"][:, 0].long().to(xyz.device) if mol_idx is None: - mol_idx = torch.cat([torch.zeros(num) + i - for i, num in enumerate(num_atoms)] - ).long().to(xyz.device) + mol_idx = torch.cat([torch.zeros(num) + i for i, num in enumerate(num_atoms)]).long().to(xyz.device) return nbrs, mol_idx, z -def get_coordination(xyz, - z, - nbrs, - r_cov, - k1, - k2): +def get_coordination(xyz, z, nbrs, r_cov, k1, k2): """ Get coordination numbers for each atom """ # distances in Bohr radii (atomic units) - r_ab = ((xyz[nbrs[:, 0]] - xyz[nbrs[:, 1]]) - .pow(2).sum(1).sqrt() / const.BOHR_RADIUS) + r_ab = (xyz[nbrs[:, 0]] - xyz[nbrs[:, 1]]).pow(2).sum(1).sqrt() / const.BOHR_RADIUS ra_cov = r_cov[z[nbrs[:, 0]]].to(r_ab.device) rb_cov = r_cov[z[nbrs[:, 1]]].to(r_ab.device) cn_ab = 1 / (1 + torch.exp(-k1 * (k2 * (ra_cov + rb_cov) / r_ab - 1))) - cn = scatter_add(cn_ab, - nbrs[:, 0], - dim_size=xyz.shape[0]) + cn = scatter_add(cn_ab, nbrs[:, 0], dim_size=xyz.shape[0]) return cn, r_ab -def get_c6(z, - cn, - nbrs, - c6_ref, - k3): +def get_c6(z, cn, nbrs, c6_ref, k3): """ Get the C6 parameter for each atom pair """ @@ -308,15 +250,13 @@ def get_c6(z, cn_a = c6ab_ref[..., 1] cn_b = c6ab_ref[..., 2] - r = ((cn_a - cn_a_i.reshape(-1, 1, 1)) ** 2 + - (cn_b - cn_b_j.reshape(-1, 1, 1)) ** 2) + r = (cn_a - cn_a_i.reshape(-1, 1, 1)) ** 2 + (cn_b - cn_b_j.reshape(-1, 1, 1)) ** 2 l_ij = torch.zeros_like(r) # exclude any info that doesn't exist for this (i, j) combination -- # signified in the tables by c6_ab_ref = -1 - valid_idx = torch.bitwise_and(torch.bitwise_and(cn_a >= 0, cn_b >= 0), - c6_ab_ref_ij >= 0) + valid_idx = torch.bitwise_and(torch.bitwise_and(cn_a >= 0, cn_b >= 0), c6_ab_ref_ij >= 0) l_ij[valid_idx] = torch.exp(-k3 * r[valid_idx]) w = l_ij.sum((1, 2)) @@ -326,10 +266,7 @@ def get_c6(z, return c6 -def get_c8(z, - nbrs, - c6, - r2r4): +def get_c8(z, nbrs, c6, r2r4): """ Get the C6 parameter for each atom pair """ @@ -339,108 +276,78 @@ def get_c8(z, return c8 -def disp_from_data(r_ab, - c6, - c8, - s6, - s8, - a1, - a2, - xyz, - nbrs, - mol_idx): - +def disp_from_data(r_ab, c6, c8, s6, s8, a1, a2, xyz, nbrs, mol_idx): r0_ab = (c8 / c6) ** 0.5 f = a1 * r0_ab + a2 - e_ab = -1 / 2 * (s6 * c6 / (r_ab ** 6 + f ** 6) + - s8 * c8 / (r_ab ** 8 + f ** 8)) + e_ab = -1 / 2 * (s6 * c6 / (r_ab**6 + f**6) + s8 * c8 / (r_ab**8 + f**8)) - e_a = scatter_add(e_ab, - nbrs[:, 0], - dim_size=xyz.shape[0]) + e_a = scatter_add(e_ab, nbrs[:, 0], dim_size=xyz.shape[0]) - e_disp = scatter_add(e_a, - mol_idx, - dim_size=int(1 + mol_idx.max())) + e_disp = scatter_add(e_a, mol_idx, dim_size=int(1 + mol_idx.max())) return e_disp -def get_func_info(functional, - disp_type, - func_params): - +def get_func_info(functional, disp_type, func_params): msg = "Parameters not present for dispersion type %s" % disp_type func_params = {key.lower(): val for key, val in func_params.items()} assert disp_type.lower() in func_params, msg - msg = ("Parameters not present for functional %s with dispersion type %s" - % (functional, disp_type)) + msg = "Parameters not present for functional %s with dispersion type %s" % (functional, disp_type) - sub_params = {key.lower(): val for key, val in - func_params[disp_type.lower()].items()} + sub_params = {key.lower(): val for key, val in func_params[disp_type.lower()].items()} assert functional.lower() in sub_params, msg all_params = sub_params[functional.lower()] - all_params.update(sub_params['universal']) + all_params.update(sub_params["universal"]) return all_params -def get_dispersion(batch, - xyz, - disp_type, - functional, - c6_ref=C6_REF, - r_cov=R_COV, - r2r4=R2R4, - func_params=FUNC_PARAMS, - nbrs=None, - mol_idx=None): - - params = get_func_info(functional=functional, - disp_type=disp_type, - func_params=func_params) - - periodic = (batch.get('cell',None) is not None) +def get_dispersion( + batch, + xyz, + disp_type, + functional, + c6_ref=C6_REF, + r_cov=R_COV, + r2r4=R2R4, + func_params=FUNC_PARAMS, + nbrs=None, + mol_idx=None, +): + params = get_func_info(functional=functional, disp_type=disp_type, func_params=func_params) + + periodic = batch.get("cell", None) is not None device = xyz.device if periodic: - (nbrs_T, nbrs, z, _xyzs, N, - lattice_points, mask_applied, - r_cut, mol_idx) = get_periodic_nbrs(batch=batch, - xyz=xyz, - nbrs_info=nbrs, - mol_idx=mol_idx) + (nbrs_T, nbrs, z, _xyzs, N, lattice_points, mask_applied, r_cut, mol_idx) = get_periodic_nbrs( + batch=batch, xyz=xyz, nbrs_info=nbrs, mol_idx=mol_idx + ) r_ij_T = [] c6 = [] c8 = [] filtered_nbrs_T = [] - for _nbrs_T, _nbrs, _z, _xyz, _N, _lattice_points, _mask_applied \ - in zip(nbrs_T, nbrs, z, _xyzs, N, lattice_points, mask_applied): + for _nbrs_T, _nbrs, _z, _xyz, _N, _lattice_points, _mask_applied in zip( + nbrs_T, nbrs, z, _xyzs, N, lattice_points, mask_applied + ): _r_ab_T, _r_ij_T, cn = get_periodic_coordination( - xyz=_xyz, - z=_z, - nbrs_T=_nbrs_T, - lattice_points=_lattice_points, - r_cov=r_cov, - k1=params["k1"], - k2=params["k2"] - ) - - _c6 = get_c6(z=_z, - cn=cn, - nbrs=_nbrs, - c6_ref=c6_ref, - k3=params["k3"]) - - _c8 = get_c8(z=_z, - nbrs=_nbrs, - c6=_c6, - r2r4=r2r4) + xyz=_xyz, + z=_z, + nbrs_T=_nbrs_T, + lattice_points=_lattice_points, + r_cov=r_cov, + k1=params["k1"], + k2=params["k2"], + ) + + _c6 = get_c6(z=_z, cn=cn, nbrs=_nbrs, c6_ref=c6_ref, k3=params["k3"]) + + _c8 = get_c8(z=_z, nbrs=_nbrs, c6=_c6, r2r4=r2r4) # get original pairwise interactions from within unit cell # change shape of all tensors to account for the fake expansion @@ -466,62 +373,52 @@ def get_dispersion(batch, c8 = torch.cat(c8) mask_applied = torch.cat(mask_applied).to(device) - + count = 0 counter = [] for _xyz in _xyzs: counter.append(count) - count+=len(_xyz) + count += len(_xyz) - filtered_nbrs_T = [_nbrs_T+count for _nbrs_T, count - in zip(filtered_nbrs_T, counter)] + filtered_nbrs_T = [_nbrs_T + count for _nbrs_T, count in zip(filtered_nbrs_T, counter)] nbrs_T = torch.cat(filtered_nbrs_T).to(device) xyzs = torch.cat(_xyzs).to(device) - e_disp=disp_from_data(r_ab=r_ab_T, - c6=c6, - c8=c8, - s6=params["s6"], - s8=params["s8"], - a1=params["a1"], - a2=params["a2"], - xyz=xyzs, - nbrs=nbrs_T, - mol_idx=mol_idx) + e_disp = disp_from_data( + r_ab=r_ab_T, + c6=c6, + c8=c8, + s6=params["s6"], + s8=params["s8"], + a1=params["a1"], + a2=params["a2"], + xyz=xyzs, + nbrs=nbrs_T, + mol_idx=mol_idx, + ) else: - nbrs, mol_idx, z = get_nbrs(batch=batch, - xyz=xyz, - nbrs=nbrs, - mol_idx=mol_idx) - cn, r_ab = get_coordination(xyz=xyz, - z=z, - nbrs=nbrs, - r_cov=r_cov.to(xyz.device), - k1=params["k1"], - k2=params["k2"]) - - c6 = get_c6(z=z, - cn=cn, - nbrs=nbrs, - c6_ref=c6_ref.to(xyz.device), - k3=params["k3"]) - - c8 = get_c8(z=z, - nbrs=nbrs, - c6=c6, - r2r4=r2r4) - - e_disp = disp_from_data(r_ab=r_ab, - c6=c6, - c8=c8, - s6=params["s6"], - s8=params["s8"], - a1=params["a1"], - a2=params["a2"], - xyz=xyz, - nbrs=nbrs, - mol_idx=mol_idx) + nbrs, mol_idx, z = get_nbrs(batch=batch, xyz=xyz, nbrs=nbrs, mol_idx=mol_idx) + cn, r_ab = get_coordination( + xyz=xyz, z=z, nbrs=nbrs, r_cov=r_cov.to(xyz.device), k1=params["k1"], k2=params["k2"] + ) + + c6 = get_c6(z=z, cn=cn, nbrs=nbrs, c6_ref=c6_ref.to(xyz.device), k3=params["k3"]) + + c8 = get_c8(z=z, nbrs=nbrs, c6=c6, r2r4=r2r4) + + e_disp = disp_from_data( + r_ab=r_ab, + c6=c6, + c8=c8, + s6=params["s6"], + s8=params["s8"], + a1=params["a1"], + a2=params["a2"], + xyz=xyz, + nbrs=nbrs, + mol_idx=mol_idx, + ) r_ij_T = None nbrs_T = None @@ -529,14 +426,16 @@ def get_dispersion(batch, def grimme_dispersion(batch, xyz, disp_type, functional): - - d3 = DFTD3(xc='pbe', damping='bj', grad=True) - atoms = Atoms(cell=batch.get('cell',None).detach().cpu().numpy(), - numbers=batch['nxyz'][:, 0].detach().cpu().numpy(), - positions=xyz.detach().cpu().numpy(), pbc=True) + d3 = DFTD3(xc="pbe", damping="bj", grad=True) + atoms = Atoms( + cell=batch.get("cell", None).detach().cpu().numpy(), + numbers=batch["nxyz"][:, 0].detach().cpu().numpy(), + positions=xyz.detach().cpu().numpy(), + pbc=True, + ) atoms.calc = d3 e_disp = atoms.get_potential_energy() stress_disp = atoms.get_stress(voigt=False) forces_disp = atoms.get_forces() - return e_disp, stress_disp, forces_disp \ No newline at end of file + return e_disp, stress_disp, forces_disp diff --git a/nff/utils/fast_attention.py b/nff/utils/fast_attention.py index e9f24b08..a54c3a9a 100644 --- a/nff/utils/fast_attention.py +++ b/nff/utils/fast_attention.py @@ -1,10 +1,8 @@ -import torch import numpy as np +import torch -def make_w(feat_dim, - rand_dim): - +def make_w(feat_dim, rand_dim): w = np.random.rand(rand_dim, feat_dim) q, r = np.linalg.qr(w) iid = np.random.randn(q.shape[1]).reshape(1, -1) @@ -13,12 +11,10 @@ def make_w(feat_dim, return orth -def phi_pos(w, - x): - +def phi_pos(w, x): rand_dim = w.shape[0] - h = torch.exp(-(x ** 2).sum(-1) / 2) - pref = h / rand_dim ** 0.5 + h = torch.exp(-(x**2).sum(-1) / 2) + pref = h / rand_dim**0.5 arg = torch.exp(torch.matmul(w, x)) out = pref * arg diff --git a/nff/utils/functions.py b/nff/utils/functions.py index 558ff398..b287e92b 100644 --- a/nff/utils/functions.py +++ b/nff/utils/functions.py @@ -1,30 +1,32 @@ """ Special functions for DimeNet and SpookyNet. -Dimenet functions taken directly from +Dimenet functions taken directly from https://github.com/klicperajo/ dimenet/blob/master/dimenet/model/ layers/basis_utils.py. """ -import numpy as np -from scipy.optimize import brentq -from scipy import special as sp -import sympy as sym +# ruff: noqa: E741 import copy -import torch import math +import numpy as np +import sympy as sym +import torch +from scipy import special as sp +from scipy.optimize import brentq EPS = 1e-15 # DimeNet + def Jn(r, n): """ numerical spherical bessel functions of order n """ - return np.sqrt(np.pi/(2*r)) * sp.jv(n+0.5, r) + return np.sqrt(np.pi / (2 * r)) * sp.jv(n + 0.5, r) def Jn_zeros(n, k): @@ -49,13 +51,13 @@ def spherical_bessel_formulas(n): """ Computes the sympy formulas for the spherical bessel functions up to order n (excluded) """ - x = sym.symbols('x') + x = sym.symbols("x") - f = [sym.sin(x)/x] - a = sym.sin(x)/x + f = [sym.sin(x) / x] + a = sym.sin(x) / x for i in range(1, n): - b = sym.diff(a, x)/x - f += [sym.simplify(b*(-x)**i)] + b = sym.diff(a, x) / x + f += [sym.simplify(b * (-x) ** i)] a = sym.simplify(b) return f @@ -71,19 +73,17 @@ def bessel_basis(n, k): for order in range(n): normalizer_tmp = [] for i in range(k): - normalizer_tmp += [0.5*Jn(zeros[order, i], order+1)**2] - normalizer_tmp = 1/np.array(normalizer_tmp)**0.5 + normalizer_tmp += [0.5 * Jn(zeros[order, i], order + 1) ** 2] + normalizer_tmp = 1 / np.array(normalizer_tmp) ** 0.5 normalizer += [normalizer_tmp] f = spherical_bessel_formulas(n) - x = sym.symbols('x') + x = sym.symbols("x") bess_basis = [] for order in range(n): bess_basis_tmp = [] for i in range(k): - bess_basis_tmp += [sym.simplify(normalizer[order] - [i]*f[order].subs( - x, zeros[order, i]*x))] + bess_basis_tmp += [sym.simplify(normalizer[order][i] * f[order].subs(x, zeros[order, i] * x))] bess_basis += [bess_basis_tmp] return bess_basis @@ -95,40 +95,36 @@ def sph_harm_prefactor(l, m): l: int, l>=0 m: int, -l<=m<=l """ - return ((2*l+1) * np.math.factorial(l-abs(m)) - / (4*np.pi*np.math.factorial(l+abs(m))))**0.5 + return ((2 * l + 1) * np.math.factorial(l - abs(m)) / (4 * np.pi * np.math.factorial(l + abs(m)))) ** 0.5 def associated_legendre_polynomials(l, zero_m_only=True): """ Computes sympy formulas of the associated legendre polynomials up to order l (excluded). """ - z = sym.symbols('z') - P_l_m = [[0]*(j+1) for j in range(l)] + z = sym.symbols("z") + P_l_m = [[0] * (j + 1) for j in range(l)] P_l_m[0][0] = 1 if l > 0: P_l_m[1][0] = z for j in range(2, l): - P_l_m[j][0] = sym.simplify( - ((2*j-1)*z*P_l_m[j-1][0] - (j-1)*P_l_m[j-2][0])/j) + P_l_m[j][0] = sym.simplify(((2 * j - 1) * z * P_l_m[j - 1][0] - (j - 1) * P_l_m[j - 2][0]) / j) if not zero_m_only: for i in range(1, l): - P_l_m[i][i] = sym.simplify((1-2*i)*P_l_m[i-1][i-1]) + P_l_m[i][i] = sym.simplify((1 - 2 * i) * P_l_m[i - 1][i - 1]) if i + 1 < l: - P_l_m[i+1][i] = sym.simplify((2*i+1)*z*P_l_m[i][i]) + P_l_m[i + 1][i] = sym.simplify((2 * i + 1) * z * P_l_m[i][i]) for j in range(i + 2, l): P_l_m[j][i] = sym.simplify( - ((2*j-1) * z * P_l_m[j-1][i] - - (i+j-1) * P_l_m[j-2][i]) / (j - i)) + ((2 * j - 1) * z * P_l_m[j - 1][i] - (i + j - 1) * P_l_m[j - 2][i]) / (j - i) + ) return P_l_m -def real_sph_harm(l, - zero_m_only=True, - spherical_coordinates=True): +def real_sph_harm(l, zero_m_only=True, spherical_coordinates=True): """ Computes formula strings of the the real part of the spherical harmonics up to order l (excluded). Variables are either cartesian coordinates x,y,z on the unit sphere or spherical coordinates phi and theta. @@ -137,58 +133,52 @@ def real_sph_harm(l, S_m = [0] C_m = [1] for i in range(1, l): - x = sym.symbols('x') - y = sym.symbols('y') - S_m += [x*S_m[i-1] + y*C_m[i-1]] - C_m += [x*C_m[i-1] - y*S_m[i-1]] + x = sym.symbols("x") + y = sym.symbols("y") + S_m += [x * S_m[i - 1] + y * C_m[i - 1]] + C_m += [x * C_m[i - 1] - y * S_m[i - 1]] P_l_m = associated_legendre_polynomials(l, zero_m_only) if spherical_coordinates: - theta = sym.symbols('theta') - z = sym.symbols('z') + theta = sym.symbols("theta") + z = sym.symbols("z") for i in range(len(P_l_m)): for j in range(len(P_l_m[i])): - if type(P_l_m[i][j]) != int: + if not isinstance(P_l_m[i][j], int): P_l_m[i][j] = P_l_m[i][j].subs(z, sym.cos(theta)) if not zero_m_only: - phi = sym.symbols('phi') + phi = sym.symbols("phi") for i in range(len(S_m)): - S_m[i] = S_m[i].subs(x, sym.sin( - theta)*sym.cos(phi)).subs(y, sym.sin(theta)*sym.sin(phi)) + S_m[i] = S_m[i].subs(x, sym.sin(theta) * sym.cos(phi)).subs(y, sym.sin(theta) * sym.sin(phi)) for i in range(len(C_m)): - C_m[i] = C_m[i].subs(x, sym.sin( - theta)*sym.cos(phi)).subs(y, sym.sin(theta)*sym.sin(phi)) + C_m[i] = C_m[i].subs(x, sym.sin(theta) * sym.cos(phi)).subs(y, sym.sin(theta) * sym.sin(phi)) - Y_func_l_m = [['0']*(2*j + 1) for j in range(l)] + Y_func_l_m = [["0"] * (2 * j + 1) for j in range(l)] for i in range(l): Y_func_l_m[i][0] = sym.simplify(sph_harm_prefactor(i, 0) * P_l_m[i][0]) if not zero_m_only: for i in range(1, l): for j in range(1, i + 1): - Y_func_l_m[i][j] = sym.simplify( - 2**0.5 * sph_harm_prefactor(i, j) * C_m[j] * P_l_m[i][j]) + Y_func_l_m[i][j] = sym.simplify(2**0.5 * sph_harm_prefactor(i, j) * C_m[j] * P_l_m[i][j]) for i in range(1, l): for j in range(1, i + 1): - Y_func_l_m[i][-j] = sym.simplify( - 2**0.5 * sph_harm_prefactor(i, -j) * S_m[j] * P_l_m[i][j]) + Y_func_l_m[i][-j] = sym.simplify(2**0.5 * sph_harm_prefactor(i, -j) * S_m[j] * P_l_m[i][j]) return Y_func_l_m # SpookyNet + def A_m(x, y, m): device = x.device - p_vals = torch.arange(0, m + 1, - device=device) + p_vals = torch.arange(0, m + 1, device=device) q_vals = m - p_vals x_p = x.reshape(-1, 1) ** p_vals y_q = y.reshape(-1, 1) ** q_vals sin = torch.sin(np.pi / 2 * (m - p_vals)) - binoms = (torch.Tensor([sp.binom(m, int(p)) - for p in p_vals]) - .to(device)) + binoms = torch.Tensor([sp.binom(m, int(p)) for p in p_vals]).to(device) out = (binoms * x_p * y_q * sin).sum(-1) return out @@ -196,26 +186,26 @@ def A_m(x, y, m): def B_m(x, y, m): device = x.device - p_vals = torch.arange(0, m + 1, - device=device) + p_vals = torch.arange(0, m + 1, device=device) q_vals = m - p_vals x_p = x.reshape(-1, 1) ** p_vals y_q = y.reshape(-1, 1) ** q_vals cos = torch.cos(np.pi / 2 * (m - p_vals)) - binoms = (torch.Tensor([sp.binom(m, int(p)) for p in p_vals]) - .to(device)) + binoms = torch.Tensor([sp.binom(m, int(p)) for p in p_vals]).to(device) out = (binoms * x_p * y_q * cos).sum(-1) return out def c_plm(p, l, m): - terms = [(-1) ** p, - 1 / (2 ** l), - sp.binom(l, p), - sp.binom(2 * l - 2 * p, l), - sp.factorial(l - 2 * p), - 1 / sp.factorial(l - 2 * p - m)] + terms = [ + (-1) ** p, + 1 / (2**l), + sp.binom(l, p), + sp.binom(2 * l - 2 * p, l), + sp.factorial(l - 2 * p), + 1 / sp.factorial(l - 2 * p - m), + ] out = torch.Tensor(terms).prod() return out @@ -223,27 +213,18 @@ def c_plm(p, l, m): def make_c_table(l_max): c_table = {} for l in range(l_max + 1): - for m in range(-l, l+1): - for p in range(0, math.floor((l - m) / 2) + 1): + for m in range(-l, l + 1): + for p in range(math.floor((l - m) / 2) + 1): c_table[(p, l, m)] = c_plm(p, l, m) return c_table -def pi_l_m(r, - z, - l, - m, - c_table): - +def pi_l_m(r, z, l, m, c_table): device = r.device pref = (sp.factorial(l - m) / sp.factorial(l + m)) ** 0.5 - p_vals = (torch.arange(0, math.floor((l - m) / 2) + 1, - device=device, - dtype=torch.float)) + p_vals = torch.arange(0, math.floor((l - m) / 2) + 1, device=device, dtype=torch.float) - c_coefs = (torch.Tensor([c_table[(int(p), l, m)] - for p in p_vals]) - .to(device)) + c_coefs = torch.Tensor([c_table[(int(p), l, m)] for p in p_vals]).to(device) r_p = r.reshape(-1, 1) ** (2 * p_vals - l) z_q = z.reshape(-1, 1) ** (l - 2 * p_vals - m) @@ -253,34 +234,25 @@ def pi_l_m(r, def norm(vec): - result = ((vec ** 2 + EPS).sum(-1)) ** 0.5 + result = ((vec**2 + EPS).sum(-1)) ** 0.5 return result -def y_lm(r_ij, - r, - l, - m, - c_table): - +def y_lm(r_ij, r, l, m, c_table): x = r_ij[:, 0].reshape(-1, 1) y = r_ij[:, 1].reshape(-1, 1) z = r_ij[:, 2].reshape(-1, 1) - pi = pi_l_m(r=r, - z=z, - l=l, - m=abs(m), - c_table=c_table) + pi = pi_l_m(r=r, z=z, l=l, m=abs(m), c_table=c_table) if m < 0: a = A_m(x, y, abs(m)) - out = (2 ** 0.5) * pi * a + out = (2**0.5) * pi * a elif m == 0: out = pi elif m > 0: b = B_m(x, y, abs(m)) - out = (2 ** 0.5) * pi * b + out = (2**0.5) * pi * b return out @@ -289,72 +261,47 @@ def make_y_lm(l_max): c_table = make_c_table(l_max) def func(r_ij, r, l, m): - out = y_lm(r_ij=r_ij, - r=r, - l=l, - m=m, - c_table=c_table) + out = y_lm(r_ij=r_ij, r=r, l=l, m=m, c_table=c_table) return out + return func def spooky_f_cut(r, r_cut): - arg = r ** 2 / ((r_cut - r) * (r_cut + r)) + arg = r**2 / ((r_cut - r) * (r_cut + r)) # arg < 20 is for numerical stability # Anything > 20 will give under 1e-9 - output = torch.where( - (r < r_cut) * (arg < 20), - torch.exp(-arg), - torch.Tensor([0]).to(r.device) - ) + output = torch.where((r < r_cut) * (arg < 20), torch.exp(-arg), torch.Tensor([0]).to(r.device)) return output -def b_k(x, - bern_k): +def b_k(x, bern_k): device = x.device - k_vals = (torch.arange(0, bern_k, device=device) - .to(torch.float)) - binoms = (torch.Tensor([sp.binom(bern_k - 1, int(k)) - for k in k_vals]) - .to(device)) - out = binoms * (x ** k_vals) * (1-x) ** (bern_k - 1 - k_vals) + k_vals = torch.arange(0, bern_k, device=device).to(torch.float) + binoms = torch.Tensor([sp.binom(bern_k - 1, int(k)) for k in k_vals]).to(device) + out = binoms * (x**k_vals) * (1 - x) ** (bern_k - 1 - k_vals) return out -def rho_k(r, - r_cut, - bern_k, - gamma): - +def rho_k(r, r_cut, bern_k, gamma): arg = torch.exp(-gamma * r) out = b_k(arg, bern_k) * spooky_f_cut(r, r_cut) return out -def get_g_func(l, - r_cut, - bern_k, - gamma, - y_lm_fn): - +def get_g_func(l, r_cut, bern_k, gamma, y_lm_fn): def fn(r_ij): - r = norm(r_ij).reshape(-1, 1) n_pairs = r_ij.shape[0] device = r_ij.device m_vals = list(range(-l, l + 1)) - y = torch.stack([y_lm_fn(r_ij, r, l, m) for m in - m_vals]).transpose(0, 1) + y = torch.stack([y_lm_fn(r_ij, r, l, m) for m in m_vals]).transpose(0, 1) rho = rho_k(r, r_cut, bern_k, gamma) - g = torch.ones(n_pairs, - bern_k, - len(m_vals), - device=device) + g = torch.ones(n_pairs, bern_k, len(m_vals), device=device) g = g * rho.reshape(n_pairs, -1, 1) g = g * y.reshape(n_pairs, 1, -1) @@ -363,23 +310,15 @@ def fn(r_ij): return fn -def make_g_funcs(bern_k, - gamma, - r_cut, - l_max=2): +def make_g_funcs(bern_k, gamma, r_cut, l_max=2): y_lm_fn = make_y_lm(l_max) g_funcs = {} letters = {0: "s", 1: "p", 2: "d"} - for l in range(0, l_max + 1): - + for l in range(l_max + 1): letter = letters[l] name = f"g_{letter}" - g_func = get_g_func(l=l, - r_cut=r_cut, - bern_k=bern_k, - gamma=gamma, - y_lm_fn=y_lm_fn) + g_func = get_g_func(l=l, r_cut=r_cut, bern_k=bern_k, gamma=gamma, y_lm_fn=y_lm_fn) g_funcs[name] = copy.deepcopy(g_func) return g_funcs diff --git a/nff/utils/geom.py b/nff/utils/geom.py index dee380e2..2fbdbf8b 100644 --- a/nff/utils/geom.py +++ b/nff/utils/geom.py @@ -5,37 +5,40 @@ import numpy as np import torch from torch.utils.data import DataLoader -from nff.utils.scatter import scatter_add +from nff.utils.scatter import scatter_add BATCH_SIZE = 3000 def quaternion_to_matrix(q): - q0 = q[:, 0] q1 = q[:, 1] q2 = q[:, 2] q3 = q[:, 3] - R_q = torch.stack([q0**2 + q1**2 - q2**2 - q3**2, - 2 * (q1 * q2 - q0 * q3), - 2 * (q1 * q3 + q0 * q2), - 2 * (q1 * q2 + q0 * q3), - q0**2 - q1**2 + q2**2 - q3**2, - 2 * (q2 * q3 - q0 * q1), - 2 * (q1 * q3 - q0 * q2), - 2 * (q2 * q3 + q0 * q1), - q0**2 - q1**2 - q2**2 + q3**2] - ).transpose(0, 1).reshape(-1, 3, 3) + R_q = ( + torch.stack( + [ + q0**2 + q1**2 - q2**2 - q3**2, + 2 * (q1 * q2 - q0 * q3), + 2 * (q1 * q3 + q0 * q2), + 2 * (q1 * q2 + q0 * q3), + q0**2 - q1**2 + q2**2 - q3**2, + 2 * (q2 * q3 - q0 * q1), + 2 * (q1 * q3 - q0 * q2), + 2 * (q2 * q3 + q0 * q1), + q0**2 - q1**2 - q2**2 + q3**2, + ] + ) + .transpose(0, 1) + .reshape(-1, 3, 3) + ) return R_q -def rotation_matrix_from_points(m0, - m1, - store_grad=False): - +def rotation_matrix_from_points(m0, m1, store_grad=False): v0 = m0[:, None, :, :] # don't have to clone this because we don't modify its actual value below v1 = m1 @@ -45,23 +48,40 @@ def rotation_matrix_from_points(m0, r_22 = out_0[:, 1] r_33 = out_0[:, 2] - out_1 = torch.sum(v0 * torch.roll(v1, -1, dims=1), dim=-1 - ).reshape(-1, 3) + out_1 = torch.sum(v0 * torch.roll(v1, -1, dims=1), dim=-1).reshape(-1, 3) r_12 = out_1[:, 0] r_23 = out_1[:, 1] r_31 = out_1[:, 2] - out_2 = torch.sum(v0 * torch.roll(v1, -2, dims=1), dim=-1 - ).reshape(-1, 3) + out_2 = torch.sum(v0 * torch.roll(v1, -2, dims=1), dim=-1).reshape(-1, 3) r_13 = out_2[:, 0] r_21 = out_2[:, 1] r_32 = out_2[:, 2] - f = torch.stack([r_11 + r_22 + r_33, r_23 - r_32, r_31 - r_13, r_12 - r_21, - r_23 - r_32, r_11 - r_22 - r_33, r_12 + r_21, r_13 + r_31, - r_31 - r_13, r_12 + r_21, -r_11 + r_22 - r_33, r_23 + r_32, - r_12 - r_21, r_13 + r_31, r_23 + r_32, -r_11 - r_22 + r_33] - ).transpose(0, 1).reshape(-1, 4, 4) + f = ( + torch.stack( + [ + r_11 + r_22 + r_33, + r_23 - r_32, + r_31 - r_13, + r_12 - r_21, + r_23 - r_32, + r_11 - r_22 - r_33, + r_12 + r_21, + r_13 + r_31, + r_31 - r_13, + r_12 + r_21, + -r_11 + r_22 - r_33, + r_23 + r_32, + r_12 - r_21, + r_13 + r_31, + r_23 + r_32, + -r_11 - r_22 + r_33, + ] + ) + .transpose(0, 1) + .reshape(-1, 4, 4) + ) # Really slow on a GPU / with torch for some reason. # See https://github.com/pytorch/pytorch/issues/22573: @@ -115,10 +135,7 @@ def rotation_matrix_from_points(m0, return r_with_nan -def minimize_rotation_and_translation(targ_nxyz, - this_nxyz, - store_grad=False): - +def minimize_rotation_and_translation(targ_nxyz, this_nxyz, store_grad=False): base_p = this_nxyz[:, :, 1:] if store_grad: base_p.requires_grad = True @@ -130,9 +147,7 @@ def minimize_rotation_and_translation(targ_nxyz, c0 = p0.mean(1).reshape(-1, 1, 3) p0 -= c0 - R = rotation_matrix_from_points(p.transpose(1, 2), - p0.transpose(1, 2), - store_grad=store_grad) + R = rotation_matrix_from_points(p.transpose(1, 2), p0.transpose(1, 2), store_grad=store_grad) num_repeats = targ_nxyz.shape[0] p_repeat = torch.repeat_interleave(p, num_repeats, dim=0) @@ -142,14 +157,11 @@ def minimize_rotation_and_translation(targ_nxyz, return new_p, p0, R, base_p -def compute_rmsd(targ_nxyz, - this_nxyz): - +def compute_rmsd(targ_nxyz, this_nxyz): targ_nxyz = torch.Tensor(targ_nxyz).reshape(1, -1, 4) this_nxyz = torch.Tensor(this_nxyz).reshape(1, -1, 4) - out = minimize_rotation_and_translation(targ_nxyz=targ_nxyz, - this_nxyz=this_nxyz) + out = minimize_rotation_and_translation(targ_nxyz=targ_nxyz, this_nxyz=this_nxyz) xyz_0, new_targ, _, _ = out num_mols_1 = targ_nxyz.shape[0] @@ -160,20 +172,13 @@ def compute_rmsd(targ_nxyz, delta_sq = (xyz_0 - xyz_1) ** 2 num_atoms = delta_sq.shape[1] - distances = (((delta_sq.sum((1, 2)) / num_atoms) ** 0.5) - .reshape(num_mols_0, num_mols_1) - .cpu().reshape(-1).item()) + distances = ((delta_sq.sum((1, 2)) / num_atoms) ** 0.5).reshape(num_mols_0, num_mols_1).cpu().reshape(-1).item() return distances -def compute_distance(targ_nxyz, - atom_nxyz, - store_grad=False): - - out = minimize_rotation_and_translation(targ_nxyz=targ_nxyz, - this_nxyz=atom_nxyz, - store_grad=store_grad) +def compute_distance(targ_nxyz, atom_nxyz, store_grad=False): + out = minimize_rotation_and_translation(targ_nxyz=targ_nxyz, this_nxyz=atom_nxyz, store_grad=store_grad) xyz_0, new_targ, R, base_p = out @@ -185,22 +190,15 @@ def compute_distance(targ_nxyz, delta_sq = (xyz_0 - xyz_1) ** 2 num_atoms = delta_sq.shape[1] - distances = ((delta_sq.sum((1, 2)) / num_atoms) ** - 0.5).reshape(num_mols_0, num_mols_1).cpu() + distances = ((delta_sq.sum((1, 2)) / num_atoms) ** 0.5).reshape(num_mols_0, num_mols_1).cpu() R = R.cpu() if store_grad: return distances, R, base_p - else: - return distances.detach(), R + return distances.detach(), R -def compute_distances(dataset, - device, - batch_size=BATCH_SIZE, - dataset_1=None, - store_grad=False, - collate_dicts=None): +def compute_distances(dataset, device, batch_size=BATCH_SIZE, dataset_1=None, store_grad=False, collate_dicts=None): """ Compute distances between different configurations for one molecule. """ @@ -216,36 +214,26 @@ def compute_distances(dataset, shape += [3, 3] R_mat = torch.zeros(tuple(shape)) - loader_0 = DataLoader(dataset, - batch_size=batch_size, - collate_fn=collate_dicts) + loader_0 = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_dicts) - loader_1 = DataLoader(dataset_1, - batch_size=batch_size, - collate_fn=collate_dicts) + loader_1 = DataLoader(dataset_1, batch_size=batch_size, collate_fn=collate_dicts) i_start = 0 xyz_list = [] for batch_0 in loader_0: - j_start = 0 for batch_1 in loader_1: - num_mols_0 = len(batch_0["num_atoms"]) num_mols_1 = len(batch_1["num_atoms"]) - targ_nxyz = (batch_0["nxyz"] - .reshape(num_mols_0, -1, 4).to(device)) - atom_nxyz = (batch_1["nxyz"] - .reshape(num_mols_1, -1, 4).to(device)) + targ_nxyz = batch_0["nxyz"].reshape(num_mols_0, -1, 4).to(device) + atom_nxyz = batch_1["nxyz"].reshape(num_mols_1, -1, 4).to(device) - out = compute_distance(targ_nxyz=targ_nxyz, - atom_nxyz=atom_nxyz, - store_grad=store_grad) + out = compute_distance(targ_nxyz=targ_nxyz, atom_nxyz=atom_nxyz, store_grad=store_grad) if store_grad: distances, R, xyz_0 = out - num_atoms = batch_1["num_atoms"].tolist() + batch_1["num_atoms"].tolist() xyz_list.append(xyz_0) else: @@ -253,18 +241,14 @@ def compute_distances(dataset, distances = distances.transpose(0, 1) - all_indices = (torch.ones_like(distances) - .nonzero(as_tuple=False) - .cpu()) + all_indices = torch.ones_like(distances).nonzero(as_tuple=False).cpu() all_indices[:, 0] += i_start all_indices[:, 1] += j_start - distance_mat[all_indices[:, 0], - all_indices[:, 1]] = distances.reshape(-1) + distance_mat[all_indices[:, 0], all_indices[:, 1]] = distances.reshape(-1) - R_mat[all_indices[:, 0], - all_indices[:, 1]] = R.detach() + R_mat[all_indices[:, 0], all_indices[:, 1]] = R.detach() j_start += num_mols_1 @@ -272,8 +256,7 @@ def compute_distances(dataset, if store_grad: return distance_mat, R_mat, xyz_list - else: - return distance_mat, R_mat + return distance_mat, R_mat """ @@ -286,10 +269,7 @@ def compute_distances(dataset, """ -def batched_translate(ref_xyz, - query_xyz, - mol_idx, - num_atoms_tensor): +def batched_translate(ref_xyz, query_xyz, mol_idx, num_atoms_tensor): """ Translate a set of batched atomic coordinates concatenated together from different molecules, so they align with the COM of the reference molecule. @@ -307,60 +287,40 @@ def batched_translate(ref_xyz, num_atoms_tensor (torch.LongTensor): tensor of number of atoms in each molecule """ - ref_sum = scatter_add(src=ref_xyz, - index=mol_idx, - dim=0, - dim_size=mol_idx.max() + 1) + ref_sum = scatter_add(src=ref_xyz, index=mol_idx, dim=0, dim_size=mol_idx.max() + 1) ref_com = ref_sum / num_atoms_tensor.reshape(-1, 1) - query_sum = scatter_add(src=query_xyz, - index=mol_idx, - dim=1, - dim_size=mol_idx.max() + 1) + query_sum = scatter_add(src=query_xyz, index=mol_idx, dim=1, dim_size=mol_idx.max() + 1) query_com = query_sum / num_atoms_tensor.reshape(-1, 1) - ref_centered = ref_xyz - torch.repeat_interleave(ref_com, - num_atoms_tensor, - dim=0) + ref_centered = ref_xyz - torch.repeat_interleave(ref_com, num_atoms_tensor, dim=0) # reshape to match query ref_centered = ref_centered.unsqueeze(0) - query_centered = query_xyz - torch.repeat_interleave(query_com, - num_atoms_tensor, - dim=1) + query_centered = query_xyz - torch.repeat_interleave(query_com, num_atoms_tensor, dim=1) return ref_centered, query_centered -def rmat_from_batched_points(ref_centered, - query_centered, - mol_idx, - num_atoms_tensor, - store_grad=False): +def rmat_from_batched_points(ref_centered, query_centered, mol_idx, num_atoms_tensor, store_grad=False): """ Rotation matrix from a set of atomic coordinates concatenated together from different molecules. """ - out_0 = scatter_add(src=(ref_centered * query_centered), - index=mol_idx, - dim=1) + out_0 = scatter_add(src=(ref_centered * query_centered), index=mol_idx, dim=1) r_11 = out_0[:, :, 0] r_22 = out_0[:, :, 1] r_33 = out_0[:, :, 2] - out_1 = scatter_add(src=(ref_centered * torch.roll(query_centered, -1, dims=2)), - index=mol_idx, - dim=1) + out_1 = scatter_add(src=(ref_centered * torch.roll(query_centered, -1, dims=2)), index=mol_idx, dim=1) r_12 = out_1[:, :, 0] r_23 = out_1[:, :, 1] r_31 = out_1[:, :, 2] - out_2 = scatter_add(src=(ref_centered * torch.roll(query_centered, -2, dims=2)), - index=mol_idx, - dim=1) + out_2 = scatter_add(src=(ref_centered * torch.roll(query_centered, -2, dims=2)), index=mol_idx, dim=1) r_13 = out_2[:, :, 0] r_21 = out_2[:, :, 1] @@ -368,15 +328,14 @@ def rmat_from_batched_points(ref_centered, f_0 = [r_11 + r_22 + r_33, r_23 - r_32, r_31 - r_13, r_12 - r_21] f_1 = [r_23 - r_32, r_11 - r_22 - r_33, r_12 + r_21, r_13 + r_31] - f_2 = [r_31 - r_13, r_12 + r_21, - r_11 + r_22 - r_33, r_23 + r_32] + f_2 = [r_31 - r_13, r_12 + r_21, -r_11 + r_22 - r_33, r_23 + r_32] f_3 = [r_12 - r_21, r_13 + r_31, r_23 + r_32, -r_11 - r_22 + r_33] - f = torch.stack( - [torch.stack(f_0), - torch.stack(f_1), - torch.stack(f_2), - torch.stack(f_3)] - ).permute(2, 3, 0, 1).reshape(-1, 4, 4) + f = ( + torch.stack([torch.stack(f_0), torch.stack(f_1), torch.stack(f_2), torch.stack(f_3)]) + .permute(2, 3, 0, 1) + .reshape(-1, 4, 4) + ) if store_grad: w, V = torch.linalg.eigh(f) @@ -398,52 +357,42 @@ def rmat_from_batched_points(ref_centered, raise NotImplementedError("Not yet implemented in numpy") -def batch_minimize_rot_trans(ref_nxyz, - query_nxyz, - mol_idx, - num_atoms_tensor, - store_grad=False): - +def batch_minimize_rot_trans(ref_nxyz, query_nxyz, mol_idx, num_atoms_tensor, store_grad=False): ref_xyz = ref_nxyz[:, 1:] if store_grad: ref_xyz.requires_grad = True query_xyz = query_nxyz[:, :, 1:] - ref_centered, query_centered = batched_translate(ref_xyz=ref_xyz, - query_xyz=query_xyz, - mol_idx=mol_idx, - num_atoms_tensor=num_atoms_tensor) + ref_centered, query_centered = batched_translate( + ref_xyz=ref_xyz, query_xyz=query_xyz, mol_idx=mol_idx, num_atoms_tensor=num_atoms_tensor + ) - r = rmat_from_batched_points(ref_centered=ref_centered, - query_centered=query_centered, - mol_idx=mol_idx, - num_atoms_tensor=num_atoms_tensor, - store_grad=store_grad) + r = rmat_from_batched_points( + ref_centered=ref_centered, + query_centered=query_centered, + mol_idx=mol_idx, + num_atoms_tensor=num_atoms_tensor, + store_grad=store_grad, + ) - query_center_rot = torch.einsum('...kj,...k->...j', r, query_centered) + query_center_rot = torch.einsum("...kj,...k->...j", r, query_centered) return ref_xyz, ref_centered, query_center_rot, r -def batch_compute_distance(ref_nxyz, - query_nxyz, - mol_idx, - num_atoms_tensor, - store_grad=False): - - out = batch_minimize_rot_trans(ref_nxyz=ref_nxyz, - query_nxyz=query_nxyz, - mol_idx=mol_idx, - num_atoms_tensor=num_atoms_tensor, - store_grad=store_grad) +def batch_compute_distance(ref_nxyz, query_nxyz, mol_idx, num_atoms_tensor, store_grad=False): + out = batch_minimize_rot_trans( + ref_nxyz=ref_nxyz, + query_nxyz=query_nxyz, + mol_idx=mol_idx, + num_atoms_tensor=num_atoms_tensor, + store_grad=store_grad, + ) ref_xyz, ref_centered, query_center_rot, r = out delta_sq = (ref_centered - query_center_rot) ** 2 - delta_sq_sum = scatter_add(src=delta_sq, - index=mol_idx, - dim=1, - dim_size=mol_idx.max() + 1).sum(-1) + delta_sq_sum = scatter_add(src=delta_sq, index=mol_idx, dim=1, dim_size=mol_idx.max() + 1).sum(-1) delta_sq_mean = delta_sq_sum / num_atoms_tensor.reshape(1, -1) - rmsd = delta_sq_mean ** 0.5 + rmsd = delta_sq_mean**0.5 return rmsd, ref_xyz diff --git a/nff/utils/misc.py b/nff/utils/misc.py index 1d2e1b21..75887b50 100644 --- a/nff/utils/misc.py +++ b/nff/utils/misc.py @@ -56,20 +56,16 @@ ] -def tqdm_enum(iter): +def tqdm_enum(iterable): """ Wrap tqdm around `enumerate`. Args: - iter (iterable): an iterable (e.g. list) + iterable (iterable): an iterable (e.g. list) Returns i (int): current index y: current value """ - - i = 0 - for y in tqdm(iter): - yield i, y - i += 1 + yield from enumerate(tqdm(iterable)) def log(prefix, msg): @@ -80,7 +76,7 @@ def log(prefix, msg): prefix (str) msg (str) """ - print("{:>12}: {}".format(prefix.upper(), msg)) + print(f"{prefix.upper():>12}: {msg}") def add_json_args(args, config_flag="config_file"): @@ -179,15 +175,12 @@ def prepare_metric(lines, metric): else: for i, item in enumerate(header_items): sub_keys = metric.split("_") - if all([key.lower() in item.lower() for key in sub_keys]): + if all(key.lower() in item.lower() for key in sub_keys): idx = i optim = METRIC_DIC[metric] - if optim == "minimize": - best_score = float("inf") - else: - best_score = -float("inf") + best_score = float("inf") if optim == "minimize" else -float("inf") best_epoch = -1 @@ -266,7 +259,7 @@ def write_csv(path, dic): None """ - keys = sorted(list(dic.keys())) + keys = sorted(dic.keys()) if "smiles" in keys: keys.remove("smiles") keys.insert(0, "smiles") @@ -364,10 +357,7 @@ def get_split_names(train_only, val_only, test_only): msg = f"Requested {string}, which are mutually exclusive" raise Exception(msg) - if len(requested) != 0: - names = requested - else: - names = ["train", "val", "test"] + names = requested if len(requested) != 0 else ["train", "val", "test"] return names @@ -412,10 +402,7 @@ def apply_metric(metric, pred, actual): """ if metric == "auc": pred = preprocess_class(pred) - if max(pred) == 0: - score = 0 - else: - score = roc_auc_score(y_true=actual, y_score=pred) + score = 0 if max(pred) == 0 else roc_auc_score(y_true=actual, y_score=pred) elif metric == "prc-auc": pred = preprocess_class(pred) if max(pred) == 0: @@ -455,8 +442,8 @@ def avg_distances(dset): all_nbrs = [] for nbrs in dset.props["nbr_list"]: for pair in nbrs: - all_nbrs.append(tuple(pair.tolist())) - all_nbrs_tuple = list(set(tuple(all_nbrs))) + all_nbrs.append(tuple(pair.tolist())) # noqa + all_nbrs_tuple = list(set(all_nbrs)) all_nbrs = torch.LongTensor([list(i) for i in all_nbrs_tuple]) @@ -516,7 +503,7 @@ def parse_args_from_json(arg_path, direc): parser = argparse.ArgumentParser(description=description) default_args.pop("description") - required = parser.add_argument_group(("required arguments (either in " "the command line or the config " "file)")) + required = parser.add_argument_group("required arguments (either in " "the command line or the config " "file)") optional = parser.add_argument_group("optional arguments") for name, info in default_args.items(): diff --git a/nff/utils/scatter.py b/nff/utils/scatter.py index 9c0b9087..bab3e0d7 100644 --- a/nff/utils/scatter.py +++ b/nff/utils/scatter.py @@ -1,38 +1,34 @@ from itertools import repeat + from torch.autograd import grad -def compute_grad(inputs, - output, - allow_unused=False): +def compute_grad(inputs, output, allow_unused=False): """Compute gradient of the scalar output with respect to inputs. Args: inputs (torch.Tensor): torch tensor, requires_grad=True - output (torch.Tensor): scalar output + output (torch.Tensor): scalar output Returns: - torch.Tensor: gradients with respect to each input component + torch.Tensor: gradients with respect to each input component """ assert inputs.requires_grad - gradspred, = grad(output, - inputs, - grad_outputs=output.data.new(output.shape).fill_(1), - create_graph=True, - retain_graph=True, - allow_unused=allow_unused) + (gradspred,) = grad( + output, + inputs, + grad_outputs=output.data.new(output.shape).fill_(1), + create_graph=True, + retain_graph=True, + allow_unused=allow_unused, + ) return gradspred -def gen(src, - index, - dim=-1, - out=None, - dim_size=None, - fill_value=0): +def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0): dim = range(src.dim())[dim] # Get real dim value. # Automatically expand index tensor to the right dimensions. @@ -51,19 +47,8 @@ def gen(src, return src, out, index, dim -def scatter_add(src, - index, - dim=-1, - out=None, - dim_size=None, - fill_value=0): - - src, out, index, dim = gen(src=src, - index=index, - dim=dim, - out=out, - dim_size=dim_size, - fill_value=fill_value) +def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0): + src, out, index, dim = gen(src=src, index=index, dim=dim, out=out, dim_size=dim_size, fill_value=fill_value) output = out.scatter_add_(dim, index, src) return output diff --git a/nff/utils/script_utils/loaders.py b/nff/utils/script_utils/loaders.py index 2e12c6b7..465c6d94 100644 --- a/nff/utils/script_utils/loaders.py +++ b/nff/utils/script_utils/loaders.py @@ -1,19 +1,18 @@ import torch -import nff.data - from torch.utils.data import DataLoader from torch.utils.data.sampler import RandomSampler + +import nff.data from nff.data.loader import collate_dicts def get_loaders(args, logging=None): - if logging is not None: logging.info("loading dataset...") dataset = torch.load(args.data_path) - if args.mode == 'eval': + if args.mode == "eval": test_loader = DataLoader( dataset, batch_size=args.batch_size, @@ -22,38 +21,26 @@ def get_loaders(args, logging=None): return test_loader - elif args.mode == 'train': - + if args.mode == "train": if logging is not None: logging.info("creating splits...") train, val, test = nff.data.split_train_validation_test( - dataset, - val_size=args.split[0], - test_size=args.split[1] + dataset, val_size=args.split[0], test_size=args.split[1] ) - + if logging is not None: logging.info("load data...") - + train_loader = DataLoader( train, batch_size=args.batch_size, num_workers=args.workers, collate_fn=collate_dicts, - sampler=RandomSampler(train) + sampler=RandomSampler(train), ) - val_loader = DataLoader( - val, - batch_size=args.batch_size, - num_workers=args.workers, - collate_fn=collate_dicts - ) - test_loader = DataLoader( - test, - batch_size=args.batch_size, - num_workers=args.workers, - collate_fn=collate_dicts - ) - + val_loader = DataLoader(val, batch_size=args.batch_size, num_workers=args.workers, collate_fn=collate_dicts) + test_loader = DataLoader(test, batch_size=args.batch_size, num_workers=args.workers, collate_fn=collate_dicts) + return train_loader, val_loader, test_loader + return None diff --git a/nff/utils/script_utils/parsers.py b/nff/utils/script_utils/parsers.py index d4a76cca..ed6aa9a7 100644 --- a/nff/utils/script_utils/parsers.py +++ b/nff/utils/script_utils/parsers.py @@ -1,23 +1,23 @@ """Argument parsing from the command line. -From https://github.com/atomistic-machine-learning/schnetpack/blob/dev/src/schnetpack/utils/script_utils/script_parsing.py +From: +https://github.com/atomistic-machine-learning/schnetpack/blob/dev/src/schnetpack/utils/script_utils/script_parsing.py """ import argparse def get_main_parser(): - """ Setup parser for command line arguments """ - ## command-specific + """Setup parser for command line arguments""" + # command-specific cmd_parser = argparse.ArgumentParser(add_help=False) cmd_parser.add_argument( "--device", - default='cuda', + default="cuda", help="Device to use", ) cmd_parser.add_argument( "--parallel", - help="Run data-parallel on all available GPUs (specify with environment" - " variable CUDA_VISIBLE_DEVICES)", + help="Run data-parallel on all available GPUs (specify with environment" " variable CUDA_VISIBLE_DEVICES)", action="store_true", ) cmd_parser.add_argument( @@ -30,16 +30,12 @@ def get_main_parser(): def add_subparsers(cmd_parser, defaults={}): - ## training + # training train_parser = argparse.ArgumentParser(add_help=False, parents=[cmd_parser]) train_parser.add_argument("data_path", help="Dataset to use") train_parser.add_argument("model_path", help="Destination for models and logs") - train_parser.add_argument( - "--seed", type=int, default=None, help="Set random seed for torch and numpy." - ) - train_parser.add_argument( - "--overwrite", help="Remove previous model directory.", action="store_true" - ) + train_parser.add_argument("--seed", type=int, default=None, help="Set random seed for torch and numpy.") + train_parser.add_argument("--overwrite", help="Remove previous model directory.", action="store_true") train_parser.add_argument( "--split", @@ -57,9 +53,8 @@ def add_subparsers(cmd_parser, defaults={}): train_parser.add_argument( "--lr_patience", type=int, - help="Epochs without improvement before reducing the learning rate " - "(default: %(default)s)", - default=25 if "lr_patience" not in defaults.keys() else defaults["lr_patience"], + help="Epochs without improvement before reducing the learning rate " "(default: %(default)s)", + default=defaults.get("lr_patience", 25), ) train_parser.add_argument( "--lr_decay", @@ -111,22 +106,22 @@ def add_subparsers(cmd_parser, defaults={}): default='{"energy": 0.1, "energy_grad": 1.0}', ) - ## evaluation + # evaluation eval_parser = argparse.ArgumentParser(add_help=False, parents=[cmd_parser]) eval_parser.add_argument("data_path", help="Dataset to use") eval_parser.add_argument("model_path", help="Path of stored model") -# eval_parser.add_argument( -# "--split", -# help="Evaluate trained model on given split", -# choices=["train", "validation", "test"], -# default=["test"], -# nargs="+", -# ) + # eval_parser.add_argument( + # "--split", + # help="Evaluate trained model on given split", + # choices=["train", "validation", "test"], + # default=["test"], + # nargs="+", + # ) # model-specific parsers model_parser = argparse.ArgumentParser(add_help=False) - ####### SchNet ####### + # SchNet schnet_parser = argparse.ArgumentParser(add_help=False, parents=[model_parser]) schnet_parser.add_argument( "--n_atom_basis", @@ -134,18 +129,14 @@ def add_subparsers(cmd_parser, defaults={}): help="Size of atom-wise representation", default=256, ) - schnet_parser.add_argument( - "--n_filters", type=int, help="Size of atom-wise representation", default=25 - ) + schnet_parser.add_argument("--n_filters", type=int, help="Size of atom-wise representation", default=25) schnet_parser.add_argument( "--n_gaussians", type=int, default=25, help="Number of Gaussians to expand distances (default: %(default)s)", ) - schnet_parser.add_argument( - "--n_convolutions", type=int, help="Number of interaction blocks", default=6 - ) + schnet_parser.add_argument("--n_convolutions", type=int, help="Number of interaction blocks", default=6) schnet_parser.add_argument( "--cutoff", type=float, @@ -154,7 +145,7 @@ def add_subparsers(cmd_parser, defaults={}): ) schnet_parser.add_argument( "--trainable_gauss", - action='store_true', + action="store_true", help="If set, sets gaussians as learnable parameters (default: False)", ) schnet_parser.add_argument( @@ -164,10 +155,8 @@ def add_subparsers(cmd_parser, defaults={}): help="Dropout rate for SchNet convolutions (default: %(default)s)", ) - ## setup subparser structure - cmd_subparsers = cmd_parser.add_subparsers( - dest="mode", help="Command-specific arguments" - ) + # setup subparser structure + cmd_subparsers = cmd_parser.add_subparsers(dest="mode", help="Command-specific arguments") cmd_subparsers.required = True subparser_train = cmd_subparsers.add_parser("train", help="Training help") subparser_eval = cmd_subparsers.add_parser("eval", help="Eval help") @@ -175,22 +164,12 @@ def add_subparsers(cmd_parser, defaults={}): subparser_export = cmd_subparsers.add_parser("export", help="Export help") subparser_export.add_argument("data_path", help="Dataset to use") subparser_export.add_argument("model_path", help="Path of stored model") - subparser_export.add_argument( - "dest_path", help="Destination path for exported model" - ) + subparser_export.add_argument("dest_path", help="Destination path for exported model") - train_subparsers = subparser_train.add_subparsers( - dest="model", help="Model-specific arguments" - ) + train_subparsers = subparser_train.add_subparsers(dest="model", help="Model-specific arguments") train_subparsers.required = True - train_subparsers.add_parser( - "schnet", help="SchNet help", parents=[train_parser, schnet_parser] - ) + train_subparsers.add_parser("schnet", help="SchNet help", parents=[train_parser, schnet_parser]) - eval_subparsers = subparser_eval.add_subparsers( - dest="model", help="Model-specific arguments" - ) + eval_subparsers = subparser_eval.add_subparsers(dest="model", help="Model-specific arguments") eval_subparsers.required = True - eval_subparsers.add_parser( - "schnet", help="SchNet help", parents=[eval_parser, schnet_parser] - ) + eval_subparsers.add_parser("schnet", help="SchNet help", parents=[eval_parser, schnet_parser]) diff --git a/nff/utils/script_utils/setup.py b/nff/utils/script_utils/setup.py index 89a8d8e4..0c6bf871 100644 --- a/nff/utils/script_utils/setup.py +++ b/nff/utils/script_utils/setup.py @@ -1,12 +1,14 @@ """Helper function to setup the run from the command line. -Adapted from https://github.com/atomistic-machine-learning/schnetpack/blob/dev/src/schnetpack/utils/script_utils/setup.py +Adapted from: +https://github.com/atomistic-machine-learning/schnetpack/blob/dev/src/schnetpack/utils/script_utils/setup.py """ -import os + import logging +import os from shutil import rmtree -from nff.utils.tools import to_json, set_random_seed, read_from_json +from nff.utils.tools import read_from_json, set_random_seed, to_json __all__ = ["setup_run"] @@ -16,8 +18,8 @@ def setup_run(args): jsonpath = os.path.join(args.model_path, "args.json") # absolute paths - argparse_dict['data_path'] = os.path.abspath(argparse_dict['data_path']) - argparse_dict['model_path'] = os.path.abspath(argparse_dict['model_path']) + argparse_dict["data_path"] = os.path.abspath(argparse_dict["data_path"]) + argparse_dict["model_path"] = os.path.abspath(argparse_dict["model_path"]) if args.mode == "train": if args.overwrite and os.path.exists(args.model_path): diff --git a/nff/utils/tools.py b/nff/utils/tools.py index abb2200b..75c23b4b 100644 --- a/nff/utils/tools.py +++ b/nff/utils/tools.py @@ -1,25 +1,24 @@ """Assorted tools in the package. Adapted from https://github.com/atomistic-machine-learning/schnetpack/blob/dev/src/schnetpack/utils/spk_utils.py """ + +import collections import json import logging -import collections from argparse import Namespace import numpy as np import torch - from torch.nn import ModuleDict, Sequential -from nff.nn.activations import (shifted_softplus, Swish, - LearnableSwish) -from nff.nn.layers import Dense +from nff.nn.activations import LearnableSwish, Swish, shifted_softplus +from nff.nn.layers import Dense __all__ = [ - "set_random_seed", "compute_params", - "to_json", "read_from_json", + "set_random_seed", + "to_json", ] layer_types = { @@ -31,31 +30,32 @@ "sigmoid": torch.nn.Sigmoid, "Dropout": torch.nn.Dropout, "LeakyReLU": torch.nn.LeakyReLU, - "ELU": torch.nn.ELU, + "ELU": torch.nn.ELU, "swish": Swish, "learnable_swish": LearnableSwish, - "softplus": torch.nn.Softplus + "softplus": torch.nn.Softplus, } def construct_Sequential(layers): - """Construct a sequential model from list of params + """Construct a sequential model from list of params Args: - layers (list): list to describe the stacked layer params + layers (list): list to describe the stacked layer params example: [ {'name': 'linear', 'param' : {'in_features': 10, 'out_features': 20}}, {'name': 'linear', 'param' : {'in_features': 10, 'out_features': 1}} ] Returns: - Sequential: Stacked Sequential Model + Sequential: Stacked Sequential Model """ - return Sequential(collections.OrderedDict([layer['name']+str(i), - layer_types[layer['name']]( - **layer['param']) - ] for i, layer in enumerate(layers))) + return Sequential( + collections.OrderedDict( + [layer["name"] + str(i), layer_types[layer["name"]](**layer["param"])] for i, layer in enumerate(layers) + ) + ) def construct_ModuleDict(moduledict): @@ -81,6 +81,7 @@ def set_random_seed(seed): seed (int, optional): if seed not present, it is generated based on time """ import time + import numpy as np # 1) if seed not present, generate based on time @@ -99,7 +100,7 @@ def set_random_seed(seed): np.random.seed(seed) # 3) Set seed for torch (manual_seed now seeds all CUDA devices automatically) torch.manual_seed(seed) - logging.info("Random state initialized with seed {:<10d}".format(seed)) + logging.info(f"Random state initialized with seed {seed:<10d}") def compute_params(model): @@ -146,7 +147,6 @@ def read_from_json(jsonpath): def make_directed(nbr_list): - gtr_ij = (nbr_list[:, 0] > nbr_list[:, 1]).any().item() gtr_ji = (nbr_list[:, 1] > nbr_list[:, 0]).any().item() directed = gtr_ij and gtr_ji @@ -157,6 +157,7 @@ def make_directed(nbr_list): new_nbrs = torch.cat([nbr_list, nbr_list.flip(1)], dim=0) return new_nbrs, directed + def make_undirected(nbr_list): gtr_ij = (nbr_list[:, 0] > nbr_list[:, 1]).any().item() gtr_ji = (nbr_list[:, 1] > nbr_list[:, 0]).any().item() @@ -165,5 +166,5 @@ def make_undirected(nbr_list): if not directed: return nbr_list, directed nbrs = nbr_list[nbr_list[:, 1] > nbr_list[:, 0]] - + return nbrs, directed diff --git a/nff/utils/xyz2mol.py b/nff/utils/xyz2mol.py index d63feafa..19061240 100644 --- a/nff/utils/xyz2mol.py +++ b/nff/utils/xyz2mol.py @@ -12,15 +12,13 @@ """ import copy -import itertools -import pickle -from functools import wraps import errno +import itertools import os +import pickle import signal +from functools import wraps - -from rdkit.Chem import rdmolops try: from rdkit.Chem import rdEHTTools # requires RDKit 2019.9.1 or later except ImportError: @@ -28,44 +26,121 @@ from collections import defaultdict -import numpy as np import networkx as nx - +import numpy as np from rdkit import Chem -from rdkit.Chem import AllChem, rdmolops, GetPeriodicTable +from rdkit.Chem import AllChem, GetPeriodicTable from rdkit.Chem.rdchem import EditableMol - -global __ATOM_LIST__ -__ATOM_LIST__ = \ - ['h', 'he', - 'li', 'be', 'b', 'c', 'n', 'o', 'f', 'ne', - 'na', 'mg', 'al', 'si', 'p', 's', 'cl', 'ar', - 'k', 'ca', 'sc', 'ti', 'v ', 'cr', 'mn', 'fe', 'co', 'ni', 'cu', - 'zn', 'ga', 'ge', 'as', 'se', 'br', 'kr', - 'rb', 'sr', 'y', 'zr', 'nb', 'mo', 'tc', 'ru', 'rh', 'pd', 'ag', - 'cd', 'in', 'sn', 'sb', 'te', 'i', 'xe', - 'cs', 'ba', 'la', 'ce', 'pr', 'nd', 'pm', 'sm', 'eu', 'gd', 'tb', 'dy', - 'ho', 'er', 'tm', 'yb', 'lu', 'hf', 'ta', 'w', 're', 'os', 'ir', 'pt', - 'au', 'hg', 'tl', 'pb', 'bi', 'po', 'at', 'rn', - 'fr', 'ra', 'ac', 'th', 'pa', 'u', 'np', 'pu'] +__ATOM_LIST__ = [ + "h", + "he", + "li", + "be", + "b", + "c", + "n", + "o", + "f", + "ne", + "na", + "mg", + "al", + "si", + "p", + "s", + "cl", + "ar", + "k", + "ca", + "sc", + "ti", + "v ", + "cr", + "mn", + "fe", + "co", + "ni", + "cu", + "zn", + "ga", + "ge", + "as", + "se", + "br", + "kr", + "rb", + "sr", + "y", + "zr", + "nb", + "mo", + "tc", + "ru", + "rh", + "pd", + "ag", + "cd", + "in", + "sn", + "sb", + "te", + "i", + "xe", + "cs", + "ba", + "la", + "ce", + "pr", + "nd", + "pm", + "sm", + "eu", + "gd", + "tb", + "dy", + "ho", + "er", + "tm", + "yb", + "lu", + "hf", + "ta", + "w", + "re", + "os", + "ir", + "pt", + "au", + "hg", + "tl", + "pb", + "bi", + "po", + "at", + "rn", + "fr", + "ra", + "ac", + "th", + "pa", + "u", + "np", + "pu", +] -global atomic_valence -global atomic_valence_electrons - atomic_valence = defaultdict(list) - atomic_valence_electrons = {} PERIODICTABLE = GetPeriodicTable() for i in range(100): dics = [atomic_valence, atomic_valence_electrons] - if all([i in dic for dic in dics]): + if all(i in dic for dic in dics): continue - valence_list = [j for j in PERIODICTABLE.GetValenceList(i)] + valence_list = list(PERIODICTABLE.GetValenceList(i)) valence_num = PERIODICTABLE.GetNOuterElecs(i) atomic_valence[i] = valence_list @@ -81,7 +156,10 @@ class TimeoutError(Exception): pass -def timeout(seconds, error_message=os.strerror(errno.ETIME)): +ERROR_MESSAGE = os.strerror(errno.ETIME) + + +def timeout(seconds, error_message=ERROR_MESSAGE): def decorator(func): def _handle_timeout(signum, frame): raise TimeoutError(error_message) @@ -105,24 +183,18 @@ def str_atom(atom): """ convert integer atom to string atom """ - global __ATOM_LIST__ - atom = __ATOM_LIST__[atom - 1] - return atom + return __ATOM_LIST__[atom - 1] def int_atom(atom): """ convert str atom to integer atom """ - global __ATOM_LIST__ - print(atom) - atom = atom.lower() - return __ATOM_LIST__.index(atom) + 1 + return __ATOM_LIST__.index(atom.lower()) + 1 def get_UA(maxValence_list, valence_list): - """ - """ + """ """ UA = [] DU = [] for i, (maxValence, valence) in enumerate(zip(maxValence_list, valence_list)): @@ -134,8 +206,7 @@ def get_UA(maxValence_list, valence_list): def get_BO(AC, UA, DU, valences, UA_pairs, use_graph=True): - """ - """ + """ """ BO = AC.copy() DU_save = [] @@ -153,18 +224,12 @@ def get_BO(AC, UA, DU, valences, UA_pairs, use_graph=True): def valences_not_too_large(BO, valences): - """ - """ + """ """ number_of_bonds_list = BO.sum(axis=1) - for valence, number_of_bonds in zip(valences, number_of_bonds_list): - if number_of_bonds > valence: - return False + return all(number_of_bonds <= valence for valence, number_of_bonds in zip(valences, number_of_bonds_list)) - return True - -def BO_is_OK(BO, AC, charge, DU, atomic_valence_electrons, atoms, valances, - allow_charged_fragments=True): +def BO_is_OK(BO, AC, charge, DU, atomic_valence_electrons, atoms, valances, allow_charged_fragments=True): """ Sanity of bond-orders @@ -172,11 +237,11 @@ def BO_is_OK(BO, AC, charge, DU, atomic_valence_electrons, atoms, valances, BO - AC - charge - - DU - + DU - optional - allow_charges_fragments - + allow_charges_fragments - returns: @@ -193,18 +258,16 @@ def BO_is_OK(BO, AC, charge, DU, atomic_valence_electrons, atoms, valances, q_list = [] if allow_charged_fragments: - BO_valences = list(BO.sum(axis=1)) for i, atom in enumerate(atoms): - q = get_atomic_charge( - atom, atomic_valence_electrons[atom], BO_valences[i]) + q = get_atomic_charge(atom, atomic_valence_electrons[atom], BO_valences[i]) Q += q if atom == 6: number_of_single_bonds_to_C = list(BO[i, :]).count(1) if number_of_single_bonds_to_C == 2 and BO_valences[i] == 2: Q += 1 q = 2 - if number_of_single_bonds_to_C == 3 and Q + 1 < charge: + if number_of_single_bonds_to_C == 3 and charge > Q + 1: Q += 2 q = 1 @@ -215,23 +278,17 @@ def BO_is_OK(BO, AC, charge, DU, atomic_valence_electrons, atoms, valances, check_charge = charge == Q # check_len = len(q_list) <= abs(charge) - if check_sum and check_charge: - return True - - return False + return bool(check_sum and check_charge) def get_atomic_charge(atom, atomic_valence_electrons, BO_valence): - """ - """ + """ """ if atom == 1: charge = 1 - BO_valence elif atom == 5: charge = 3 - BO_valence - elif atom == 15 and BO_valence == 5: - charge = 0 - elif atom == 16 and BO_valence == 6: + elif (atom == 15 and BO_valence == 5) or (atom == 16 and BO_valence == 6): charge = 0 else: charge = atomic_valence_electrons - 8 + BO_valence @@ -254,10 +311,12 @@ def clean_charges(mol): # '[O:1]=[c:2][c-:3]>>[*-:1][*:2][*+0:3]', # '[O:1]=[C:2][C-:3]>>[*-:1][*:2]=[*+0:3]'] - rxn_smarts = ['[#6,#7:1]1=[#6,#7:2][#6,#7:3]=[#6,#7:4][CX3-,NX3-:5][#6,#7:6]1=[#6,#7:7]>>' - '[#6,#7:1]1=[#6,#7:2][#6,#7:3]=[#6,#7:4][-0,-0:5]=[#6,#7:6]1[#6-,#7-:7]', - '[#6,#7:1]1=[#6,#7:2][#6,#7:3](=[#6,#7:4])[#6,#7:5]=[#6,#7:6][CX3-,NX3-:7]1>>' - '[#6,#7:1]1=[#6,#7:2][#6,#7:3]([#6-,#7-:4])=[#6,#7:5][#6,#7:6]=[-0,-0:7]1'] + rxn_smarts = [ + "[#6,#7:1]1=[#6,#7:2][#6,#7:3]=[#6,#7:4][CX3-,NX3-:5][#6,#7:6]1=[#6,#7:7]>>" + "[#6,#7:1]1=[#6,#7:2][#6,#7:3]=[#6,#7:4][-0,-0:5]=[#6,#7:6]1[#6-,#7-:7]", + "[#6,#7:1]1=[#6,#7:2][#6,#7:3](=[#6,#7:4])[#6,#7:5]=[#6,#7:6][CX3-,NX3-:7]1>>" + "[#6,#7:1]1=[#6,#7:2][#6,#7:3]([#6-,#7-:4])=[#6,#7:5][#6,#7:6]=[-0,-0:7]1", + ] fragments = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False) @@ -269,16 +328,12 @@ def clean_charges(mol): ps = rxn.RunReactants((fragment,)) fragment = ps[0][0] Chem.SanitizeMol(fragment) - if i == 0: - mol = fragment - else: - mol = Chem.CombineMols(mol, fragment) + mol = fragment if i == 0 else Chem.CombineMols(mol, fragment) return mol -def BO2mol(mol, BO_matrix, atoms, atomic_valence_electrons, - mol_charge, allow_charged_fragments=True): +def BO2mol(mol, BO_matrix, atoms, atomic_valence_electrons, mol_charge, allow_charged_fragments=True): """ based on code written by Paolo Toscani @@ -300,26 +355,21 @@ def BO2mol(mol, BO_matrix, atoms, atomic_valence_electrons, """ - l = len(BO_matrix) + l1 = len(BO_matrix) l2 = len(atoms) BO_valences = list(BO_matrix.sum(axis=1)) - if (l != l2): - raise RuntimeError( - 'sizes of adjMat ({0:d}) and Atoms {1:d} differ'.format(l, l2)) + if l1 != l2: + raise RuntimeError(f"sizes of adjMat ({l1:d}) and Atoms {l2:d} differ") rwMol = Chem.RWMol(mol) - bondTypeDict = { - 1: Chem.BondType.SINGLE, - 2: Chem.BondType.DOUBLE, - 3: Chem.BondType.TRIPLE - } + bondTypeDict = {1: Chem.BondType.SINGLE, 2: Chem.BondType.DOUBLE, 3: Chem.BondType.TRIPLE} - for i in range(l): - for j in range(i + 1, l): + for i in range(l1): + for j in range(i + 1, l1): bo = int(round(BO_matrix[i, j])) - if (bo == 0): + if bo == 0: continue bt = bondTypeDict.get(bo, Chem.BondType.SINGLE) rwMol.AddBond(i, j, bt) @@ -327,29 +377,19 @@ def BO2mol(mol, BO_matrix, atoms, atomic_valence_electrons, mol = rwMol.GetMol() if allow_charged_fragments: - mol = set_atomic_charges( - mol, - atoms, - atomic_valence_electrons, - BO_valences, - BO_matrix, - mol_charge) + mol = set_atomic_charges(mol, atoms, atomic_valence_electrons, BO_valences, BO_matrix, mol_charge) else: - mol = set_atomic_radicals( - mol, atoms, atomic_valence_electrons, BO_valences) + mol = set_atomic_radicals(mol, atoms, atomic_valence_electrons, BO_valences) return mol -def set_atomic_charges(mol, atoms, atomic_valence_electrons, - BO_valences, BO_matrix, mol_charge): - """ - """ +def set_atomic_charges(mol, atoms, atomic_valence_electrons, BO_valences, BO_matrix, mol_charge): + """ """ q = 0 for i, atom in enumerate(atoms): a = mol.GetAtomWithIdx(i) - charge = get_atomic_charge( - atom, atomic_valence_electrons[atom], BO_valences[i]) + charge = get_atomic_charge(atom, atomic_valence_electrons[atom], BO_valences[i]) q += charge if atom == 6: number_of_single_bonds_to_C = list(BO_matrix[i, :]).count(1) @@ -360,7 +400,7 @@ def set_atomic_charges(mol, atoms, atomic_valence_electrons, q += 2 charge = 1 - if (abs(charge) > 0): + if abs(charge) > 0: a.SetFormalCharge(int(charge)) mol = clean_charges(mol) @@ -376,35 +416,28 @@ def set_atomic_radicals(mol, atoms, atomic_valence_electrons, BO_valences): """ for i, atom in enumerate(atoms): a = mol.GetAtomWithIdx(i) - charge = get_atomic_charge( - atom, - atomic_valence_electrons[atom], - BO_valences[i]) + charge = get_atomic_charge(atom, atomic_valence_electrons[atom], BO_valences[i]) - if (abs(charge) > 0): + if abs(charge) > 0: a.SetNumRadicalElectrons(abs(int(charge))) return mol def get_bonds(UA, AC): - """ - - """ + """ """ bonds = [] for k, i in enumerate(UA): - for j in UA[k + 1:]: + for j in UA[k + 1 :]: if AC[i, j] == 1: - bonds.append(tuple(sorted([i, j]))) + bonds.append(tuple(sorted([i, j]))) # noqa return bonds def get_UA_pairs(UA, AC, use_graph=True): - """ - - """ + """ """ bonds = get_bonds(UA, AC) @@ -435,7 +468,7 @@ def get_UA_pairs(UA, AC, use_graph=True): def AC2BO(AC, atoms, charge, allow_charged_fragments=True, use_graph=True): """ - implemenation of algorithm shown in Figure 2 + implementation of the algorithm shown in Figure 2 UA: unsaturated atoms @@ -445,14 +478,9 @@ def AC2BO(AC, atoms, charge, allow_charged_fragments=True, use_graph=True): """ - global atomic_valence - global atomic_valence_electrons - # make a list of valences, e.g. for CO: [[4],[2,1]] - valences_list_of_lists = [] AC_valence = list(AC.sum(axis=1)) - for atomicNum in atoms: - valences_list_of_lists.append(atomic_valence[atomicNum]) + valences_list_of_lists = [atomic_valence[atomicNum] for atomicNum in atoms] # convert [[4],[2,1]] to [[4,2],[4,1]] valences_list = itertools.product(*valences_list_of_lists) @@ -460,14 +488,20 @@ def AC2BO(AC, atoms, charge, allow_charged_fragments=True, use_graph=True): best_BO = AC.copy() for valences in valences_list: - UA, DU_from_AC = get_UA(valences, AC_valence) - check_len = (len(UA) == 0) + check_len = len(UA) == 0 if check_len: - check_bo = BO_is_OK(AC, AC, charge, DU_from_AC, - atomic_valence_electrons, atoms, valences, - allow_charged_fragments=allow_charged_fragments) + check_bo = BO_is_OK( + AC, + AC, + charge, + DU_from_AC, + atomic_valence_electrons, + atoms, + valences, + allow_charged_fragments=allow_charged_fragments, + ) else: check_bo = None @@ -476,48 +510,43 @@ def AC2BO(AC, atoms, charge, allow_charged_fragments=True, use_graph=True): UA_pairs_list = get_UA_pairs(UA, AC, use_graph=use_graph) for UA_pairs in UA_pairs_list: - BO = get_BO(AC, UA, DU_from_AC, valences, - UA_pairs, use_graph=use_graph) - status = BO_is_OK(BO, AC, charge, DU_from_AC, - atomic_valence_electrons, atoms, valences, - allow_charged_fragments=allow_charged_fragments) + BO = get_BO(AC, UA, DU_from_AC, valences, UA_pairs, use_graph=use_graph) + status = BO_is_OK( + BO, + AC, + charge, + DU_from_AC, + atomic_valence_electrons, + atoms, + valences, + allow_charged_fragments=allow_charged_fragments, + ) if status: return BO, atomic_valence_electrons - elif BO.sum() >= best_BO.sum() and valences_not_too_large(BO, valences): + if BO.sum() >= best_BO.sum() and valences_not_too_large(BO, valences): best_BO = BO.copy() return best_BO, atomic_valence_electrons def AC2mol(mol, AC, atoms, charge, allow_charged_fragments=True, use_graph=True): - """ - """ + """ """ # convert AC matrix to bond order (BO) matrix BO, atomic_valence_electrons = AC2BO( - AC, - atoms, - charge, - allow_charged_fragments=allow_charged_fragments, - use_graph=use_graph) + AC, atoms, charge, allow_charged_fragments=allow_charged_fragments, use_graph=use_graph + ) # add BO connectivity and charge info to mol object - mol = BO2mol( - mol, - BO, - atoms, - atomic_valence_electrons, - charge, - allow_charged_fragments=allow_charged_fragments) + mol = BO2mol(mol, BO, atoms, atomic_valence_electrons, charge, allow_charged_fragments=allow_charged_fragments) return mol def get_proto_mol(atoms): - """ - """ + """ """ mol = Chem.MolFromSmarts("[#" + str(atoms[0]) + "]") rwMol = Chem.RWMol(mol) for i in range(1, len(atoms)): @@ -530,20 +559,17 @@ def get_proto_mol(atoms): def read_xyz_file(filename, look_for_charge=True): - """ - """ + """ """ atomic_symbols = [] xyz_coordinates = [] charge = 0 - title = "" with open(filename, "r") as file: for line_number, line in enumerate(file): if line_number == 0: - num_atoms = int(line) + int(line) elif line_number == 1: - title = line if "charge=" in line: charge = int(line.split("=")[1]) else: @@ -577,12 +603,10 @@ def xyz2AC(atoms, xyz, charge, use_huckel=False): if use_huckel: return xyz2AC_huckel(atoms, xyz, charge) - else: - return xyz2AC_vdW(atoms, xyz) + return xyz2AC_vdW(atoms, xyz) def xyz2AC_vdW(atoms, xyz): - # Get mol template mol = get_proto_mol(atoms) @@ -668,10 +692,9 @@ def xyz2AC_huckel(atomicNumList, xyz, charge): passed, result = rdEHTTools.RunMol(mol_huckel) opop = result.GetReducedOverlapPopulationMatrix() tri = np.zeros((num_atoms, num_atoms)) - tri[np.tril(np.ones((num_atoms, num_atoms), dtype=bool)) - ] = opop # lower triangular to square matrix + tri[np.tril(np.ones((num_atoms, num_atoms), dtype=bool))] = opop # lower triangular to square matrix for i in range(num_atoms): - for j in range(i+1, num_atoms): + for j in range(i + 1, num_atoms): pair_pop = abs(tri[j, i]) if pair_pop >= 0.15: # arbitry cutoff for bond. May need adjustment AC[i, j] = 1 @@ -696,9 +719,7 @@ def chiral_stereo_check(mol): return -def check_mol(mol, - coordinates): - +def check_mol(mol, coordinates): conf = mol.GetConformers()[0] new_coords = conf.GetPositions() old_coords = np.array(coordinates) @@ -711,40 +732,34 @@ def check_mol(mol, new_pos = mol.GetConformers()[0].GetPositions() - dist = np.linalg.norm(new_pos.reshape(1, *new_pos.shape) - - old_coords.reshape(old_coords.shape[0], - 1, - old_coords.shape[1]), - axis=-1) + dist = np.linalg.norm( + new_pos.reshape(1, *new_pos.shape) - old_coords.reshape(old_coords.shape[0], 1, old_coords.shape[1]), axis=-1 + ) new_idx = dist.argmin(-1).tolist() rev_idx = dist.argmin(0).tolist() - ed_mol = EditableMol(Chem.MolFromSmiles('')) + ed_mol = EditableMol(Chem.MolFromSmiles("")) - for i, idx in enumerate(new_idx): + for idx in new_idx: atom = mol.GetAtoms()[idx] ed_mol.AddAtom(atom) all_old_bond_idx = [] all_old_bond_types = [] - for i, atom in enumerate(mol.GetAtoms()): - + for atom in mol.GetAtoms(): bonds = atom.GetBonds() - old_bond_idx = [[i.GetBeginAtomIdx(), i.GetEndAtomIdx()] - for i in bonds] + old_bond_idx = [[i.GetBeginAtomIdx(), i.GetEndAtomIdx()] for i in bonds] bond_types = [i.GetBondType() for i in bonds] - use_idx = [j for j, idx in enumerate(old_bond_idx) - if idx not in all_old_bond_idx] + use_idx = [j for j, idx in enumerate(old_bond_idx) if idx not in all_old_bond_idx] all_old_bond_idx += [old_bond_idx[j] for j in use_idx] all_old_bond_types += [bond_types[j] for j in use_idx] for bond_idx, bond_type in zip(all_old_bond_idx, all_old_bond_types): - new_bond_idx = [rev_idx[bond_idx[0]], - rev_idx[bond_idx[1]]] + new_bond_idx = [rev_idx[bond_idx[0]], rev_idx[bond_idx[1]]] ed_mol.AddBond(new_bond_idx[0], new_bond_idx[1], bond_type) @@ -761,13 +776,9 @@ def check_mol(mol, @timeout(seconds=MAX_TIME) -def xyz2mol(atoms, - coordinates, - charge=0, - allow_charged_fragments=True, - use_graph=True, - use_huckel=False, - embed_chiral=True): +def xyz2mol( + atoms, coordinates, charge=0, allow_charged_fragments=True, use_graph=True, use_huckel=False, embed_chiral=True +): """ Generate a rdkit molobj from atoms, coordinates and a total_charge. @@ -793,62 +804,39 @@ def xyz2mol(atoms, # Convert AC to bond order matrix and add connectivity and charge info to # mol object - new_mol = AC2mol(mol, AC, atoms, charge, - allow_charged_fragments=allow_charged_fragments, - use_graph=use_graph) + new_mol = AC2mol(mol, AC, atoms, charge, allow_charged_fragments=allow_charged_fragments, use_graph=use_graph) # Check for stereocenters and chiral centers if embed_chiral: chiral_stereo_check(new_mol) - new_mol = check_mol(mol=new_mol, - coordinates=coordinates) + new_mol = check_mol(mol=new_mol, coordinates=coordinates) return new_mol def main(): - return if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser(usage='%(prog)s [options] molecule.xyz') - parser.add_argument('structure', metavar='structure', type=str) - parser.add_argument('-s', '--sdf', - action="store_true", - help="Dump sdf file") - parser.add_argument('--ignore-chiral', - action="store_true", - help="Ignore chiral centers") - parser.add_argument('--no-charged-fragments', - action="store_true", - help="Allow radicals to be made") - parser.add_argument('--no-graph', - action="store_true", - help="Run xyz2mol without networkx dependencies") + parser = argparse.ArgumentParser(usage="%(prog)s [options] molecule.xyz") + parser.add_argument("structure", metavar="structure", type=str) + parser.add_argument("-s", "--sdf", action="store_true", help="Dump sdf file") + parser.add_argument("--ignore-chiral", action="store_true", help="Ignore chiral centers") + parser.add_argument("--no-charged-fragments", action="store_true", help="Allow radicals to be made") + parser.add_argument("--no-graph", action="store_true", help="Run xyz2mol without networkx dependencies") # huckel uses extended Huckel bond orders to locate bonds (requires RDKit 2019.9.1 or later) # otherwise van der Waals radii are used - parser.add_argument('--use-huckel', - action="store_true", - help="Use Huckel method for atom connectivity") - parser.add_argument('-o', '--output-format', - action="store", - type=str, - help="Output format [smiles,sdf] (default=sdf)") - parser.add_argument('-c', '--charge', - action="store", - metavar="int", - type=int, - help="Total charge of the system") - parser.add_argument('--save_name', - type=str, - default=DEFAULT_SAVE, - help='Save name for RDKit mol') + parser.add_argument("--use-huckel", action="store_true", help="Use Huckel method for atom connectivity") + parser.add_argument( + "-o", "--output-format", action="store", type=str, help="Output format [smiles,sdf] (default=sdf)" + ) + parser.add_argument("-c", "--charge", action="store", metavar="int", type=int, help="Total charge of the system") + parser.add_argument("--save_name", type=str, default=DEFAULT_SAVE, help="Save name for RDKit mol") args = parser.parse_args() @@ -878,12 +866,15 @@ def main(): charge = int(args.charge) # Get the molobj - mol = xyz2mol(atoms, xyz_coordinates, - charge=charge, - use_graph=quick, - allow_charged_fragments=charged_fragments, - embed_chiral=embed_chiral, - use_huckel=use_huckel) + mol = xyz2mol( + atoms, + xyz_coordinates, + charge=charge, + use_graph=quick, + allow_charged_fragments=charged_fragments, + embed_chiral=embed_chiral, + use_huckel=use_huckel, + ) # Print output if args.output_format == "sdf": @@ -899,5 +890,5 @@ def main(): print(smiles) save_name = args.save_name - with open(save_name, 'wb') as f: + with open(save_name, "wb") as f: pickle.dump(mol, f) diff --git a/pyproject.toml b/pyproject.toml index cf84ccc1..122b4960 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,11 +12,12 @@ readme = "README.md" license = { text = "MIT" } dependencies = [ "ase==3.23.0", + "numpy >=1.26.4, <2", "pymatgen>=2023.3.10", "rdkit", "scikit-learn", "scipy", - "torch>=2.2.0", + "torch >= 2.2.0, < 2.6.0", "tqdm", "mace-torch>=0.3.4", "chgnet>=0.3.5", @@ -53,6 +54,20 @@ Homepage = "https://github.mit.edu/MLMat/NeuralForceField/" [tool.setuptools] packages.find = { where = ["."], include = ["nff*"] } +[tool.flake8] +max_line_length = 120 +# F401 and F403: start imports, they are bad but everywhere +per-file-ignores = [ + '__init__.py:F401, F403', +] +# E741 ambiguous variable name 'l' +# F405 rely on objects from star imports, maybe worth removing in the future +# E203 no whitespace before ':', ruff prefers this for list comprehensions +extend-ignore = ['E741', 'F405', 'E203'] +exclude = [ + 'nff/nn/models/spooky_net_source/modules/electron_configurations.py' +] + [tool.ruff] include = ["**/pyproject.toml", "*.ipynb", "*.py", "*.pyi"] exclude = ["__init__.py"] @@ -124,5 +139,17 @@ ignore = [ "S310", # url open functions can be unsafe "TRY003", # long exception messages not defined in the exception class itself "UP015", # unnecessary "r" in open call + "Q000", # single quotes + "D103", # missing docstring in public function + "D", # do not fight with doc for now, todo in the future + "UP035", # deprecated Dict and List, but keep to support python<3.10 + "UP006", # deprecated Dict and List, but keep to support python<3.10 + "ICN001", # import name convention + "UP031", # Use format specifiers instead of percent format + "PD901", # dont name dataframes df + "FA100", # from future suggestion + "RUF012", # type annotation for mutable class attributes + "RUF002", # ambiguous characters in docstrings + "RUF003", # ambiguous characters in comments ] pydocstyle.convention = "google" diff --git a/scripts/train_nff.py b/scripts/train_nff.py index 149bf943..16177e6d 100644 --- a/scripts/train_nff.py +++ b/scripts/train_nff.py @@ -255,7 +255,7 @@ def main( logger.info("Fine-tuning model") model = load_model(model_path, model_type=model_type, map_location=device, device=device) if "NffScaleMACE" in model_type and trim_embeddings: - atomic_numbers = to_tensor(train.props["nxyz"], stack=True)[:, 0].unique().to(int).tolist() + atomic_numbers = to_tensor(train.props["nxyz"], stack=True)[:, 0].unique().to(torch.int64).tolist() logger.info("Trimming embeddings with MACE model and atomic numbers %s", atomic_numbers) model = reduce_foundations(model, atomic_numbers, load_readout=True) model_freezer = get_layer_freezer(model_type)