diff --git a/pina/condition/__init__.py b/pina/condition/__init__.py index 4e57811fb..13429a829 100644 --- a/pina/condition/__init__.py +++ b/pina/condition/__init__.py @@ -8,7 +8,7 @@ "TensorInputTensorTargetCondition", "TensorInputGraphTargetCondition", "GraphInputTensorTargetCondition", - "GraphInputGraphTargetCondition", + # "GraphInputGraphTargetCondition", "InputEquationCondition", "InputTensorEquationCondition", "InputGraphEquationCondition", @@ -25,7 +25,7 @@ TensorInputTensorTargetCondition, TensorInputGraphTargetCondition, GraphInputTensorTargetCondition, - GraphInputGraphTargetCondition, + # GraphInputGraphTargetCondition, ) from .input_equation_condition import ( InputEquationCondition, diff --git a/pina/condition/condition.py b/pina/condition/condition.py index ad8764c9f..3c43f7176 100644 --- a/pina/condition/condition.py +++ b/pina/condition/condition.py @@ -86,12 +86,12 @@ class Condition: """ # Combine all possible keyword arguments from the different Condition types - __slots__ = list( + available_kwargs = list( set( - InputTargetCondition.__slots__ - + InputEquationCondition.__slots__ - + DomainEquationCondition.__slots__ - + DataCondition.__slots__ + InputTargetCondition.__fields__ + + InputEquationCondition.__fields__ + + DomainEquationCondition.__fields__ + + DataCondition.__fields__ ) ) @@ -112,28 +112,28 @@ def __new__(cls, *args, **kwargs): if len(args) != 0: raise ValueError( "Condition takes only the following keyword " - f"arguments: {Condition.__slots__}." + f"arguments: {Condition.available_kwargs}." ) # Class specialization based on keyword arguments sorted_keys = sorted(kwargs.keys()) # Input - Target Condition - if sorted_keys == sorted(InputTargetCondition.__slots__): + if sorted_keys == sorted(InputTargetCondition.__fields__): return InputTargetCondition(**kwargs) # Input - Equation Condition - if sorted_keys == sorted(InputEquationCondition.__slots__): + if sorted_keys == sorted(InputEquationCondition.__fields__): return InputEquationCondition(**kwargs) # Domain - Equation Condition - if sorted_keys == sorted(DomainEquationCondition.__slots__): + if sorted_keys == sorted(DomainEquationCondition.__fields__): return DomainEquationCondition(**kwargs) # Data Condition if ( - sorted_keys == sorted(DataCondition.__slots__) - or sorted_keys[0] == DataCondition.__slots__[0] + sorted_keys == sorted(DataCondition.__fields__) + or sorted_keys[0] == DataCondition.__fields__[0] ): return DataCondition(**kwargs) diff --git a/pina/condition/condition_base.py b/pina/condition/condition_base.py new file mode 100644 index 000000000..6283a78bb --- /dev/null +++ b/pina/condition/condition_base.py @@ -0,0 +1,319 @@ +""" +Base class for conditions. +""" + +from copy import deepcopy +from functools import partial +import torch +from torch_geometric.data import Data, Batch +from torch.utils.data import DataLoader +from .condition_interface import ConditionInterface +from ..graph import Graph, LabelBatch +from ..label_tensor import LabelTensor + + +class TensorCondition: + """ + Base class for tensor conditions. + """ + + def store_data(self, **kwargs): + """ + Store data for standard tensor condition + + :param kwargs: Keyword arguments representing the data to be stored. + :return: A dictionary containing the stored data. + :rtype: dict + """ + data = {} + for key, value in kwargs.items(): + data[key] = value + return data + + +class GraphCondition: + """ + Base class for graph conditions. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + example = kwargs.get(self.graph_field)[0] + self.batch_fn = ( + LabelBatch.from_data_list + if isinstance(example, Graph) + else Batch.from_data_list + ) + + def store_data(self, **kwargs): + """ + Store data for graph condition + + :param graphs: List of graphs to store data in. + :type graphs: list[Graph] | list[Data] + :param tensors: List of tensors to store in the graphs. + :type tensors: list[torch.Tensor] | list[LabelTensor] + :param key: Key under which to store the tensors in the graphs. + :type key: str + :return: A dictionary containing the stored data. + :rtype: dict + """ + data = [] + graphs = kwargs.get(self.graph_field) + for i, graph in enumerate(graphs): + new_graph = deepcopy(graph) + for key in self.tensor_fields: + tensor = kwargs[key][i] + mapping_key = self.keys_map.get(key) + setattr(new_graph, mapping_key, tensor) + data.append(new_graph) + return {"data": data} + + def __getitem__(self, idx): + if isinstance(idx, list): + return self.get_multiple_data(idx) + return {"data": self.data["data"][idx]} + + def get_multiple_data(self, indices): + """ + Get multiple data items based on the provided indices. + + :param List[int] indices: List of indices to retrieve. + :return: Dictionary containing 'input' and 'target' data. + :rtype: dict + """ + to_return_dict = {} + data = self.batch_fn([self.data["data"][i] for i in indices]) + to_return_dict[self.graph_field] = data + for key in self.tensor_fields: + mapping_key = self.keys_map.get(key) + y = getattr(data, mapping_key) + delattr(data, mapping_key) # Avoid duplication of y on GPU memory + to_return_dict[key] = y + return to_return_dict + + @classmethod + def automatic_batching_collate_fn(cls, batch): + """ + Collate function to be used in DataLoader. + + :param batch: A list of items from the dataset. + :type batch: list + :return: A collated batch. + :rtype: dict + """ + collated_graphs = super().automatic_batching_collate_fn(batch)["data"] + to_return_dict = {} + for key in cls.tensor_fields: + mapping_key = cls.keys_map.get(key) + tensor = getattr(collated_graphs, mapping_key) + to_return_dict[key] = tensor + delattr(collated_graphs, mapping_key) + to_return_dict[cls.graph_field] = collated_graphs + return to_return_dict + + +class ConditionBase(ConditionInterface): + """ + Base abstract class for all conditions in PINA. + This class provides common functionality for handling data storage, + batching, and interaction with the associated problem. + """ + + collate_fn_dict = { + "tensor": torch.stack, + "label_tensor": LabelTensor.stack, + "graph": LabelBatch.from_data_list, + "data": Batch.from_data_list, + } + + def __init__(self, **kwargs): + """ + Initialization of the :class:`ConditionBase` class. + + :param kwargs: Keyword arguments representing the data to be stored. + """ + super().__init__() + self.data = self.store_data(**kwargs) + + @property + def problem(self): + """ + Return the problem associated with this condition. + + :return: Problem associated with this condition. + :rtype: ~pina.problem.abstract_problem.AbstractProblem + """ + return self._problem + + @problem.setter + def problem(self, value): + """ + Set the problem associated with this condition. + + :param pina.problem.abstract_problem.AbstractProblem value: The problem + to associate with this condition + """ + self._problem = value + + @staticmethod + def _check_graph_list_consistency(data_list): + """ + Check the consistency of the list of Data | Graph objects. + The following checks are performed: + + - All elements in the list must be of the same type (either + :class:`~torch_geometric.data.Data` or :class:`~pina.graph.Graph`). + + - All elements in the list must have the same keys. + + - The data type of each tensor must be consistent across all elements. + + - If a tensor is a :class:`~pina.label_tensor.LabelTensor`, its labels + must also be consistent across all elements. + + :param data_list: The list of Data | Graph objects to check. + :type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph] + :raises ValueError: If the input types are invalid. + :raises ValueError: If all elements in the list do not have the same + keys. + :raises ValueError: If the type of each tensor is not consistent across + all elements in the list. + :raises ValueError: If the labels of the LabelTensors are not consistent + across all elements in the list. + """ + # If the data is a Graph or Data object, perform no checks + if isinstance(data_list, (Graph, Data)): + return + + # Check all elements in the list are of the same type + if not all(isinstance(i, (Graph, Data)) for i in data_list): + raise ValueError( + "Invalid input. Please, provide either Data or Graph objects." + ) + + # Store the keys, data types and labels of the first element + data = data_list[0] + keys = sorted(list(data.keys())) + data_types = {name: tensor.__class__ for name, tensor in data.items()} + labels = { + name: tensor.labels + for name, tensor in data.items() + if isinstance(tensor, LabelTensor) + } + + # Iterate over the list of Data | Graph objects + for data in data_list[1:]: + + # Check that all elements in the list have the same keys + if sorted(list(data.keys())) != keys: + raise ValueError( + "All elements in the list must have the same keys." + ) + + # Iterate over the tensors in the current element + for name, tensor in data.items(): + # Check that the type of each tensor is consistent + if tensor.__class__ is not data_types[name]: + raise ValueError( + f"Data {name} must be a {data_types[name]}, got " + f"{tensor.__class__}" + ) + + # Check that the labels of each LabelTensor are consistent + if isinstance(tensor, LabelTensor): + if tensor.labels != labels[name]: + raise ValueError( + "LabelTensor must have the same labels" + ) + + def __len__(self): + """ + Return the number of data points in the condition. + + :return: Number of data points. + :rtype: int + """ + return len(next(iter(self.data.values()))) + + def __getitem__(self, idx): + """ + Return the data point(s) at the specified index. + + :param idx: Index(es) of the data point(s) to retrieve. + :type idx: int | list[int] + :return: Data point(s) at the specified index. + """ + return {name: data[idx] for name, data in self.data.items()} + + @classmethod + def automatic_batching_collate_fn(cls, batch): + """ + Collate function for automatic batching to be used in DataLoader. + :param batch: A list of items from the dataset. + :type batch: list + :return: A collated batch. + :rtype: dict + """ + if not batch: + return {} + keys = batch[0].keys() + columns = zip(*[item.values() for item in batch]) + + to_return = {} + + # 2. Process each column + for key, values in zip(keys, columns): + # Determine type based on the first sample only + first_val = values[0] + + if isinstance(first_val, (LabelTensor, torch.Tensor)): + lookup_key = "label_tensor" + elif isinstance(first_val, Graph): + lookup_key = "graph" + else: + lookup_key = "data" + + # Execute the specific collate function + to_return[key] = cls.collate_fn_dict[lookup_key](list(values)) + + return to_return + + @staticmethod + def collate_fn(batch, condition): + """ + Collate function for custom batching to be used in DataLoader. + + :param batch: A list of items from the dataset. + :type batch: list + :param condition: The condition instance. + :type condition: ConditionBase + :return: A collated batch. + :rtype: dict + """ + data = condition[batch] + return data + + def create_dataloader( + self, dataset, batch_size, shuffle, automatic_batching + ): + """ + Create a DataLoader for the condition. + + :param int batch_size: The batch size for the DataLoader. + :param bool shuffle: Whether to shuffle the data. Default is ``False``. + :return: The DataLoader for the condition. + :rtype: torch.utils.data.DataLoader + """ + if batch_size == len(dataset): + pass # will be updated in the near future + return DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=shuffle, + collate_fn=( + partial(self.collate_fn, condition=self) + if not automatic_batching + else self.automatic_batching_collate_fn + ), + ) diff --git a/pina/condition/condition_interface.py b/pina/condition/condition_interface.py index b0264517c..229b9a025 100644 --- a/pina/condition/condition_interface.py +++ b/pina/condition/condition_interface.py @@ -1,9 +1,6 @@ """Module for the Condition interface.""" -from abc import ABCMeta -from torch_geometric.data import Data -from ..label_tensor import LabelTensor -from ..graph import Graph +from abc import ABCMeta, abstractmethod class ConditionInterface(metaclass=ABCMeta): @@ -15,13 +12,14 @@ class ConditionInterface(metaclass=ABCMeta): description of all available conditions and how to instantiate them. """ - def __init__(self): + @abstractmethod + def __init__(self, **kwargs): """ Initialization of the :class:`ConditionInterface` class. """ - self._problem = None @property + @abstractmethod def problem(self): """ Return the problem associated with this condition. @@ -29,9 +27,9 @@ def problem(self): :return: Problem associated with this condition. :rtype: ~pina.problem.abstract_problem.AbstractProblem """ - return self._problem @problem.setter + @abstractmethod def problem(self, value): """ Set the problem associated with this condition. @@ -39,88 +37,21 @@ def problem(self, value): :param pina.problem.abstract_problem.AbstractProblem value: The problem to associate with this condition """ - self._problem = value - @staticmethod - def _check_graph_list_consistency(data_list): + @abstractmethod + def __len__(self): """ - Check the consistency of the list of Data | Graph objects. - The following checks are performed: + Return the number of data points in the condition. - - All elements in the list must be of the same type (either - :class:`~torch_geometric.data.Data` or :class:`~pina.graph.Graph`). - - - All elements in the list must have the same keys. - - - The data type of each tensor must be consistent across all elements. - - - If a tensor is a :class:`~pina.label_tensor.LabelTensor`, its labels - must also be consistent across all elements. - - :param data_list: The list of Data | Graph objects to check. - :type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph] - :raises ValueError: If the input types are invalid. - :raises ValueError: If all elements in the list do not have the same - keys. - :raises ValueError: If the type of each tensor is not consistent across - all elements in the list. - :raises ValueError: If the labels of the LabelTensors are not consistent - across all elements in the list. + :return: Number of data points. + :rtype: int """ - # If the data is a Graph or Data object, perform no checks - if isinstance(data_list, (Graph, Data)): - return - - # Check all elements in the list are of the same type - if not all(isinstance(i, (Graph, Data)) for i in data_list): - raise ValueError( - "Invalid input. Please, provide either Data or Graph objects." - ) - - # Store the keys, data types and labels of the first element - data = data_list[0] - keys = sorted(list(data.keys())) - data_types = {name: tensor.__class__ for name, tensor in data.items()} - labels = { - name: tensor.labels - for name, tensor in data.items() - if isinstance(tensor, LabelTensor) - } - - # Iterate over the list of Data | Graph objects - for data in data_list[1:]: - - # Check that all elements in the list have the same keys - if sorted(list(data.keys())) != keys: - raise ValueError( - "All elements in the list must have the same keys." - ) - - # Iterate over the tensors in the current element - for name, tensor in data.items(): - # Check that the type of each tensor is consistent - if tensor.__class__ is not data_types[name]: - raise ValueError( - f"Data {name} must be a {data_types[name]}, got " - f"{tensor.__class__}" - ) - - # Check that the labels of each LabelTensor are consistent - if isinstance(tensor, LabelTensor): - if tensor.labels != labels[name]: - raise ValueError( - "LabelTensor must have the same labels" - ) - def __getattribute__(self, name): + @abstractmethod + def __getitem__(self, idx): """ - Get an attribute from the object. + Return the data point(s) at the specified index. - :param str name: The name of the attribute to get. - :return: The requested attribute. - :rtype: Any + :param int idx: Index of the data point(s) to retrieve. + :return: Data point(s) at the specified index. """ - to_return = super().__getattribute__(name) - if isinstance(to_return, (Graph, Data)): - to_return = [to_return] - return to_return diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py index 5f5e7d36b..e78dc800f 100644 --- a/pina/condition/data_condition.py +++ b/pina/condition/data_condition.py @@ -2,12 +2,12 @@ import torch from torch_geometric.data import Data -from .condition_interface import ConditionInterface +from .condition_base import ConditionBase, GraphCondition, TensorCondition from ..label_tensor import LabelTensor from ..graph import Graph -class DataCondition(ConditionInterface): +class DataCondition(ConditionBase): """ The class :class:`DataCondition` defines an unsupervised condition based on ``input`` data. This condition is typically used in data-driven problems, @@ -38,7 +38,7 @@ class DataCondition(ConditionInterface): """ # Available input data types - __slots__ = ["input", "conditional_variables"] + __fields__ = ["input", "conditional_variables"] _avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) _avail_conditional_variables_cls = (torch.Tensor, LabelTensor) @@ -99,22 +99,105 @@ def __init__(self, input, conditional_variables=None): the list must share the same structure, with matching keys and consistent data types. """ - super().__init__() - self.input = input - self.conditional_variables = conditional_variables + if conditional_variables is None: + super().__init__(input=input) + else: + super().__init__( + input=input, conditional_variables=conditional_variables + ) + + @property + def conditional_variables(self): + """ + Return the conditional variables for the condition. + + :return: The conditional variables. + :rtype: torch.Tensor | LabelTensor | None + """ + return self.data.get("conditional_variables", None) -class TensorDataCondition(DataCondition): +class TensorDataCondition(TensorCondition, DataCondition): """ Specialization of the :class:`DataCondition` class for the case where ``input`` is either a :class:`~pina.label_tensor.LabelTensor` object or a :class:`torch.Tensor` object. """ + @property + def input(self): + """ + Return the input data for the condition. -class GraphDataCondition(DataCondition): + :return: The input data. + :rtype: torch.Tensor | LabelTensor + """ + return self.data["input"] + + +class GraphDataCondition(GraphCondition, DataCondition): """ Specialization of the :class:`DataCondition` class for the case where ``input`` is either a :class:`~pina.graph.Graph` object or a :class:`~torch_geometric.data.Data` object. """ + + def __init__(self, input, conditional_variables=None): + """ + Initialization of the :class:`GraphDataCondition` class. + + :param input: The input data for the condition. + :type input: Graph | Data | list[Graph] | list[Data] | + tuple[Graph] | tuple[Data] + :param conditional_variables: The conditional variables for the + condition. Default is ``None``. + :type conditional_variables: torch.Tensor | LabelTensor + + .. note:: + + If ``input`` is a list of :class:`~pina.graph.Graph` or + :class:`~torch_geometric.data.Data`, all elements in + the list must share the same structure, with matching keys and + consistent data types. + """ + self.graph_field = "input" + self.tensor_fields = [] + self.keys_map = {} + if conditional_variables is not None: + self.tensor_fields.append("conditional_variables") + self.keys_map["conditional_variables"] = "cond_vars" + super().__init__( + input=input, conditional_variables=conditional_variables + ) + + @property + def input(self): + """ + Return the input data for the condition. + + :return: The input data. + :rtype: Graph | Data | list[Graph] | list[Data] | tuple[Graph] | + tuple[Data] + """ + return self.data["data"] + + @property + def conditional_variables(self): + """ + Return the target data for the condition. + + :return: The target data. + :rtype: list[torch.Tensor] | list[LabelTensor] + """ + + if not hasattr(self.data["data"][0], "cond_vars"): + return None + cond_vars = [] + is_lt = isinstance(self.data["data"][0].cond_vars, LabelTensor) + for graph in self.data["data"]: + cond_vars.append(graph.cond_vars) + return ( + torch.stack(cond_vars) + if not is_lt + else LabelTensor.stack(cond_vars) + ) diff --git a/pina/condition/domain_equation_condition.py b/pina/condition/domain_equation_condition.py index 3565c0b41..673bf1612 100644 --- a/pina/condition/domain_equation_condition.py +++ b/pina/condition/domain_equation_condition.py @@ -1,12 +1,11 @@ """Module for the DomainEquationCondition class.""" -from .condition_interface import ConditionInterface -from ..utils import check_consistency +from .condition_base import ConditionBase from ..domain import DomainInterface from ..equation.equation_interface import EquationInterface -class DomainEquationCondition(ConditionInterface): +class DomainEquationCondition(ConditionBase): """ The class :class:`DomainEquationCondition` defines a condition based on a ``domain`` and an ``equation``. This condition is typically used in @@ -30,7 +29,7 @@ class DomainEquationCondition(ConditionInterface): """ # Available slots - __slots__ = ["domain", "equation"] + __fields__ = ["domain", "equation"] def __init__(self, domain, equation): """ @@ -41,24 +40,52 @@ def __init__(self, domain, equation): :param EquationInterface equation: The equation to be satisfied over the specified domain. """ + if not isinstance(domain, (DomainInterface, str)): + raise ValueError( + f"`domain` must be an instance of DomainInterface, " + f"got {type(domain)} instead." + ) + if not isinstance(equation, EquationInterface): + raise ValueError( + f"`equation` must be an instance of EquationInterface, " + f"got {type(equation)} instead." + ) super().__init__() self.domain = domain self.equation = equation - def __setattr__(self, key, value): + def __len__(self): """ - Set the attribute value with type checking. + Raise NotImplementedError since the number of points is determined by + the domain sampling strategy. - :param str key: The attribute name. - :param any value: The value to set for the attribute. + :raises NotImplementedError: Always raised since the number of points is + determined by the domain sampling strategy. """ - if key == "domain": - check_consistency(value, (DomainInterface, str)) - DomainEquationCondition.__dict__[key].__set__(self, value) + raise NotImplementedError( + "`__len__` method is not implemented for " + "`DomainEquationCondition` since the number of points is " + "determined by the domain sampling strategy." + ) - elif key == "equation": - check_consistency(value, (EquationInterface)) - DomainEquationCondition.__dict__[key].__set__(self, value) + def __getitem__(self, idx): + """ + Raise NotImplementedError since data retrieval is not applicable. + + :param int idx: Index of the data point(s) to retrieve. + :raises NotImplementedError: Always raised since data retrieval is not + applicable for this condition. + """ + raise NotImplementedError( + "`__getitem__` method is not implemented for " + "`DomainEquationCondition`" + ) - elif key in ("_problem"): - super().__setattr__(key, value) + def store_data(self): + """ + Store data for the condition. No data is stored for this condition. + + :return: An empty dictionary since no data is stored. + :rtype: dict + """ + return {} diff --git a/pina/condition/input_equation_condition.py b/pina/condition/input_equation_condition.py index d32597894..913cdc4d2 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/condition/input_equation_condition.py @@ -1,13 +1,12 @@ """Module for the InputEquationCondition class and its subclasses.""" -from .condition_interface import ConditionInterface +from .condition_base import ConditionBase, TensorCondition, GraphCondition from ..label_tensor import LabelTensor from ..graph import Graph -from ..utils import check_consistency from ..equation.equation_interface import EquationInterface -class InputEquationCondition(ConditionInterface): +class InputEquationCondition(ConditionBase): """ The class :class:`InputEquationCondition` defines a condition based on ``input`` data and an ``equation``. This condition is typically used in @@ -41,7 +40,7 @@ class InputEquationCondition(ConditionInterface): """ # Available input data types - __slots__ = ["input", "equation"] + __fields__ = ["input", "equation"] _avail_input_cls = (LabelTensor, Graph, list, tuple) _avail_equation_cls = EquationInterface @@ -97,42 +96,53 @@ def __init__(self, input, equation): the list must share the same structure, with matching keys and consistent data types. """ - super().__init__() - self.input = input + super().__init__(input=input) self.equation = equation - def __setattr__(self, key, value): - """ - Set the attribute value with type checking. - - :param str key: The attribute name. - :param any value: The value to set for the attribute. - """ - if key == "input": - check_consistency(value, self._avail_input_cls) - InputEquationCondition.__dict__[key].__set__(self, value) - - elif key == "equation": - check_consistency(value, self._avail_equation_cls) - InputEquationCondition.__dict__[key].__set__(self, value) - elif key in ("_problem"): - super().__setattr__(key, value) - - -class InputTensorEquationCondition(InputEquationCondition): +class InputTensorEquationCondition(TensorCondition, InputEquationCondition): """ Specialization of the :class:`InputEquationCondition` class for the case where ``input`` is a :class:`~pina.label_tensor.LabelTensor` object. """ + @property + def input(self): + """ + Return the input data for the condition. + + :return: The input data. + :rtype: LabelTensor | Graph | list[Graph] | tuple[Graph] + """ + return self.data["input"] + -class InputGraphEquationCondition(InputEquationCondition): +class InputGraphEquationCondition(GraphCondition, InputEquationCondition): """ Specialization of the :class:`InputEquationCondition` class for the case where ``input`` is a :class:`~pina.graph.Graph` object. """ + def __init__(self, input, equation): + """ + Initialization of the :class:`InputGraphEquationCondition` class. + + :param input: The input data for the condition. + :type input: Graph | list[Graph] | tuple[Graph] + :param EquationInterface equation: The equation to be satisfied over the + specified input points. + + .. note:: + + If ``input`` is a list of :class:`~pina.graph.Graph` all elements in + the list must share the same structure, with matching keys and + consistent data types. + """ + self.graph_field = "input" + self.tensor_fields = [] + self.keys_map = {} + super().__init__(input=[input], equation=equation) + @staticmethod def _check_label_tensor(input): """ @@ -155,3 +165,13 @@ def _check_label_tensor(input): return raise ValueError("The input must contain at least one LabelTensor.") + + @property + def input(self): + """ + Return the input data for the condition. + + :return: The input data. + :rtype: list[Graph] | list[Data] + """ + return self.data["data"] diff --git a/pina/condition/input_target_condition.py b/pina/condition/input_target_condition.py index 07b07bb7b..064f4b5eb 100644 --- a/pina/condition/input_target_condition.py +++ b/pina/condition/input_target_condition.py @@ -6,10 +6,10 @@ from torch_geometric.data import Data from ..label_tensor import LabelTensor from ..graph import Graph -from .condition_interface import ConditionInterface +from .condition_base import ConditionBase, GraphCondition, TensorCondition -class InputTargetCondition(ConditionInterface): +class InputTargetCondition(ConditionBase): """ The :class:`InputTargetCondition` class represents a supervised condition defined by both ``input`` and ``target`` data. The model is trained to @@ -55,7 +55,7 @@ class InputTargetCondition(ConditionInterface): """ # Available input and target data types - __slots__ = ["input", "target"] + __fields__ = ["input", "target"] _avail_input_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) _avail_output_cls = (torch.Tensor, LabelTensor, Data, Graph, list, tuple) @@ -109,80 +109,42 @@ def __new__(cls, input, target): subclass = GraphInputTensorTargetCondition return subclass.__new__(subclass, input, target) - # Graph - Graph - if isinstance(input, (Graph, Data, list, tuple)) and isinstance( - target, (Graph, Data, list, tuple) - ): - cls._check_graph_list_consistency(input) - cls._check_graph_list_consistency(target) - subclass = GraphInputGraphTargetCondition - return subclass.__new__(subclass, input, target) - - # If the input and/or target are not of the correct type raise an error raise ValueError( "Invalid input | target types." "Please provide either torch_geometric.data.Data, Graph, " "LabelTensor or torch.Tensor objects." ) - def __init__(self, input, target): - """ - Initialization of the :class:`InputTargetCondition` class. - :param input: The input data for the condition. - :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] - :param target: The target data for the condition. - :type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] - - .. note:: +class TensorInputTensorTargetCondition(InputTargetCondition, TensorCondition): + """ + Specialization of the :class:`InputTargetCondition` class for the case where + both ``input`` and ``target`` are :class:`torch.Tensor` or + :class:`~pina.label_tensor.LabelTensor` objects. + """ - If either ``input`` or ``target`` is a list of - :class:`~pina.graph.Graph` or :class:`~torch_geometric.data.Data` - objects, all elements in the list must share the same structure, - with matching keys and consistent data types. + @property + def input(self): """ - super().__init__() - self._check_input_target_len(input, target) - self.input = input - self.target = target + Return the input data for the condition. - @staticmethod - def _check_input_target_len(input, target): + :return: The input data. + :rtype: torch.Tensor | LabelTensor """ - Check that the length of the input and target lists are the same. + return self.data["input"] - :param input: The input data. - :type input: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] - :param target: The target data. - :type target: torch.Tensor | LabelTensor | Graph | Data | list[Graph] | - list[Data] | tuple[Graph] | tuple[Data] - :raises ValueError: If the lengths of the input and target lists do not - match. + @property + def target(self): """ - if isinstance(input, (Graph, Data)) or isinstance( - target, (Graph, Data) - ): - return - - # Raise an error if the lengths of the input and target do not match - if len(input) != len(target): - raise ValueError( - "The input and target lists must have the same length." - ) + Return the target data for the condition. - -class TensorInputTensorTargetCondition(InputTargetCondition): - """ - Specialization of the :class:`InputTargetCondition` class for the case where - both ``input`` and ``target`` are :class:`torch.Tensor` or - :class:`~pina.label_tensor.LabelTensor` objects. - """ + :return: The target data. + :rtype: torch.Tensor | LabelTensor + """ + return self.data["target"] -class TensorInputGraphTargetCondition(InputTargetCondition): +class TensorInputGraphTargetCondition(GraphCondition, InputTargetCondition): """ Specialization of the :class:`InputTargetCondition` class for the case where ``input`` is either a :class:`torch.Tensor` or a @@ -190,8 +152,47 @@ class TensorInputGraphTargetCondition(InputTargetCondition): :class:`~pina.graph.Graph` or a :class:`torch_geometric.data.Data` object. """ + def __init__(self, input, target): + """ + Initialization of the :class:`TensorInputGraphTargetCondition` class. + + :param input: The input data for the condition. + :type input: torch.Tensor | LabelTensor + :param target: The target data for the condition. + :type target: Graph | Data | list[Graph] | list[Data] | + tuple[Graph] | tuple[Data] + """ + self.graph_field = "target" + self.tensor_fields = ["input"] + self.keys_map = {"input": "x"} + super().__init__(input=input, target=target) + + @property + def input(self): + """ + Return the input data for the condition. + + :return: The input data. + :rtype: list[torch.Tensor] | list[LabelTensor] + """ + targets = [] + is_lt = isinstance(self.data["data"][0].x, LabelTensor) + for graph in self.data["data"]: + targets.append(graph.x) + return torch.stack(targets) if not is_lt else LabelTensor.stack(targets) + + @property + def target(self): + """ + Return the target data for the condition. -class GraphInputTensorTargetCondition(InputTargetCondition): + :return: The target data. + :rtype: list[Graph] | list[Data] + """ + return self.data["data"] + + +class GraphInputTensorTargetCondition(GraphCondition, InputTargetCondition): """ Specialization of the :class:`InputTargetCondition` class for the case where ``input`` is either a :class:`~pina.graph.Graph` or @@ -199,10 +200,42 @@ class GraphInputTensorTargetCondition(InputTargetCondition): :class:`torch.Tensor` or a :class:`~pina.label_tensor.LabelTensor` object. """ + def __init__(self, input, target): + """ + Initialization of the :class:`GraphInputTensorTargetCondition` class. + + :param input: The input data for the condition. + :type input: Graph | Data | list[Graph] | list[Data] | + tuple[Graph] | tuple[Data] + :param target: The target data for the condition. + :type target: torch.Tensor | LabelTensor + """ + self.graph_field = "input" + self.tensor_fields = ["target"] + self.keys_map = {"target": "y"} + super().__init__(input=input, target=target) -class GraphInputGraphTargetCondition(InputTargetCondition): - """ - Specialization of the :class:`InputTargetCondition` class for the case where - both ``input`` and ``target`` are either :class:`~pina.graph.Graph` or - :class:`torch_geometric.data.Data` objects. - """ + @property + def input(self): + """ + Return the input data for the condition. + + :return: The input data. + :rtype: list[Graph] | list[Data] + """ + return self.data["data"] + + @property + def target(self): + """ + Return the target data for the condition. + + :return: The target data. + :rtype: list[torch.Tensor] | list[LabelTensor] + """ + targets = [] + is_lt = isinstance(self.data["data"][0].y, LabelTensor) + for graph in self.data["data"]: + targets.append(graph.y) + + return torch.stack(targets) if not is_lt else LabelTensor.stack(targets) diff --git a/tests/test_condition.py b/tests/test_condition.py index 9199f2bd9..8a5480499 100644 --- a/tests/test_condition.py +++ b/tests/test_condition.py @@ -1,154 +1,171 @@ -import torch -import pytest - -from pina import LabelTensor, Condition -from pina.condition import ( - TensorInputGraphTargetCondition, - TensorInputTensorTargetCondition, - GraphInputGraphTargetCondition, - GraphInputTensorTargetCondition, -) -from pina.condition import ( - InputTensorEquationCondition, - InputGraphEquationCondition, - DomainEquationCondition, -) -from pina.condition import ( - TensorDataCondition, - GraphDataCondition, -) -from pina.domain import CartesianDomain -from pina.equation.equation_factory import FixedValue -from pina.graph import RadiusGraph - -example_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]}) - -input_tensor = torch.rand((10, 3)) -target_tensor = torch.rand((10, 2)) -input_lt = LabelTensor(torch.rand((10, 3)), ["x", "y", "z"]) -target_lt = LabelTensor(torch.rand((10, 2)), ["a", "b"]) - -x = torch.rand(10, 20, 2) -pos = torch.rand(10, 20, 2) -radius = 0.1 -input_graph = [ - RadiusGraph( - x=x_, - pos=pos_, - radius=radius, - ) - for x_, pos_ in zip(x, pos) -] -target_graph = [ - RadiusGraph( - x=x_, - pos=pos_, - radius=radius, - ) - for x_, pos_ in zip(x, pos) -] - -x = LabelTensor(torch.rand(10, 20, 2), ["u", "v"]) -pos = LabelTensor(torch.rand(10, 20, 2), ["x", "y"]) -radius = 0.1 -input_graph_lt = [ - RadiusGraph( - x=x[i], - pos=pos[i], - radius=radius, - ) - for i in range(len(x)) -] -target_graph_lt = [ - RadiusGraph( - x=x[i], - pos=pos[i], - radius=radius, - ) - for i in range(len(x)) -] - -input_single_graph = input_graph[0] -target_single_graph = target_graph[0] - - -def test_init_input_target(): - cond = Condition(input=input_tensor, target=target_tensor) - assert isinstance(cond, TensorInputTensorTargetCondition) - cond = Condition(input=input_tensor, target=target_tensor) - assert isinstance(cond, TensorInputTensorTargetCondition) - cond = Condition(input=input_tensor, target=target_graph) - assert isinstance(cond, TensorInputGraphTargetCondition) - cond = Condition(input=input_graph, target=target_tensor) - assert isinstance(cond, GraphInputTensorTargetCondition) - cond = Condition(input=input_graph, target=target_graph) - assert isinstance(cond, GraphInputGraphTargetCondition) - - cond = Condition(input=input_lt, target=input_single_graph) - assert isinstance(cond, TensorInputGraphTargetCondition) - cond = Condition(input=input_single_graph, target=target_lt) - assert isinstance(cond, GraphInputTensorTargetCondition) - cond = Condition(input=input_graph, target=target_graph) - assert isinstance(cond, GraphInputGraphTargetCondition) - cond = Condition(input=input_single_graph, target=target_single_graph) - assert isinstance(cond, GraphInputGraphTargetCondition) - - with pytest.raises(ValueError): - Condition(input_tensor, input_tensor) - with pytest.raises(ValueError): - Condition(input=3.0, target="example") - with pytest.raises(ValueError): - Condition(input=example_domain, target=example_domain) - - # Test wrong graph condition initialisation - input = [input_graph[0], input_graph_lt[0]] - target = [target_graph[0], target_graph_lt[0]] - with pytest.raises(ValueError): - Condition(input=input, target=target) - - input_graph_lt[0].x.labels = ["a", "b"] - with pytest.raises(ValueError): - Condition(input=input_graph_lt, target=target_graph_lt) - input_graph_lt[0].x.labels = ["u", "v"] - - -def test_init_domain_equation(): - cond = Condition(domain=example_domain, equation=FixedValue(0.0)) - assert isinstance(cond, DomainEquationCondition) - with pytest.raises(ValueError): - Condition(example_domain, FixedValue(0.0)) - with pytest.raises(ValueError): - Condition(domain=3.0, equation="example") - with pytest.raises(ValueError): - Condition(domain=input_tensor, equation=input_graph) - - -def test_init_input_equation(): - cond = Condition(input=input_lt, equation=FixedValue(0.0)) - assert isinstance(cond, InputTensorEquationCondition) - cond = Condition(input=input_graph_lt, equation=FixedValue(0.0)) - assert isinstance(cond, InputGraphEquationCondition) - with pytest.raises(ValueError): - cond = Condition(input=input_tensor, equation=FixedValue(0.0)) - with pytest.raises(ValueError): - Condition(example_domain, FixedValue(0.0)) - with pytest.raises(ValueError): - Condition(input=3.0, equation="example") - with pytest.raises(ValueError): - Condition(input=example_domain, equation=input_graph) - - -test_init_input_equation() - - -def test_init_data_condition(): - cond = Condition(input=input_lt) - assert isinstance(cond, TensorDataCondition) - cond = Condition(input=input_tensor) - assert isinstance(cond, TensorDataCondition) - cond = Condition(input=input_tensor, conditional_variables=torch.tensor(1)) - assert isinstance(cond, TensorDataCondition) - cond = Condition(input=input_graph) - assert isinstance(cond, GraphDataCondition) - cond = Condition(input=input_graph, conditional_variables=torch.tensor(1)) - assert isinstance(cond, GraphDataCondition) +# import torch +# import pytest + +# from pina import LabelTensor, Condition +# from pina.condition import ( +# TensorInputGraphTargetCondition, +# TensorInputTensorTargetCondition, +# # GraphInputGraphTargetCondition, +# GraphInputTensorTargetCondition, +# ) +# from pina.condition import ( +# InputTensorEquationCondition, +# InputGraphEquationCondition, +# DomainEquationCondition, +# ) +# from pina.condition import ( +# TensorDataCondition, +# GraphDataCondition, +# ) +# from pina.domain import CartesianDomain +# from pina.equation.equation_factory import FixedValue +# from pina.graph import RadiusGraph + +# def _create_tensor_data(): +# input_tensor = torch.rand((10, 3)) +# target_tensor = torch.rand((10, 2)) +# return input_tensor, target_tensor + +# def _create_graph_data(): +# x = torch.rand(10, 20, 2) +# pos = torch.rand(10, 20, 2) +# radius = 0.1 +# input_graph = [ +# RadiusGraph( +# x=x_, +# pos=pos_, +# radius=radius, +# ) +# for x_, pos_ in zip(x, pos) +# ] +# target_graph = [ +# RadiusGraph( +# y=x_, +# pos=pos_, +# radius=radius, +# ) +# for x_, pos_ in zip(x, pos) +# ] +# return input_graph, target_graph + +# def _create_lt_data(): +# input_lt = LabelTensor(torch.rand((10, 3)), ["x", "y", "z"]) +# target_lt = LabelTensor(torch.rand((10, 2)), ["a", "b"]) +# return input_lt, target_lt + +# def _create_graph_lt_data(): +# x = LabelTensor(torch.rand((10, 20, 2)), ["u", "v"]) +# pos = LabelTensor(torch.rand((10, 20, 2)), ["x", "y"]) +# radius = 0.1 +# input_graph = [ +# RadiusGraph( +# x=x[i], +# pos=pos[i], +# radius=radius, +# ) +# for i in range(len(x)) +# ] +# target_graph = [ +# RadiusGraph( +# y=x[i], +# pos=pos[i], +# radius=radius, +# ) +# for i in range(len(x)) +# ] +# return input_graph, target_graph + +# example_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]}) + +# def test_init_input_target(): +# input_tensor, target_tensor = _create_tensor_data() +# cond = Condition(input=input_tensor, target=target_tensor) +# assert isinstance(cond, TensorInputTensorTargetCondition) + +# input_lt, target_lt = _create_lt_data() +# cond = Condition(input=input_lt, target=target_lt) +# assert isinstance(cond, TensorInputTensorTargetCondition) + +# input_graph, target_graph = _create_graph_data() +# cond = Condition(input=input_tensor, target=target_graph) +# assert isinstance(cond, TensorInputGraphTargetCondition) + +# cond = Condition(input=input_graph, target=target_tensor) +# assert isinstance(cond, GraphInputTensorTargetCondition) + +# input_single_graph = input_graph[0] +# target_single_graph = target_graph[0] +# cond = Condition(input=input_lt, target=input_single_graph) +# assert isinstance(cond, TensorInputGraphTargetCondition) +# cond = Condition(input=input_single_graph, target=target_lt) +# assert isinstance(cond, GraphInputTensorTargetCondition) +# # cond = Condition(input=input_graph, target=target_graph) +# # assert isinstance(cond, GraphInputGraphTargetCondition) +# # cond = Condition(input=input_single_graph, target=target_single_graph) +# # assert isinstance(cond, GraphInputGraphTargetCondition) +# input_graph_lt, target_graph_lt = _create_graph_lt_data() + +# with pytest.raises(ValueError): +# Condition(input_tensor, input_tensor) +# with pytest.raises(ValueError): +# Condition(input=3.0, target="example") +# with pytest.raises(ValueError): +# Condition(input=example_domain, target=example_domain) + +# # Test wrong graph condition initialisation +# input = [input_graph[0], input_graph_lt[0]] +# target = [target_graph[0], target_graph_lt[0]] +# with pytest.raises(ValueError): +# Condition(input=input, target=target) + +# input_graph_lt[0].x.labels = ["a", "b"] +# with pytest.raises(ValueError): +# Condition(input=input_graph_lt, target=target_graph_lt) +# input_graph_lt[0].x.labels = ["u", "v"] + + +# def test_init_domain_equation(): +# input_tensor, _ = _create_tensor_data() +# input_graph, _ = _create_graph_data() +# cond = Condition(domain=example_domain, equation=FixedValue(0.0)) +# assert isinstance(cond, DomainEquationCondition) +# with pytest.raises(ValueError): +# Condition(example_domain, FixedValue(0.0)) +# with pytest.raises(ValueError): +# Condition(domain=3.0, equation="example") +# with pytest.raises(ValueError): +# Condition(domain=input_tensor, equation=input_graph) + + +# def test_init_input_equation(): +# input_lt, _ = _create_lt_data() +# input_graph_lt, _ = _create_graph_lt_data() +# input_tensor, _ = _create_tensor_data() +# input_graph, _ = _create_graph_data() +# cond = Condition(input=input_lt, equation=FixedValue(0.0)) +# assert isinstance(cond, InputTensorEquationCondition) +# cond = Condition(input=input_graph_lt, equation=FixedValue(0.0)) +# assert isinstance(cond, InputGraphEquationCondition) +# with pytest.raises(ValueError): +# cond = Condition(input=input_tensor, equation=FixedValue(0.0)) +# with pytest.raises(ValueError): +# Condition(example_domain, FixedValue(0.0)) +# with pytest.raises(ValueError): +# Condition(input=3.0, equation="example") +# with pytest.raises(ValueError): +# Condition(input=example_domain, equation=input_graph) + +# def test_init_data_condition(): +# input_lt, _ = _create_lt_data() +# input_tensor, _ = _create_tensor_data() +# input_graph, _ = _create_graph_data() +# cond = Condition(input=input_lt) +# assert isinstance(cond, TensorDataCondition) +# cond = Condition(input=input_tensor) +# assert isinstance(cond, TensorDataCondition) +# cond = Condition(input=input_tensor, conditional_variables=torch.tensor(1)) +# assert isinstance(cond, TensorDataCondition) +# cond = Condition(input=input_graph) +# assert isinstance(cond, GraphDataCondition) +# cond = Condition(input=input_graph, conditional_variables=torch.tensor(1)) +# assert isinstance(cond, GraphDataCondition) diff --git a/tests/test_condition/test_data_condition.py b/tests/test_condition/test_data_condition.py new file mode 100644 index 000000000..e922bdbcd --- /dev/null +++ b/tests/test_condition/test_data_condition.py @@ -0,0 +1,200 @@ +import pytest +import torch +from pina import Condition, LabelTensor +from pina.condition import ( + TensorDataCondition, + GraphDataCondition, +) +from pina.graph import RadiusGraph +from torch_geometric.data import Data, Batch +from pina.graph import Graph, LabelBatch + + +def _create_tensor_data(use_lt=False, conditional_variables=False): + input_tensor = torch.rand((10, 3)) + if use_lt: + input_tensor = LabelTensor(input_tensor, ["x", "y", "z"]) + if conditional_variables: + cond_vars = torch.rand((10, 2)) + if use_lt: + cond_vars = LabelTensor(cond_vars, ["a", "b"]) + else: + cond_vars = None + return input_tensor, cond_vars + + +def _create_graph_data(use_lt=False, conditional_variables=False): + if use_lt: + x = LabelTensor(torch.rand(10, 20, 2), ["u", "v"]) + pos = LabelTensor(torch.rand(10, 20, 2), ["x", "y"]) + else: + x = torch.rand(10, 20, 2) + pos = torch.rand(10, 20, 2) + radius = 0.1 + input_graph = [ + RadiusGraph(pos=pos[i], radius=radius, x=x[i]) for i in range(len(x)) + ] + if conditional_variables: + if use_lt: + cond_vars = LabelTensor(torch.rand(10, 20, 1), ["f"]) + else: + cond_vars = torch.rand(10, 20, 1) + else: + cond_vars = None + return input_graph, cond_vars + + +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_init_tensor_data_condition(use_lt, conditional_variables): + input_tensor, cond_vars = _create_tensor_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = Condition(input=input_tensor, conditional_variables=cond_vars) + assert isinstance(condition, TensorDataCondition) + type_ = LabelTensor if use_lt else torch.Tensor + if conditional_variables: + assert condition.conditional_variables is not None + assert isinstance(condition.conditional_variables, type_) + if use_lt: + assert condition.conditional_variables.labels == ["a", "b"] + else: + assert condition.conditional_variables is None + assert isinstance(condition.input, type_) + if use_lt: + assert condition.input.labels == ["x", "y", "z"] + + +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_init_graph_data_condition(use_lt, conditional_variables): + input_graph, cond_vars = _create_graph_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = Condition(input=input_graph, conditional_variables=cond_vars) + assert isinstance(condition, GraphDataCondition) + type_ = LabelTensor if use_lt else torch.Tensor + if conditional_variables: + assert condition.conditional_variables is not None + assert isinstance(condition.conditional_variables, type_) + if use_lt: + assert condition.conditional_variables.labels == ["f"] + else: + assert condition.conditional_variables is None + # assert "conditional_variables" not in condition.data.keys() + assert isinstance(condition.input, list) + for graph in condition.input: + assert isinstance(graph, Data) + assert isinstance(graph.x, type_) + if use_lt: + assert graph.x.labels == ["u", "v"] + assert isinstance(graph.pos, type_) + if use_lt: + assert graph.pos.labels == ["x", "y"] + + +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_getitem_tensor_data_condition(use_lt, conditional_variables): + input_tensor, cond_vars = _create_tensor_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = Condition(input=input_tensor, conditional_variables=cond_vars) + item = condition[0] + assert isinstance(item, dict) + assert "input" in item + type_ = LabelTensor if use_lt else torch.Tensor + assert isinstance(item["input"], type_) + assert item["input"].shape == (3,) + if type_ is LabelTensor: + assert item["input"].labels == ["x", "y", "z"] + if conditional_variables: + assert "conditional_variables" in item + assert isinstance(item["conditional_variables"], type_) + assert item["conditional_variables"].shape == (2,) + if type_ is LabelTensor: + assert item["conditional_variables"].labels == ["a", "b"] + else: + assert "conditional_variables" not in item + + +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_getitem_graph_data_condition(use_lt, conditional_variables): + input_graph, cond_vars = _create_graph_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = Condition(input=input_graph, conditional_variables=cond_vars) + item = condition[0] + assert isinstance(item, dict) + assert "data" in item + graph = item["data"] + assert isinstance(graph, Data) + type_ = LabelTensor if use_lt else torch.Tensor + assert isinstance(graph.x, type_) + assert graph.x.shape == (20, 2) + if use_lt: + assert graph.x.labels == ["u", "v"] + assert isinstance(graph.pos, type_) + assert graph.pos.shape == (20, 2) + if use_lt: + assert graph.pos.labels == ["x", "y"] + if conditional_variables: + assert hasattr(graph, "cond_vars") + cond_var = graph.cond_vars + assert isinstance(cond_var, type_) + assert cond_var.shape == (20, 1) + if use_lt: + assert cond_var.labels == ["f"] + + +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_getitems_graph_data_condition(use_lt, conditional_variables): + input_graph, cond_vars = _create_graph_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = Condition(input=input_graph, conditional_variables=cond_vars) + idxs = [0, 1, 3] + items = condition[idxs] + assert isinstance(items, dict) + assert "input" in items + graphs = items["input"] + assert isinstance(graphs, LabelBatch) + assert graphs.num_graphs == 3 + if conditional_variables: + type_ = LabelTensor if use_lt else torch.Tensor + assert "conditional_variables" in items + cond_vars_batch = items["conditional_variables"] + assert isinstance(cond_vars_batch, type_) + assert cond_vars_batch.shape == (60, 1) + if use_lt: + assert cond_vars_batch.labels == ["f"] + + +@pytest.mark.parametrize("use_lt", [False, True]) +@pytest.mark.parametrize("conditional_variables", [False, True]) +def test_getitems_tensor_data_condition(use_lt, conditional_variables): + input_tensor, cond_vars = _create_tensor_data( + use_lt=use_lt, conditional_variables=conditional_variables + ) + condition = Condition(input=input_tensor, conditional_variables=cond_vars) + idxs = [0, 1, 3] + items = condition[idxs] + assert isinstance(items, dict) + assert "input" in items + type_ = LabelTensor if use_lt else torch.Tensor + inputs = items["input"] + assert isinstance(inputs, type_) + assert inputs.shape == (3, 3) + if use_lt: + assert inputs.labels == ["x", "y", "z"] + if conditional_variables: + assert "conditional_variables" in items + cond_vars_items = items["conditional_variables"] + assert isinstance(cond_vars_items, type_) + assert cond_vars_items.shape == (3, 2) + if use_lt: + assert cond_vars_items.labels == ["a", "b"] + else: + assert "conditional_variables" not in items diff --git a/tests/test_condition/test_domain_equation_condition.py b/tests/test_condition/test_domain_equation_condition.py new file mode 100644 index 000000000..2b7c78b00 --- /dev/null +++ b/tests/test_condition/test_domain_equation_condition.py @@ -0,0 +1,27 @@ +import pytest +from pina import Condition +from pina.domain import CartesianDomain +from pina.equation.equation_factory import FixedValue +from pina.condition import DomainEquationCondition + +example_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]}) +example_equation = FixedValue(0.0) + + +def test_init_domain_equation(): + cond = Condition(domain=example_domain, equation=example_equation) + assert isinstance(cond, DomainEquationCondition) + assert cond.domain is example_domain + assert cond.equation is example_equation + + +def test_len_not_implemented(): + cond = Condition(domain=example_domain, equation=FixedValue(0.0)) + with pytest.raises(NotImplementedError): + len(cond) + + +def test_getitem_not_implemented(): + cond = Condition(domain=example_domain, equation=FixedValue(0.0)) + with pytest.raises(NotImplementedError): + cond[0] diff --git a/tests/test_condition/test_input_equation_condition.py b/tests/test_condition/test_input_equation_condition.py new file mode 100644 index 000000000..af11d382e --- /dev/null +++ b/tests/test_condition/test_input_equation_condition.py @@ -0,0 +1,67 @@ +import torch +from pina import Condition +from pina.condition.input_equation_condition import ( + InputTensorEquationCondition, + InputGraphEquationCondition, +) +from pina.equation import Equation +from pina import LabelTensor + + +def _create_pts_and_equation(): + def dummy_equation(pts): + return pts["x"] ** 2 + pts["y"] ** 2 - 1 + + pts = LabelTensor(torch.randn(100, 2), labels=["x", "y"]) + equation = Equation(dummy_equation) + return pts, equation + + +def _create_graph_and_equation(): + from pina.graph import KNNGraph + + def dummy_equation(pts): + return pts.x[:, 0] ** 2 + pts.x[:, 1] ** 2 - 1 + + x = LabelTensor(torch.randn(100, 2), labels=["u", "v"]) + pos = LabelTensor(torch.randn(100, 2), labels=["x", "y"]) + graph = KNNGraph(x=x, pos=pos, neighbours=5, edge_attr=True) + equation = Equation(dummy_equation) + return graph, equation + + +def test_init_tensor_equation_condition(): + pts, equation = _create_pts_and_equation() + condition = Condition(input=pts, equation=equation) + assert isinstance(condition, InputTensorEquationCondition) + assert condition.input.shape == (100, 2) + assert condition.equation is equation + + +def test_init_graph_equation_condition(): + graph, equation = _create_graph_and_equation() + condition = Condition(input=graph, equation=equation) + assert isinstance(condition, InputGraphEquationCondition) + assert isinstance(condition.input, list) + assert len(condition.input) == 1 + assert condition.input[0].x.shape == (100, 2) + assert condition.equation is equation + + +def test_getitem_tensor_equation_condition(): + pts, equation = _create_pts_and_equation() + condition = Condition(input=pts, equation=equation) + item = condition[0] + assert isinstance(item, dict) + assert "input" in item + assert item["input"].shape == (2,) + + +def test_getitems_tensor_equation_condition(): + pts, equation = _create_pts_and_equation() + condition = Condition(input=pts, equation=equation) + idxs = [0, 1, 3] + item = condition[idxs] + assert isinstance(item, dict) + assert "input" in item + assert item["input"].shape == (3, 2) diff --git a/tests/test_condition/test_input_target_condition.py b/tests/test_condition/test_input_target_condition.py new file mode 100644 index 000000000..81c3a9b24 --- /dev/null +++ b/tests/test_condition/test_input_target_condition.py @@ -0,0 +1,297 @@ +import torch +import pytest +from torch_geometric.data import Batch +from pina import LabelTensor, Condition +from pina.condition import ( + TensorInputGraphTargetCondition, + TensorInputTensorTargetCondition, + GraphInputTensorTargetCondition, +) +from pina.graph import RadiusGraph, LabelBatch + + +def _create_tensor_data(use_lt=False): + if use_lt: + input_tensor = LabelTensor(torch.rand((10, 3)), ["x", "y", "z"]) + target_tensor = LabelTensor(torch.rand((10, 2)), ["a", "b"]) + return input_tensor, target_tensor + input_tensor = torch.rand((10, 3)) + target_tensor = torch.rand((10, 2)) + return input_tensor, target_tensor + + +def _create_graph_data(tensor_input=True, use_lt=False): + if use_lt: + x = LabelTensor(torch.rand(10, 20, 2), ["u", "v"]) + pos = LabelTensor(torch.rand(10, 20, 2), ["x", "y"]) + else: + x = torch.rand(10, 20, 2) + pos = torch.rand(10, 20, 2) + radius = 0.1 + graph = [ + RadiusGraph( + pos=pos[i], + radius=radius, + x=x[i] if not tensor_input else None, + y=x[i] if tensor_input else None, + ) + for i in range(len(x)) + ] + if use_lt: + tensor = LabelTensor(torch.rand(10, 20, 1), ["f"]) + else: + tensor = torch.rand(10, 20, 1) + return graph, tensor + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_init_tensor_input_tensor_target_condition(use_lt): + input_tensor, target_tensor = _create_tensor_data(use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_tensor) + assert isinstance(condition, TensorInputTensorTargetCondition) + assert torch.allclose( + condition.input, input_tensor + ), "TensorInputTensorTargetCondition input failed" + assert torch.allclose( + condition.target, target_tensor + ), "TensorInputTensorTargetCondition target failed" + if use_lt: + assert isinstance( + condition.input, LabelTensor + ), "TensorInputTensorTargetCondition input type failed" + assert condition.input.labels == [ + "x", + "y", + "z", + ], "TensorInputTensorTargetCondition input labels failed" + assert isinstance( + condition.target, LabelTensor + ), "TensorInputTensorTargetCondition target type failed" + assert condition.target.labels == [ + "a", + "b", + ], "TensorInputTensorTargetCondition target labels failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_init_tensor_input_graph_target_condition(use_lt): + target_graph, input_tensor = _create_graph_data(use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_graph) + assert isinstance(condition, TensorInputGraphTargetCondition) + assert torch.allclose( + condition.input, input_tensor + ), "TensorInputGraphTargetCondition input failed" + for i in range(len(target_graph)): + assert torch.allclose( + condition.target[i].y, target_graph[i].y + ), "TensorInputGraphTargetCondition target failed" + if use_lt: + assert isinstance( + condition.target[i].y, LabelTensor + ), "TensorInputGraphTargetCondition target type failed" + assert condition.target[i].y.labels == [ + "u", + "v", + ], "TensorInputGraphTargetCondition target labels failed" + if use_lt: + assert isinstance( + condition.input, LabelTensor + ), "TensorInputGraphTargetCondition target type failed" + assert condition.input.labels == [ + "f" + ], "TensorInputGraphTargetCondition target labels failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_init_graph_input_tensor_target_condition(use_lt): + input_graph, target_tensor = _create_graph_data(False, use_lt=use_lt) + condition = Condition(input=input_graph, target=target_tensor) + assert isinstance(condition, GraphInputTensorTargetCondition) + for i in range(len(input_graph)): + assert torch.allclose( + condition.input[i].x, input_graph[i].x + ), "GraphInputTensorTargetCondition input failed" + if use_lt: + assert isinstance( + condition.input[i].x, LabelTensor + ), "GraphInputTensorTargetCondition input type failed" + assert ( + condition.input[i].x.labels == input_graph[i].x.labels + ), "GraphInputTensorTargetCondition labels failed" + + assert torch.allclose( + condition.target[i], target_tensor[i] + ), "GraphInputTensorTargetCondition target failed" + if use_lt: + assert isinstance( + condition.target, LabelTensor + ), "GraphInputTensorTargetCondition target type failed" + assert condition.target.labels == [ + "f" + ], "GraphInputTensorTargetCondition target labels failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_getitem_tensor_input_tensor_target_condition(use_lt): + input_tensor, target_tensor = _create_tensor_data(use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_tensor) + for i in range(len(input_tensor)): + item = condition[i] + assert torch.allclose( + item["input"], input_tensor[i] + ), "TensorInputTensorTargetCondition __getitem__ input failed" + assert torch.allclose( + item["target"], target_tensor[i] + ), "TensorInputTensorTargetCondition __getitem__ target failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_getitem_tensor_input_graph_target_condition(use_lt): + target_graph, input_tensor = _create_graph_data(use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_graph) + for i in range(len(input_tensor)): + item = condition[i]["data"] + assert torch.allclose( + item.x, input_tensor[i] + ), "TensorInputGraphTargetCondition __getitem__ input failed" + assert torch.allclose( + item.y, target_graph[i].y + ), "TensorInputGraphTargetCondition __getitem__ target failed" + if use_lt: + assert isinstance( + item.y, LabelTensor + ), "TensorInputGraphTargetCondition __getitem__ target type failed" + assert item.y.labels == [ + "u", + "v", + ], "TensorInputGraphTargetCondition __getitem__ target labels failed" + + +def test_getitem_graph_input_tensor_target_condition(): + input_graph, target_tensor = _create_graph_data(False) + condition = Condition(input=input_graph, target=target_tensor) + for i in range(len(input_graph)): + item = condition[i]["data"] + print(item) + assert torch.allclose( + item.x, input_graph[i].x + ), "GraphInputTensorTargetCondition __getitem__ input failed" + assert torch.allclose( + item.y, target_tensor[i] + ), "GraphInputTensorTargetCondition __getitem__ target failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_getitems_graph_input_tensor_target_condition(use_lt): + input_graph, target_tensor = _create_graph_data(False, use_lt=use_lt) + condition = Condition(input=input_graph, target=target_tensor) + indices = [0, 2, 4] + items = condition[indices] + candidate_input = items["input"] + candidate_target = items["target"] + + if use_lt: + input_ = LabelBatch.from_data_list([input_graph[i] for i in indices]) + target_ = LabelTensor.cat([target_tensor[i] for i in indices], dim=0) + else: + input_ = Batch.from_data_list([input_graph[i] for i in indices]) + target_ = torch.cat([target_tensor[i] for i in indices], dim=0) + assert torch.allclose( + candidate_input.x, input_.x + ), "GraphInputTensorTargetCondition __geitemsem__ input failed" + assert torch.allclose( + candidate_target, target_ + ), "GraphInputTensorTargetCondition __geitemsem__ input failed" + if use_lt: + assert isinstance( + candidate_target, LabelTensor + ), "GraphInputTensorTargetCondition __getitems__ target type failed" + assert candidate_target.labels == [ + "f" + ], "GraphInputTensorTargetCondition __getitems__ target labels failed" + + assert isinstance( + candidate_input.x, LabelTensor + ), "GraphInputTensorTargetCondition __getitems__ input type failed" + assert ( + candidate_input.x.labels == input_graph[0].x.labels + ), "GraphInputTensorTargetCondition __getitems__ input labels failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_getitems_tensor_input_tensor_target_condition(use_lt): + + input_tensor, target_tensor = _create_tensor_data(use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_tensor) + indices = [1, 3, 5, 7] + items = condition[indices] + candidate_input = items["input"] + candidate_target = items["target"] + + if use_lt: + input_ = LabelTensor.stack([input_tensor[i] for i in indices]) + target_ = LabelTensor.stack([target_tensor[i] for i in indices]) + else: + input_ = torch.stack([input_tensor[i] for i in indices]) + target_ = torch.stack([target_tensor[i] for i in indices]) + assert torch.allclose( + candidate_input, input_ + ), "TensorInputTensorTargetCondition __getitems__ input failed" + assert torch.allclose( + candidate_target, target_ + ), "TensorInputTensorTargetCondition __getitems__ target failed" + if use_lt: + assert isinstance( + candidate_input, LabelTensor + ), "TensorInputTensorTargetCondition __getitems__ input type failed" + assert candidate_input.labels == [ + "x", + "y", + "z", + ], "TensorInputTensorTargetCondition __getitems__ input labels failed" + assert isinstance( + candidate_target, LabelTensor + ), "TensorInputTensorTargetCondition __getitems__ target type failed" + assert candidate_target.labels == [ + "a", + "b", + ], "TensorInputTensorTargetCondition __getitems__ target labels failed" + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_getitems_tensor_input_graph_target_condition(use_lt): + target_graph, input_tensor = _create_graph_data(True, use_lt=use_lt) + condition = Condition(input=input_tensor, target=target_graph) + indices = [0, 2, 4] + items = condition[indices] + candidate_input = items["input"] + candidate_target = items["target"] + if use_lt: + input_ = LabelTensor.cat([input_tensor[i] for i in indices], dim=0) + target_ = LabelBatch.from_data_list([target_graph[i] for i in indices]) + else: + input_ = torch.cat([input_tensor[i] for i in indices], dim=0) + target_ = Batch.from_data_list([target_graph[i] for i in indices]) + assert torch.allclose( + candidate_input, input_ + ), "TensorInputGraphTargetCondition __getitems__ input failed" + assert torch.allclose( + candidate_target.y, target_.y + ), "TensorInputGraphTargetCondition __getitems__ target failed" + if use_lt: + assert isinstance( + candidate_input, LabelTensor + ), "TensorInputGraphTargetCondition __getitems__ input type failed" + assert candidate_input.labels == [ + "f" + ], "TensorInputGraphTargetCondition __getitems__ input labels failed" + assert isinstance( + candidate_target.y, LabelTensor + ), "TensorInputGraphTargetCondition __getitems__ target type failed" + assert candidate_target.y.labels == [ + "u", + "v", + ], "TensorInputGraphTargetCondition __getitems__ target labels failed" + + +test_init_graph_input_tensor_target_condition(use_lt=True) diff --git a/tests/test_solver/test_ensemble_supervised_solver.py b/tests/test_solver/test_ensemble_supervised_solver.py index c5f0b9e52..4be2897d9 100644 --- a/tests/test_solver/test_ensemble_supervised_solver.py +++ b/tests/test_solver/test_ensemble_supervised_solver.py @@ -83,7 +83,8 @@ def forward(self, batch): y = self.conv(y, edge_index) y = self.activation(y) y = self.output(y) - return to_dense_batch(y, batch.batch)[0] + return y + # return to_dense_batch(y, batch.batch)[0] graph_models = [Models() for i in range(10)] diff --git a/tests/test_solver/test_supervised_solver.py b/tests/test_solver/test_supervised_solver.py index 6f7d1ab4d..461130a6b 100644 --- a/tests/test_solver/test_supervised_solver.py +++ b/tests/test_solver/test_supervised_solver.py @@ -83,7 +83,8 @@ def forward(self, batch): y = self.conv(y, edge_index) y = self.activation(y) y = self.output(y) - return to_dense_batch(y, batch.batch)[0] + return y + # return to_dense_batch(y, batch.batch)[0] graph_model = Model()