Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 55 additions & 26 deletions src/easyscience/base_classes/collection_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Tuple
from typing import Union

from easyscience.base_classes.new_base import NewBase
from easyscience.global_object.undo_redo import NotarizedDict

from ..variable.descriptor_base import DescriptorBase
Expand All @@ -34,7 +35,7 @@
def __init__(
self,
name: str,
*args: Union[BasedBase, DescriptorBase],
*args: Union[BasedBase, DescriptorBase, NewBase],
interface: Optional[InterfaceFactoryTemplate] = None,
unique_name: Optional[str] = None,
**kwargs,
Expand Down Expand Up @@ -64,8 +65,10 @@
_kwargs[key] = item
kwargs = _kwargs
for item in list(kwargs.values()) + _args:
if not issubclass(type(item), (DescriptorBase, BasedBase)):
raise AttributeError('A collection can only be formed from easyscience objects.')
if not issubclass(type(item), (DescriptorBase, BasedBase, NewBase)):
raise AttributeError(
"A collection can only be formed from easyscience objects."
)
args = _args
_kwargs = {}
for key, item in kwargs.items():
Expand All @@ -79,30 +82,34 @@

for key in kwargs.keys():
if key in self.__dict__.keys() or key in self.__slots__:
raise AttributeError(f'Given kwarg: `{key}`, is an internal attribute. Please rename.')
raise AttributeError(
f"Given kwarg: `{key}`, is an internal attribute. Please rename."
)
if kwargs[key]: # Might be None (empty tuple or list)
self._global_object.map.add_edge(self, kwargs[key])
self._global_object.map.reset_type(kwargs[key], 'created_internal')
self._global_object.map.reset_type(kwargs[key], "created_internal")
if interface is not None:
kwargs[key].interface = interface
# TODO wrap getter and setter in Logger
if interface is not None:
self.interface = interface
self._kwargs._stack_enabled = True

def insert(self, index: int, value: Union[DescriptorBase, BasedBase]) -> None:
def insert(
self, index: int, value: Union[DescriptorBase, BasedBase, NewBase]
) -> None:
"""
Insert an object into the collection at an index.

:param index: Index for EasyScience object to be inserted.
:type index: int
:param value: Object to be inserted.
:type value: Union[BasedBase, DescriptorBase]
:type value: Union[BasedBase, DescriptorBase, NewBase]
:return: None
:rtype: None
"""
t_ = type(value)
if issubclass(t_, (BasedBase, DescriptorBase)):
if issubclass(t_, (BasedBase, DescriptorBase, NewBase)):
update_key = list(self._kwargs.keys())
values = list(self._kwargs.values())
# Update the internal dict
Expand All @@ -112,12 +119,16 @@
self._kwargs.reorder(**{k: v for k, v in zip(update_key, values)})
# ADD EDGE
self._global_object.map.add_edge(self, value)
self._global_object.map.reset_type(value, 'created_internal')
self._global_object.map.reset_type(value, "created_internal")
value.interface = self.interface
else:
raise AttributeError('Only EasyScience objects can be put into an EasyScience group')
raise AttributeError(
"Only EasyScience objects can be put into an EasyScience group"
)

def __getitem__(self, idx: Union[int, slice]) -> Union[DescriptorBase, BasedBase]:
def __getitem__(
self, idx: Union[int, slice]
) -> Union[DescriptorBase, BasedBase, NewBase]:
"""
Get an item in the collection based on its index.

Expand All @@ -128,30 +139,36 @@
"""
if isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
return self.__class__(getattr(self, 'name'), *[self[i] for i in range(start, stop, step)])
return self.__class__(
getattr(self, "name"), *[self[i] for i in range(start, stop, step)]
)
if str(idx) in self._kwargs.keys():
return self._kwargs[str(idx)]
if isinstance(idx, str):
idx = [index for index, item in enumerate(self) if item.name == idx]
noi = len(idx)
if noi == 0:
raise IndexError('Given index does not exist')
raise IndexError("Given index does not exist")

Check warning on line 151 in src/easyscience/base_classes/collection_base.py

View check run for this annotation

Codecov / codecov/patch

src/easyscience/base_classes/collection_base.py#L151

Added line #L151 was not covered by tests
elif noi == 1:
idx = idx[0]
else:
return self.__class__(getattr(self, 'name'), *[self[i] for i in idx])
return self.__class__(getattr(self, "name"), *[self[i] for i in idx])
elif not isinstance(idx, int) or isinstance(idx, bool):
if isinstance(idx, bool):
raise TypeError('Boolean indexing is not supported at the moment')
raise TypeError("Boolean indexing is not supported at the moment")
try:
if idx > len(self):
raise IndexError(f'Given index {idx} is out of bounds')
raise IndexError(f"Given index {idx} is out of bounds")
except TypeError:
raise IndexError('Index must be of type `int`/`slice` or an item name (`str`)')
raise IndexError(
"Index must be of type `int`/`slice` or an item name (`str`)"
)
keys = list(self._kwargs.keys())
return self._kwargs[keys[idx]]

def __setitem__(self, key: int, value: Union[BasedBase, DescriptorBase]) -> None:
def __setitem__(
self, key: int, value: Union[BasedBase, DescriptorBase, NewBase]
) -> None:
"""
Set an item via it's index.

Expand All @@ -163,7 +180,7 @@
if isinstance(value, Number): # noqa: S3827
item = self.__getitem__(key)
item.value = value
elif issubclass(type(value), (BasedBase, DescriptorBase)):
elif issubclass(type(value), (BasedBase, DescriptorBase, NewBase)):
update_key = list(self._kwargs.keys())
values = list(self._kwargs.values())
old_item = values[key]
Expand All @@ -172,12 +189,14 @@
self._kwargs.update(update_dict)
# ADD EDGE
self._global_object.map.add_edge(self, value)
self._global_object.map.reset_type(value, 'created_internal')
self._global_object.map.reset_type(value, "created_internal")
value.interface = self.interface
# REMOVE EDGE
self._global_object.map.prune_vertex_from_edge(self, old_item)
else:
raise NotImplementedError('At the moment only numerical values or EasyScience objects can be set.')
raise NotImplementedError(
"At the moment only numerical values or EasyScience objects can be set."
)

def __delitem__(self, key: int) -> None:
"""
Expand All @@ -202,18 +221,22 @@
"""
return len(self._kwargs.keys())

def _convert_to_dict(self, in_dict, encoder, skip: List[str] = [], **kwargs) -> dict:
def _convert_to_dict(
self, in_dict, encoder, skip: List[str] = [], **kwargs
) -> dict:
"""
Convert ones self into a serialized form.

:return: dictionary of ones self
:rtype: dict
"""
d = {}
if hasattr(self, '_modify_dict'):
if hasattr(self, "_modify_dict"):
# any extra keys defined on the inheriting class
d = self._modify_dict(skip=skip, **kwargs)
in_dict['data'] = [encoder._convert_to_dict(item, skip=skip, **kwargs) for item in self]
in_dict["data"] = [
encoder._convert_to_dict(item, skip=skip, **kwargs) for item in self
]
out_dict = {**in_dict, **d}
return out_dict

Expand All @@ -231,9 +254,15 @@
return tuple(self._kwargs.values())

def __repr__(self) -> str:
return f'{self.__class__.__name__} `{getattr(self, "name")}` of length {len(self)}'
return (
f'{self.__class__.__name__} `{getattr(self, "name")}` of length {len(self)}'
)

def sort(self, mapping: Callable[[Union[BasedBase, DescriptorBase]], Any], reverse: bool = False) -> None:
def sort(
self,
mapping: Callable[[Union[BasedBase, DescriptorBase, NewBase]], Any],
reverse: bool = False,
) -> None:
"""
Sort the collection according to the given mapping.

Expand Down
Loading