diff --git a/models/foundation_models/chgnet/0.2.0/README.md b/models/foundation_models/chgnet/0.2.0/README.md deleted file mode 100755 index 6dcf0356..00000000 --- a/models/foundation_models/chgnet/0.2.0/README.md +++ /dev/null @@ -1,74 +0,0 @@ -## Model 0.2.0 - -This is the pretrained weights published with CHGNet Nature Machine Intelligence paper. -All the experiments and results shown in the paper were performed with this version of weights. - -Date: 2/24/2023 - -Author: Bowen Deng - -## Model Parameters - -```python -model = CHGNet( - atom_fea_dim=64, - bond_fea_dim=64, - angle_fea_dim=64, - composition_model="MPtrj", - num_radial=9, - num_angular=9, - n_conv=4, - atom_conv_hidden_dim=64, - update_bond=True, - bond_conv_hidden_dim=64, - update_angle=True, - angle_layer_hidden_dim=0, - conv_dropout=0, - read_out="ave", - mlp_hidden_dims=[64, 64], - mlp_first=True, - is_intensive=True, - non_linearity="silu", - atom_graph_cutoff=5, - bond_graph_cutoff=3, - graph_converter_algorithm="fast", - cutoff_coeff=5, - learnable_rbf=True, - mlp_out_bias=True, -) -``` - -## Dataset Used - -MPtrj dataset with 8-1-1 train-val-test splitting - -## Trainer - -```python -trainer = Trainer( - model=model, - targets='efsm', - energy_loss_ratio=1, - force_loss_ratio=1, - stress_loss_ratio=0.1, - mag_loss_ratio=0.1, - optimizer='Adam', - weight_decay=0, - scheduler='CosLR', - criterion='Huber', - delta=0.1, - epochs=20, - starting_epoch=0, - learning_rate=1e-3, - use_device='cuda', - print_freq=1000 -) -``` - -## Mean Absolute Error (MAE) logs - -| partition | Energy (meV/atom) | Force (meV/A) | stress (GPa) | magmom (muB) | -| ---------- | ----------------- | ------------- | ------------ | ------------ | -| Train | 22 | 59 | 0.246 | 0.030 | -| Validation | 30 | 75 | 0.350 | 0.033 | -| Test | 30 | 77 | 0.348 | 0.032 | diff --git a/models/foundation_models/chgnet/0.2.0/chgnet_0.2.0_e30f77s348m32.pth.tar b/models/foundation_models/chgnet/0.2.0/chgnet_0.2.0_e30f77s348m32.pth.tar deleted file mode 100644 index f82e2f20..00000000 Binary files a/models/foundation_models/chgnet/0.2.0/chgnet_0.2.0_e30f77s348m32.pth.tar and /dev/null differ diff --git a/models/foundation_models/chgnet/0.3.0/README.md b/models/foundation_models/chgnet/0.3.0/README.md deleted file mode 100755 index 50223bbc..00000000 --- a/models/foundation_models/chgnet/0.3.0/README.md +++ /dev/null @@ -1,80 +0,0 @@ -## Model 0.3.0 - -Major changes: - -1. Increased AtomGraph cutoff to 6A -2. Resolved discontinuity issue when no BondGraph presents -3. Added some normalization layers -4. Slight improvements on energy, force, stress accuracies - -Date: 10/22/2023 - -Author: Bowen Deng - -## Model Parameters - -```python -model = CHGNet( - atom_fea_dim=64, - bond_fea_dim=64, - angle_fea_dim=64, - composition_model="MPtrj", - num_radial=31, - num_angular=31, - n_conv=4, - atom_conv_hidden_dim=64, - update_bond=True, - bond_conv_hidden_dim=64, - update_angle=True, - angle_layer_hidden_dim=0, - conv_dropout=0, - read_out="ave", - gMLP_norm='layer', - readout_norm='layer', - mlp_hidden_dims=[64, 64, 64], - mlp_first=True, - is_intensive=True, - non_linearity="silu", - atom_graph_cutoff=6, - bond_graph_cutoff=3, - graph_converter_algorithm="fast", - cutoff_coeff=8, - learnable_rbf=True, -) -``` - -## Dataset Used - -MPtrj dataset with 9-0.5-0.5 train-val-test splitting - -## Trainer - -```python -trainer = Trainer( - model=model, - targets='efsm', - energy_loss_ratio=1, - force_loss_ratio=1, - stress_loss_ratio=0.1, - mag_loss_ratio=0.1, - optimizer='Adam', - weight_decay=0, - scheduler='CosLR', - scheduler_params={'decay_fraction': 0.5e-2}, - criterion='Huber', - delta=0.1, - epochs=30, - starting_epoch=0, - learning_rate=5e-3, - use_device='cuda', - print_freq=1000 -) -``` - -## Mean Absolute Error (MAE) logs - -| partition | Energy (meV/atom) | Force (meV/A) | stress (GPa) | magmom (muB) | -| ---------- | ----------------- | ------------- | ------------ | ------------ | -| Train | 26 | 60 | 0.266 | 0.037 | -| Validation | 29 | 70 | 0.308 | 0.037 | -| Test | 29 | 68 | 0.314 | 0.037 | diff --git a/models/foundation_models/chgnet/0.3.0/chgnet_0.3.0_e29f68s314m37.pth.tar b/models/foundation_models/chgnet/0.3.0/chgnet_0.3.0_e29f68s314m37.pth.tar deleted file mode 100644 index 20ce3228..00000000 Binary files a/models/foundation_models/chgnet/0.3.0/chgnet_0.3.0_e29f68s314m37.pth.tar and /dev/null differ diff --git a/nff/analysis/loss_plot.py b/nff/analysis/loss_plot.py index e46c82d4..6286e2cd 100644 --- a/nff/analysis/loss_plot.py +++ b/nff/analysis/loss_plot.py @@ -4,18 +4,24 @@ from . import mpl_settings -def plot_loss(energy_history, forces_history, figname, train_key="train", val_key="val"): +def plot_loss( + energy_history: dict, + forces_history: dict, + figname: str, + train_key: str = "train", + val_key: str = "val", +) -> None: """Plot the loss history of the model. - Args: - energy_history (dict): energy loss history of the model for training and validation - forces_history (dict): forces loss history of the model for training and validation - figname (str): name of the figure - Returns: - None + Args: + energy_history: energy loss history of the model for training and validation + forces_history: forces loss history of the model for training and validation + figname: name of the figure + train_key: key for training data in the history dictionary + val_key: key for validation data in the history dictionary """ epochs = np.arange(1, len(energy_history[train_key]) + 1) - fig, ax_fig = plt.subplots(1, 2, figsize=(12, 6), dpi=mpl_settings.DPI) + fig, ax_fig = plt.subplots(1, 2, figsize=(5, 2.5), dpi=mpl_settings.DPI) ax_fig[0].semilogy(epochs, energy_history[train_key], label="train", color=mpl_settings.colors[1]) ax_fig[0].semilogy(epochs, energy_history[val_key], label="val", color=mpl_settings.colors[2]) ax_fig[0].legend() diff --git a/nff/analysis/mpl_settings.py b/nff/analysis/mpl_settings.py index 2a238246..d3f23b80 100644 --- a/nff/analysis/mpl_settings.py +++ b/nff/analysis/mpl_settings.py @@ -1,34 +1,44 @@ +from __future__ import annotations + import json from pathlib import Path -from typing import List, Optional +from typing import List -import matplotlib +import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np plt.style.use("default") -DPI = 100 -LINEWIDTH = 2 -FONTSIZE = 20 -LABELSIZE = 18 +dir_name = Path(__file__).parent + +DPI = 300 +LINEWIDTH = 1.25 +FONTSIZE = 8 +LABELSIZE = 8 ALPHA = 0.8 -LINE_MARKERSIZE = 15 * 25 -MARKERSIZE = 15 -GRIDSIZE = 40 -MAJOR_TICKLEN = 6 -MINOR_TICKLEN = 3 -TICKPADDING = 5 +MARKERSIZE = 25 +GRIDSIZE = 20 +MAJOR_TICKLEN = 4 +MINOR_TICKLEN = 2 +TICKPADDING = 3 SECONDARY_CMAP = "inferno" -params = { +custom_settings = { "mathtext.default": "regular", "font.family": "Arial", "font.size": FONTSIZE, "axes.labelsize": LABELSIZE, "axes.titlesize": FONTSIZE, + "axes.linewidth": LINEWIDTH, "grid.linewidth": LINEWIDTH, "lines.linewidth": LINEWIDTH, + "lines.color": "black", + "axes.labelcolor": "black", + "axes.edgecolor": "black", + "axes.titlecolor": "black", + "axes.titleweight": "bold", + "axes.grid": False, "lines.markersize": MARKERSIZE, "xtick.major.size": MAJOR_TICKLEN, "xtick.minor.size": MINOR_TICKLEN, @@ -38,66 +48,67 @@ "ytick.minor.size": MINOR_TICKLEN, "ytick.major.pad": TICKPADDING, "ytick.minor.pad": TICKPADDING, - "axes.linewidth": LINEWIDTH, - "legend.fontsize": LABELSIZE, - "figure.dpi": DPI, - "savefig.dpi": DPI, "ytick.major.width": LINEWIDTH, "xtick.major.width": LINEWIDTH, "ytick.minor.width": LINEWIDTH, "xtick.minor.width": LINEWIDTH, + "legend.fontsize": LABELSIZE, + "figure.dpi": DPI, + "savefig.dpi": DPI, + "savefig.format": "png", + "savefig.bbox": "tight", + "savefig.pad_inches": 0.1, + "figure.facecolor": "white", } -plt.rcParams.update(params) +plt.rcParams.update(custom_settings) + +def update_custom_settings(custom_settings: dict | None = custom_settings) -> None: + """Update the custom settings for Matplotlib. -def hex_to_rgb(value: str) -> tuple: + Args: + custom_settings: Custom settings for Matplotlib. Defaults to + custom_settings. """ - Converts hex to rgb colours + current_settings = plt.rcParams.copy() + new_settings = current_settings | custom_settings + plt.rcParams.update(new_settings) - Parameters - ---------- - value: string of 6 characters representing a hex colour - Returns - ---------- - tuple of 3 integers representing the RGB values - """ +def hex_to_rgb(value: str) -> list[float]: + """Converts hex to rgb colors. + Args: + value: string of 6 characters representing a hex color. + """ value = value.strip("#") # removes hash symbol if present lv = len(value) return tuple(int(value[i : i + lv // 3], 16) for i in range(0, lv, lv // 3)) -def rgb_to_dec(value: list): - """ - Converts rgb to decimal colours (i.e. divides each value by 256) +def rgb_to_dec(value: list[float]) -> list[float]: + """Converts rgb to decimal colors (i.e. divides each value by 256). - Parameters - ---------- - value: list of 3 integers representing the RGB values + Args: + value: string of 6 characters representing a hex color. - Returns - ---------- - list of 3 floats representing the RGB values + Returns: + list: length 3 of RGB values """ - return [v / 256 for v in value] -def get_continuous_cmap(hex_list: List[str], float_list: Optional[List[float]] = None) -> matplotlib.colors.Colormap: - """ - Creates and returns a color map that can be used in heat map figures. - If float_list is not provided, colour map graduates linearly between each color in hex_list. - If float_list is provided, each color in hex_list is mapped to the respective location in float_list. - - Parameters - ---------- - hex_list: list of hex code strings - float_list: list of floats between 0 and 1, same length as hex_list. Must start with 0 and end with 1. - - Returns - ---------- - Colormap +def get_continuous_cmap( + hex_list: list[str], float_list: list[float] | None = None +) -> mpl.colors.LinearSegmentedColormap: + """Creates a color map that can be used in heat map figures. If float_list is not provided, + color map graduates linearly between each color in hex_list. If float_list is provided, + each color in hex_list is mapped to the respective location in float_list. + + Args: + hex_list: list of hex code strings + float_list: list of floats between 0 and 1, same length as hex_list. + Must start with 0 and end with 1. """ rgb_list = [rgb_to_dec(hex_to_rgb(i)) for i in hex_list] if float_list: @@ -109,15 +120,12 @@ def get_continuous_cmap(hex_list: List[str], float_list: Optional[List[float]] = for num, col in enumerate(["red", "green", "blue"]): col_list = [[float_list[i], rgb_list[i][num], rgb_list[i][num]] for i in range(len(float_list))] cdict[col] = col_list - cmp = matplotlib.colors.LinearSegmentedColormap("j_cmap", segmentdata=cdict, N=256) - return cmp + return mpl.colors.LinearSegmentedColormap("j_cmap", segmentdata=cdict, N=256) -# colors taken from Johannes Dietschreit's script and interpolated with correct lightness and Bezier +# Colors taken from Johannes Dietschreit's script and interpolated with correct lightness and Bezier # http://www.vis4.net/palettes/#/100|s|fce1a4,fabf7b,f08f6e,d12959,6e005f|ffffe0,ff005e,93003a|1|1 hex_list: List[str] -dir_name = Path(__file__).parent - with open(dir_name / "config/mpl_settings.json", "r") as f: hex_list = json.load(f)["plot_colors"] diff --git a/nff/analysis/parity_plot.py b/nff/analysis/parity_plot.py index 8893015b..5a44d9d6 100644 --- a/nff/analysis/parity_plot.py +++ b/nff/analysis/parity_plot.py @@ -1,46 +1,51 @@ -from typing import Dict, Literal, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Literal import matplotlib.pyplot as plt import numpy as np import pandas as pd -import torch from matplotlib.lines import Line2D from scipy import stats from scipy.stats import gaussian_kde +if TYPE_CHECKING: + from torch import Tensor from nff.data import to_tensor from nff.utils import cuda from . import mpl_settings +plt.style.use("ggplot") +mpl_settings.update_custom_settings() + def plot_parity( - results: Dict[str, Union[list, torch.Tensor]], - targets: Dict[str, Union[list, torch.Tensor]], + results: Dict[str, list | Tensor], + targets: Dict[str, list | Tensor], figname: str, plot_type: Literal["hexbin", "scatter"] = "hexbin", energy_key: str = "energy", force_key: str = "energy_grad", units: Dict[str, str] = {"energy": "eV", "energy_grad": "eV/Ang"}, ) -> tuple[float, float]: - """ - Perform a parity plot between the results and the targets. + """Perform a parity plot between the results and the targets. Args: - results (dict): dictionary containing the results - targets (dict): dictionary containing the targets - figname (str): name of the figure - plot_type (str): type of plot to use, either "hexbin" or "scatter" - energy_key (str): key for the energy - force_key (str): key for the forces - units (dict): dictionary containing the units of the keys + results: dictionary containing the results + targets: dictionary containing the targets + figname: name of the figure + plot_type: type of plot to use, either "hexbin" or "scatter" + energy_key: key for the energy + force_key: key for the forces + units: dictionary containing the units of the keys Returns: float: MAE of the energy float: MAE of the forces """ - fig, ax_fig = plt.subplots(1, 2, figsize=(12, 6), dpi=mpl_settings.DPI) + fig, ax_fig = plt.subplots(1, 2, figsize=(5, 2.5), dpi=mpl_settings.DPI) mae_save = {force_key: 0, energy_key: 0} @@ -48,14 +53,14 @@ def plot_parity( targets = cuda.batch_detach(targets) for ax, key in zip(ax_fig, units.keys()): - pred = to_tensor(results[key], stack=True) - targ = to_tensor(targets[key], stack=True) + pred = to_tensor(results[key], stack=True).numpy() + targ = to_tensor(targets[key], stack=True).numpy() mae = abs(pred - targ).mean() mae_save[key] = mae - lim_min = min(torch.min(pred), torch.min(targ)) - lim_max = max(torch.max(pred), torch.max(targ)) + lim_min = min(np.min(pred), np.min(targ)) + lim_max = max(np.max(pred), np.max(targ)) if lim_min < 0: lim_min *= 1.1 @@ -66,7 +71,6 @@ def plot_parity( lim_max *= 0.9 else: lim_max *= 1.1 - if plot_type.lower() == "hexbin": hb = ax.hexbin( pred, @@ -77,10 +81,11 @@ def plot_parity( cmap=mpl_settings.cmap, edgecolor="None", extent=(lim_min, lim_max, lim_min, lim_max), + rasterized=True, ) else: - hb = ax.scatter(pred, targ, color="#ff7f0e", alpha=0.3) + hb = ax.scatter(pred, targ, color="#ff7f0e", alpha=0.3, rasterized=True) cb = fig.colorbar(hb, ax=ax) cb.set_label("Counts") @@ -103,7 +108,7 @@ def plot_parity( ) plt.tight_layout() - plt.savefig(f"{figname}.png") + plt.savefig(f"{figname}.pdf") plt.show() mae_energy = float(mae_save[energy_key]) mae_forces = float(mae_save[force_key]) @@ -111,8 +116,8 @@ def plot_parity( def plot_err_var( - err: Union[torch.Tensor, np.ndarray], - var: Union[torch.Tensor, np.ndarray], + err: Tensor | np.ndarray, + var: Tensor | np.ndarray, figname: str, units: str = "eV/Å", x_min: float = 0.0, @@ -126,20 +131,17 @@ def plot_err_var( """Plot the error vs variance of the forces. Args: - err (torch.Tensor): error of the forces - var (torch.Tensor): variance of the forces - figname (str): name of the figure - units (str): units of the error and variance - x_min (float): minimum value of the x-axis - x_max (float): maximum value of the x-axis - y_min (float): minimum value of the y-axis - y_max (float): maximum value of the y-axis - sample_frac (float): fraction of the data to sample for the plot - num_bins (int): number of bins to use for binning - cb_format (str): format of the colorbar - - Returns: - None + err: error of the forces + var: variance of the forces + figname: name of the figure + units: units of the error and variance + x_min: minimum value of the x-axis + x_max: maximum value of the x-axis + y_min: minimum value of the y-axis + y_max: maximum value of the y-axis + sample_frac: fraction of the data to sample for the plot + num_bins: number of bins to use for binning + cb_format: format of the colorbar """ fig, ax = plt.subplots(1, 1, figsize=(6, 6), dpi=mpl_settings.DPI) diff --git a/nff/data/dataset.py b/nff/data/dataset.py index a5f3ecdc..e16eadc4 100644 --- a/nff/data/dataset.py +++ b/nff/data/dataset.py @@ -1049,7 +1049,11 @@ def split_train_validation_test( Returns: tuple[Dataset, Dataset, Dataset]: train, validation and test datasets """ - train, validation = split_train_test(dataset, test_size=val_size, seed=seed, **kwargs) - train, test = split_train_test(train, test_size=test_size / (1 - val_size), seed=seed, **kwargs) + if np.isclose(val_size, 0.0): # for no validation set + train, test = split_train_test(dataset, test_size=test_size, seed=seed, **kwargs) + validation = None + else: + train, validation = split_train_test(dataset, test_size=val_size, seed=seed, **kwargs) + train, test = split_train_test(train, test_size=test_size / (1 - val_size), seed=seed, **kwargs) return train, validation, test diff --git a/nff/data/stats.py b/nff/data/stats.py index 03a1627a..31c3685a 100644 --- a/nff/data/stats.py +++ b/nff/data/stats.py @@ -106,11 +106,9 @@ def remove_dataset_outliers( reference_std=reference_std, max_value=max_value, ) - new_props = {key: [val[i] for i in idx] for key, val in dset.props.items()} logging.info("reference_mean: %s", mean) logging.info("reference_std: %s", std) - return Dataset(new_props, units=dset.units), mean, std @@ -147,34 +145,25 @@ def center_dataset( def get_atom_count(formula: str) -> Dict[str, int]: """Count the number of each atom type in the formula. - Parameters - ---------- - formula - The formula parameter is a string representing a chemical formula. - - Returns - ------- - a dictionary containing the count of each atom in the given chemical formula. + Args: + formula (str): A chemical formula. + Returns: + Dict[str, int]: A dictionary containing the count of each atom in the given + chemical formula. """ - - # return dictionary formula = Formula(formula) return formula.count() def all_atoms(unique_formulas: List[str]) -> set: - """Return set of all atoms in the list of formulas. + """Return a set of all atoms present in the list of formulas. - Parameters - ---------- - unique_formulas - list of strings representing the chemical formulas for which you want to count the - occurrences of each atom. + Args: + unique_formulas (List[str]): A list of chemical formula strings. - Returns - ------- - a set containing all the atoms in the list of formulas. + Returns: + set: A set containing all unique atom types found in the provided formulas. """ atom_set = set() for formula in unique_formulas: @@ -185,47 +174,32 @@ def all_atoms(unique_formulas: List[str]) -> set: def reg_atom_count(formula: str, atoms: List[str]) -> np.ndarray: - """Count the number of each specified atom type in the formula. - - Parameters - ---------- - formula - A string that represents a chemical formula. It can contain elements and - their corresponding subscripts. For example, "H2O" represents water, where "H" is the element - hydrogen and "O" is the element oxygen. The subscript "2" indicates that there are two - atoms - list of strings representing the atoms for which you want to count the - occurrences in the `formula`. - - Returns - ------- - an array containing the count of each atom in the given formula. + """Count the occurrence of specified atom types in the formula. + + Args: + formula (str): A chemical formula string. + atoms (List[str]): A list of atom types to count in the formula. + + Returns: + np.ndarray: An array containing the count of each specified atom in the formula. """ dictio = get_atom_count(formula) count_array = np.array([dictio.get(atom, 0) for atom in atoms]) - return count_array def get_stoich_dict(dset: Dataset, formula_key: str = "formula", energy_key: str = "energy") -> Dict[str, float]: - """Linear regression to find the per atom energy for each element in the dataset. - - Parameters - ---------- - dset - Dataset object containing properties for each data point. It is assumed to have a property - for the chemical formula of each data point and a property for the energy value of each data point. - formula_key, optional - key for chemical formula in the dset properties dictionary. - energy_key, optional - key for energy in the dset properties dictionary. - - Returns - ------- - a dictionary containing the stoichiometric energy coefficients for each element in the dataset. + """Determine per-atom energy coefficients via linear regression. + + Args: + dset (Dataset): Dataset object containing chemical formulas and energy values. + formula_key (str, optional): Key for chemical formulas in the dataset properties. + energy_key (str, optional): Key for energy values in the dataset properties. + Returns: + Dict[str, float]: A dictionary with stoichiometric energy coefficients for + each element, including an 'offset' representing the intercept. """ - # calculates the linear regresion and return the stoich dictionary formulas = dset.props[formula_key] energies = dset.props[energy_key] logging.debug("formulas: %s", formulas) @@ -233,7 +207,7 @@ def get_stoich_dict(dset: Dataset, formula_key: str = "formula", energy_key: str unique_formulas = list(set(formulas)) logging.debug("unique formulas: %s", unique_formulas) - # find the ground state energy for each formula/stoichiometry + # Find the ground state energy for each formula/stoichiometry. ground_en = [ min([energies[i] for i in range(len(formulas)) if formulas[i] == formula]) for formula in unique_formulas ] @@ -243,7 +217,6 @@ def get_stoich_dict(dset: Dataset, formula_key: str = "formula", energy_key: str logging.debug("unique atoms: %s", unique_atoms) x_in = np.stack([reg_atom_count(formula, unique_atoms) for formula in unique_formulas]) - y_out = np.array(ground_en) logging.debug("x_in: %s", x_in) @@ -253,14 +226,13 @@ def get_stoich_dict(dset: Dataset, formula_key: str = "formula", energy_key: str clf.fit(x_in, y_out) pred = (clf.coef_ * x_in).sum(-1) + clf.intercept_ - # pred = clf.predict(x_in) logging.info("coef: %s", clf.coef_) logging.info("intercept: %s", clf.intercept_) logging.debug("pred: %s", pred) err = abs(pred - y_out).mean() # in kcal/mol logging.info("MAE between target energy and stoich energy is %.3f kcal/mol", err) logging.info("R : %s", clf.score(x_in, y_out)) - fit_dic = {atom: coef for atom, coef in zip(unique_atoms, clf.coef_.reshape(-1))} # noqa + fit_dic = dict(zip(unique_atoms, clf.coef_.reshape(-1))) stoich_dict = {**fit_dic, "offset": clf.intercept_.item()} logging.info(stoich_dict) @@ -273,27 +245,18 @@ def perform_energy_offset( formula_key: str = "formula", energy_key: str = "energy", ) -> Dataset: - """Peform energy offset calculation on the dataset. Subtract the energy of the reference state for each atom - from the energy of each data point in the dataset. - - Parameters - ---------- - dset - Dataset object containing properties for each data point. It is assumed to have a property - for the chemical formula of each data point and a property for the energy value of each data point. - stoic_dict - a dictionary containing the stoichiometric energy coefficients for each element in the dataset. - formula_key, optional - key for chemical formula in the dset properties dictionary. - energy_key, optional - key for energy in the dset properties dictionary. - - Returns - ------- - a new dataset with the energy offset performed. + """Perform energy offset calculation by subtracting the reference energy per atom. + Args: + dset (Dataset): Dataset object containing chemical formulas and energy values. + stoic_dict (Dict[str, float]): Dictionary with stoichiometric energy coefficients for + each element. + formula_key (str, optional): Key for chemical formulas in the dataset properties. + energy_key (str, optional): Key for energy values in the dataset properties. + + Returns: + Dataset: A new dataset with the energy offset applied to each energy value. """ - # perform the energy offset formulas = dset.props[formula_key] energies = dset.props[energy_key] diff --git a/nff/io/ase.py b/nff/io/ase.py index 69f16a38..f38c84fc 100644 --- a/nff/io/ase.py +++ b/nff/io/ase.py @@ -1,5 +1,7 @@ """ASE wrapper for the Neural Force Field.""" +import copy + import numpy as np import torch from ase import Atoms, units @@ -427,11 +429,14 @@ def from_atoms(cls, atoms, **kwargs): An instance of the class initialized with the properties of the ASE Atoms object. """ props = kwargs.pop("props", {}) - return cls( + atoms_batch = cls( atoms, props=props, **kwargs, ) + atoms_batch.arrays = copy.deepcopy(atoms.arrays) + atoms_batch.constraints = copy.deepcopy(atoms.constraints) + return atoms_batch def copy(self) -> Self: """Copy the current object. @@ -439,7 +444,7 @@ def copy(self) -> Self: Returns: AtomsBatch: A copy of the current object. """ - return self.__class__.from_atoms( + atoms_batch = self.__class__.from_atoms( self, props=self.props, cutoff=self.cutoff, @@ -449,6 +454,9 @@ def copy(self) -> Self: dense_nbrs=self.mol_nbrs is not None and self.mol_idx is not None, device=self.device, ) + atoms_batch.arrays = copy.deepcopy(self.arrays) + atoms_batch.constraints = copy.deepcopy(self.constraints) + return atoms_batch def todict(self, update_props=True) -> dict: """Serialize the object to a dictionary. Calls the parent class todict method. diff --git a/nff/io/ase_calcs.py b/nff/io/ase_calcs.py index ddfdd6ec..b19a0918 100644 --- a/nff/io/ase_calcs.py +++ b/nff/io/ase_calcs.py @@ -229,7 +229,14 @@ class EnsembleNFF(Calculator): """Produces an ensemble of NFF calculators to predict the discrepancy between the properties""" - implemented_properties = ["energy", "forces", "stress", "energy_std", "forces_std", "stress_std"] + implemented_properties = [ + "energy", + "forces", + "stress", + "energy_std", + "forces_std", + "stress_std", + ] def __init__( self, diff --git a/nff/io/chgnet.py b/nff/io/chgnet.py index 4e3c19dd..c903aa16 100644 --- a/nff/io/chgnet.py +++ b/nff/io/chgnet.py @@ -1,9 +1,10 @@ """Convert NFF Dataset to CHGNet StructureData""" -from typing import Dict +from typing import Dict, List import torch from chgnet.data.dataset import StructureData +from pymatgen.core.structure import Structure from pymatgen.io.ase import AseAtomsAdaptor from nff.data import Dataset @@ -16,25 +17,16 @@ def convert_nff_to_chgnet_structure_data( cutoff: float = 5.0, shuffle: bool = True, ): - """The function `convert_nff_to_chgnet_structure_data` converts a dataset in NFF format to a dataset in - CHGNet structure data format. - - Parameters - ---------- - dataset : Dataset - The `dataset` parameter is an object of the `Dataset` class. - cutoff : float - The `cutoff` parameter is a float value that represents the distance cutoff for constructing the - neighbor list in the conversion process. It determines the maximum distance between atoms within - which they are considered neighbors. Any atoms beyond this distance will not be included in the - neighbor list. - shuffle : bool - The `shuffle` parameter is a boolean value that determines whether the dataset should be shuffled + """ + Converts a dataset in NFF format to a dataset in CHGNet structure data format. - Returns: - ------- - a `chgnet_dataset` object of type `StructureData`. + Args: + dataset (Dataset): An object of the Dataset class. + cutoff (float, optional): Distance cutoff for constructing the neighbor list. Defaults to 5.0. + shuffle (bool, optional): Whether the dataset should be shuffled. Defaults to True. + Returns: + StructureData: A CHGNet StructureData object. """ dataset = dataset.copy() dataset.to_units("eV/atom") # convert units to eV @@ -59,39 +51,135 @@ def convert_nff_to_chgnet_structure_data( ) +def convert_chgnet_structure_targets_to_nff( + structures: List[Structure], + targets: List[Dict], + stresses: bool = False, + magmoms: bool = False, +) -> Dataset: + """ + Converts a dataset in CHGNet structure JSON data format to a dataset in NFF format. + + Args: + structures (List[Structure]): List of pymatgen structures. + targets (List[Dict]): List of dictionaries containing the properties of each structure. + stresses (bool, optional): Whether the dataset should include stresses. Defaults to False. + magmoms (bool, optional): Whether the dataset should include magnetic moments. Defaults to False. + + Returns: + Dataset: An NFF Dataset. + """ + energies_per_atom = [] + energy_grad = [] + stresses_list = [] + magmoms_list = [] + for target in targets: + energies_per_atom.append(target["e"]) + energy_grad.append(-target["f"]) + if stresses: + stresses_list.append(target["s"]) + if magmoms: + magmoms_list.append(target["m"]) + + lattice = [] + num_atoms = [] # TODO: check if this is correct + nxyz = [] + units = ["eV/atom" for _ in range(len(structures))] + formula = [] + for structure in structures: + atoms = structure.to_ase_atoms() + lattice.append(atoms.cell.tolist()) + num_atoms.append(len(atoms)) + nxyz.append([torch.cat([torch.tensor([atom.number]), torch.tensor(atom.position)]).tolist() for atom in atoms]) + formula.append(atoms.get_chemical_formula()) + + concated_batch = { + "nxyz": nxyz, + "lattice": lattice, + "num_atoms": num_atoms, + "energy": energies_per_atom, + "energy_grad": energy_grad, + "formula": formula, + "units": units, + } + if stresses: + concated_batch["stress"] = stresses_list + if magmoms: + concated_batch["magmoms"] = magmoms_list + return Dataset(concated_batch, units=units[0]) + + +def convert_chgnet_structure_data_to_nff( + structure_data: StructureData, + cutoff: float = 6.0, + shuffle: bool = False, +) -> Dataset: + """ + Converts a dataset in CHGNet structure data format to a dataset in NFF format. + + Args: + structure_data (StructureData): A CHGNet StructureData object. + cutoff (float, optional): Distance cutoff for constructing the neighbor list. Defaults to 6.0. + shuffle (bool, optional): Whether the dataset should be shuffled. Defaults to False. + + Returns: + Dataset: An NFF Dataset. + """ + pymatgen_structures = structure_data.structures + energies_per_atom = structure_data.energies + energy_grad = ( + [-x for x in structure_data.forces] if isinstance(structure_data.forces, list) else -structure_data.forces + ) + stresses = structure_data.stresses + magmoms = structure_data.magmoms + lattice = [] + num_atoms = [structure.num_sites for structure in pymatgen_structures] # TODO: check if this is correct + nxyz = [] + units = ["eV/atom" for _ in range(len(pymatgen_structures))] + formula = [structure.composition.formula for structure in pymatgen_structures] + for structure in pymatgen_structures: + lattice.append(structure.lattice.matrix) + nxyz.append( + [torch.cat([torch.tensor([atom.species.number]), torch.tensor(atom.coords)]).tolist() for atom in structure] + ) + + concated_batch = { + "nxyz": nxyz, + "lattice": lattice, + "num_atoms": num_atoms, + "energy": energies_per_atom, + "energy_grad": energy_grad, + "stress": stresses, + "magmoms": magmoms, + "formula": formula, + "units": units, + } + return Dataset(concated_batch, units=units[0]) + + def convert_data_batch( data_batch: Dict, cutoff: float = 5.0, shuffle: bool = True, ): - """Converts a dataset in NFF format to a dataset in - CHGNet structure data format. - - Parameters - ---------- - data_batch : Dict - A dictionary of properties for each structure in the batch. - Basically the props in NFF Dataset - Example: - props = { - 'nxyz': [np.array([[1, 0, 0, 0], [1, 1.1, 0, 0]]), - np.array([[1, 3, 0, 0], [1, 1.1, 5, 0]])], - 'lattice': [np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], - np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])], - 'num_atoms': [2, 2], - } - cutoff : float - The `cutoff` parameter is a float value that represents the distance cutoff for constructing the - neighbor list in the conversion process. It determines the maximum distance between atoms within - which they are considered neighbors. Any atoms beyond this distance will not be included in the - neighbor list. - shuffle : bool - The `shuffle` parameter is a boolean value that determines whether the dataset should be shuffled + """ + Converts a dataset in NFF format to a dataset in CHGNet structure data format. + + Args: + data_batch (Dict): Dictionary of properties for each structure in the batch. + Example: + props = { + 'nxyz': [np.array([[1, 0, 0, 0], [1, 1.1, 0, 0]]), + np.array([[1, 3, 0, 0], [1, 1.1, 5, 0]])], + 'lattice': [np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])], + 'num_atoms': [2, 2], + } + cutoff (float, optional): Distance cutoff for neighbor list construction. Defaults to 5.0. + shuffle (bool, optional): Whether the dataset should be shuffled. Defaults to True. Returns: - ------- - a `chgnet_dataset` object of type `StructureData`. - + StructureData: A CHGNet StructureData object. """ detached_batch = batch_detach(data_batch) nxyz = detached_batch["nxyz"] @@ -108,7 +196,7 @@ def convert_data_batch( pymatgen_structures = [AseAtomsAdaptor.get_structure(atoms_batch) for atoms_batch in atoms_list] - energies = data_batch.get("energy") + energies = torch.atleast_1d(data_batch.get("energy")) if energies is not None and len(energies) > 0: energies_per_atom = energies else: @@ -143,4 +231,5 @@ def convert_data_batch( forces=forces, stresses=stresses, magmoms=magmoms, + shuffle=shuffle, ) diff --git a/nff/nn/models/chgnet.py b/nff/nn/models/chgnet.py index 7e2f784e..cdb955c3 100644 --- a/nff/nn/models/chgnet.py +++ b/nff/nn/models/chgnet.py @@ -5,8 +5,10 @@ from pathlib import Path from typing import TYPE_CHECKING, Dict, List +import chgnet import torch from chgnet.data.dataset import collate_graphs +from chgnet.graph import CrystalGraph try: from chgnet.graph.crystalgraph import datatype @@ -21,7 +23,6 @@ if TYPE_CHECKING: from collections.abc import Sequence - from chgnet.graph import CrystalGraph module_dir = os.path.dirname(os.path.abspath(__file__)) @@ -182,10 +183,12 @@ def load(cls, model_name: str = "0.3.0", **kwargs) -> CHGNetNFF: Returns: CHGNetNFF: CHGNetNFF foundational model. """ + chgnet_path = Path(chgnet.__file__).parent + try: checkpoint_path = { - "0.3.0": "../../../models/foundation_models/chgnet/0.3.0/chgnet_0.3.0_e29f68s314m37.pth.tar", - "0.2.0": "../../..models/foundation_models/chgnet/0.2.0/chgnet_0.2.0_e30f77s348m32.pth.tar", + "0.3.0": chgnet_path / "pretrained/0.3.0/chgnet_0.3.0_e29f68s314m37.pth.tar", + "0.2.0": chgnet_path / "pretrained/0.2.0/chgnet_0.2.0_e30f77s348m32.pth.tar", }[model_name] except KeyError as e: @@ -202,8 +205,7 @@ def load(cls, model_name: str = "0.3.0", **kwargs) -> CHGNetNFF: ) def to(self, device: str, **kwargs) -> CHGNetNFF: - """ - Move the model to the specified device. + """Move the model to the specified device. Args: device (str): Device to move the model to. @@ -348,7 +350,10 @@ def from_graphs( bond_bases_bg = torch.cat(bond_bases_bg, dim=0) angle_bases = torch.cat(angle_bases, dim=0) if len(angle_bases) != 0 else torch.tensor([]) batched_atom_graph = torch.cat(batched_atom_graph, dim=0) - batched_bond_graph = torch.cat(batched_bond_graph, dim=0) if batched_bond_graph != [] else torch.tensor([]) + if batched_bond_graph != []: + batched_bond_graph = torch.cat(batched_bond_graph, dim=0) + else: # when bond graph is empty or disabled + batched_bond_graph = torch.tensor([]) atom_owners = torch.cat(atom_owners, dim=0).type(torch.int32).to(atomic_numbers.device) directed2undirected = torch.cat(directed2undirected, dim=0) volumes = torch.tensor(volumes, dtype=datatype, device=atomic_numbers.device) diff --git a/nff/train/transfer.py b/nff/train/transfer.py index e85204a2..ea3db699 100644 --- a/nff/train/transfer.py +++ b/nff/train/transfer.py @@ -45,7 +45,9 @@ def custom_unfreeze(self, model: torch.nn.Module, custom_layers: List[str]) -> N from list(model.named_parameters()) """ for module in model.named_parameters(): + print(f"In custom unfreeze: {module[0]}") if module[0] in custom_layers: + print(f"Unfreezing {module[0]}") module[1].requires_grad = True def unfreeze_readout(self, model: torch.nn.Module) -> None: @@ -175,17 +177,46 @@ def model_tl( class MaceLayerFreezer(LayerFreezer): """Class to handle freezing layers in MACE models.""" - def unfreeze_mace_interaction_linears(self, model: torch.nn.Module) -> None: + def unfreeze_mace_node_embedding(self, model: torch.nn.Module) -> None: + """Unfreeze the node embedding layer in a MACE model. + + Args: + model (torch.nn.Module): model to be transfer learned + """ + self.unfreeze_parameters(model.node_embedding) + print("Unfreezing node embedding") + + def unfreeze_mace_radial_embedding(self, model: torch.nn.Module) -> None: + """Unfreeze the radial embedding layer in a MACE model. + + Args: + model (torch.nn.Module): model to be transfer learned + """ + self.unfreeze_parameters(model.radial_embedding) + print("Unfreezing radial embedding") + + def unfreeze_mace_interaction_linears(self, model: torch.nn.Module, num_layers: int = 1) -> None: """Unfreeze the linear readout layer from the interaction blocks in a MACE model. Args: model (torch.nn.Module): model to be transfer learned """ - interaction_linears = [f"interactions.{i}.linear.weight" for i in range(model.num_interactions.item())] - self.custom_unfreeze(model, interaction_linears) + for i in reversed(range(model.num_interactions.item() - num_layers, model.num_interactions.item())): + print(f"Unfreezing # {i} interaction linear layers from last") + self.custom_unfreeze(model, [f"interactions.{i}.linear.weight"]) + + def unfreeze_mace_interactions(self, model: torch.nn.Module, num_layers: int = 1) -> None: + """Unfreeze the interaction layers in a MACE model. + + Args: + model (torch.nn.Module): model to be transfer learned + """ + for i in reversed(range(model.num_interactions.item() - num_layers, model.num_interactions.item())): + print(f"Unfreezing # {i} interaction linear layers from last") + self.unfreeze_parameters(model.interactions[i]) - def unfreeze_mace_produce_linears(self, model: torch.nn.Module) -> None: + def unfreeze_mace_product_linears(self, model: torch.nn.Module) -> None: """Unfreeze the linear readout layer from the interaction blocks in a MACE model. @@ -203,6 +234,7 @@ def unfreeze_mace_pooling(self, model: torch.nn.Module) -> None: """ for module in model.products: self.unfreeze_parameters(module) + print("Unfreezing products") def unfreeze_mace_readout(self, model: torch.nn.Module, freeze_skip: bool = False): """Unfreeze the readout layers in a MACE model. @@ -218,6 +250,7 @@ def unfreeze_mace_readout(self, model: torch.nn.Module, freeze_skip: bool = Fals for i, block in enumerate(model.readouts): if unfreeze_skip or i == num_readouts - 1: self.unfreeze_parameters(block) + print(f"Unfreezing {block.__class__.__name__}") def model_tl( self, @@ -236,6 +269,11 @@ def model_tl( model (torch.nn.Module): MACE model freeze_gap_embedding (bool, optional): Unused for MACE, inherited from parent class for consistency with the diabatic models. + Defaults to False. + freeze_interactions (bool, optional): If true, keep all product layers frozen. + Defaults to True. + freeze_products (bool, optional): If true, keep product linear layers frozen. + Defaults to False. freeze_pooling (bool, optional): If true, keep all pooling layers frozen. Defaults to True. freeze_skip (bool, optional): If true, keep all but the last readout layer @@ -250,12 +288,20 @@ def model_tl( else: self.unfreeze_mace_readout(model, freeze_skip=freeze_skip) unfreeze_pool = not freeze_pooling + unfreeze_interactions = not freeze_interactions if unfreeze_pool: self.unfreeze_mace_pooling(model) - if not freeze_interactions: - self.unfreeze_mace_interaction_linears(model) + num_layers = kwargs.get("unfreeze_conv_layers", 0) + if num_layers > 0: + if unfreeze_interactions: + self.unfreeze_mace_interactions(model, num_layers=num_layers) + else: + self.unfreeze_mace_interaction_linears(model, num_layers=num_layers) if not freeze_products: - self.unfreeze_mace_produce_linears(model) + self.unfreeze_mace_product_linears(model) + if kwargs.get("unfreeze_embeddings", False): + self.unfreeze_mace_node_embedding(model) + self.unfreeze_mace_radial_embedding(model) class ChgnetLayerFreezer(LayerFreezer): @@ -267,14 +313,68 @@ class ChgnetLayerFreezer(LayerFreezer): (accessed 2024-03-09) """ - def unfreeze_chgnet_last_atom_conv_layer(self, model: torch.nn.Module) -> None: - """Unfreeze the pooling layers in a CHGNet model. + def unfreeze_chgnet_atom_embedding(self, model: torch.nn.Module) -> None: + """Unfreeze the atom embedding layer in a CHGNet model. + + Args: + model (torch.nn.Module): model to be transfer learned + """ + self.unfreeze_parameters(model.atom_embedding) + + def unfreeze_chgnet_bond_embedding(self, model: torch.nn.Module) -> None: + """Unfreeze the bond embedding and weights layers in a CHGNet model. + + Args: + model (torch.nn.Module): model to be transfer learned + """ + self.unfreeze_parameters(model.bond_embedding) + self.unfreeze_parameters(model.bond_weights_ag) + self.unfreeze_parameters(model.bond_weights_bg) + + def unfreeze_chgnet_angle_embedding(self, model: torch.nn.Module) -> None: + """Unfreeze the angle embedding and basis expansion layers in a CHGNet model. + + Args: + model (torch.nn.Module): model to be transfer learned + """ + self.unfreeze_parameters(model.angle_embedding) + self.unfreeze_parameters(model.angle_basis_expansion) + + def unfreeze_chgnet_atom_layers(self, model: torch.nn.Module, num_layers: int = 1) -> None: + """Unfreeze the atom layers in a CHGNet model starting from the + last layer. + + Args: + model (torch.nn.Module): model to be transfer learned + num_layers (int, optional): number of layers to unfreeze. Defaults to 1. + """ + for i, module in enumerate(reversed(model.atom_conv_layers[-num_layers:]), start=1): + print(f"Unfreezing # {i} {module.__class__.__name__} module from last") + self.unfreeze_parameters(module) + + def unfreeze_chgnet_bond_layers(self, model: torch.nn.Module, num_layers: int = 1) -> None: + """Unfreeze the bond layers in a CHGNet model starting from the + last layer. Args: model (torch.nn.Module): model to be transfer learned + num_layers (int, optional): number of layers to unfreeze. Defaults to 1. """ - module = model.atom_conv_layers[-1] - self.unfreeze_parameters(module) + for i, module in enumerate(reversed(model.bond_conv_layers[-num_layers:]), start=1): + print(f"Unfreezing # {i} {module.__class__.__name__} module from last") + self.unfreeze_parameters(module) + + def unfreeze_chgnet_angle_layers(self, model: torch.nn.Module, num_layers: int = 1) -> None: + """Unfreeze the angle layers in a CHGNet model starting from the + last layer. + + Args: + model (torch.nn.Module): model to be transfer learned + num_layers (int, optional): number of layers to unfreeze. Defaults to 1. + """ + for i, module in enumerate(reversed(model.angle_layers[-num_layers:]), start=1): + print(f"Unfreezing # {i} {module.__class__.__name__} module from last") + self.unfreeze_parameters(module) def unfreeze_chgnet_pooling(self, model: torch.nn.Module) -> None: """Unfreeze the "pooling" layers after the representation layers @@ -328,10 +428,23 @@ def model_tl( """ self.freeze_parameters(model) if custom_layers: + print("Custom layers provided. Unfreezing custom layers.") self.custom_unfreeze(model, custom_layers) else: self.unfreeze_chgnet_readout(model, freeze_skip=freeze_skip) unfreeze_pool = not freeze_pooling if unfreeze_pool: - self.unfreeze_chgnet_last_atom_conv_layer(model) self.unfreeze_chgnet_pooling(model) + if "unfreeze_conv_layers" in kwargs: + num_layers = kwargs.get("unfreeze_conv_layers", 1) + self.unfreeze_chgnet_atom_layers( + model, num_layers=num_layers + 1 + ) # additional layer for the last layer + self.unfreeze_chgnet_bond_layers(model, num_layers=num_layers) + self.unfreeze_chgnet_angle_layers(model, num_layers=num_layers) + else: + self.unfreeze_chgnet_atom_layers(model) + if kwargs.get("unfreeze_embeddings", False): + self.unfreeze_chgnet_atom_embedding(model) + self.unfreeze_chgnet_bond_embedding(model) + self.unfreeze_chgnet_angle_embedding(model) diff --git a/nff/utils/cuda.py b/nff/utils/cuda.py index d9a392b0..093aeec3 100644 --- a/nff/utils/cuda.py +++ b/nff/utils/cuda.py @@ -40,9 +40,7 @@ def detach(val: torch.Tensor, to_numpy: bool = False) -> torch.Tensor | np.ndarr return val.detach().cpu() if hasattr(val, "detach") else val -def batch_detach( - batch: Dict[str, List | torch.Tensor], to_numpy: bool = False -) -> Dict[str, List | torch.Tensor]: +def batch_detach(batch: Dict[str, List | torch.Tensor], to_numpy: bool = False) -> Dict[str, List | torch.Tensor]: """Detach batch of GPU tensors Args: @@ -114,5 +112,8 @@ def get_final_device(device: str) -> str: str: final device to use """ if "cuda" in device and torch.cuda.is_available(): - return f"cuda:{cuda_devices_sorted_by_free_mem()[-1]}" + try: + return f"cuda:{cuda_devices_sorted_by_free_mem()[-1]}" + except nvidia_smi.NVMLError: + return "cuda:0" return "cpu" diff --git a/pyproject.toml b/pyproject.toml index be3f3ed7..122b4960 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ requires-python = ">=3.6" readme = "README.md" license = { text = "MIT" } dependencies = [ - "ase==3.22.1", + "ase==3.23.0", "numpy >=1.26.4, <2", "pymatgen>=2023.3.10", "rdkit", diff --git a/scripts/evaluate_nff.py b/scripts/evaluate_nff.py index a8f7a9a1..ad4b15d0 100644 --- a/scripts/evaluate_nff.py +++ b/scripts/evaluate_nff.py @@ -40,6 +40,23 @@ def parse_args(): default="./", help="Folder to save output figures.", ) + parser.add_argument( + "--plot_type", + choices=["hexbin", "scatter"], + default="hexbin", + help="Type of plot to use", + ) + parser.add_argument( + "--per_atom_energy", + action="store_true", + help="Whether to calculate per atom energy", + ) + parser.add_argument( + "--batch_size", + type=int, + default=32, + help="Batch size", + ) parser.add_argument( "--device", choices=["cpu", "cuda"], @@ -47,8 +64,7 @@ def parse_args(): help="device to use for calculations", ) - args = parser.parse_args() - return args + return parser.parse_args() def main( @@ -56,6 +72,9 @@ def main( model_type: str, data_path: str, train_log_path: str, + plot_type: str = "hexbin", + per_atom_energy: bool = False, + batch_size: int = 32, device: str = "cpu", save_folder: str = "./", ): @@ -66,6 +85,9 @@ def main( model_type (str): name of the model data_path (str): path to the data train_log_path (str): path to the training log + plot_type (str, optional): type of plot to use. Defaults to "hexbin". + per_atom_energy (bool, optional): whether to calculate per atom energy. Defaults to False. + batch_size (int, optional): batch size. Defaults to 32. device (str, optional): device to use. Defaults to "cpu". save_folder (str, optional): folder to save the results. Defaults to "./". """ @@ -88,14 +110,13 @@ def main( model.to(device) test_data = Dataset.from_file(data_path) - if hasattr(model, "units"): - units = model.units - else: - units = "eV" + + units = model.units if hasattr(model, "units") else "eV" + test_data.to_units(units) print(f"Using dataset units: {test_data.units}") - test_loader = DataLoader(test_data, batch_size=4, collate_fn=collate_dicts, pin_memory=True) + test_loader = DataLoader(test_data, batch_size=batch_size, collate_fn=collate_dicts, pin_memory=True) loss_fn = loss.build_mse_loss(loss_coef={"energy": 0.05, "energy_grad": 1}) @@ -105,14 +126,25 @@ def main( # plot parity plot parity_plot_path = save_path / f"{start_time}_parity_plot" print(f"Saving parity plot to {parity_plot_path}") + + # convert units to per atom if needed + if per_atom_energy and "/atom" not in test_data.units: + units = f"{units}/atom" + results["energy"] = [x / y for x, y in zip(results["energy"], targets["num_atoms"])] + targets["energy"] = [x / y for x, y in zip(targets["energy"], targets["num_atoms"])] + + # Change energy_grad to force + results["force"] = results["energy_grad"] + targets["force"] = targets["energy_grad"] + mae_energy, mae_force = plot_parity( results, targets, parity_plot_path, - plot_type="reg", + plot_type=plot_type, energy_key="energy", - force_key="energy_grad", - units={"energy_grad": "eV/Å", "energy": units}, + force_key="force", + units={"force": "eV/Å", "energy": units}, ) # plot loss curves @@ -132,6 +164,9 @@ def main( model_type=args.model_type, data_path=args.data_path, train_log_path=args.train_log_path, + plot_type=args.plot_type, + per_atom_energy=args.per_atom_energy, + batch_size=args.batch_size, device=args.device, save_folder=args.save_folder, ) diff --git a/scripts/train_nff.py b/scripts/train_nff.py index de7f5d82..16177e6d 100644 --- a/scripts/train_nff.py +++ b/scripts/train_nff.py @@ -9,6 +9,7 @@ from torch.utils.data import DataLoader from nff.data import Dataset, collate_dicts +from nff.data.dataset import to_tensor from nff.io.mace import update_mace_init_params from nff.nn.models.mace import reduce_foundations from nff.train import Trainer, get_layer_freezer, get_model, hooks, load_model, loss, metrics @@ -42,6 +43,29 @@ def build_default_arg_parser() -> argparse.ArgumentParser: help="Whether to fine-tune the model", action="store_true", ) + parser.add_argument( + "--custom_layers", + nargs="+", + type=str, + default=[], + help="Which layers to unfreeze for fine-tuning", + ) + parser.add_argument( + "--freeze_pooling", + help="Whether to freeze pooling layers for fine-tuning", + action="store_true", + ) + parser.add_argument( + "--unfreeze_embeddings", + help="Whether to unfreeze embeddings for fine-tuning", + action="store_true", + ) + parser.add_argument( + "--unfreeze_conv_layers", help="Number of convolutional layers to unfreeze for fine-tuning", type=int, default=1 + ) + parser.add_argument( + "--unfreeze_interactions", help="Whether to unfreeze all MACE interactions for fine-tuning", action="store_true" + ) parser.add_argument( "--trim_embeddings", help="Whether to reduce the size of MACE foundational model by resizing the embedding layers", @@ -134,6 +158,11 @@ def main( train_file: Union[str, Path], val_file: Union[str, Path], fine_tune: bool = False, + custom_layers: Iterable[str] = [], + freeze_pooling: bool = False, + unfreeze_embeddings: bool = False, + unfreeze_conv_layers: int = 1, + unfreeze_interactions: bool = False, trim_embeddings: bool = False, targets: Iterable[str] = ["energy", "energy_grad"], loss_weights: Iterable[float] = [0.05, 1.0], @@ -149,6 +178,37 @@ def main( pin_memory: bool = True, seed: int = 1337, ): + """Train a neural network model. + + Args: + name (str): Model name + model_type (str): Type of model + model_params_path (Union[str, Path]): Path to model parameters + model_path (Union[str, Path]): Path to a trained model + train_dir (Union[str, Path]): Model training directory + train_file (Union[str, Path]): Training set pth.tar file + val_file (Union[str, Path]): Validation set pth.tar file + fine_tune (bool, optional): Whether to fine tune an existing model. Defaults to False. + custom_layers (Iterable[str], optional): Named modules to unfreeze for finetuning. Defaults to []. + freeze_pooling (bool, optional): Whether to freeze pooling layers for fine-tuning. Defaults to False. + unfreeze_embeddings (bool, optional): Whether to unfreeze embeddings for fine-tuning. Defaults to False. + unfreeze_conv_layers (int, optional): Number of convolutional layers to unfreeze for fine-tuning. Defaults to 1. + unfreeze_interactions (bool, optional): Whether to unfreeze all MACE interactions for fine-tuning. Defaults to False. + trim_embeddings (bool, optional): Whether to trim MACE embeddings. Defaults to False. + targets (Iterable[str], optional): Model output. Defaults to ["energy", "energy_grad"]. + loss_weights (Iterable[float], optional): Relative weights of output targets. Defaults to [0.05, 1.0]. + criterion (Literal["MSE", "MAE"], optional): Loss function criterion. Defaults to "MSE". + batch_size (int, optional): Batch size. Defaults to 16. + lr (float, optional): Learning rate. Defaults to 1e-3. + min_lr (float, optional): Minimum LR. Defaults to 1e-6. + max_num_epochs (int, optional): Max number training epochs. Defaults to 200. + patience (int, optional): LR patience. Defaults to 25. + lr_decay (float, optional): LR decay rate. Defaults to 0.5. + weight_decay (float, optional): Weight decay for optimizer. Defaults to 0.0. + num_workers (int, optional): Number of workers for data loader. Defaults to 1. + pin_memory (bool, optional): Whether to pin memory for data loader. Defaults to True. + seed (int, optional): Random seed. Defaults to 1337. + """ # Set seeds torch.manual_seed(seed) np.random.seed(seed) @@ -195,12 +255,28 @@ def main( logger.info("Fine-tuning model") model = load_model(model_path, model_type=model_type, map_location=device, device=device) if "NffScaleMACE" in model_type and trim_embeddings: - atomic_numbers = np.unique(train[0]["nxyz"][:, 0]).astype(int).tolist() + atomic_numbers = to_tensor(train.props["nxyz"], stack=True)[:, 0].unique().to(torch.int64).tolist() logger.info("Trimming embeddings with MACE model and atomic numbers %s", atomic_numbers) model = reduce_foundations(model, atomic_numbers, load_readout=True) model_freezer = get_layer_freezer(model_type) - model_freezer.model_tl(model) # TODO: add custom options for freezing layers - + if unfreeze_conv_layers > 0: + model_freezer.model_tl( + model, + custom_layers=custom_layers, + freeze_interactions=not unfreeze_interactions, # freeze MACE all interactions (all conv parameters, not + # just the linear layers) + freeze_pooling=freeze_pooling, + unfreeze_conv_layers=unfreeze_conv_layers, + unfreeze_embeddings=unfreeze_embeddings, + ) + else: + model_freezer.model_tl( + model, + custom_layers=custom_layers, + freeze_interactions=not unfreeze_interactions, + freeze_pooling=freeze_pooling, + unfreeze_embeddings=unfreeze_embeddings, + ) else: # Load model params and save a copy logger.info("Training model from scratch") @@ -257,13 +333,6 @@ def main( ] train_hooks = [ - hooks.WarmRestartHook( - T0=max_num_epochs, - Tmult=1, - min_lr=min_lr, - lr_factor=lr, - optimizer=optimizer, - ), hooks.MaxEpochHook(max_num_epochs), hooks.CSVHook( save_path, @@ -314,6 +383,11 @@ def main( train_file=args.train_file, val_file=args.val_file, fine_tune=args.fine_tune, + custom_layers=args.custom_layers, + freeze_pooling=args.freeze_pooling, + unfreeze_embeddings=args.unfreeze_embeddings, + unfreeze_conv_layers=args.unfreeze_conv_layers, + unfreeze_interactions=args.unfreeze_interactions, trim_embeddings=args.trim_embeddings, targets=args.targets, loss_weights=args.loss_weights,