diff --git a/src/easyscience/global_object/map.py b/src/easyscience/global_object/map.py index 7fb224b..33636c6 100644 --- a/src/easyscience/global_object/map.py +++ b/src/easyscience/global_object/map.py @@ -74,6 +74,20 @@ def __init__(self): # A dict with object names as keys and a list of their object types as values, with weak references self.__type_dict = {} + def _snapshot_items(self): + """Return a stable snapshot of __type_dict items. + + Some callers iterate over __type_dict while other threads or + weakref finalizers may modify it. Creating a list snapshot (with + a retry loop) prevents RuntimeError: dictionary changed size during iteration. + """ + while True: + try: + return list(self.__type_dict.items()) + except RuntimeError: + # Dict changed during snapshot creation, retry + continue + def vertices(self) -> List[str]: """Returns the vertices of a map. @@ -109,7 +123,15 @@ def returned_objs(self) -> List[str]: def _nested_get(self, obj_type: str) -> List[str]: """Access a nested object in root by key sequence.""" - return [key for key, item in self.__type_dict.items() if obj_type in item.type] + # Create a stable snapshot of the dict items to avoid RuntimeError + # when the dict is modified during iteration (e.g., by finalizers). + while True: + try: + items = self._snapshot_items() + return [key for key, item in items if obj_type in item.type] + except RuntimeError: + # In case the snapshot itself raises (very rare), retry + continue def get_item_by_key(self, item_id: str) -> object: if item_id in self._store: @@ -143,10 +165,13 @@ def add_vertex(self, obj: object, obj_type: str = None): # but the finalizer hasn't run yet if name in self.__type_dict: del self.__type_dict[name] + self._store[name] = obj - self.__type_dict[name] = _EntryList() # Add objects type to the list of types - self.__type_dict[name].finalizer = weakref.finalize(self._store[name], self.prune, name) - self.__type_dict[name].type = obj_type + + entry_list = _EntryList() + entry_list.finalizer = weakref.finalize(obj, self.prune, name) + entry_list.type = obj_type + self.__type_dict[name] = entry_list # Add objects type to the list of types def add_edge(self, start_obj: object, end_obj: object): if start_obj.unique_name in self.__type_dict: @@ -167,8 +192,11 @@ def __generate_edges(self) -> list: vertices """ edges = [] - for vertex in self.__type_dict: - for neighbour in self.__type_dict[vertex]: + # Iterate over a snapshot of items and snapshot neighbour lists to + # avoid concurrent modification issues. + for vertex, neighbours in self._snapshot_items(): + neighbours_snapshot = list(neighbours) + for neighbour in neighbours_snapshot: if {neighbour, vertex} not in edges: edges.append({vertex, neighbour}) return edges @@ -190,12 +218,10 @@ def prune(self, key: str): def find_isolated_vertices(self) -> list: """returns a list of isolated vertices.""" - graph = self.__type_dict isolated = [] - for vertex in graph: - print(isolated, vertex) - if not graph[vertex]: - isolated += [vertex] + for vertex, neighbours in self._snapshot_items(): + if not list(neighbours): + isolated.append(vertex) return isolated def find_path(self, start_vertex: str, end_vertex: str, path=[]) -> list: @@ -247,9 +273,10 @@ def reverse_route(self, end_vertex: str, start_vertex: Optional[str] = None) -> path_length = sys.maxsize optimum_path = [] if start_vertex is None: - # We now have to find where to begin..... - for possible_start, vertices in self.__type_dict.items(): - if end_vertex in vertices: + # We now have to find where to begin..... Iterate over a snapshot + for possible_start, vertices in self._snapshot_items(): + vertices_snapshot = list(vertices) + if end_vertex in vertices_snapshot: temp_path = self.find_path(possible_start, end_vertex) if len(temp_path) < path_length: path_length = len(temp_path) @@ -270,7 +297,7 @@ def is_connected(self, vertices_encountered=None, start_vertex=None) -> bool: start_vertex = vertices[0] vertices_encountered.add(start_vertex) if len(vertices_encountered) != len(vertices): - for vertex in graph[start_vertex]: + for vertex in list(graph[start_vertex]): if vertex not in vertices_encountered and self.is_connected(vertices_encountered, vertex): return True else: diff --git a/tests/unit_tests/global_object/test_map.py b/tests/unit_tests/global_object/test_map.py index 89537dd..a2783e6 100644 --- a/tests/unit_tests/global_object/test_map.py +++ b/tests/unit_tests/global_object/test_map.py @@ -241,6 +241,24 @@ def test_find_type_unknown_object(self, clear): # When/Then result = global_object.map.find_type(unknown_obj) assert result is None + + def test_returned_objs_access_safe_under_modification(self, clear): + """Ensure accessing returned_objs doesn't raise when entries change size during iteration.""" + objs = [ObjBase(name=f"race_{i}") for i in range(8)] + # Mark all as returned + for o in objs: + global_object.map.change_type(o, 'returned') + + # Repeatedly access returned_objs while deleting objects and forcing GC to + # try to trigger concurrent modification. This used to raise RuntimeError. + for _ in range(200): + _ = global_object.map.returned_objs # should not raise + if _ and objs: + # delete one object and collect to trigger finalizer/prune + del objs[0] + gc.collect() + # If we got here without exceptions, consider the access safe + assert True def test_reset_type(self, clear, base_object): """Test resetting object type"""