From a169969d652bc933892b74e5daaa83ac7093187c Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Fri, 26 Dec 2025 12:26:39 +0100 Subject: [PATCH 01/12] refact base and input-target condition --- pina/condition/condition_base.py | 210 +++++++++++++++++++++++ pina/condition/condition_interface.py | 94 +--------- pina/condition/input_target_condition.py | 186 +++++++++++++++++--- 3 files changed, 378 insertions(+), 112 deletions(-) create mode 100644 pina/condition/condition_base.py diff --git a/pina/condition/condition_base.py b/pina/condition/condition_base.py new file mode 100644 index 000000000..361352d0f --- /dev/null +++ b/pina/condition/condition_base.py @@ -0,0 +1,210 @@ +import torch +from copy import deepcopy +from .condition_interface import ConditionInterface +from ..graph import Graph, LabelBatch +from ..label_tensor import LabelTensor +from ..data.dummy_dataloader import DummyDataloader +from torch_geometric.data import Data, Batch +from torch.utils.data import DataLoader +from functools import partial + + +class ConditionBase(ConditionInterface): + collate_fn_dict = { + "tensor": torch.stack, + "label_tensor": LabelTensor.stack, + "graph": LabelBatch.from_data_list, + "data": Batch.from_data_list, + } + + def __init__(self, **kwargs): + super().__init__() + self.data = self._store_data(**kwargs) + + @property + def problem(self): + return self._problem + + @problem.setter + def problem(self, value): + self._problem = value + + @staticmethod + def _check_graph_list_consistency(data_list): + """ + Check the consistency of the list of Data | Graph objects. + The following checks are performed: + + - All elements in the list must be of the same type (either + :class:`~torch_geometric.data.Data` or :class:`~pina.graph.Graph`). + + - All elements in the list must have the same keys. + + - The data type of each tensor must be consistent across all elements. + + - If a tensor is a :class:`~pina.label_tensor.LabelTensor`, its labels + must also be consistent across all elements. + + :param data_list: The list of Data | Graph objects to check. + :type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph] + :raises ValueError: If the input types are invalid. + :raises ValueError: If all elements in the list do not have the same + keys. + :raises ValueError: If the type of each tensor is not consistent across + all elements in the list. + :raises ValueError: If the labels of the LabelTensors are not consistent + across all elements in the list. + """ + # If the data is a Graph or Data object, perform no checks + if isinstance(data_list, (Graph, Data)): + return + + # Check all elements in the list are of the same type + if not all(isinstance(i, (Graph, Data)) for i in data_list): + raise ValueError( + "Invalid input. Please, provide either Data or Graph objects." + ) + + # Store the keys, data types and labels of the first element + data = data_list[0] + keys = sorted(list(data.keys())) + data_types = {name: tensor.__class__ for name, tensor in data.items()} + labels = { + name: tensor.labels + for name, tensor in data.items() + if isinstance(tensor, LabelTensor) + } + + # Iterate over the list of Data | Graph objects + for data in data_list[1:]: + + # Check that all elements in the list have the same keys + if sorted(list(data.keys())) != keys: + raise ValueError( + "All elements in the list must have the same keys." + ) + + # Iterate over the tensors in the current element + for name, tensor in data.items(): + # Check that the type of each tensor is consistent + if tensor.__class__ is not data_types[name]: + raise ValueError( + f"Data {name} must be a {data_types[name]}, got " + f"{tensor.__class__}" + ) + + # Check that the labels of each LabelTensor are consistent + if isinstance(tensor, LabelTensor): + if tensor.labels != labels[name]: + raise ValueError( + "LabelTensor must have the same labels" + ) + + def _store_tensor_data(self, **kwargs): + """ + Store data for standard tensor condition + + :param kwargs: Keyword arguments representing the data to be stored. + :return: A dictionary containing the stored data. + :rtype: dict + """ + data = {} + for key, value in kwargs.items(): + data[key] = value + return data + + def _store_graph_data(self, graphs, tensors=None, key=None): + """ + Store data for graph condition + + :param graphs: List of graphs to store data in. + :type graphs: list[Graph] | list[Data] + :param tensors: List of tensors to store in the graphs. + :type tensors: list[torch.Tensor] | list[LabelTensor] + :param key: Key under which to store the tensors in the graphs. + :type key: str + :return: A dictionary containing the stored data. + :rtype: dict + """ + data = [] + for i, graph in enumerate(graphs): + new_graph = deepcopy(graph) + tensor = tensors[i] + setattr(new_graph, key, tensor) + data.append(new_graph) + return {"data": data} + + def _store_data(self, **kwargs): + return self._store_tensor_data(**kwargs) + + def __len__(self): + return len(next(iter(self.data.values()))) + + def __getitem__(self, idx): + return {key: self.data[key][idx] for key in self.data} + + @classmethod + def automatic_batching_collate_fn(cls, batch): + """ + Collate function to be used in DataLoader. + + :param batch: A list of items from the dataset. + :type batch: list + :return: A collated batch. + :rtype: dict + """ + + to_return = {key: [] for key in batch[0].keys()} + for item in batch: + for key, value in item.items(): + to_return[key].append(value) + for key, values in to_return.items(): + collate_function = cls.collate_fn_dict.get( + "label_tensor" + if isinstance(values[0], LabelTensor) + else ( + "label_tensor" + if isinstance(values[0], torch.Tensor) + else "graph" if isinstance(values[0], Graph) else "data" + ) + ) + to_return[key] = collate_function(values) + return to_return + + @staticmethod + def collate_fn(batch, condition): + """ + Collate function for automatic batching to be used in DataLoader. + + :param batch: A list of items from the dataset. + :type batch: list + :return: A collated batch. + :rtype: list + """ + data = condition[batch] + return data + + def create_dataloader( + self, dataset, batch_size, shuffle, automatic_batching + ): + """ + Create a DataLoader for the condition. + + :param int batch_size: The batch size for the DataLoader. + :param bool shuffle: Whether to shuffle the data. Default is ``False``. + :return: The DataLoader for the condition. + :rtype: torch.utils.data.DataLoader + """ + if batch_size == len(dataset): + return DummyDataloader(dataset) + return DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=shuffle, + collate_fn=( + partial(self.collate_fn, condition=self) + if not automatic_batching + else self.automatic_batching_collate_fn + ), + # collate_fn = self.automatic_batching_collate_fn + ) diff --git a/pina/condition/condition_interface.py b/pina/condition/condition_interface.py index b0264517c..427b85502 100644 --- a/pina/condition/condition_interface.py +++ b/pina/condition/condition_interface.py @@ -1,6 +1,6 @@ """Module for the Condition interface.""" -from abc import ABCMeta +from abc import ABCMeta, abstractmethod from torch_geometric.data import Data from ..label_tensor import LabelTensor from ..graph import Graph @@ -15,13 +15,14 @@ class ConditionInterface(metaclass=ABCMeta): description of all available conditions and how to instantiate them. """ - def __init__(self): + @abstractmethod + def __init__(self, **kwargs): """ Initialization of the :class:`ConditionInterface` class. """ - self._problem = None @property + @abstractmethod def problem(self): """ Return the problem associated with this condition. @@ -29,9 +30,9 @@ def problem(self): :return: Problem associated with this condition. :rtype: ~pina.problem.abstract_problem.AbstractProblem """ - return self._problem @problem.setter + @abstractmethod def problem(self, value): """ Set the problem associated with this condition. @@ -39,88 +40,3 @@ def problem(self, value): :param pina.problem.abstract_problem.AbstractProblem value: The problem to associate with this condition """ - self._problem = value - - @staticmethod - def _check_graph_list_consistency(data_list): - """ - Check the consistency of the list of Data | Graph objects. - The following checks are performed: - - - All elements in the list must be of the same type (either - :class:`~torch_geometric.data.Data` or :class:`~pina.graph.Graph`). - - - All elements in the list must have the same keys. - - - The data type of each tensor must be consistent across all elements. - - - If a tensor is a :class:`~pina.label_tensor.LabelTensor`, its labels - must also be consistent across all elements. - - :param data_list: The list of Data | Graph objects to check. - :type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph] - :raises ValueError: If the input types are invalid. - :raises ValueError: If all elements in the list do not have the same - keys. - :raises ValueError: If the type of each tensor is not consistent across - all elements in the list. - :raises ValueError: If the labels of the LabelTensors are not consistent - across all elements in the list. - """ - # If the data is a Graph or Data object, perform no checks - if isinstance(data_list, (Graph, Data)): - return - - # Check all elements in the list are of the same type - if not all(isinstance(i, (Graph, Data)) for i in data_list): - raise ValueError( - "Invalid input. Please, provide either Data or Graph objects." - ) - - # Store the keys, data types and labels of the first element - data = data_list[0] - keys = sorted(list(data.keys())) - data_types = {name: tensor.__class__ for name, tensor in data.items()} - labels = { - name: tensor.labels - for name, tensor in data.items() - if isinstance(tensor, LabelTensor) - } - - # Iterate over the list of Data | Graph objects - for data in data_list[1:]: - - # Check that all elements in the list have the same keys - if sorted(list(data.keys())) != keys: - raise ValueError( - "All elements in the list must have the same keys." - ) - - # Iterate over the tensors in the current element - for name, tensor in data.items(): - # Check that the type of each tensor is consistent - if tensor.__class__ is not data_types[name]: - raise ValueError( - f"Data {name} must be a {data_types[name]}, got " - f"{tensor.__class__}" - ) - - # Check that the labels of each LabelTensor are consistent - if isinstance(tensor, LabelTensor): - if tensor.labels != labels[name]: - raise ValueError( - "LabelTensor must have the same labels" - ) - - def __getattribute__(self, name): - """ - Get an attribute from the object. - - :param str name: The name of the attribute to get. - :return: The requested attribute. - :rtype: Any - """ - to_return = super().__getattribute__(name) - if isinstance(to_return, (Graph, Data)): - to_return = [to_return] - return to_return diff --git a/pina/condition/input_target_condition.py b/pina/condition/input_target_condition.py index 07b07bb7b..965eeecfc 100644 --- a/pina/condition/input_target_condition.py +++ b/pina/condition/input_target_condition.py @@ -3,13 +3,15 @@ """ import torch +from copy import deepcopy from torch_geometric.data import Data from ..label_tensor import LabelTensor -from ..graph import Graph -from .condition_interface import ConditionInterface +from ..graph import Graph, LabelBatch +from .condition_base import ConditionBase +from torch_geometric.data import Batch -class InputTargetCondition(ConditionInterface): +class InputTargetCondition(ConditionBase): """ The :class:`InputTargetCondition` class represents a supervised condition defined by both ``input`` and ``target`` data. The model is trained to @@ -55,7 +57,7 @@ class InputTargetCondition(ConditionInterface): """ # Available input and target data types - __slots__ = ["input", "target"] + __fields__ = ["input", "target"] _avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) _avail_output_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) @@ -109,16 +111,6 @@ def __new__(cls, input, target): subclass = GraphInputTensorTargetCondition return subclass.__new__(subclass, input, target) - # Graph - Graph - if isinstance(input, (Graph, Data, list, tuple)) and isinstance( - target, (Graph, Data, list, tuple) - ): - cls._check_graph_list_consistency(input) - cls._check_graph_list_consistency(target) - subclass = GraphInputGraphTargetCondition - return subclass.__new__(subclass, input, target) - - # If the input and/or target are not of the correct type raise an error raise ValueError( "Invalid input | target types." "Please provide either torch_geometric.data.Data, Graph, " @@ -143,10 +135,8 @@ def __init__(self, input, target): objects, all elements in the list must share the same structure, with matching keys and consistent data types. """ - super().__init__() self._check_input_target_len(input, target) - self.input = input - self.target = target + super().__init__(input=input, target=target) @staticmethod def _check_input_target_len(input, target): @@ -181,6 +171,26 @@ class TensorInputTensorTargetCondition(InputTargetCondition): :class:`~pina.label_tensor.LabelTensor` objects. """ + @property + def input(self): + """ + Return the input data for the condition. + + :return: The input data. + :rtype: torch.Tensor | LabelTensor + """ + return self.data["input"] + + @property + def target(self): + """ + Return the target data for the condition. + + :return: The target data. + :rtype: torch.Tensor | LabelTensor + """ + return self.data["target"] + class TensorInputGraphTargetCondition(InputTargetCondition): """ @@ -190,6 +200,65 @@ class TensorInputGraphTargetCondition(InputTargetCondition): :class:`~pina.graph.Graph` or a :class:`torch_geometric.data.Data` object. """ + def _store_data(self, **kwargs): + return self._store_graph_data( + kwargs["target"], kwargs["input"], key="x" + ) + + @property + def input(self): + """ + Return the input data for the condition. + + :return: The input data. + :rtype: list[torch.Tensor] | list[LabelTensor] + """ + targets = [] + is_lt = isinstance(self.data["data"][0].x, LabelTensor) + for graph in self.data["data"]: + targets.append(graph.x) + return torch.stack(targets) if not is_lt else LabelTensor.stack(targets) + + @property + def target(self): + """ + Return the target data for the condition. + + :return: The target data. + :rtype: list[Graph] | list[Data] + """ + return self.data["data"] + + def __getitem__(self, idx): + if isinstance(idx, list): + return self.get_multiple_data(idx) + return {"data": self.data["data"][idx]} + + def get_multiple_data(self, indices): + data = self.batch_fn([self.data["data"][i] for i in indices]) + x = data.x + del data.x # Avoid duplication of y on GPU memory + return { + "input": x, + "target": data, + } + + @classmethod + def automatic_batching_collate_fn(cls, batch): + """ + Collate function to be used in DataLoader. + + :param batch: A list of items from the dataset. + :type batch: list + :return: A collated batch. + :rtype: dict + """ + collated_graphs = super().automatic_batching_collate_fn(batch) + x = collated_graphs["data"].x + del collated_graphs["data"].x # Avoid duplication of y on GPU memory + to_return = {"input": x, "input": collated_graphs["data"]} + return to_return + class GraphInputTensorTargetCondition(InputTargetCondition): """ @@ -199,10 +268,81 @@ class GraphInputTensorTargetCondition(InputTargetCondition): :class:`torch.Tensor` or a :class:`~pina.label_tensor.LabelTensor` object. """ + def __init__(self, input, target): + """ + Initialization of the :class:`GraphInputTensorTargetCondition` class. -class GraphInputGraphTargetCondition(InputTargetCondition): - """ - Specialization of the :class:`InputTargetCondition` class for the case where - both ``input`` and ``target`` are either :class:`~pina.graph.Graph` or - :class:`torch_geometric.data.Data` objects. - """ + :param input: The input data for the condition. + :type input: Graph | Data | list[Graph] | list[Data] | + tuple[Graph] | tuple[Data] + :param target: The target data for the condition. + :type target: torch.Tensor | LabelTensor + """ + super().__init__(input=input, target=target) + self.batch_fn = ( + LabelBatch.from_data_list + if isinstance(input[0], Graph) + else Batch.from_data_list + ) + + def _store_data(self, **kwargs): + return self._store_graph_data( + kwargs["input"], kwargs["target"], key="y" + ) + + @property + def input(self): + """ + Return the input data for the condition. + + :return: The input data. + :rtype: list[Graph] | list[Data] + """ + return self.data["data"] + + @property + def target(self): + """ + Return the target data for the condition. + + :return: The target data. + :rtype: list[torch.Tensor] | list[LabelTensor] + """ + targets = [] + is_lt = isinstance(self.data["data"][0].y, LabelTensor) + for graph in self.data["data"]: + targets.append(graph.y) + + return torch.stack(targets) if not is_lt else LabelTensor.stack(targets) + + def __getitem__(self, idx): + if isinstance(idx, list): + return self.get_multiple_data(idx) + return {"data": self.data["data"][idx]} + + def get_multiple_data(self, indices): + data = self.batch_fn([self.data["data"][i] for i in indices]) + y = data.y + del data.y # Avoid duplication of y on GPU memory + return { + "input": data, + "target": y, + } + + @classmethod + def automatic_batching_collate_fn(cls, batch): + """ + Collate function to be used in DataLoader. + + :param batch: A list of items from the dataset. + :type batch: list + :return: A collated batch. + :rtype: dict + """ + collated_graphs = super().automatic_batching_collate_fn(batch) + y = collated_graphs["data"].y + del collated_graphs["data"].y # Avoid duplication of y on GPU memory + print("y shape:", y.shape) + print(y.labels) + to_return = {"target": y, "input": collated_graphs["data"]} + return to_return From 52ee3e78a687b688aa7944274a871e92f47a5122 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 29 Dec 2025 08:58:01 +0100 Subject: [PATCH 02/12] refact condition factory --- pina/condition/condition.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/pina/condition/condition.py b/pina/condition/condition.py index ad8764c9f..3c43f7176 100644 --- a/pina/condition/condition.py +++ b/pina/condition/condition.py @@ -86,12 +86,12 @@ class Condition: """ # Combine all possible keyword arguments from the different Condition types - __slots__ = list( + available_kwargs = list( set( - InputTargetCondition.__slots__ - + InputEquationCondition.__slots__ - + DomainEquationCondition.__slots__ - + DataCondition.__slots__ + InputTargetCondition.__fields__ + + InputEquationCondition.__fields__ + + DomainEquationCondition.__fields__ + + DataCondition.__fields__ ) ) @@ -112,28 +112,28 @@ def __new__(cls, *args, **kwargs): if len(args) != 0: raise ValueError( "Condition takes only the following keyword " - f"arguments: {Condition.__slots__}." + f"arguments: {Condition.available_kwargs}." ) # Class specialization based on keyword arguments sorted_keys = sorted(kwargs.keys()) # Input - Target Condition - if sorted_keys == sorted(InputTargetCondition.__slots__): + if sorted_keys == sorted(InputTargetCondition.__fields__): return InputTargetCondition(**kwargs) # Input - Equation Condition - if sorted_keys == sorted(InputEquationCondition.__slots__): + if sorted_keys == sorted(InputEquationCondition.__fields__): return InputEquationCondition(**kwargs) # Domain - Equation Condition - if sorted_keys == sorted(DomainEquationCondition.__slots__): + if sorted_keys == sorted(DomainEquationCondition.__fields__): return DomainEquationCondition(**kwargs) # Data Condition if ( - sorted_keys == sorted(DataCondition.__slots__) - or sorted_keys[0] == DataCondition.__slots__[0] + sorted_keys == sorted(DataCondition.__fields__) + or sorted_keys[0] == DataCondition.__fields__[0] ): return DataCondition(**kwargs) From 8ae31089022ffe2e3cadc6b07ad4c298d9202014 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 29 Dec 2025 10:00:44 +0100 Subject: [PATCH 03/12] fix TensorInputGraphTargetCondition --- pina/condition/input_target_condition.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/pina/condition/input_target_condition.py b/pina/condition/input_target_condition.py index 965eeecfc..ece2f22e6 100644 --- a/pina/condition/input_target_condition.py +++ b/pina/condition/input_target_condition.py @@ -200,6 +200,23 @@ class TensorInputGraphTargetCondition(InputTargetCondition): :class:`~pina.graph.Graph` or a :class:`torch_geometric.data.Data` object. """ + def __init__(self, input, target): + """ + Initialization of the :class:`TensorInputGraphTargetCondition` class. + + :param input: The input data for the condition. + :type input: torch.Tensor | LabelTensor + :param target: The target data for the condition. + :type target: Graph | Data | list[Graph] | list[Data] | + tuple[Graph] | tuple[Data] + """ + super().__init__(input=input, target=target) + self.batch_fn = ( + LabelBatch.from_data_list + if isinstance(target[0], Graph) + else Batch.from_data_list + ) + def _store_data(self, **kwargs): return self._store_graph_data( kwargs["target"], kwargs["input"], key="x" From 5ed425dc5fc62ee7fcbea806f254540feb41a1b0 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 29 Dec 2025 10:01:20 +0100 Subject: [PATCH 04/12] Implement test for InputTargetCondition --- .../test_input_target_condition.py | 294 ++++++++++++++++++ 1 file changed, 294 insertions(+) create mode 100644 tests/test_condition/test_input_target_condition.py diff --git a/tests/test_condition/test_input_target_condition.py b/tests/test_condition/test_input_target_condition.py new file mode 100644 index 000000000..033f7094a --- /dev/null +++ b/tests/test_condition/test_input_target_condition.py @@ -0,0 +1,294 @@ +import torch +import pytest +from torch_geometric.data import Batch +from pina import LabelTensor, Condition +from pina.condition import ( + TensorInputGraphTargetCondition, + TensorInputTensorTargetCondition, + GraphInputTensorTargetCondition, +) +from pina.graph import RadiusGraph, LabelBatch + + +def _create_tensor_data(use_lt=False): + if use_lt: + input_tensor = LabelTensor(torch.rand((10, 3)), ["x", "y", "z"]) + target_tensor = LabelTensor(torch.rand((10, 2)), ["a", "b"]) + return input_tensor, target_tensor + input_tensor = torch.rand((10, 3)) + target_tensor = torch.rand((10, 2)) + return input_tensor, target_tensor + + +def _create_graph_data(tensor_input=True, use_lt=False): + if use_lt: + x = LabelTensor(torch.rand(10, 20, 2), ["u", "v"]) + pos = LabelTensor(torch.rand(10, 20, 2), ["x", "y"]) + else: + x = torch.rand(10, 20, 2) + pos = torch.rand(10, 20, 2) + radius = 0.1 + graph = [ + RadiusGraph( + pos=pos[i], + radius=radius, + x=x[i] if not tensor_input else None, + y=x[i] if tensor_input else None, + ) + for i in range(len(x)) + ] + if use_lt: + tensor = LabelTensor(torch.rand(10, 20, 1), ["f"]) + else: + tensor = torch.rand(10, 20, 1) + return graph, tensor + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_init_tensor_input_tensor_target_condition(use_lt): + input_tensor, target_tensor = _create_tensor_data(use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_tensor) + assert isinstance(condition, TensorInputTensorTargetCondition) + assert torch.allclose( + condition.input, input_tensor + ), "TensorInputTensorTargetCondition input failed" + assert torch.allclose( + condition.target, target_tensor + ), "TensorInputTensorTargetCondition target failed" + if use_lt: + assert isinstance( + condition.input, LabelTensor + ), "TensorInputTensorTargetCondition input type failed" + assert condition.input.labels == [ + "x", + "y", + "z", + ], "TensorInputTensorTargetCondition input labels failed" + assert isinstance( + condition.target, LabelTensor + ), "TensorInputTensorTargetCondition target type failed" + assert condition.target.labels == [ + "a", + "b", + ], "TensorInputTensorTargetCondition target labels failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_init_tensor_input_graph_target_condition(use_lt): + target_graph, input_tensor = _create_graph_data(use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_graph) + assert isinstance(condition, TensorInputGraphTargetCondition) + assert torch.allclose( + condition.input, input_tensor + ), "TensorInputGraphTargetCondition input failed" + for i in range(len(target_graph)): + assert torch.allclose( + condition.target[i].y, target_graph[i].y + ), "TensorInputGraphTargetCondition target failed" + if use_lt: + assert isinstance( + condition.target[i].y, LabelTensor + ), "TensorInputGraphTargetCondition target type failed" + assert condition.target[i].y.labels == [ + "u", + "v", + ], "TensorInputGraphTargetCondition target labels failed" + if use_lt: + assert isinstance( + condition.input, LabelTensor + ), "TensorInputGraphTargetCondition target type failed" + assert condition.input.labels == [ + "f" + ], "TensorInputGraphTargetCondition target labels failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_init_graph_input_tensor_target_condition(use_lt): + input_graph, target_tensor = _create_graph_data(False, use_lt=use_lt) + condition = Condition(input=input_graph, target=target_tensor) + assert isinstance(condition, GraphInputTensorTargetCondition) + for i in range(len(input_graph)): + assert torch.allclose( + condition.input[i].x, input_graph[i].x + ), "GraphInputTensorTargetCondition input failed" + if use_lt: + assert isinstance( + condition.input[i].x, LabelTensor + ), "GraphInputTensorTargetCondition input type failed" + assert ( + condition.input[i].x.labels == input_graph[i].x.labels + ), "GraphInputTensorTargetCondition labels failed" + + assert torch.allclose( + condition.target[i], target_tensor[i] + ), "GraphInputTensorTargetCondition target failed" + if use_lt: + assert isinstance( + condition.target, LabelTensor + ), "GraphInputTensorTargetCondition target type failed" + assert condition.target.labels == [ + "f" + ], "GraphInputTensorTargetCondition target labels failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_getitem_tensor_input_tensor_target_condition(use_lt): + input_tensor, target_tensor = _create_tensor_data(use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_tensor) + for i in range(len(input_tensor)): + item = condition[i] + assert torch.allclose( + item["input"], input_tensor[i] + ), "TensorInputTensorTargetCondition __getitem__ input failed" + assert torch.allclose( + item["target"], target_tensor[i] + ), "TensorInputTensorTargetCondition __getitem__ target failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_getitem_tensor_input_graph_target_condition(use_lt): + target_graph, input_tensor = _create_graph_data(use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_graph) + for i in range(len(input_tensor)): + item = condition[i]["data"] + assert torch.allclose( + item.x, input_tensor[i] + ), "TensorInputGraphTargetCondition __getitem__ input failed" + assert torch.allclose( + item.y, target_graph[i].y + ), "TensorInputGraphTargetCondition __getitem__ target failed" + if use_lt: + assert isinstance( + item.y, LabelTensor + ), "TensorInputGraphTargetCondition __getitem__ target type failed" + assert item.y.labels == [ + "u", + "v", + ], "TensorInputGraphTargetCondition __getitem__ target labels failed" + + +def test_getitem_graph_input_tensor_target_condition(): + input_graph, target_tensor = _create_graph_data(False) + condition = Condition(input=input_graph, target=target_tensor) + for i in range(len(input_graph)): + item = condition[i]["data"] + print(item) + assert torch.allclose( + item.x, input_graph[i].x + ), "GraphInputTensorTargetCondition __getitem__ input failed" + assert torch.allclose( + item.y, target_tensor[i] + ), "GraphInputTensorTargetCondition __getitem__ target failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_getitems_graph_input_tensor_target_condition(use_lt): + input_graph, target_tensor = _create_graph_data(False, use_lt=use_lt) + condition = Condition(input=input_graph, target=target_tensor) + indices = [0, 2, 4] + items = condition[indices] + candidate_input = items["input"] + candidate_target = items["target"] + + if use_lt: + input_ = LabelBatch.from_data_list([input_graph[i] for i in indices]) + target_ = LabelTensor.cat([target_tensor[i] for i in indices], dim=0) + else: + input_ = Batch.from_data_list([input_graph[i] for i in indices]) + target_ = torch.cat([target_tensor[i] for i in indices], dim=0) + assert torch.allclose( + candidate_input.x, input_.x + ), "GraphInputTensorTargetCondition __geitemsem__ input failed" + assert torch.allclose( + candidate_target, target_ + ), "GraphInputTensorTargetCondition __geitemsem__ input failed" + if use_lt: + assert isinstance( + candidate_target, LabelTensor + ), "GraphInputTensorTargetCondition __getitems__ target type failed" + assert candidate_target.labels == [ + "f" + ], "GraphInputTensorTargetCondition __getitems__ target labels failed" + + assert isinstance( + candidate_input.x, LabelTensor + ), "GraphInputTensorTargetCondition __getitems__ input type failed" + assert ( + candidate_input.x.labels == input_graph[0].x.labels + ), "GraphInputTensorTargetCondition __getitems__ input labels failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_getitems_tensor_input_tensor_target_condition(use_lt): + + input_tensor, target_tensor = _create_tensor_data(use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_tensor) + indices = [1, 3, 5, 7] + items = condition[indices] + candidate_input = items["input"] + candidate_target = items["target"] + + if use_lt: + input_ = LabelTensor.stack([input_tensor[i] for i in indices]) + target_ = LabelTensor.stack([target_tensor[i] for i in indices]) + else: + input_ = torch.stack([input_tensor[i] for i in indices]) + target_ = torch.stack([target_tensor[i] for i in indices]) + assert torch.allclose( + candidate_input, input_ + ), "TensorInputTensorTargetCondition __getitems__ input failed" + assert torch.allclose( + candidate_target, target_ + ), "TensorInputTensorTargetCondition __getitems__ target failed" + if use_lt: + assert isinstance( + candidate_input, LabelTensor + ), "TensorInputTensorTargetCondition __getitems__ input type failed" + assert candidate_input.labels == [ + "x", + "y", + "z", + ], "TensorInputTensorTargetCondition __getitems__ input labels failed" + assert isinstance( + candidate_target, LabelTensor + ), "TensorInputTensorTargetCondition __getitems__ target type failed" + assert candidate_target.labels == [ + "a", + "b", + ], "TensorInputTensorTargetCondition __getitems__ target labels failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_getitems_tensor_input_graph_target_condition(use_lt): + target_graph, input_tensor = _create_graph_data(True, use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_graph) + indices = [0, 2, 4] + items = condition[indices] + candidate_input = items["input"] + candidate_target = items["target"] + if use_lt: + input_ = LabelTensor.cat([input_tensor[i] for i in indices], dim=0) + target_ = LabelBatch.from_data_list([target_graph[i] for i in indices]) + else: + input_ = torch.cat([input_tensor[i] for i in indices], dim=0) + target_ = Batch.from_data_list([target_graph[i] for i in indices]) + assert torch.allclose( + candidate_input, input_ + ), "TensorInputGraphTargetCondition __getitems__ input failed" + assert torch.allclose( + candidate_target.y, target_.y + ), "TensorInputGraphTargetCondition __getitems__ target failed" + if use_lt: + assert isinstance( + candidate_input, LabelTensor + ), "TensorInputGraphTargetCondition __getitems__ input type failed" + assert candidate_input.labels == [ + "f" + ], "TensorInputGraphTargetCondition __getitems__ input labels failed" + assert isinstance( + candidate_target.y, LabelTensor + ), "TensorInputGraphTargetCondition __getitems__ target type failed" + assert candidate_target.y.labels == [ + "u", + "v", + ], "TensorInputGraphTargetCondition __getitems__ target labels failed" From ea26f3dae4f6cc3794d4edfc1b04686b352b9afe Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 29 Dec 2025 14:50:46 +0100 Subject: [PATCH 05/12] fix codacy --- pina/condition/condition_base.py | 37 ++++++++++++++++++++---- pina/condition/condition_interface.py | 21 ++++++++++++-- pina/condition/input_target_condition.py | 29 +++++++++++++++---- 3 files changed, 74 insertions(+), 13 deletions(-) diff --git a/pina/condition/condition_base.py b/pina/condition/condition_base.py index 361352d0f..0232375e3 100644 --- a/pina/condition/condition_base.py +++ b/pina/condition/condition_base.py @@ -1,15 +1,25 @@ -import torch +""" +Base class for conditions. +""" + from copy import deepcopy +from functools import partial +import torch +from torch_geometric.data import Data, Batch +from torch.utils.data import DataLoader from .condition_interface import ConditionInterface from ..graph import Graph, LabelBatch from ..label_tensor import LabelTensor from ..data.dummy_dataloader import DummyDataloader -from torch_geometric.data import Data, Batch -from torch.utils.data import DataLoader -from functools import partial class ConditionBase(ConditionInterface): + """ + Base abstract class for all conditions in PINA. + This class provides common functionality for handling data storage, + batching, and interaction with the associated problem. + """ + collate_fn_dict = { "tensor": torch.stack, "label_tensor": LabelTensor.stack, @@ -18,15 +28,32 @@ class ConditionBase(ConditionInterface): } def __init__(self, **kwargs): + """ + Initialization of the :class:`ConditionBase` class. + + :param kwargs: Keyword arguments representing the data to be stored. + """ super().__init__() self.data = self._store_data(**kwargs) @property def problem(self): + """ + Return the problem associated with this condition. + + :return: Problem associated with this condition. + :rtype: ~pina.problem.abstract_problem.AbstractProblem + """ return self._problem @problem.setter def problem(self, value): + """ + Set the problem associated with this condition. + + :param pina.problem.abstract_problem.AbstractProblem value: The problem + to associate with this condition + """ self._problem = value @staticmethod @@ -141,7 +168,7 @@ def __len__(self): return len(next(iter(self.data.values()))) def __getitem__(self, idx): - return {key: self.data[key][idx] for key in self.data} + return {name: data[idx] for name, data in self.data.items()} @classmethod def automatic_batching_collate_fn(cls, batch): diff --git a/pina/condition/condition_interface.py b/pina/condition/condition_interface.py index 427b85502..229b9a025 100644 --- a/pina/condition/condition_interface.py +++ b/pina/condition/condition_interface.py @@ -1,9 +1,6 @@ """Module for the Condition interface.""" from abc import ABCMeta, abstractmethod -from torch_geometric.data import Data -from ..label_tensor import LabelTensor -from ..graph import Graph class ConditionInterface(metaclass=ABCMeta): @@ -40,3 +37,21 @@ def problem(self, value): :param pina.problem.abstract_problem.AbstractProblem value: The problem to associate with this condition """ + + @abstractmethod + def __len__(self): + """ + Return the number of data points in the condition. + + :return: Number of data points. + :rtype: int + """ + + @abstractmethod + def __getitem__(self, idx): + """ + Return the data point(s) at the specified index. + + :param int idx: Index of the data point(s) to retrieve. + :return: Data point(s) at the specified index. + """ diff --git a/pina/condition/input_target_condition.py b/pina/condition/input_target_condition.py index ece2f22e6..c90fcc8e3 100644 --- a/pina/condition/input_target_condition.py +++ b/pina/condition/input_target_condition.py @@ -3,12 +3,10 @@ """ import torch -from copy import deepcopy -from torch_geometric.data import Data +from torch_geometric.data import Data, Batch from ..label_tensor import LabelTensor from ..graph import Graph, LabelBatch from .condition_base import ConditionBase -from torch_geometric.data import Batch class InputTargetCondition(ConditionBase): @@ -218,6 +216,13 @@ def __init__(self, input, target): ) def _store_data(self, **kwargs): + """ + Store the input and target data for the condition. + + :param kwargs: Keyword arguments containing 'input' and 'target'. + :return: Stored data dictionary. + :rtype: dict + """ return self._store_graph_data( kwargs["target"], kwargs["input"], key="x" ) @@ -252,6 +257,13 @@ def __getitem__(self, idx): return {"data": self.data["data"][idx]} def get_multiple_data(self, indices): + """ + Get multiple data items based on the provided indices. + + :param List[int] indices: List of indices to retrieve. + :return: Dictionary containing 'input' and 'target' data. + :rtype: dict + """ data = self.batch_fn([self.data["data"][i] for i in indices]) x = data.x del data.x # Avoid duplication of y on GPU memory @@ -266,14 +278,14 @@ def automatic_batching_collate_fn(cls, batch): Collate function to be used in DataLoader. :param batch: A list of items from the dataset. - :type batch: list + :type batch: List[dict] :return: A collated batch. :rtype: dict """ collated_graphs = super().automatic_batching_collate_fn(batch) x = collated_graphs["data"].x del collated_graphs["data"].x # Avoid duplication of y on GPU memory - to_return = {"input": x, "input": collated_graphs["data"]} + to_return = {"input": x, "target": collated_graphs["data"]} return to_return @@ -338,6 +350,13 @@ def __getitem__(self, idx): return {"data": self.data["data"][idx]} def get_multiple_data(self, indices): + """ + Get multiple data items based on the provided indices. + + :param List[int] indices: List of indices to retrieve. + :return: Dictionary containing 'input' and 'target' data. + :rtype: dict + """ data = self.batch_fn([self.data["data"][i] for i in indices]) y = data.y del data.y # Avoid duplication of y on GPU memory From 74749fc40a0434a6e59b0bd747e81aea421658b8 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 29 Dec 2025 15:13:38 +0100 Subject: [PATCH 06/12] refact InputEquationCondition --- pina/condition/input_equation_condition.py | 30 +++------ .../test_input_equation_condition.py | 65 +++++++++++++++++++ 2 files changed, 75 insertions(+), 20 deletions(-) create mode 100644 tests/test_condition/test_input_equation_condition.py diff --git a/pina/condition/input_equation_condition.py b/pina/condition/input_equation_condition.py index d32597894..fa41f79e2 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/condition/input_equation_condition.py @@ -1,13 +1,12 @@ """Module for the InputEquationCondition class and its subclasses.""" -from .condition_interface import ConditionInterface +from .condition_base import ConditionBase from ..label_tensor import LabelTensor from ..graph import Graph -from ..utils import check_consistency from ..equation.equation_interface import EquationInterface -class InputEquationCondition(ConditionInterface): +class InputEquationCondition(ConditionBase): """ The class :class:`InputEquationCondition` defines a condition based on ``input`` data and an ``equation``. This condition is typically used in @@ -41,7 +40,7 @@ class InputEquationCondition(ConditionInterface): """ # Available input data types - __slots__ = ["input", "equation"] + __fields__ = ["input", "equation"] _avail_input_cls = (LabelTensor, Graph, list, tuple) _avail_equation_cls = EquationInterface @@ -97,27 +96,18 @@ def __init__(self, input, equation): the list must share the same structure, with matching keys and consistent data types. """ - super().__init__() - self.input = input + super().__init__(input=input) self.equation = equation - def __setattr__(self, key, value): + @property + def input(self): """ - Set the attribute value with type checking. + Return the input data for the condition. - :param str key: The attribute name. - :param any value: The value to set for the attribute. + :return: The input data. + :rtype: LabelTensor | Graph | list[Graph] | tuple[Graph] """ - if key == "input": - check_consistency(value, self._avail_input_cls) - InputEquationCondition.__dict__[key].__set__(self, value) - - elif key == "equation": - check_consistency(value, self._avail_equation_cls) - InputEquationCondition.__dict__[key].__set__(self, value) - - elif key in ("_problem"): - super().__setattr__(key, value) + return self.data["input"] class InputTensorEquationCondition(InputEquationCondition): diff --git a/tests/test_condition/test_input_equation_condition.py b/tests/test_condition/test_input_equation_condition.py new file mode 100644 index 000000000..b6a687e2a --- /dev/null +++ b/tests/test_condition/test_input_equation_condition.py @@ -0,0 +1,65 @@ +import torch +from pina import Condition +from pina.condition.input_equation_condition import ( + InputTensorEquationCondition, + InputGraphEquationCondition, +) +from pina.equation import Equation +from pina import LabelTensor + + +def _create_pts_and_equation(): + def dummy_equation(pts): + return pts["x"] ** 2 + pts["y"] ** 2 - 1 + + pts = LabelTensor(torch.randn(100, 2), labels=["x", "y"]) + equation = Equation(dummy_equation) + return pts, equation + + +def _create_graph_and_equation(): + from pina.graph import KNNGraph + + def dummy_equation(pts): + return pts.x[:, 0] ** 2 + pts.x[:, 1] ** 2 - 1 + + x = LabelTensor(torch.randn(100, 2), labels=["u", "v"]) + pos = LabelTensor(torch.randn(100, 2), labels=["x", "y"]) + graph = KNNGraph(x=x, pos=pos, neighbours=5, edge_attr=True) + equation = Equation(dummy_equation) + return graph, equation + + +def test_init_tensor_equation_condition(): + pts, equation = _create_pts_and_equation() + condition = Condition(input=pts, equation=equation) + assert isinstance(condition, InputTensorEquationCondition) + assert condition.input.shape == (100, 2) + assert condition.equation is equation + + +def test_init_graph_equation_condition(): + graph, equation = _create_graph_and_equation() + condition = Condition(input=graph, equation=equation) + assert isinstance(condition, InputGraphEquationCondition) + assert condition.input is graph + assert condition.equation is equation + + +def test_getitem_tensor_equation_condition(): + pts, equation = _create_pts_and_equation() + condition = Condition(input=pts, equation=equation) + item = condition[0] + assert isinstance(item, dict) + assert "input" in item + assert item["input"].shape == (2,) + + +def test_getitems_tensor_equation_condition(): + pts, equation = _create_pts_and_equation() + condition = Condition(input=pts, equation=equation) + idxs = [0, 1, 3] + item = condition[idxs] + assert isinstance(item, dict) + assert "input" in item + assert item["input"].shape == (3, 2) From 8fbfde94af38e39164de7ea157c34f15d3de251b Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 29 Dec 2025 15:22:11 +0100 Subject: [PATCH 07/12] refact DomainEquationCondition --- pina/condition/domain_equation_condition.py | 45 ++++++++++--------- .../test_domain_equation_condition.py | 27 +++++++++++ 2 files changed, 52 insertions(+), 20 deletions(-) create mode 100644 tests/test_condition/test_domain_equation_condition.py diff --git a/pina/condition/domain_equation_condition.py b/pina/condition/domain_equation_condition.py index 3565c0b41..3e4adbaee 100644 --- a/pina/condition/domain_equation_condition.py +++ b/pina/condition/domain_equation_condition.py @@ -1,12 +1,11 @@ """Module for the DomainEquationCondition class.""" -from .condition_interface import ConditionInterface -from ..utils import check_consistency +from .condition_base import ConditionBase from ..domain import DomainInterface from ..equation.equation_interface import EquationInterface -class DomainEquationCondition(ConditionInterface): +class DomainEquationCondition(ConditionBase): """ The class :class:`DomainEquationCondition` defines a condition based on a ``domain`` and an ``equation``. This condition is typically used in @@ -30,7 +29,7 @@ class DomainEquationCondition(ConditionInterface): """ # Available slots - __slots__ = ["domain", "equation"] + __fields__ = ["domain", "equation"] def __init__(self, domain, equation): """ @@ -41,24 +40,30 @@ def __init__(self, domain, equation): :param EquationInterface equation: The equation to be satisfied over the specified domain. """ + if not isinstance(domain, (DomainInterface, str)): + raise ValueError( + f"`domain` must be an instance of DomainInterface, " + f"got {type(domain)} instead." + ) + if not isinstance(equation, EquationInterface): + raise ValueError( + f"`equation` must be an instance of EquationInterface, " + f"got {type(equation)} instead." + ) super().__init__() self.domain = domain self.equation = equation - def __setattr__(self, key, value): - """ - Set the attribute value with type checking. - - :param str key: The attribute name. - :param any value: The value to set for the attribute. - """ - if key == "domain": - check_consistency(value, (DomainInterface, str)) - DomainEquationCondition.__dict__[key].__set__(self, value) - - elif key == "equation": - check_consistency(value, (EquationInterface)) - DomainEquationCondition.__dict__[key].__set__(self, value) + def __len__(self): + raise NotImplementedError( + "`__len__` method is not implemented for " + "`DomainEquationCondition` since the number of points is " + "determined by the domain sampling strategy." + ) - elif key in ("_problem"): - super().__setattr__(key, value) + def __getitem__(self, idx): + """ """ + raise NotImplementedError( + "`__getitem__` method is not implemented for " + "`DomainEquationCondition`" + ) diff --git a/tests/test_condition/test_domain_equation_condition.py b/tests/test_condition/test_domain_equation_condition.py new file mode 100644 index 000000000..2b7c78b00 --- /dev/null +++ b/tests/test_condition/test_domain_equation_condition.py @@ -0,0 +1,27 @@ +import pytest +from pina import Condition +from pina.domain import CartesianDomain +from pina.equation.equation_factory import FixedValue +from pina.condition import DomainEquationCondition + +example_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]}) +example_equation = FixedValue(0.0) + + +def test_init_domain_equation(): + cond = Condition(domain=example_domain, equation=example_equation) + assert isinstance(cond, DomainEquationCondition) + assert cond.domain is example_domain + assert cond.equation is example_equation + + +def test_len_not_implemented(): + cond = Condition(domain=example_domain, equation=FixedValue(0.0)) + with pytest.raises(NotImplementedError): + len(cond) + + +def test_getitem_not_implemented(): + cond = Condition(domain=example_domain, equation=FixedValue(0.0)) + with pytest.raises(NotImplementedError): + cond[0] From c941444007947efa95f6a106a3d6d5cc7470ddaa Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Tue, 30 Dec 2025 12:27:09 +0100 Subject: [PATCH 08/12] implement TensorCondition and GraphCondition --- pina/condition/condition_base.py | 112 ++++++++++++++++++++----------- 1 file changed, 74 insertions(+), 38 deletions(-) diff --git a/pina/condition/condition_base.py b/pina/condition/condition_base.py index 0232375e3..b8b828767 100644 --- a/pina/condition/condition_base.py +++ b/pina/condition/condition_base.py @@ -13,6 +13,79 @@ from ..data.dummy_dataloader import DummyDataloader +class TensorCondition: + def store_data(self, **kwargs): + """ + Store data for standard tensor condition + + :param kwargs: Keyword arguments representing the data to be stored. + :return: A dictionary containing the stored data. + :rtype: dict + """ + data = {} + for key, value in kwargs.items(): + data[key] = value + return data + + +class GraphCondition: + def __init__(self, **kwargs): + super().__init__(**kwargs) + example = kwargs.get(self.graph_field)[0] + self.batch_fn = ( + LabelBatch.from_data_list + if isinstance(example, Graph) + else Batch.from_data_list + ) + + def store_data(self, **kwargs): + """ + Store data for graph condition + + :param graphs: List of graphs to store data in. + :type graphs: list[Graph] | list[Data] + :param tensors: List of tensors to store in the graphs. + :type tensors: list[torch.Tensor] | list[LabelTensor] + :param key: Key under which to store the tensors in the graphs. + :type key: str + :return: A dictionary containing the stored data. + :rtype: dict + """ + data = [] + graphs = kwargs.get(self.graph_field) + for i, graph in enumerate(graphs): + new_graph = deepcopy(graph) + for key in self.tensor_fields: + tensor = kwargs[key][i] + mapping_key = self.keys_map.get(key) + setattr(new_graph, mapping_key, tensor) + data.append(new_graph) + return {"data": data} + + def __getitem__(self, idx): + if isinstance(idx, list): + return self.get_multiple_data(idx) + return {"data": self.data["data"][idx]} + + def get_multiple_data(self, indices): + """ + Get multiple data items based on the provided indices. + + :param List[int] indices: List of indices to retrieve. + :return: Dictionary containing 'input' and 'target' data. + :rtype: dict + """ + to_return_dict = {} + data = self.batch_fn([self.data["data"][i] for i in indices]) + to_return_dict[self.graph_field] = data + for key in self.tensor_fields: + mapping_key = self.keys_map.get(key) + y = getattr(data, mapping_key) + delattr(data, mapping_key) # Avoid duplication of y on GPU memory + to_return_dict[key] = y + return to_return_dict + + class ConditionBase(ConditionInterface): """ Base abstract class for all conditions in PINA. @@ -34,7 +107,7 @@ def __init__(self, **kwargs): :param kwargs: Keyword arguments representing the data to be stored. """ super().__init__() - self.data = self._store_data(**kwargs) + self.data = self.store_data(**kwargs) @property def problem(self): @@ -127,43 +200,6 @@ def _check_graph_list_consistency(data_list): "LabelTensor must have the same labels" ) - def _store_tensor_data(self, **kwargs): - """ - Store data for standard tensor condition - - :param kwargs: Keyword arguments representing the data to be stored. - :return: A dictionary containing the stored data. - :rtype: dict - """ - data = {} - for key, value in kwargs.items(): - data[key] = value - return data - - def _store_graph_data(self, graphs, tensors=None, key=None): - """ - Store data for graph condition - - :param graphs: List of graphs to store data in. - :type graphs: list[Graph] | list[Data] - :param tensors: List of tensors to store in the graphs. - :type tensors: list[torch.Tensor] | list[LabelTensor] - :param key: Key under which to store the tensors in the graphs. - :type key: str - :return: A dictionary containing the stored data. - :rtype: dict - """ - data = [] - for i, graph in enumerate(graphs): - new_graph = deepcopy(graph) - tensor = tensors[i] - setattr(new_graph, key, tensor) - data.append(new_graph) - return {"data": data} - - def _store_data(self, **kwargs): - return self._store_tensor_data(**kwargs) - def __len__(self): return len(next(iter(self.data.values()))) From 4fda9a7875ae69f3b11c9ad514bb6603dd4238fa Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Tue, 30 Dec 2025 12:27:49 +0100 Subject: [PATCH 09/12] use GraphCondition and Tensor condition classes --- pina/condition/__init__.py | 4 +- pina/condition/data_condition.py | 113 +++++++++++++++++-- pina/condition/domain_equation_condition.py | 9 ++ pina/condition/input_equation_condition.py | 46 ++++++-- pina/condition/input_target_condition.py | 117 +++----------------- 5 files changed, 166 insertions(+), 123 deletions(-) diff --git a/pina/condition/__init__.py b/pina/condition/__init__.py index 4e57811fb..13429a829 100644 --- a/pina/condition/__init__.py +++ b/pina/condition/__init__.py @@ -8,7 +8,7 @@ "TensorInputTensorTargetCondition", "TensorInputGraphTargetCondition", "GraphInputTensorTargetCondition", - "GraphInputGraphTargetCondition", + # "GraphInputGraphTargetCondition", "InputEquationCondition", "InputTensorEquationCondition", "InputGraphEquationCondition", @@ -25,7 +25,7 @@ TensorInputTensorTargetCondition, TensorInputGraphTargetCondition, GraphInputTensorTargetCondition, - GraphInputGraphTargetCondition, + # GraphInputGraphTargetCondition, ) from .input_equation_condition import ( InputEquationCondition, diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py index 5f5e7d36b..b04166b51 100644 --- a/pina/condition/data_condition.py +++ b/pina/condition/data_condition.py @@ -1,13 +1,13 @@ """Module for the DataCondition class.""" import torch -from torch_geometric.data import Data -from .condition_interface import ConditionInterface +from torch_geometric.data import Data, Batch +from .condition_base import ConditionBase, GraphCondition, TensorCondition from ..label_tensor import LabelTensor -from ..graph import Graph +from ..graph import Graph, LabelBatch -class DataCondition(ConditionInterface): +class DataCondition(ConditionBase): """ The class :class:`DataCondition` defines an unsupervised condition based on ``input`` data. This condition is typically used in data-driven problems, @@ -38,7 +38,7 @@ class DataCondition(ConditionInterface): """ # Available input data types - __slots__ = ["input", "conditional_variables"] + __fields__ = ["input", "conditional_variables"] _avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) _avail_conditional_variables_cls = (torch.Tensor, LabelTensor) @@ -99,22 +99,115 @@ def __init__(self, input, conditional_variables=None): the list must share the same structure, with matching keys and consistent data types. """ - super().__init__() - self.input = input - self.conditional_variables = conditional_variables + if conditional_variables is None: + super().__init__(input=input) + else: + super().__init__( + input=input, conditional_variables=conditional_variables + ) + + @property + def conditional_variables(self): + """ + Return the conditional variables for the condition. + + :return: The conditional variables. + :rtype: torch.Tensor | LabelTensor | None + """ + return self.data.get("conditional_variables", None) -class TensorDataCondition(DataCondition): +class TensorDataCondition(TensorCondition, DataCondition): """ Specialization of the :class:`DataCondition` class for the case where ``input`` is either a :class:`~pina.label_tensor.LabelTensor` object or a :class:`torch.Tensor` object. """ + @property + def input(self): + """ + Return the input data for the condition. + + :return: The input data. + :rtype: torch.Tensor | LabelTensor + """ + return self.data["input"] + -class GraphDataCondition(DataCondition): +class GraphDataCondition(GraphCondition, DataCondition): """ Specialization of the :class:`DataCondition` class for the case where ``input`` is either a :class:`~pina.graph.Graph` object or a :class:`~torch_geometric.data.Data` object. """ + + def __init__(self, input, conditional_variables=None): + """ + Initialization of the :class:`GraphDataCondition` class. + + :param input: The input data for the condition. + :type input: Graph | Data | list[Graph] | list[Data] | + tuple[Graph] | tuple[Data] + :param conditional_variables: The conditional variables for the + condition. Default is ``None``. + :type conditional_variables: torch.Tensor | LabelTensor + + .. note:: + + If ``input`` is a list of :class:`~pina.graph.Graph` or + :class:`~torch_geometric.data.Data`, all elements in + the list must share the same structure, with matching keys and + consistent data types. + """ + self.graph_field = "input" + self.tensor_fields = [] + self.keys_map = {} + if conditional_variables is not None: + self.tensor_fields.append("conditional_variables") + self.keys_map["conditional_variables"] = "cond_vars" + super().__init__( + input=input, conditional_variables=conditional_variables + ) + + @property + def input(self): + """ + Return the input data for the condition. + + :return: The input data. + :rtype: Graph | Data | list[Graph] | list[Data] | tuple[Graph] | + tuple[Data] + """ + return self.data["data"] + + @property + def conditional_variables(self): + """ + Return the target data for the condition. + + :return: The target data. + :rtype: list[torch.Tensor] | list[LabelTensor] + """ + + if not hasattr(self.data["data"][0], "cond_vars"): + return None + cond_vars = [] + is_lt = isinstance(self.data["data"][0].cond_vars, LabelTensor) + for graph in self.data["data"]: + cond_vars.append(graph.cond_vars) + return ( + torch.stack(cond_vars) + if not is_lt + else LabelTensor.stack(cond_vars) + ) + + def __getitem__(self, idx): + """ + Get item by index from the input data. + + :param int index: The index of the item to retrieve. + :return: The item at the specified index. + :rtype: Graph | Data + """ + input_ = self.batch_fn(self.data["input"][idx]) diff --git a/pina/condition/domain_equation_condition.py b/pina/condition/domain_equation_condition.py index 3e4adbaee..0ce05eeab 100644 --- a/pina/condition/domain_equation_condition.py +++ b/pina/condition/domain_equation_condition.py @@ -67,3 +67,12 @@ def __getitem__(self, idx): "`__getitem__` method is not implemented for " "`DomainEquationCondition`" ) + + def store_data(self): + """ + Store the data for the condition by sampling points from the domain. + + :return: Sampled points from the domain. + :rtype: dict + """ + return {} diff --git a/pina/condition/input_equation_condition.py b/pina/condition/input_equation_condition.py index fa41f79e2..913cdc4d2 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/condition/input_equation_condition.py @@ -1,6 +1,6 @@ """Module for the InputEquationCondition class and its subclasses.""" -from .condition_base import ConditionBase +from .condition_base import ConditionBase, TensorCondition, GraphCondition from ..label_tensor import LabelTensor from ..graph import Graph from ..equation.equation_interface import EquationInterface @@ -99,6 +99,13 @@ def __init__(self, input, equation): super().__init__(input=input) self.equation = equation + +class InputTensorEquationCondition(TensorCondition, InputEquationCondition): + """ + Specialization of the :class:`InputEquationCondition` class for the case + where ``input`` is a :class:`~pina.label_tensor.LabelTensor` object. + """ + @property def input(self): """ @@ -110,18 +117,31 @@ def input(self): return self.data["input"] -class InputTensorEquationCondition(InputEquationCondition): +class InputGraphEquationCondition(GraphCondition, InputEquationCondition): """ Specialization of the :class:`InputEquationCondition` class for the case - where ``input`` is a :class:`~pina.label_tensor.LabelTensor` object. + where ``input`` is a :class:`~pina.graph.Graph` object. """ + def __init__(self, input, equation): + """ + Initialization of the :class:`InputGraphEquationCondition` class. -class InputGraphEquationCondition(InputEquationCondition): - """ - Specialization of the :class:`InputEquationCondition` class for the case - where ``input`` is a :class:`~pina.graph.Graph` object. - """ + :param input: The input data for the condition. + :type input: Graph | list[Graph] | tuple[Graph] + :param EquationInterface equation: The equation to be satisfied over the + specified input points. + + .. note:: + + If ``input`` is a list of :class:`~pina.graph.Graph` all elements in + the list must share the same structure, with matching keys and + consistent data types. + """ + self.graph_field = "input" + self.tensor_fields = [] + self.keys_map = {} + super().__init__(input=[input], equation=equation) @staticmethod def _check_label_tensor(input): @@ -145,3 +165,13 @@ def _check_label_tensor(input): return raise ValueError("The input must contain at least one LabelTensor.") + + @property + def input(self): + """ + Return the input data for the condition. + + :return: The input data. + :rtype: list[Graph] | list[Data] + """ + return self.data["data"] diff --git a/pina/condition/input_target_condition.py b/pina/condition/input_target_condition.py index c90fcc8e3..3e041bf90 100644 --- a/pina/condition/input_target_condition.py +++ b/pina/condition/input_target_condition.py @@ -3,10 +3,10 @@ """ import torch -from torch_geometric.data import Data, Batch +from torch_geometric.data import Data from ..label_tensor import LabelTensor -from ..graph import Graph, LabelBatch -from .condition_base import ConditionBase +from ..graph import Graph +from .condition_base import ConditionBase, GraphCondition, TensorCondition class InputTargetCondition(ConditionBase): @@ -115,7 +115,7 @@ def __new__(cls, input, target): "LabelTensor or torch.Tensor objects." ) - def __init__(self, input, target): + def __init__(self, **kwargs): """ Initialization of the :class:`InputTargetCondition` class. @@ -133,36 +133,10 @@ def __init__(self, input, target): objects, all elements in the list must share the same structure, with matching keys and consistent data types. """ - self._check_input_target_len(input, target) - super().__init__(input=input, target=target) + super().__init__(**kwargs) - @staticmethod - def _check_input_target_len(input, target): - """ - Check that the length of the input and target lists are the same. - :param input: The input data. - :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] - :param target: The target data. - :type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] - :raises ValueError: If the lengths of the input and target lists do not - match. - """ - if isinstance(input, (Graph, Data)) or isinstance( - target, (Graph, Data) - ): - return - - # Raise an error if the lengths of the input and target do not match - if len(input) != len(target): - raise ValueError( - "The input and target lists must have the same length." - ) - - -class TensorInputTensorTargetCondition(InputTargetCondition): +class TensorInputTensorTargetCondition(InputTargetCondition, TensorCondition): """ Specialization of the :class:`InputTargetCondition` class for the case where both ``input`` and ``target`` are :class:`torch.Tensor` or @@ -190,7 +164,7 @@ def target(self): return self.data["target"] -class TensorInputGraphTargetCondition(InputTargetCondition): +class TensorInputGraphTargetCondition(GraphCondition, InputTargetCondition): """ Specialization of the :class:`InputTargetCondition` class for the case where ``input`` is either a :class:`torch.Tensor` or a @@ -208,24 +182,10 @@ def __init__(self, input, target): :type target: Graph | Data | list[Graph] | list[Data] | tuple[Graph] | tuple[Data] """ + self.graph_field = "target" + self.tensor_fields = ["input"] + self.keys_map = {"input": "x"} super().__init__(input=input, target=target) - self.batch_fn = ( - LabelBatch.from_data_list - if isinstance(target[0], Graph) - else Batch.from_data_list - ) - - def _store_data(self, **kwargs): - """ - Store the input and target data for the condition. - - :param kwargs: Keyword arguments containing 'input' and 'target'. - :return: Stored data dictionary. - :rtype: dict - """ - return self._store_graph_data( - kwargs["target"], kwargs["input"], key="x" - ) @property def input(self): @@ -251,27 +211,6 @@ def target(self): """ return self.data["data"] - def __getitem__(self, idx): - if isinstance(idx, list): - return self.get_multiple_data(idx) - return {"data": self.data["data"][idx]} - - def get_multiple_data(self, indices): - """ - Get multiple data items based on the provided indices. - - :param List[int] indices: List of indices to retrieve. - :return: Dictionary containing 'input' and 'target' data. - :rtype: dict - """ - data = self.batch_fn([self.data["data"][i] for i in indices]) - x = data.x - del data.x # Avoid duplication of y on GPU memory - return { - "input": x, - "target": data, - } - @classmethod def automatic_batching_collate_fn(cls, batch): """ @@ -289,7 +228,7 @@ def automatic_batching_collate_fn(cls, batch): return to_return -class GraphInputTensorTargetCondition(InputTargetCondition): +class GraphInputTensorTargetCondition(GraphCondition, InputTargetCondition): """ Specialization of the :class:`InputTargetCondition` class for the case where ``input`` is either a :class:`~pina.graph.Graph` or @@ -307,17 +246,10 @@ def __init__(self, input, target): :param target: The target data for the condition. :type target: torch.Tensor | LabelTensor """ + self.graph_field = "input" + self.tensor_fields = ["target"] + self.keys_map = {"target": "y"} super().__init__(input=input, target=target) - self.batch_fn = ( - LabelBatch.from_data_list - if isinstance(input[0], Graph) - else Batch.from_data_list - ) - - def _store_data(self, **kwargs): - return self._store_graph_data( - kwargs["input"], kwargs["target"], key="y" - ) @property def input(self): @@ -344,27 +276,6 @@ def target(self): return torch.stack(targets) if not is_lt else LabelTensor.stack(targets) - def __getitem__(self, idx): - if isinstance(idx, list): - return self.get_multiple_data(idx) - return {"data": self.data["data"][idx]} - - def get_multiple_data(self, indices): - """ - Get multiple data items based on the provided indices. - - :param List[int] indices: List of indices to retrieve. - :return: Dictionary containing 'input' and 'target' data. - :rtype: dict - """ - data = self.batch_fn([self.data["data"][i] for i in indices]) - y = data.y - del data.y # Avoid duplication of y on GPU memory - return { - "input": data, - "target": y, - } - @classmethod def automatic_batching_collate_fn(cls, batch): """ From 598bce42cf09b62366926e51dec59ac3a49b826e Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Tue, 30 Dec 2025 12:28:04 +0100 Subject: [PATCH 10/12] fix tests --- tests/test_condition.py | 325 +++++++++--------- tests/test_condition/test_data_condition.py | 100 ++++++ .../test_input_equation_condition.py | 4 +- .../test_input_target_condition.py | 3 + .../test_ensemble_supervised_solver.py | 3 +- tests/test_solver/test_supervised_solver.py | 3 +- 6 files changed, 281 insertions(+), 157 deletions(-) create mode 100644 tests/test_condition/test_data_condition.py diff --git a/tests/test_condition.py b/tests/test_condition.py index 9199f2bd9..8a5480499 100644 --- a/tests/test_condition.py +++ b/tests/test_condition.py @@ -1,154 +1,171 @@ -import torch -import pytest - -from pina import LabelTensor, Condition -from pina.condition import ( - TensorInputGraphTargetCondition, - TensorInputTensorTargetCondition, - GraphInputGraphTargetCondition, - GraphInputTensorTargetCondition, -) -from pina.condition import ( - InputTensorEquationCondition, - InputGraphEquationCondition, - DomainEquationCondition, -) -from pina.condition import ( - TensorDataCondition, - GraphDataCondition, -) -from pina.domain import CartesianDomain -from pina.equation.equation_factory import FixedValue -from pina.graph import RadiusGraph - -example_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]}) - -input_tensor = torch.rand((10, 3)) -target_tensor = torch.rand((10, 2)) -input_lt = LabelTensor(torch.rand((10, 3)), ["x", "y", "z"]) -target_lt = LabelTensor(torch.rand((10, 2)), ["a", "b"]) - -x = torch.rand(10, 20, 2) -pos = torch.rand(10, 20, 2) -radius = 0.1 -input_graph = [ - RadiusGraph( - x=x_, - pos=pos_, - radius=radius, - ) - for x_, pos_ in zip(x, pos) -] -target_graph = [ - RadiusGraph( - x=x_, - pos=pos_, - radius=radius, - ) - for x_, pos_ in zip(x, pos) -] - -x = LabelTensor(torch.rand(10, 20, 2), ["u", "v"]) -pos = LabelTensor(torch.rand(10, 20, 2), ["x", "y"]) -radius = 0.1 -input_graph_lt = [ - RadiusGraph( - x=x[i], - pos=pos[i], - radius=radius, - ) - for i in range(len(x)) -] -target_graph_lt = [ - RadiusGraph( - x=x[i], - pos=pos[i], - radius=radius, - ) - for i in range(len(x)) -] - -input_single_graph = input_graph[0] -target_single_graph = target_graph[0] - - -def test_init_input_target(): - cond = Condition(input=input_tensor, target=target_tensor) - assert isinstance(cond, TensorInputTensorTargetCondition) - cond = Condition(input=input_tensor, target=target_tensor) - assert isinstance(cond, TensorInputTensorTargetCondition) - cond = Condition(input=input_tensor, target=target_graph) - assert isinstance(cond, TensorInputGraphTargetCondition) - cond = Condition(input=input_graph, target=target_tensor) - assert isinstance(cond, GraphInputTensorTargetCondition) - cond = Condition(input=input_graph, target=target_graph) - assert isinstance(cond, GraphInputGraphTargetCondition) - - cond = Condition(input=input_lt, target=input_single_graph) - assert isinstance(cond, TensorInputGraphTargetCondition) - cond = Condition(input=input_single_graph, target=target_lt) - assert isinstance(cond, GraphInputTensorTargetCondition) - cond = Condition(input=input_graph, target=target_graph) - assert isinstance(cond, GraphInputGraphTargetCondition) - cond = Condition(input=input_single_graph, target=target_single_graph) - assert isinstance(cond, GraphInputGraphTargetCondition) - - with pytest.raises(ValueError): - Condition(input_tensor, input_tensor) - with pytest.raises(ValueError): - Condition(input=3.0, target="example") - with pytest.raises(ValueError): - Condition(input=example_domain, target=example_domain) - - # Test wrong graph condition initialisation - input = [input_graph[0], input_graph_lt[0]] - target = [target_graph[0], target_graph_lt[0]] - with pytest.raises(ValueError): - Condition(input=input, target=target) - - input_graph_lt[0].x.labels = ["a", "b"] - with pytest.raises(ValueError): - Condition(input=input_graph_lt, target=target_graph_lt) - input_graph_lt[0].x.labels = ["u", "v"] - - -def test_init_domain_equation(): - cond = Condition(domain=example_domain, equation=FixedValue(0.0)) - assert isinstance(cond, DomainEquationCondition) - with pytest.raises(ValueError): - Condition(example_domain, FixedValue(0.0)) - with pytest.raises(ValueError): - Condition(domain=3.0, equation="example") - with pytest.raises(ValueError): - Condition(domain=input_tensor, equation=input_graph) - - -def test_init_input_equation(): - cond = Condition(input=input_lt, equation=FixedValue(0.0)) - assert isinstance(cond, InputTensorEquationCondition) - cond = Condition(input=input_graph_lt, equation=FixedValue(0.0)) - assert isinstance(cond, InputGraphEquationCondition) - with pytest.raises(ValueError): - cond = Condition(input=input_tensor, equation=FixedValue(0.0)) - with pytest.raises(ValueError): - Condition(example_domain, FixedValue(0.0)) - with pytest.raises(ValueError): - Condition(input=3.0, equation="example") - with pytest.raises(ValueError): - Condition(input=example_domain, equation=input_graph) - - -test_init_input_equation() - - -def test_init_data_condition(): - cond = Condition(input=input_lt) - assert isinstance(cond, TensorDataCondition) - cond = Condition(input=input_tensor) - assert isinstance(cond, TensorDataCondition) - cond = Condition(input=input_tensor, conditional_variables=torch.tensor(1)) - assert isinstance(cond, TensorDataCondition) - cond = Condition(input=input_graph) - assert isinstance(cond, GraphDataCondition) - cond = Condition(input=input_graph, conditional_variables=torch.tensor(1)) - assert isinstance(cond, GraphDataCondition) +# import torch +# import pytest + +# from pina import LabelTensor, Condition +# from pina.condition import ( +# TensorInputGraphTargetCondition, +# TensorInputTensorTargetCondition, +# # GraphInputGraphTargetCondition, +# GraphInputTensorTargetCondition, +# ) +# from pina.condition import ( +# InputTensorEquationCondition, +# InputGraphEquationCondition, +# DomainEquationCondition, +# ) +# from pina.condition import ( +# TensorDataCondition, +# GraphDataCondition, +# ) +# from pina.domain import CartesianDomain +# from pina.equation.equation_factory import FixedValue +# from pina.graph import RadiusGraph + +# def _create_tensor_data(): +# input_tensor = torch.rand((10, 3)) +# target_tensor = torch.rand((10, 2)) +# return input_tensor, target_tensor + +# def _create_graph_data(): +# x = torch.rand(10, 20, 2) +# pos = torch.rand(10, 20, 2) +# radius = 0.1 +# input_graph = [ +# RadiusGraph( +# x=x_, +# pos=pos_, +# radius=radius, +# ) +# for x_, pos_ in zip(x, pos) +# ] +# target_graph = [ +# RadiusGraph( +# y=x_, +# pos=pos_, +# radius=radius, +# ) +# for x_, pos_ in zip(x, pos) +# ] +# return input_graph, target_graph + +# def _create_lt_data(): +# input_lt = LabelTensor(torch.rand((10, 3)), ["x", "y", "z"]) +# target_lt = LabelTensor(torch.rand((10, 2)), ["a", "b"]) +# return input_lt, target_lt + +# def _create_graph_lt_data(): +# x = LabelTensor(torch.rand((10, 20, 2)), ["u", "v"]) +# pos = LabelTensor(torch.rand((10, 20, 2)), ["x", "y"]) +# radius = 0.1 +# input_graph = [ +# RadiusGraph( +# x=x[i], +# pos=pos[i], +# radius=radius, +# ) +# for i in range(len(x)) +# ] +# target_graph = [ +# RadiusGraph( +# y=x[i], +# pos=pos[i], +# radius=radius, +# ) +# for i in range(len(x)) +# ] +# return input_graph, target_graph + +# example_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]}) + +# def test_init_input_target(): +# input_tensor, target_tensor = _create_tensor_data() +# cond = Condition(input=input_tensor, target=target_tensor) +# assert isinstance(cond, TensorInputTensorTargetCondition) + +# input_lt, target_lt = _create_lt_data() +# cond = Condition(input=input_lt, target=target_lt) +# assert isinstance(cond, TensorInputTensorTargetCondition) + +# input_graph, target_graph = _create_graph_data() +# cond = Condition(input=input_tensor, target=target_graph) +# assert isinstance(cond, TensorInputGraphTargetCondition) + +# cond = Condition(input=input_graph, target=target_tensor) +# assert isinstance(cond, GraphInputTensorTargetCondition) + +# input_single_graph = input_graph[0] +# target_single_graph = target_graph[0] +# cond = Condition(input=input_lt, target=input_single_graph) +# assert isinstance(cond, TensorInputGraphTargetCondition) +# cond = Condition(input=input_single_graph, target=target_lt) +# assert isinstance(cond, GraphInputTensorTargetCondition) +# # cond = Condition(input=input_graph, target=target_graph) +# # assert isinstance(cond, GraphInputGraphTargetCondition) +# # cond = Condition(input=input_single_graph, target=target_single_graph) +# # assert isinstance(cond, GraphInputGraphTargetCondition) +# input_graph_lt, target_graph_lt = _create_graph_lt_data() + +# with pytest.raises(ValueError): +# Condition(input_tensor, input_tensor) +# with pytest.raises(ValueError): +# Condition(input=3.0, target="example") +# with pytest.raises(ValueError): +# Condition(input=example_domain, target=example_domain) + +# # Test wrong graph condition initialisation +# input = [input_graph[0], input_graph_lt[0]] +# target = [target_graph[0], target_graph_lt[0]] +# with pytest.raises(ValueError): +# Condition(input=input, target=target) + +# input_graph_lt[0].x.labels = ["a", "b"] +# with pytest.raises(ValueError): +# Condition(input=input_graph_lt, target=target_graph_lt) +# input_graph_lt[0].x.labels = ["u", "v"] + + +# def test_init_domain_equation(): +# input_tensor, _ = _create_tensor_data() +# input_graph, _ = _create_graph_data() +# cond = Condition(domain=example_domain, equation=FixedValue(0.0)) +# assert isinstance(cond, DomainEquationCondition) +# with pytest.raises(ValueError): +# Condition(example_domain, FixedValue(0.0)) +# with pytest.raises(ValueError): +# Condition(domain=3.0, equation="example") +# with pytest.raises(ValueError): +# Condition(domain=input_tensor, equation=input_graph) + + +# def test_init_input_equation(): +# input_lt, _ = _create_lt_data() +# input_graph_lt, _ = _create_graph_lt_data() +# input_tensor, _ = _create_tensor_data() +# input_graph, _ = _create_graph_data() +# cond = Condition(input=input_lt, equation=FixedValue(0.0)) +# assert isinstance(cond, InputTensorEquationCondition) +# cond = Condition(input=input_graph_lt, equation=FixedValue(0.0)) +# assert isinstance(cond, InputGraphEquationCondition) +# with pytest.raises(ValueError): +# cond = Condition(input=input_tensor, equation=FixedValue(0.0)) +# with pytest.raises(ValueError): +# Condition(example_domain, FixedValue(0.0)) +# with pytest.raises(ValueError): +# Condition(input=3.0, equation="example") +# with pytest.raises(ValueError): +# Condition(input=example_domain, equation=input_graph) + +# def test_init_data_condition(): +# input_lt, _ = _create_lt_data() +# input_tensor, _ = _create_tensor_data() +# input_graph, _ = _create_graph_data() +# cond = Condition(input=input_lt) +# assert isinstance(cond, TensorDataCondition) +# cond = Condition(input=input_tensor) +# assert isinstance(cond, TensorDataCondition) +# cond = Condition(input=input_tensor, conditional_variables=torch.tensor(1)) +# assert isinstance(cond, TensorDataCondition) +# cond = Condition(input=input_graph) +# assert isinstance(cond, GraphDataCondition) +# cond = Condition(input=input_graph, conditional_variables=torch.tensor(1)) +# assert isinstance(cond, GraphDataCondition) diff --git a/tests/test_condition/test_data_condition.py b/tests/test_condition/test_data_condition.py new file mode 100644 index 000000000..954e8f777 --- /dev/null +++ b/tests/test_condition/test_data_condition.py @@ -0,0 +1,100 @@ +import pytest +import torch +from pina import Condition, LabelTensor +from pina.condition import ( + TensorDataCondition, + GraphDataCondition, +) +from pina.graph import RadiusGraph +from torch_geometric.data import Data + + +def _create_tensor_data(use_lt=False, conditional_variables=False): + input_tensor = torch.rand((10, 3)) + if use_lt: + input_tensor = LabelTensor(input_tensor, ["x", "y", "z"]) + if conditional_variables: + cond_vars = torch.rand((10, 2)) + if use_lt: + cond_vars = LabelTensor(cond_vars, ["a", "b"]) + else: + cond_vars = None + return input_tensor, cond_vars + + +def _create_graph_data(use_lt=False, conditional_variables=False): + if use_lt: + x = LabelTensor(torch.rand(10, 20, 2), ["u", "v"]) + pos = LabelTensor(torch.rand(10, 20, 2), ["x", "y"]) + else: + x = torch.rand(10, 20, 2) + pos = torch.rand(10, 20, 2) + radius = 0.1 + input_graph = [ + RadiusGraph(pos=pos[i], radius=radius, x=x[i]) for i in range(len(x)) + ] + if conditional_variables: + if use_lt: + cond_vars = LabelTensor(torch.rand(10, 20, 1), ["f"]) + else: + cond_vars = torch.rand(10, 20, 1) + else: + cond_vars = None + return input_graph, cond_vars + + +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_init_tensor_data_condition(use_lt, conditional_variables): + input_tensor, cond_vars = _create_tensor_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = TensorDataCondition( + input=input_tensor, conditional_variables=cond_vars + ) + type_ = LabelTensor if use_lt else torch.Tensor + if conditional_variables: + assert condition.conditional_variables is not None + assert isinstance(condition.conditional_variables, type_) + if use_lt: + assert condition.conditional_variables.labels == ["a", "b"] + else: + assert condition.conditional_variables is None + assert isinstance(condition.input, type_) + if use_lt: + assert condition.input.labels == ["x", "y", "z"] + + +test_init_tensor_data_condition(False, False) + + +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_init_graph_data_condition(use_lt, conditional_variables): + input_graph, cond_vars = _create_graph_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = GraphDataCondition( + input=input_graph, conditional_variables=cond_vars + ) + type_ = LabelTensor if use_lt else torch.Tensor + if conditional_variables: + assert condition.conditional_variables is not None + assert isinstance(condition.conditional_variables, type_) + if use_lt: + assert condition.conditional_variables.labels == ["f"] + else: + assert condition.conditional_variables is None + # assert "conditional_variables" not in condition.data.keys() + assert isinstance(condition.input, list) + for graph in condition.input: + assert isinstance(graph, Data) + assert isinstance(graph.x, type_) + if use_lt: + assert graph.x.labels == ["u", "v"] + assert isinstance(graph.pos, type_) + if use_lt: + assert graph.pos.labels == ["x", "y"] + + +test_init_graph_data_condition(False, False) diff --git a/tests/test_condition/test_input_equation_condition.py b/tests/test_condition/test_input_equation_condition.py index b6a687e2a..af11d382e 100644 --- a/tests/test_condition/test_input_equation_condition.py +++ b/tests/test_condition/test_input_equation_condition.py @@ -42,7 +42,9 @@ def test_init_graph_equation_condition(): graph, equation = _create_graph_and_equation() condition = Condition(input=graph, equation=equation) assert isinstance(condition, InputGraphEquationCondition) - assert condition.input is graph + assert isinstance(condition.input, list) + assert len(condition.input) == 1 + assert condition.input[0].x.shape == (100, 2) assert condition.equation is equation diff --git a/tests/test_condition/test_input_target_condition.py b/tests/test_condition/test_input_target_condition.py index 033f7094a..81c3a9b24 100644 --- a/tests/test_condition/test_input_target_condition.py +++ b/tests/test_condition/test_input_target_condition.py @@ -292,3 +292,6 @@ def test_getitems_tensor_input_graph_target_condition(use_lt): "u", "v", ], "TensorInputGraphTargetCondition __getitems__ target labels failed" + + +test_init_graph_input_tensor_target_condition(use_lt=True) diff --git a/tests/test_solver/test_ensemble_supervised_solver.py b/tests/test_solver/test_ensemble_supervised_solver.py index c5f0b9e52..4be2897d9 100644 --- a/tests/test_solver/test_ensemble_supervised_solver.py +++ b/tests/test_solver/test_ensemble_supervised_solver.py @@ -83,7 +83,8 @@ def forward(self, batch): y = self.conv(y, edge_index) y = self.activation(y) y = self.output(y) - return to_dense_batch(y, batch.batch)[0] + return y + # return to_dense_batch(y, batch.batch)[0] graph_models = [Models() for i in range(10)] diff --git a/tests/test_solver/test_supervised_solver.py b/tests/test_solver/test_supervised_solver.py index 6f7d1ab4d..461130a6b 100644 --- a/tests/test_solver/test_supervised_solver.py +++ b/tests/test_solver/test_supervised_solver.py @@ -83,7 +83,8 @@ def forward(self, batch): y = self.conv(y, edge_index) y = self.activation(y) y = self.output(y) - return to_dense_batch(y, batch.batch)[0] + return y + # return to_dense_batch(y, batch.batch)[0] graph_model = Model() From f6619e101d2e0769ee5a4149968513d05c7fe11a Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Wed, 31 Dec 2025 08:38:20 +0100 Subject: [PATCH 11/12] fix DataCondition and relative tests --- pina/condition/condition_base.py | 29 ++++- pina/condition/data_condition.py | 14 +-- pina/condition/input_target_condition.py | 54 --------- tests/test_condition/test_data_condition.py | 122 ++++++++++++++++++-- 4 files changed, 141 insertions(+), 78 deletions(-) diff --git a/pina/condition/condition_base.py b/pina/condition/condition_base.py index b8b828767..8365c51d4 100644 --- a/pina/condition/condition_base.py +++ b/pina/condition/condition_base.py @@ -14,6 +14,10 @@ class TensorCondition: + """ + Base class for tensor conditions. + """ + def store_data(self, **kwargs): """ Store data for standard tensor condition @@ -29,6 +33,10 @@ def store_data(self, **kwargs): class GraphCondition: + """ + Base class for graph conditions. + """ + def __init__(self, **kwargs): super().__init__(**kwargs) example = kwargs.get(self.graph_field)[0] @@ -85,6 +93,26 @@ def get_multiple_data(self, indices): to_return_dict[key] = y return to_return_dict + @classmethod + def automatic_batching_collate_fn(cls, batch): + """ + Collate function to be used in DataLoader. + + :param batch: A list of items from the dataset. + :type batch: list + :return: A collated batch. + :rtype: dict + """ + collated_graphs = super().automatic_batching_collate_fn(batch)["data"] + to_return_dict = {} + for key in cls.tensor_fields: + mapping_key = cls.keys_map.get(key) + tensor = getattr(collated_graphs, mapping_key) + to_return_dict[key] = tensor + delattr(collated_graphs, mapping_key) + to_return_dict[cls.graph_field] = collated_graphs + return to_return_dict + class ConditionBase(ConditionInterface): """ @@ -269,5 +297,4 @@ def create_dataloader( if not automatic_batching else self.automatic_batching_collate_fn ), - # collate_fn = self.automatic_batching_collate_fn ) diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py index b04166b51..e78dc800f 100644 --- a/pina/condition/data_condition.py +++ b/pina/condition/data_condition.py @@ -1,10 +1,10 @@ """Module for the DataCondition class.""" import torch -from torch_geometric.data import Data, Batch +from torch_geometric.data import Data from .condition_base import ConditionBase, GraphCondition, TensorCondition from ..label_tensor import LabelTensor -from ..graph import Graph, LabelBatch +from ..graph import Graph class DataCondition(ConditionBase): @@ -201,13 +201,3 @@ def conditional_variables(self): if not is_lt else LabelTensor.stack(cond_vars) ) - - def __getitem__(self, idx): - """ - Get item by index from the input data. - - :param int index: The index of the item to retrieve. - :return: The item at the specified index. - :rtype: Graph | Data - """ - input_ = self.batch_fn(self.data["input"][idx]) diff --git a/pina/condition/input_target_condition.py b/pina/condition/input_target_condition.py index 3e041bf90..064f4b5eb 100644 --- a/pina/condition/input_target_condition.py +++ b/pina/condition/input_target_condition.py @@ -115,26 +115,6 @@ def __new__(cls, input, target): "LabelTensor or torch.Tensor objects." ) - def __init__(self, **kwargs): - """ - Initialization of the :class:`InputTargetCondition` class. - - :param input: The input data for the condition. - :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] - :param target: The target data for the condition. - :type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] - - .. note:: - - If either ``input`` or ``target`` is a list of - :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` - objects, all elements in the list must share the same structure, - with matching keys and consistent data types. - """ - super().__init__(**kwargs) - class TensorInputTensorTargetCondition(InputTargetCondition, TensorCondition): """ @@ -211,22 +191,6 @@ def target(self): """ return self.data["data"] - @classmethod - def automatic_batching_collate_fn(cls, batch): - """ - Collate function to be used in DataLoader. - - :param batch: A list of items from the dataset. - :type batch: List[dict] - :return: A collated batch. - :rtype: dict - """ - collated_graphs = super().automatic_batching_collate_fn(batch) - x = collated_graphs["data"].x - del collated_graphs["data"].x # Avoid duplication of y on GPU memory - to_return = {"input": x, "target": collated_graphs["data"]} - return to_return - class GraphInputTensorTargetCondition(GraphCondition, InputTargetCondition): """ @@ -275,21 +239,3 @@ def target(self): targets.append(graph.y) return torch.stack(targets) if not is_lt else LabelTensor.stack(targets) - - @classmethod - def automatic_batching_collate_fn(cls, batch): - """ - Collate function to be used in DataLoader. - - :param batch: A list of items from the dataset. - :type batch: list - :return: A collated batch. - :rtype: dict - """ - collated_graphs = super().automatic_batching_collate_fn(batch) - y = collated_graphs["data"].y - del collated_graphs["data"].y # Avoid duplication of y on GPU memory - print("y shape:", y.shape) - print(y.labels) - to_return = {"target": y, "input": collated_graphs["data"]} - return to_return diff --git a/tests/test_condition/test_data_condition.py b/tests/test_condition/test_data_condition.py index 954e8f777..e922bdbcd 100644 --- a/tests/test_condition/test_data_condition.py +++ b/tests/test_condition/test_data_condition.py @@ -6,7 +6,8 @@ GraphDataCondition, ) from pina.graph import RadiusGraph -from torch_geometric.data import Data +from torch_geometric.data import Data, Batch +from pina.graph import Graph, LabelBatch def _create_tensor_data(use_lt=False, conditional_variables=False): @@ -49,9 +50,8 @@ def test_init_tensor_data_condition(use_lt, conditional_variables): input_tensor, cond_vars = _create_tensor_data( use_lt=use_lt, conditional_variables=conditional_variables ) - condition = TensorDataCondition( - input=input_tensor, conditional_variables=cond_vars - ) + condition = Condition(input=input_tensor, conditional_variables=cond_vars) + assert isinstance(condition, TensorDataCondition) type_ = LabelTensor if use_lt else torch.Tensor if conditional_variables: assert condition.conditional_variables is not None @@ -65,18 +65,14 @@ def test_init_tensor_data_condition(use_lt, conditional_variables): assert condition.input.labels == ["x", "y", "z"] -test_init_tensor_data_condition(False, False) - - @pytest.mark.parametrize("use_lt", [False, True]) @pytest.mark.parametrize("conditional_variables", [False, True]) def test_init_graph_data_condition(use_lt, conditional_variables): input_graph, cond_vars = _create_graph_data( use_lt=use_lt, conditional_variables=conditional_variables ) - condition = GraphDataCondition( - input=input_graph, conditional_variables=cond_vars - ) + condition = Condition(input=input_graph, conditional_variables=cond_vars) + assert isinstance(condition, GraphDataCondition) type_ = LabelTensor if use_lt else torch.Tensor if conditional_variables: assert condition.conditional_variables is not None @@ -97,4 +93,108 @@ def test_init_graph_data_condition(use_lt, conditional_variables): assert graph.pos.labels == ["x", "y"] -test_init_graph_data_condition(False, False) +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_getitem_tensor_data_condition(use_lt, conditional_variables): + input_tensor, cond_vars = _create_tensor_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = Condition(input=input_tensor, conditional_variables=cond_vars) + item = condition[0] + assert isinstance(item, dict) + assert "input" in item + type_ = LabelTensor if use_lt else torch.Tensor + assert isinstance(item["input"], type_) + assert item["input"].shape == (3,) + if type_ is LabelTensor: + assert item["input"].labels == ["x", "y", "z"] + if conditional_variables: + assert "conditional_variables" in item + assert isinstance(item["conditional_variables"], type_) + assert item["conditional_variables"].shape == (2,) + if type_ is LabelTensor: + assert item["conditional_variables"].labels == ["a", "b"] + else: + assert "conditional_variables" not in item + + +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_getitem_graph_data_condition(use_lt, conditional_variables): + input_graph, cond_vars = _create_graph_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = Condition(input=input_graph, conditional_variables=cond_vars) + item = condition[0] + assert isinstance(item, dict) + assert "data" in item + graph = item["data"] + assert isinstance(graph, Data) + type_ = LabelTensor if use_lt else torch.Tensor + assert isinstance(graph.x, type_) + assert graph.x.shape == (20, 2) + if use_lt: + assert graph.x.labels == ["u", "v"] + assert isinstance(graph.pos, type_) + assert graph.pos.shape == (20, 2) + if use_lt: + assert graph.pos.labels == ["x", "y"] + if conditional_variables: + assert hasattr(graph, "cond_vars") + cond_var = graph.cond_vars + assert isinstance(cond_var, type_) + assert cond_var.shape == (20, 1) + if use_lt: + assert cond_var.labels == ["f"] + + +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_getitems_graph_data_condition(use_lt, conditional_variables): + input_graph, cond_vars = _create_graph_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = Condition(input=input_graph, conditional_variables=cond_vars) + idxs = [0, 1, 3] + items = condition[idxs] + assert isinstance(items, dict) + assert "input" in items + graphs = items["input"] + assert isinstance(graphs, LabelBatch) + assert graphs.num_graphs == 3 + if conditional_variables: + type_ = LabelTensor if use_lt else torch.Tensor + assert "conditional_variables" in items + cond_vars_batch = items["conditional_variables"] + assert isinstance(cond_vars_batch, type_) + assert cond_vars_batch.shape == (60, 1) + if use_lt: + assert cond_vars_batch.labels == ["f"] + + +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_getitems_tensor_data_condition(use_lt, conditional_variables): + input_tensor, cond_vars = _create_tensor_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = Condition(input=input_tensor, conditional_variables=cond_vars) + idxs = [0, 1, 3] + items = condition[idxs] + assert isinstance(items, dict) + assert "input" in items + type_ = LabelTensor if use_lt else torch.Tensor + inputs = items["input"] + assert isinstance(inputs, type_) + assert inputs.shape == (3, 3) + if use_lt: + assert inputs.labels == ["x", "y", "z"] + if conditional_variables: + assert "conditional_variables" in items + cond_vars_items = items["conditional_variables"] + assert isinstance(cond_vars_items, type_) + assert cond_vars_items.shape == (3, 2) + if use_lt: + assert cond_vars_items.labels == ["a", "b"] + else: + assert "conditional_variables" not in items From 7b9096ef3431ceb759cc91c67ffddfbc5d184655 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Mon, 12 Jan 2026 15:37:52 +0100 Subject: [PATCH 12/12] fixes --- pina/condition/condition_base.py | 61 ++++++++++++++------- pina/condition/domain_equation_condition.py | 19 ++++++- 2 files changed, 56 insertions(+), 24 deletions(-) diff --git a/pina/condition/condition_base.py b/pina/condition/condition_base.py index 8365c51d4..6283a78bb 100644 --- a/pina/condition/condition_base.py +++ b/pina/condition/condition_base.py @@ -10,7 +10,6 @@ from .condition_interface import ConditionInterface from ..graph import Graph, LabelBatch from ..label_tensor import LabelTensor -from ..data.dummy_dataloader import DummyDataloader class TensorCondition: @@ -229,48 +228,68 @@ def _check_graph_list_consistency(data_list): ) def __len__(self): + """ + Return the number of data points in the condition. + + :return: Number of data points. + :rtype: int + """ return len(next(iter(self.data.values()))) def __getitem__(self, idx): + """ + Return the data point(s) at the specified index. + + :param idx: Index(es) of the data point(s) to retrieve. + :type idx: int | list[int] + :return: Data point(s) at the specified index. + """ return {name: data[idx] for name, data in self.data.items()} @classmethod def automatic_batching_collate_fn(cls, batch): """ - Collate function to be used in DataLoader. - + Collate function for automatic batching to be used in DataLoader. :param batch: A list of items from the dataset. :type batch: list :return: A collated batch. :rtype: dict """ + if not batch: + return {} + keys = batch[0].keys() + columns = zip(*[item.values() for item in batch]) + + to_return = {} + + # 2. Process each column + for key, values in zip(keys, columns): + # Determine type based on the first sample only + first_val = values[0] + + if isinstance(first_val, (LabelTensor, torch.Tensor)): + lookup_key = "label_tensor" + elif isinstance(first_val, Graph): + lookup_key = "graph" + else: + lookup_key = "data" + + # Execute the specific collate function + to_return[key] = cls.collate_fn_dict[lookup_key](list(values)) - to_return = {key: [] for key in batch[0].keys()} - for item in batch: - for key, value in item.items(): - to_return[key].append(value) - for key, values in to_return.items(): - collate_function = cls.collate_fn_dict.get( - "label_tensor" - if isinstance(values[0], LabelTensor) - else ( - "label_tensor" - if isinstance(values[0], torch.Tensor) - else "graph" if isinstance(values[0], Graph) else "data" - ) - ) - to_return[key] = collate_function(values) return to_return @staticmethod def collate_fn(batch, condition): """ - Collate function for automatic batching to be used in DataLoader. + Collate function for custom batching to be used in DataLoader. :param batch: A list of items from the dataset. :type batch: list + :param condition: The condition instance. + :type condition: ConditionBase :return: A collated batch. - :rtype: list + :rtype: dict """ data = condition[batch] return data @@ -287,7 +306,7 @@ def create_dataloader( :rtype: torch.utils.data.DataLoader """ if batch_size == len(dataset): - return DummyDataloader(dataset) + pass # will be updated in the near future return DataLoader( dataset=dataset, batch_size=batch_size, diff --git a/pina/condition/domain_equation_condition.py b/pina/condition/domain_equation_condition.py index 0ce05eeab..673bf1612 100644 --- a/pina/condition/domain_equation_condition.py +++ b/pina/condition/domain_equation_condition.py @@ -55,6 +55,13 @@ def __init__(self, domain, equation): self.equation = equation def __len__(self): + """ + Raise NotImplementedError since the number of points is determined by + the domain sampling strategy. + + :raises NotImplementedError: Always raised since the number of points is + determined by the domain sampling strategy. + """ raise NotImplementedError( "`__len__` method is not implemented for " "`DomainEquationCondition` since the number of points is " @@ -62,7 +69,13 @@ def __len__(self): ) def __getitem__(self, idx): - """ """ + """ + Raise NotImplementedError since data retrieval is not applicable. + + :param int idx: Index of the data point(s) to retrieve. + :raises NotImplementedError: Always raised since data retrieval is not + applicable for this condition. + """ raise NotImplementedError( "`__getitem__` method is not implemented for " "`DomainEquationCondition`" @@ -70,9 +83,9 @@ def __getitem__(self, idx): def store_data(self): """ - Store the data for the condition by sampling points from the domain. + Store data for the condition. No data is stored for this condition. - :return: Sampled points from the domain. + :return: An empty dictionary since no data is stored. :rtype: dict """ return {}