From 710c42011d17f4bb56b9e44119f13a0d8193e54f Mon Sep 17 00:00:00 2001 From: Guillaume Broggi <25569517+GuillaumeBroggi@users.noreply.github.com> Date: Mon, 28 Apr 2025 09:27:37 +0200 Subject: [PATCH 1/6] Update gitignore --- .gitignore | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 79eb2cc2..85d17b1e 100644 --- a/.gitignore +++ b/.gitignore @@ -129,6 +129,7 @@ venv/ ENV/ env.bak/ venv.bak/ +.python-version # Spyder project settings .spyderproject @@ -161,5 +162,20 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ -# VSCode -.vscode/ \ No newline at end of file +# Visual Studio Code +.vscode/ + +# Local projects +projects/ + +# Benchmark +benchmarks/cfd/ + +# Visual Studio Code +.vscode/ + +# Local projects +projects/ + +# Benchmark +benchmarks/cfd/ \ No newline at end of file From 4caa77ce36d108c87b176d85fbb0e5d9b79d32f4 Mon Sep 17 00:00:00 2001 From: Guillaume Broggi <25569517+GuillaumeBroggi@users.noreply.github.com> Date: Tue, 29 Apr 2025 16:13:11 +0200 Subject: [PATCH 2/6] QoL: propagate `is_verbose` in data scaler fitting --- .../gnn_base_model/model/gnn_model.py | 30 +++++++++++-------- .../gnn_base_model/train/training.py | 2 +- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/graphorge/gnn_base_model/model/gnn_model.py b/src/graphorge/gnn_base_model/model/gnn_model.py index d0229224..b727d3b7 100644 --- a/src/graphorge/gnn_base_model/model/gnn_model.py +++ b/src/graphorge/gnn_base_model/model/gnn_model.py @@ -1571,13 +1571,15 @@ def fit_data_scalers(self, dataset, is_verbose=False): if self._n_node_in > 0: mean, std = graph_standard_partial_fit( dataset, features_type='node_features_in', - n_features=self._n_node_in * self._n_time_node) + n_features=self._n_node_in * self._n_time_node, + is_verbose=is_verbose) scaler_node_in.set_mean_and_std(mean, std) # Get scaling parameters and fit data scalers: node output features if self._n_node_out > 0: mean, std = graph_standard_partial_fit( dataset, features_type='node_features_out', - n_features=self._n_node_out*self._n_time_node) + n_features=self._n_node_out*self._n_time_node, + is_verbose=is_verbose) scaler_node_out.set_mean_and_std(mean, std) else: # No time series data @@ -1585,13 +1587,13 @@ def fit_data_scalers(self, dataset, is_verbose=False): if self._n_node_in > 0: mean, std = graph_standard_partial_fit( dataset, features_type='node_features_in', - n_features=self._n_node_in) + n_features=self._n_node_in, is_verbose=is_verbose) scaler_node_in.set_mean_and_std(mean, std) # Get scaling parameters and fit data scalers: node output features if self._n_node_out > 0: mean, std = graph_standard_partial_fit( dataset, features_type='node_features_out', - n_features=self._n_node_out) + n_features=self._n_node_out, is_verbose=is_verbose) scaler_node_out.set_mean_and_std(mean, std) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if self._n_time_edge > 0: @@ -1599,26 +1601,28 @@ def fit_data_scalers(self, dataset, is_verbose=False): if self._n_edge_in > 0: mean, std = graph_standard_partial_fit( dataset, features_type='edge_features_in', - n_features=self._n_edge_in*self._n_time_edge) + n_features=self._n_edge_in*self._n_time_edge, + is_verbose=is_verbose) scaler_edge_in.set_mean_and_std(mean, std) # Get scaling parameters and fit data scalers: edge output features if self._n_edge_out > 0: mean, std = graph_standard_partial_fit( dataset, features_type='edge_features_out', - n_features=self._n_edge_out*self._n_time_edge) + n_features=self._n_edge_out*self._n_time_edge, + is_verbose=is_verbose) scaler_edge_out.set_mean_and_std(mean, std) else: # Get scaling parameters and fit data scalers: edge input features if self._n_edge_in > 0: mean, std = graph_standard_partial_fit( dataset, features_type='edge_features_in', - n_features=self._n_edge_in) + n_features=self._n_edge_in, is_verbose=is_verbose) scaler_edge_in.set_mean_and_std(mean, std) # Get scaling parameters and fit data scalers: edge output features if self._n_edge_out > 0: mean, std = graph_standard_partial_fit( dataset, features_type='edge_features_out', - n_features=self._n_edge_out) + n_features=self._n_edge_out,is_verbose=is_verbose) scaler_edge_out.set_mean_and_std(mean, std) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if self._n_time_global > 0: @@ -1627,14 +1631,16 @@ def fit_data_scalers(self, dataset, is_verbose=False): if self._n_global_in > 0: mean, std = graph_standard_partial_fit( dataset, features_type='global_features_in', - n_features=self._n_global_in*self._n_time_global) + n_features=self._n_global_in*self._n_time_global, + is_verbose=is_verbose) scaler_global_in.set_mean_and_std(mean, std) # Get scaling parameters and fit data scalers: # global output features if self._n_global_out > 0: mean, std = graph_standard_partial_fit( dataset, features_type='global_features_out', - n_features=self._n_global_out*self._n_time_global) + n_features=self._n_global_out*self._n_time_global, + is_verbose=is_verbose) scaler_global_out.set_mean_and_std(mean, std) else: # Get scaling parameters and fit data scalers: @@ -1642,14 +1648,14 @@ def fit_data_scalers(self, dataset, is_verbose=False): if self._n_global_in > 0: mean, std = graph_standard_partial_fit( dataset, features_type='global_features_in', - n_features=self._n_global_in) + n_features=self._n_global_in, is_verbose=is_verbose) scaler_global_in.set_mean_and_std(mean, std) # Get scaling parameters and fit data scalers: # global output features if self._n_global_out > 0: mean, std = graph_standard_partial_fit( dataset, features_type='global_features_out', - n_features=self._n_global_out) + n_features=self._n_global_out, is_verbose=is_verbose) scaler_global_out.set_mean_and_std(mean, std) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if is_verbose: diff --git a/src/graphorge/gnn_base_model/train/training.py b/src/graphorge/gnn_base_model/train/training.py index 60827fa9..46c2c7d9 100644 --- a/src/graphorge/gnn_base_model/train/training.py +++ b/src/graphorge/gnn_base_model/train/training.py @@ -222,7 +222,7 @@ class GNNEPDBaseModel). is_model_out_normalized = model.is_model_out_normalized # Fit model data scalers if is_model_in_normalized or is_model_out_normalized: - model.fit_data_scalers(dataset) + model.fit_data_scalers(dataset, is_verbose=is_verbose) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Save model initial state model.save_model_init_state() From f3a46939ab7c500b4b5288cfd26204059c0b60be Mon Sep 17 00:00:00 2001 From: Guillaume Broggi <25569517+GuillaumeBroggi@users.noreply.github.com> Date: Tue, 29 Apr 2025 16:52:30 +0200 Subject: [PATCH 3/6] Implements #25 #27 --- requirements.txt | 4 ++-- src/graphorge/gnn_base_model/model/gnn_model.py | 3 ++- src/graphorge/gnn_base_model/train/training.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/requirements.txt b/requirements.txt index eabcface..ca2dde41 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,8 +7,8 @@ pytest==7.1.2 scikit_learn==1.3.2 scipy==1.10.0 sphinx_rtd_theme==1.2.0 -torch==2.1.0+cu118 -torch_geometric==2.4.0 +torch==2.4.1 +torch_geometric==2.5.0 tqdm==4.65.0 shapely==2.0.2 torchinfo==1.8.0 diff --git a/src/graphorge/gnn_base_model/model/gnn_model.py b/src/graphorge/gnn_base_model/model/gnn_model.py index b727d3b7..34f7d109 100644 --- a/src/graphorge/gnn_base_model/model/gnn_model.py +++ b/src/graphorge/gnn_base_model/model/gnn_model.py @@ -1342,7 +1342,8 @@ def load_model_state(self, load_model_state=None, # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Load model state self.load_state_dict(torch.load(model_path, - map_location=torch.device('cpu'))) + map_location=torch.device('cpu'), + weights_only=True)) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ return epoch # ------------------------------------------------------------------------- diff --git a/src/graphorge/gnn_base_model/train/training.py b/src/graphorge/gnn_base_model/train/training.py index 46c2c7d9..73f164d7 100644 --- a/src/graphorge/gnn_base_model/train/training.py +++ b/src/graphorge/gnn_base_model/train/training.py @@ -753,7 +753,7 @@ def load_training_state(model, opt_algorithm, optimizer, raise RuntimeError('Unknown optimization algorithm') # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Load optimizer state - optimizer_state = torch.load(optimizer_path) + optimizer_state = torch.load(optimizer_path, weights_only=True) # Set loaded optimizer state optimizer.load_state_dict(optimizer_state['state']) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ From 39ddbbb592c1eb7d0efdbfe0818169c3c02e3fc6 Mon Sep 17 00:00:00 2001 From: Guillaume Broggi <25569517+GuillaumeBroggi@users.noreply.github.com> Date: Mon, 28 Apr 2025 10:40:43 +0200 Subject: [PATCH 4/6] Initial commit for cfd benchmark --- .gitignore | 7 +- benchmarks/cfd/cylinder_flow.ipynb | 481 +++++++++++++++++++++++++++++ benchmarks/cfd/utils.py | 226 ++++++++++++++ 3 files changed, 710 insertions(+), 4 deletions(-) create mode 100644 benchmarks/cfd/cylinder_flow.ipynb create mode 100644 benchmarks/cfd/utils.py diff --git a/.gitignore b/.gitignore index 85d17b1e..912044bc 100644 --- a/.gitignore +++ b/.gitignore @@ -169,13 +169,12 @@ cython_debug/ projects/ # Benchmark -benchmarks/cfd/ +benchmarks/cfd/* +!benchmarks/cfd/cylinder_flow.ipynb +!benchmarks/cfd/utils.py # Visual Studio Code .vscode/ # Local projects projects/ - -# Benchmark -benchmarks/cfd/ \ No newline at end of file diff --git a/benchmarks/cfd/cylinder_flow.ipynb b/benchmarks/cfd/cylinder_flow.ipynb new file mode 100644 index 00000000..93f10523 --- /dev/null +++ b/benchmarks/cfd/cylinder_flow.ipynb @@ -0,0 +1,481 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "from pathlib import Path\n", + "\n", + "import numpy as np\n", + "\n", + "from tqdm.notebook import tqdm\n", + "import logging\n", + "import pickle\n", + "\n", + "from itertools import count\n", + "from graphorge.gnn_base_model.data.graph_data import GraphData\n", + "from graphorge.gnn_base_model.data.graph_dataset import GNNGraphDataset\n", + "from graphorge.gnn_base_model.train.training import train_model\n", + "\n", + "import torch\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "from utils import *\n", + "\n", + "logger = logging.getLogger()\n", + "logger.setLevel(logging.INFO)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Step 1: Raw data\n", + "\n", + "The dataset is available as specified in [Deepmind's repository](https://github.com/google-deepmind/deepmind-research/tree/master/meshgraphnets) and was reported in [^1]. The dataset describes the turbulent flow of water around a cylinder obstacle. Each sample is simulated using COMSOL with irregular triangular 2D meshes over 600 time steps with a time step size of 0.01 seconds.\n", + "\n", + "[^1]: Learning Mesh-Based Simulation with Graph Networks, https://doi.org/10.48550/arXiv.2010.03409" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "base_url = \"https://storage.googleapis.com/dm-meshgraphnets/cylinder_flow\"\n", + "\n", + "download_file(file=\"meta.json\", base_url=base_url, dest_path=\"data\")\n", + "download_file(file=\"valid.tfrecord\", base_url=base_url, dest_path=\"data\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-04-07 19:41:13.495567: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2025-04-07 19:41:13.549141: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "2025-04-07 19:41:15.557545: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n" + ] + }, + { + "ename": "FileNotFoundError", + "evalue": "Dataset 'valid' not found in /home/guillaume/Documents/code/graphorge_fork/benchmarks/cfd/data.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[3], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mparse_tensorflow_dataset\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mvalid\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset_directory\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mdata\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Documents/code/graphorge_fork/benchmarks/cfd/utils.py:126\u001b[0m, in \u001b[0;36mparse_tensorflow_dataset\u001b[0;34m(dataset_name, dataset_directory)\u001b[0m\n\u001b[1;32m 123\u001b[0m dataset_path \u001b[38;5;241m=\u001b[39m dataset_directory \u001b[38;5;241m/\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdataset_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.tfrecord\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 125\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m dataset_path\u001b[38;5;241m.\u001b[39mexists():\n\u001b[0;32m--> 126\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mFileNotFoundError\u001b[39;00m(\n\u001b[1;32m 127\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataset \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdataset_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m not found in \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdataset_directory\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 128\u001b[0m )\n\u001b[1;32m 130\u001b[0m \u001b[38;5;66;03m# Load the dataset\u001b[39;00m\n\u001b[1;32m 131\u001b[0m dataset \u001b[38;5;241m=\u001b[39m tf\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mTFRecordDataset(dataset_path)\n", + "\u001b[0;31mFileNotFoundError\u001b[0m: Dataset 'valid' not found in /home/guillaume/Documents/code/graphorge_fork/benchmarks/cfd/data." + ] + } + ], + "source": [ + "parse_tensorflow_dataset(dataset_name=\"valid\", dataset_directory=\"data\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Step 2: Graph dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset_directories = dict(\n", + " train=\"1_training_dataset\",\n", + " valid=\"2_validation_dataset\",\n", + " test=\"5_testing_id_dataset\",\n", + ")\n", + "\n", + "# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", + "# Iterate over the datasets\n", + "for dataset_name, dataset_directory in dataset_directories.items():\n", + "\n", + " # Prepare a directory to store the dataset\n", + " dataset_directory = Path(dataset_directories[dataset_name])\n", + " dataset_directory.mkdir(parents=True, exist_ok=True)\n", + "\n", + " # Initialize a list to store the graph file paths\n", + " graph_file_paths = []\n", + "\n", + " # Locate the raw data downloaded from Google's server and parsed\n", + " raw_data_directory = Path(\"data\") / dataset_name\n", + "\n", + " # Search the raw data directory for sample files\n", + " sample_paths = list(\n", + " raw_data_directory.glob(f\"{dataset_name}_sample_*.pkl\")\n", + " )\n", + "\n", + " # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", + " # Iterate over the sample files\n", + " for sample_path in tqdm(sample_paths, desc=\"Generating graphs: \"):\n", + "\n", + " # Extract the sample id from the file name\n", + " sample_id = int(sample_path.stem.split(\"_\")[-1])\n", + "\n", + " # Load the sample\n", + " with open(sample_path, \"rb\") as file:\n", + " sample = pickle.load(file)\n", + "\n", + " # Initialize a list to store edge indexes\n", + " edge_indexes = []\n", + "\n", + " # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", + " # Iterate over the cells defined un the sample\n", + " for cell in sample[\"cells\"]:\n", + "\n", + " # Generate pairs of vertex indexes which define the edges\n", + " for i in range(len(cell) - 1):\n", + " edge_indexes.append((cell[i], cell[i + 1]))\n", + "\n", + " # Add the pair of vertices closing the cell\n", + " edge_indexes.append(sorted((cell[-1], cell[0])))\n", + "\n", + " # Cast as a numpy array as required by the GraphData class\n", + " edge_indexes = np.asarray(edge_indexes, dtype=int)\n", + "\n", + " # Remove duplicates and ensure edges are undirected\n", + " edge_indexes = GraphData.get_undirected_unique_edges(\n", + " edges_indexes=edge_indexes\n", + " )\n", + "\n", + " # Since the mesh is fixed, we can extract geometry features\n", + " # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", + " # Get the node type values. We follow the original meshgraphnet\n", + " # implementation and encode the node types as one-hot vectors.\n", + " # Node types take 0, 4, 5 and 6 values. They should be postprocessed to\n", + " # take 0, 1, 2 and 3 values for compatibility with the one_hot encoding\n", + " node_types = np.max(\n", + " (sample[\"node_type\"] - 3, np.zeros_like(sample[\"node_type\"])),\n", + " axis=0,\n", + " )\n", + "\n", + " # Encode the node types as one-hot vectors\n", + " node_type_one_hot = torch.nn.functional.one_hot(\n", + " torch.Tensor(node_types.reshape(-1)).long(), num_classes=4\n", + " )\n", + "\n", + " # Calculate the edge length vectors\n", + " distance_vector = (\n", + " sample[\"mesh_pos\"][edge_indexes[:, 0]]\n", + " - sample[\"mesh_pos\"][edge_indexes[:, 1]]\n", + " )\n", + "\n", + " # Calculate the edge length (euclidian) norms\n", + " distance_norm = np.linalg.norm(distance_vector, axis=1, keepdims=True)\n", + "\n", + " # Prepare the edge features\n", + " edge_features = np.hstack((distance_vector, distance_norm))\n", + "\n", + " # Save the edge feature names and shapes for the metadata\n", + " edge_feature_names = (\"ditance\", \"distance_norm\")\n", + " edge_feature_shapes = ((2,), (1,))\n", + "\n", + " # Initialize a time step counter\n", + " time_step_counter = count(1)\n", + "\n", + " # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", + " # Iterate over the time steps. The time steps are not explicitly stored\n", + " # in the sample, instead the first data dimension is the time.\n", + " # The meshgraphnet implementation predict a velocity update based on\n", + " # the velocity at the previous time step. Hence, a prediction cannot be\n", + " # made for the first time step. The total number of graphs is the\n", + " # number of time steps minus one.\n", + " for initial_velocity, updated_velocity, updated_pressure in zip(\n", + " sample[\"velocity\"][:-1],\n", + " sample[\"velocity\"][1:],\n", + " sample[\"pressure\"][1:],\n", + " ):\n", + "\n", + " # Initialize the graph data\n", + " graph_data = GraphData(n_dim=2, nodes_coords=sample[\"mesh_pos\"])\n", + "\n", + " # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", + " # Prepare the node features\n", + " node_features = np.hstack((initial_velocity, node_type_one_hot))\n", + "\n", + " # Save the node feature names and shapes for the metadata\n", + " node_feature_names = (\"velocity\", \"node_type_one_hot\")\n", + " node_feature_shapes = ((2,), (4,))\n", + "\n", + " # Set node features matrix\n", + " graph_data.set_node_features_matrix(\n", + " node_features_matrix=node_features\n", + " )\n", + "\n", + " # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", + " # Set graph edges, uniqueness has already been checked\n", + " graph_data.set_graph_edges_indexes(\n", + " edges_indexes_mesh=edge_indexes, is_unique=False\n", + " )\n", + "\n", + " # Set edge features\n", + " graph_data.set_edge_features_matrix(\n", + " edge_features_matrix=edge_features\n", + " )\n", + "\n", + " # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", + " # Compute the velocity update\n", + " velocity_update = updated_velocity - initial_velocity\n", + "\n", + " # Prepare the node targets\n", + " node_targets = np.hstack((velocity_update, updated_pressure))\n", + "\n", + " # Save the node target names and shapes for the metadata\n", + " node_target_names = (\"velocity_update\", \"pressure\")\n", + " node_target_shapes = ((2,), (1,))\n", + "\n", + " # Set node targets matrix\n", + " graph_data.set_node_targets_matrix(\n", + " node_targets_matrix=node_targets\n", + " )\n", + "\n", + " # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", + " # Get the time step id\n", + " time_step_id = next(time_step_counter)\n", + "\n", + " # Prepare the graph metadata\n", + " metadata_dict = dict(\n", + " dataset_name=dataset_name,\n", + " sample_id=int(sample_path.stem.split(\"_\")[-1]),\n", + " time_step_id=time_step_id,\n", + " edge_features=edge_feature_names,\n", + " edge_features_shapes=edge_feature_shapes,\n", + " node_features=node_feature_names,\n", + " node_features_shapes=node_feature_shapes,\n", + " node_targets=node_target_names,\n", + " node_targets_shapes=node_target_shapes,\n", + " )\n", + "\n", + " # Set the graph metadata\n", + " graph_data.set_metadata(metadata=metadata_dict)\n", + "\n", + " # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", + " # Extract the pyg graph data object\n", + " pyg_graph = graph_data.get_torch_data_object()\n", + "\n", + " # Cast edge indexes to int16, save memory\n", + " pyg_graph.edge_index = pyg_graph.edge_index.to(dtype=torch.int16)\n", + "\n", + " # Generate the graph file path\n", + " graph_file_path = (\n", + " dataset_directory\n", + " / f\"{dataset_name}_graph_{sample_id}_{time_step_id}.pt\"\n", + " )\n", + "\n", + " # Save the graph to file\n", + " torch.save(pyg_graph, graph_file_path)\n", + "\n", + " # Append the graph file path to the list of graph file paths\n", + " graph_file_paths.append(graph_file_path)\n", + "\n", + " # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", + " # Generate the GNN-based data set, use `is_store_dataset=False` as the\n", + " # graphs are already stored as individual files\n", + " dataset = GNNGraphDataset(\n", + " dataset_directory=dataset_directory,\n", + " dataset_sample_files=graph_file_paths,\n", + " dataset_basename=f\"meshgraphnet_{dataset_name}\",\n", + " is_store_dataset=False,\n", + " )\n", + "\n", + " # Save the dataset to file\n", + " _ = dataset.save_dataset(is_append_n_sample=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c1cf04702f424b32a845fdebcab426ce", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "EmbeddableWidget(value='