Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
31e70e6
BLD: update ase version
xiaochendu Jul 18, 2024
06be0d4
BUG: change pretrained CHGNet paths
xiaochendu Jul 23, 2024
99b8fad
MAINT: update copy `AtomsBatch` to copy arrays and constraints
xiaochendu Sep 15, 2024
3681bcf
MAINT: `NeuralFF` update neighbors in `calculate`
xiaochendu Sep 15, 2024
45f9f8a
BUG: chgnet latest version fix import
xiaochendu Sep 18, 2024
73aa4ec
MAINT: update `AtomsBatch.from_atoms` with arrays and constraints copy
xiaochendu Sep 18, 2024
cadbc65
BUG & MAINT: fix `lr_decay` and set default model units
xiaochendu Sep 20, 2024
b020bb5
MAINT: add conversion functions for CHGNet structure data to NFF format
xiaochendu Nov 5, 2024
b5dfbfd
MAINT: allow for no validation set and fix `reference_mean` and `refe…
xiaochendu Nov 10, 2024
dc50ffd
BUG: handle NVMLError when retrieving the final CUDA device
xiaochendu Nov 17, 2024
714f67a
MAINT: convert tensors to numpy arrays in parity plot functions for c…
xiaochendu Nov 27, 2024
c534593
FEAT: add plot type argument to evaluate_nff script for customizable …
xiaochendu Nov 27, 2024
28b5e0c
FEAT: add support for custom layer unfreezing and configurable convol…
xiaochendu Dec 25, 2024
4022cec
FEAT: add option to unfreeze embeddings in CHGNet model training
xiaochendu Jan 7, 2025
4b51145
ENH: add option to unfreeze node embedding layer and improve interact…
xiaochendu Jan 7, 2025
e34fde3
MAINT: update parity plot function to save figures as PDF and improve…
xiaochendu Jan 14, 2025
5deb72c
ENH: add options to unfreeze radial embedding and interaction layers …
xiaochendu Jan 14, 2025
f935839
ENH: add options for per atom energy calculation and configurable bat…
xiaochendu Jan 15, 2025
7d018fc
MAINT: formatting
xiaochendu Jan 17, 2025
6013d7d
ENH: improve plotting functions and update Matplotlib settings for be…
xiaochendu Feb 21, 2025
2573e68
MAINT: remove commented-out method for unfreezing last atom conv laye…
xiaochendu Feb 21, 2025
05cc9fe
ENH: update ase dependency to version 3.23.0
xiaochendu Apr 13, 2025
74501cf
Merge branch 'surface-sampling-0.2.0' into vssr_pourbaix
xiaochendu Apr 13, 2025
df42616
Merge branch 'master' into vssr_pourbaix
xiaochendu Apr 23, 2025
5fe46a2
Apply suggestions from code review
xiaochendu Apr 23, 2025
31fbcfc
MAINT & STY: update based on PR #34 comments and style fixes
xiaochendu Apr 23, 2025
95937c9
MAINT: update shuffle default in `convert_chgnet_structure_data_to_nf…
xiaochendu May 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 0 additions & 74 deletions models/foundation_models/chgnet/0.2.0/README.md

This file was deleted.

Binary file not shown.
80 changes: 0 additions & 80 deletions models/foundation_models/chgnet/0.3.0/README.md

This file was deleted.

Binary file not shown.
22 changes: 14 additions & 8 deletions nff/analysis/loss_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,24 @@
from . import mpl_settings


def plot_loss(energy_history, forces_history, figname, train_key="train", val_key="val"):
def plot_loss(
energy_history: dict,
forces_history: dict,
figname: str,
train_key: str = "train",
val_key: str = "val",
) -> None:
"""Plot the loss history of the model.
Args:
energy_history (dict): energy loss history of the model for training and validation
forces_history (dict): forces loss history of the model for training and validation
figname (str): name of the figure

Returns:
None
Args:
energy_history: energy loss history of the model for training and validation
forces_history: forces loss history of the model for training and validation
figname: name of the figure
train_key: key for training data in the history dictionary
val_key: key for validation data in the history dictionary
"""
epochs = np.arange(1, len(energy_history[train_key]) + 1)
fig, ax_fig = plt.subplots(1, 2, figsize=(12, 6), dpi=mpl_settings.DPI)
fig, ax_fig = plt.subplots(1, 2, figsize=(5, 2.5), dpi=mpl_settings.DPI)
ax_fig[0].semilogy(epochs, energy_history[train_key], label="train", color=mpl_settings.colors[1])
ax_fig[0].semilogy(epochs, energy_history[val_key], label="val", color=mpl_settings.colors[2])
ax_fig[0].legend()
Expand Down
120 changes: 64 additions & 56 deletions nff/analysis/mpl_settings.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,44 @@
from __future__ import annotations

import json
from pathlib import Path
from typing import List, Optional
from typing import List

import matplotlib
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

plt.style.use("default")

DPI = 100
LINEWIDTH = 2
FONTSIZE = 20
LABELSIZE = 18
dir_name = Path(__file__).parent

DPI = 300
LINEWIDTH = 1.25
FONTSIZE = 8
LABELSIZE = 8
ALPHA = 0.8
LINE_MARKERSIZE = 15 * 25
MARKERSIZE = 15
GRIDSIZE = 40
MAJOR_TICKLEN = 6
MINOR_TICKLEN = 3
TICKPADDING = 5
MARKERSIZE = 25
GRIDSIZE = 20
MAJOR_TICKLEN = 4
MINOR_TICKLEN = 2
TICKPADDING = 3
SECONDARY_CMAP = "inferno"

params = {
custom_settings = {
"mathtext.default": "regular",
"font.family": "Arial",
"font.size": FONTSIZE,
"axes.labelsize": LABELSIZE,
"axes.titlesize": FONTSIZE,
"axes.linewidth": LINEWIDTH,
"grid.linewidth": LINEWIDTH,
"lines.linewidth": LINEWIDTH,
"lines.color": "black",
"axes.labelcolor": "black",
"axes.edgecolor": "black",
"axes.titlecolor": "black",
"axes.titleweight": "bold",
"axes.grid": False,
"lines.markersize": MARKERSIZE,
"xtick.major.size": MAJOR_TICKLEN,
"xtick.minor.size": MINOR_TICKLEN,
Expand All @@ -38,66 +48,67 @@
"ytick.minor.size": MINOR_TICKLEN,
"ytick.major.pad": TICKPADDING,
"ytick.minor.pad": TICKPADDING,
"axes.linewidth": LINEWIDTH,
"legend.fontsize": LABELSIZE,
"figure.dpi": DPI,
"savefig.dpi": DPI,
"ytick.major.width": LINEWIDTH,
"xtick.major.width": LINEWIDTH,
"ytick.minor.width": LINEWIDTH,
"xtick.minor.width": LINEWIDTH,
"legend.fontsize": LABELSIZE,
"figure.dpi": DPI,
"savefig.dpi": DPI,
"savefig.format": "png",
"savefig.bbox": "tight",
"savefig.pad_inches": 0.1,
"figure.facecolor": "white",
}
plt.rcParams.update(params)
plt.rcParams.update(custom_settings)


def update_custom_settings(custom_settings: dict | None = custom_settings) -> None:
"""Update the custom settings for Matplotlib.

def hex_to_rgb(value: str) -> tuple:
Args:
custom_settings: Custom settings for Matplotlib. Defaults to
custom_settings.
"""
Converts hex to rgb colours
current_settings = plt.rcParams.copy()
new_settings = current_settings | custom_settings
plt.rcParams.update(new_settings)

Parameters
----------
value: string of 6 characters representing a hex colour

Returns
----------
tuple of 3 integers representing the RGB values
"""
def hex_to_rgb(value: str) -> list[float]:
"""Converts hex to rgb colors.

Args:
value: string of 6 characters representing a hex color.
"""
value = value.strip("#") # removes hash symbol if present
lv = len(value)
return tuple(int(value[i : i + lv // 3], 16) for i in range(0, lv, lv // 3))


def rgb_to_dec(value: list):
"""
Converts rgb to decimal colours (i.e. divides each value by 256)
def rgb_to_dec(value: list[float]) -> list[float]:
"""Converts rgb to decimal colors (i.e. divides each value by 256).

Parameters
----------
value: list of 3 integers representing the RGB values
Args:
value: string of 6 characters representing a hex color.

Returns
----------
list of 3 floats representing the RGB values
Returns:
list: length 3 of RGB values
"""

return [v / 256 for v in value]


def get_continuous_cmap(hex_list: List[str], float_list: Optional[List[float]] = None) -> matplotlib.colors.Colormap:
"""
Creates and returns a color map that can be used in heat map figures.
If float_list is not provided, colour map graduates linearly between each color in hex_list.
If float_list is provided, each color in hex_list is mapped to the respective location in float_list.

Parameters
----------
hex_list: list of hex code strings
float_list: list of floats between 0 and 1, same length as hex_list. Must start with 0 and end with 1.

Returns
----------
Colormap
def get_continuous_cmap(
hex_list: list[str], float_list: list[float] | None = None
) -> mpl.colors.LinearSegmentedColormap:
"""Creates a color map that can be used in heat map figures. If float_list is not provided,
color map graduates linearly between each color in hex_list. If float_list is provided,
each color in hex_list is mapped to the respective location in float_list.

Args:
hex_list: list of hex code strings
float_list: list of floats between 0 and 1, same length as hex_list.
Must start with 0 and end with 1.
"""
rgb_list = [rgb_to_dec(hex_to_rgb(i)) for i in hex_list]
if float_list:
Expand All @@ -109,15 +120,12 @@ def get_continuous_cmap(hex_list: List[str], float_list: Optional[List[float]] =
for num, col in enumerate(["red", "green", "blue"]):
col_list = [[float_list[i], rgb_list[i][num], rgb_list[i][num]] for i in range(len(float_list))]
cdict[col] = col_list
cmp = matplotlib.colors.LinearSegmentedColormap("j_cmap", segmentdata=cdict, N=256)
return cmp
return mpl.colors.LinearSegmentedColormap("j_cmap", segmentdata=cdict, N=256)


# colors taken from Johannes Dietschreit's script and interpolated with correct lightness and Bezier
# Colors taken from Johannes Dietschreit's script and interpolated with correct lightness and Bezier
# http://www.vis4.net/palettes/#/100|s|fce1a4,fabf7b,f08f6e,d12959,6e005f|ffffe0,ff005e,93003a|1|1
hex_list: List[str]
dir_name = Path(__file__).parent

with open(dir_name / "config/mpl_settings.json", "r") as f:
hex_list = json.load(f)["plot_colors"]

Expand Down
Loading