From 545e94b69ca4372189580bd1e00b5500c547ace4 Mon Sep 17 00:00:00 2001 From: huangzhenhua111 Date: Sun, 3 May 2026 21:00:37 +0800 Subject: [PATCH 1/3] add checkpoint MVP --- masfactory/adapters/memory.py | 58 +++++++++++++++++++++-- masfactory/adapters/model/openai.py | 15 +++--- masfactory/adapters/retrieval.py | 50 ++++++++++++++++++-- masfactory/checkpoint/__init__.py | 0 masfactory/checkpoint/checkpointable.py | 12 +++++ masfactory/checkpoint/collector.py | 40 ++++++++++++++++ masfactory/checkpoint/manager.py | 61 +++++++++++++++++++++++++ masfactory/checkpoint/restorer.py | 37 +++++++++++++++ masfactory/checkpoint/storage.py | 27 +++++++++++ masfactory/core/edge.py | 21 ++++++++- masfactory/core/node.py | 26 ++++++++++- 11 files changed, 331 insertions(+), 16 deletions(-) create mode 100644 masfactory/checkpoint/__init__.py create mode 100644 masfactory/checkpoint/checkpointable.py create mode 100644 masfactory/checkpoint/collector.py create mode 100644 masfactory/checkpoint/manager.py create mode 100644 masfactory/checkpoint/restorer.py create mode 100644 masfactory/checkpoint/storage.py diff --git a/masfactory/adapters/memory.py b/masfactory/adapters/memory.py index b32bb5b..e912f50 100644 --- a/masfactory/adapters/memory.py +++ b/masfactory/adapters/memory.py @@ -8,9 +8,10 @@ from .context.provider import ContextProvider, HistoryProvider from .context.types import ContextBlock, ContextQuery from masfactory.core.multimodal import MediaMessageBlock, TextMessageBlock, iter_media_message_blocks +from masfactory.checkpoint.checkpointable import Checkpointable +from copy import deepcopy - -class Memory(ContextProvider, ABC): +class Memory(ContextProvider,Checkpointable,ABC): """Base interface for memory backends. Memory is a long-lived stateful adapter that can both: @@ -50,7 +51,19 @@ def reset(self): def get_blocks(self, query: ContextQuery, *, top_k: int = 8) -> list[ContextBlock]: """Return context blocks relevant to the query.""" raise NotImplementedError - + + def get_checkpoint_state(self): + return { + "type":self.__class__.__name__, + "context_label":self._context_label, + "passive":self.passive, + "active":self.active, + } + + def load_checkpoint_state(self, state): + self._context_label = state["context_label"] + self.passive = state["passive"] + self.active = state["active"] class HistoryMemory(Memory, HistoryProvider): """Conversation history memory (list-of-dict message format).""" @@ -146,6 +159,22 @@ def delete(self, key: str, index: int = -1): def reset(self): self._memory = [] + def get_checkpoint_state(self): + state=super().get_checkpoint_state() + state.update({ + "memory":deepcopy(self._memory), + "memory_size":self._memory_size , + "top_k":self._top_k, + "merge_historical_media":self._merge_historical_media, + }) + return state + + def load_checkpoint_state(self, state): + super().load_checkpoint_state(state) + self._memory=deepcopy(state["memory"]) + self._memory_size = int(state["memory_size"]) + self._top_k = int(state["top_k"]) + self._merge_historical_media = bool(state["merge_historical_media"]) class VectorMemory(Memory): """Semantic memory backed by embeddings and cosine similarity.""" @@ -238,3 +267,26 @@ def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float: if norm1 == 0 or norm2 == 0: return 0.0 return float(np.dot(vec1, vec2) / (norm1 * norm2)) + + def get_checkpoint_state(self)->dict: + state=super().get_checkpoint_state() + state.update({ + "memory_size":self._memory_size, + "top_k":self._top_k, + "query_threshold":self._query_threshold, + "memory":deepcopy(self._memory), + "embeddings":{ + key:value.tolist() for key,value in self._embeddings.items() + } + }) + return state + + def load_checkpoint_state(self, state:dict)->None: + super().load_checkpoint_state(state) + self._memory_size = int(state["memory_size"]) + self._top_k = int(state["top_k"]) + self._query_threshold =float(state["query_threshold"]) + self._memory=deepcopy(state["memory"]) + self._embeddings={ + key:np.array(value) for key,value in state["embeddings"].items() + } \ No newline at end of file diff --git a/masfactory/adapters/model/openai.py b/masfactory/adapters/model/openai.py index 5200251..f0d7911 100644 --- a/masfactory/adapters/model/openai.py +++ b/masfactory/adapters/model/openai.py @@ -75,14 +75,15 @@ def __init__( "tool_choice": {"name": "tool_choice", "type": (str, dict)}, } - def _encode_responses_content(self, content: object) -> list[dict]: + def _encode_responses_content(self, content: object, role:str|None=None) -> list[dict]: encoded: list[dict] = [] + text_type="output_text" if role=="assistant" else "input_text" for block in content_blocks(content): if isinstance(block, str): - encoded.append({"type": "input_text", "text": block}) + encoded.append({"type": text_type, "text": block}) continue if isinstance(block, TextMessageBlock): - encoded.append({"type": "input_text", "text": block.text}) + encoded.append({"type": text_type, "text": block.text}) continue if isinstance(block, MediaMessageBlock): validate_media_capability( @@ -109,9 +110,9 @@ def _encode_responses_content(self, content: object) -> list[dict]: item["file_data"] = asset_to_base64(asset) encoded.append(item) continue - encoded.append({"type": "input_text", "text": str(block)}) + encoded.append({"type": text_type, "text": str(block)}) if not encoded: - encoded.append({"type": "input_text", "text": ""}) + encoded.append({"type": text_type, "text": ""}) return encoded def _encode_responses_input(self, messages: list[dict]) -> list[dict]: @@ -135,7 +136,7 @@ def _encode_responses_input(self, messages: list[dict]) -> list[dict]: items.append( { "role": "assistant", - "content": [{"type": "input_text", "text": assistant_text}], + "content": [{"type": "output_text", "text": assistant_text}], } ) for tool_call in tool_calls: @@ -154,7 +155,7 @@ def _encode_responses_input(self, messages: list[dict]) -> list[dict]: items.append( { "role": role, - "content": self._encode_responses_content(message.get("content")), + "content": self._encode_responses_content(message.get("content"),role=role), } ) return items diff --git a/masfactory/adapters/retrieval.py b/masfactory/adapters/retrieval.py index 36bdefe..f874f4e 100644 --- a/masfactory/adapters/retrieval.py +++ b/masfactory/adapters/retrieval.py @@ -11,9 +11,11 @@ from .context.provider import ContextProvider from .context.types import ContextBlock, ContextQuery +from masfactory.checkpoint.checkpointable import Checkpointable +from copy import deepcopy -class Retrieval(ContextProvider, ABC): +class Retrieval(ContextProvider, Checkpointable,ABC): """Read-only retrieval interface for external context (RAG).""" supports_passive: bool = True @@ -32,7 +34,19 @@ def context_label(self) -> str: def get_blocks(self, query: ContextQuery, *, top_k: int = 8) -> list[ContextBlock]: """Return structured context blocks relevant to the query.""" raise NotImplementedError - + + def get_checkpoint_state(self): + return { + "type":self.__class__.__name__, + "context_label":self._context_label, + "passive":self.passive, + "active":self.active, + } + + def load_checkpoint_state(self, state): + self._context_label =state["context_label"] + self.passive = state["passive"] + self.active = state["active"] class VectorRetriever(Retrieval): """In-memory semantic retriever based on embedding cosine similarity.""" @@ -94,8 +108,26 @@ def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float: if norm1 == 0 or norm2 == 0: return 0.0 return float(np.dot(vec1, vec2) / (norm1 * norm2)) - - + + def get_checkpoint_state(self): + state=super().get_checkpoint_state() + state.update({ + "documents":deepcopy(self._documents), + "similarity_threshold":float(self._similarity_threshold), + "doc_embeddings":{ + key:value.tolist() for key,value in self._doc_embeddings.items() + }, + }) + return state + + def load_checkpoint_state(self, state): + super().load_checkpoint_state(state) + self._documents =deepcopy(state["documents"]) + self._similarity_threshold = float(state["similarity_threshold"]) + self._doc_embeddings={ + key:np.array(value) for key,value in state["doc_embeddings"].items() + } + class FileSystemRetriever(Retrieval): """File-system retriever that indexes files in a directory and retrieves by embeddings.""" @@ -255,3 +287,13 @@ def _compute_relevance(self, query: str, document: str) -> float: count += len(words) * 2 return count / (len(document_lower.split()) + 1) + def get_checkpoint_state(self): + state=super().get_checkpoint_state() + state.update({ + "documents":deepcopy(self._documents) + }) + return state + + def load_checkpoint_state(self, state): + super().load_checkpoint_state(state) + self._documents=deepcopy(state["documents"]) \ No newline at end of file diff --git a/masfactory/checkpoint/__init__.py b/masfactory/checkpoint/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/masfactory/checkpoint/checkpointable.py b/masfactory/checkpoint/checkpointable.py new file mode 100644 index 0000000..863a76b --- /dev/null +++ b/masfactory/checkpoint/checkpointable.py @@ -0,0 +1,12 @@ +from abc import ABC,abstractmethod + +class Checkpointable(ABC): + @abstractmethod + def get_checkpoint_state(self) -> dict: + """Export this object's runtime checkpoint state.""" + raise NotImplementedError + + @abstractmethod + def load_checkpoint_state(self,state:dict) -> None: + """Restore this object's runtime checkpoint state.""" + raise NotImplementedError \ No newline at end of file diff --git a/masfactory/checkpoint/collector.py b/masfactory/checkpoint/collector.py new file mode 100644 index 0000000..0e9188a --- /dev/null +++ b/masfactory/checkpoint/collector.py @@ -0,0 +1,40 @@ +class CheckpointCollector: + + def collect(self,root_graph): + state={ + "graphs":{}, + "nodes":{}, + "edges":{}, + "components":{}, + } + state["graphs"]["root"]=root_graph.get_checkpoint_state() + + for node_name,node in root_graph._nodes.items(): + node_id=f'root.{node_name}' + state["nodes"][node_id]=node.get_checkpoint_state() + self._collect_node_components(node,node_id,state) + + for index,edge in enumerate(root_graph._edges): + edge_id=f'root.edge.{index}' + state["edges"][edge_id]=edge.get_checkpoint_state() + + return state + + def _collect_node_components(self, node, node_id: str, state: dict) -> None: + memories = getattr(node, "_memories", None) + if memories: + for index, memory in enumerate(memories): + component_id = f"{node_id}.memories.{index}" + state["components"][component_id] = memory.get_checkpoint_state() + + history_memories = getattr(node, "_history_memories", None) + if history_memories: + for index, memory in enumerate(history_memories): + component_id = f"{node_id}.history_memories.{index}" + state["components"][component_id] = memory.get_checkpoint_state() + + retrievers = getattr(node, "_retrievers", None) + if retrievers: + for index, retriever in enumerate(retrievers): + component_id = f"{node_id}.retrievers.{index}" + state["components"][component_id] = retriever.get_checkpoint_state() diff --git a/masfactory/checkpoint/manager.py b/masfactory/checkpoint/manager.py new file mode 100644 index 0000000..3717be9 --- /dev/null +++ b/masfactory/checkpoint/manager.py @@ -0,0 +1,61 @@ +from masfactory.core.node import Node +from masfactory.core.gate import Gate +from masfactory.checkpoint.collector import CheckpointCollector +from masfactory.checkpoint.restorer import CheckpointRestorer + +class CheckpointManager: + + def __init__(self,root_graph,storage): + self.collector=CheckpointCollector() + self.restorer=CheckpointRestorer() + self.root_graph=root_graph + self.storage=storage + self.last_checkpoint_path=None + + def save(self,trigger=None): + checkpoint_state=self.collector.collect(self.root_graph) + path_str=self.storage.save(checkpoint_state) + self.last_checkpoint_path=path_str + return path_str + + def load(self,checkpoint_path): + checkpoint_state=self.storage.load(checkpoint_path) + self.restorer.restore(self.root_graph,checkpoint_state) + self.last_checkpoint_path = checkpoint_path + return checkpoint_state + + def load_last(self): + path_str=self.storage.get_last_path() + if path_str is None: + raise FileNotFoundError("No checkpoint file found.") + return self.load(path_str) + + def attach_hooks(self): + self.root_graph.hook_register( + Node.Hook.EXECUTE.AFTER, + self._save_after_execute, + recursion=True + ) + + def _save_after_execute(self,node,result,outer_env=None): + self.save(trigger=node) + + def resume(self): + graph=self.root_graph + max_iterations=10000 + for _ in range(max_iterations): + if graph._exit.is_ready or graph._gate !=Gate.OPEN: + break + + executed_any = False + for node in graph._nodes.values(): + if node.is_ready and graph._gate==Gate.OPEN: + node.execute(graph.attributes) + executed_any=True + break + + if not executed_any: + break + if graph._exit.is_ready: + graph._exit.execute(graph.attributes) + return graph._exit.output.copy(),graph.attributes.copy() \ No newline at end of file diff --git a/masfactory/checkpoint/restorer.py b/masfactory/checkpoint/restorer.py new file mode 100644 index 0000000..8969e0f --- /dev/null +++ b/masfactory/checkpoint/restorer.py @@ -0,0 +1,37 @@ +class CheckpointRestorer: + + def restore(self,root_graph,state): + + root_graph.load_checkpoint_state(state["graphs"]["root"]) + + for node_name,node in root_graph._nodes.items(): + node_id=f'root.{node_name}' + node.load_checkpoint_state(state["nodes"][node_id]) + self._restore_node_components(node,node_id,state) + + for index,edge in enumerate(root_graph._edges): + edge_id=f'root.edge.{index}' + edge.load_checkpoint_state(state["edges"][edge_id]) + + def _restore_node_components(self,node,node_id,state): + + memories=getattr(node,"_memories",None) + if memories: + for index,memory in enumerate(memories): + component_id=f'{node_id}.memories.{index}' + if component_id in state["components"]: + memory.load_checkpoint_state(state["components"][component_id]) + + history_memories=getattr(node,"_history_memories",None) + if history_memories: + for index,history in enumerate(history_memories): + component_id=f'{node_id}.history_memories.{index}' + if component_id in state["components"]: + history.load_checkpoint_state(state["components"][component_id]) + + retrievers=getattr(node,"_retrievers",None) + if retrievers: + for index,retriever in enumerate(retrievers): + component_id=f'{node_id}.retrievers.{index}' + if component_id in state["components"]: + retriever.load_checkpoint_state(state["components"][component_id]) diff --git a/masfactory/checkpoint/storage.py b/masfactory/checkpoint/storage.py new file mode 100644 index 0000000..f8febe9 --- /dev/null +++ b/masfactory/checkpoint/storage.py @@ -0,0 +1,27 @@ +from pathlib import Path +import json +from datetime import datetime + +class FileCheckpointStorage: + def __init__(self,checkpoint_dir:str): + self.checkpoint_dir=Path(checkpoint_dir) + self.checkpoint_dir.mkdir(parents=True,exist_ok=True) + + def save(self,checkpoint_state:dict)->str: + timestamp=datetime.now().strftime("%Y%m%d_%H%M%S_%f") + path=self.checkpoint_dir/f"checkpoint{timestamp}.json" + with path.open('w',encoding="utf-8") as f: + json.dump(checkpoint_state,f,ensure_ascii=False,indent=2) + return str(path) + + def load(self,checkpoint_path:str)->dict: + path=Path(checkpoint_path) + with path.open('r',encoding="utf-8") as f: + return json.load(f) + + def get_last_path(self): + paths = list(self.checkpoint_dir.glob("checkpoint*.json")) + if not paths: + return None + latest_path = max(paths, key=lambda path: path.stat().st_mtime) + return str(latest_path) diff --git a/masfactory/core/edge.py b/masfactory/core/edge.py index 633cf79..fd96a9c 100644 --- a/masfactory/core/edge.py +++ b/masfactory/core/edge.py @@ -6,10 +6,12 @@ from .gate import Gate from .multimodal import FieldSpec, normalize_field_specs, validate_field_value +from masfactory.checkpoint.checkpointable import Checkpointable +from copy import deepcopy if TYPE_CHECKING: from .node import Node -class Edge: +class Edge(Checkpointable): """Directed message channel between two nodes. An `Edge` buffers at most one in-flight message at a time. It can be opened/closed via a @@ -158,3 +160,20 @@ def reset(self): self._gate = Gate.OPEN def reset_gate(self): self._gate = Gate.OPEN + + def get_checkpoint_state(self) -> dict: + return { + "type": self.__class__.__name__, + "sender": self._sender.name, + "receiver": self._receiver.name, + "is_congested" : self._is_congested, + "gate" : self._gate.value, + "keys": deepcopy(self._keys), + "message" : deepcopy(self._message), + } + + def load_checkpoint_state(self,state:dict) -> None: + self._is_congested=state["is_congested"] + self._gate=Gate(state["gate"]) + self._keys=deepcopy(state["keys"]) + self._message=deepcopy(state["message"]) diff --git a/masfactory/core/node.py b/masfactory/core/node.py index 3c8d21b..dc1eb83 100644 --- a/masfactory/core/node.py +++ b/masfactory/core/node.py @@ -6,6 +6,8 @@ from masfactory.utils.hook import masf_hook, HookManager, HookStage from masfactory.utils.naming import validate_name from masfactory.utils.selector import Selector, build_selector +from masfactory.checkpoint.checkpointable import Checkpointable +from copy import deepcopy if TYPE_CHECKING: from .edge import Edge def merge_message(input:dict[str,object], message:dict[str,object]): @@ -24,7 +26,7 @@ def merge_message(input:dict[str,object], message:dict[str,object]): else: input[key] = message[key] return input -class Node(ABC): +class Node(Checkpointable,ABC): """Base node type for MASFactory graphs. A node consumes attributes from an outer graph scope (via `pull_keys`), executes its core @@ -391,3 +393,25 @@ def hook_register( ) if selector.match(self): self._hooks.register(hook_key, func) + + def get_checkpoint_state(self) -> dict: + return { + "type":self.__class__.__name__, + "name":self.name, + "is_built": self._is_built, + "gate": self._gate.value, + "attributes": deepcopy(self._attributes_store), + "default_attributes": deepcopy(self._default_attributes), + "pull_keys": deepcopy(self._pull_keys), + "push_keys": deepcopy(self._push_keys), + } + + def load_checkpoint_state(self, state: dict) -> None: + self._is_built = state["is_built"] + self._gate = Gate(state["gate"]) + self._attributes_store = deepcopy(state["attributes"]) + self._default_attributes = deepcopy(state["default_attributes"]) + self._pull_keys = deepcopy(state["pull_keys"]) + self._push_keys = deepcopy(state["push_keys"]) + + From 3e5655c64753c1f094ff95b66099f527bc1fe13b Mon Sep 17 00:00:00 2001 From: huangzhenhua111 Date: Tue, 5 May 2026 13:23:09 +0800 Subject: [PATCH 2/3] complete checkpoint nested graph support --- masfactory/adapters/retrieval.py | 24 ++++++++++++++++++++++++ masfactory/checkpoint/collector.py | 27 ++++++++++++++++----------- masfactory/checkpoint/manager.py | 25 +++++++++++++++++++++---- masfactory/checkpoint/restorer.py | 27 +++++++++++++++------------ masfactory/components/graphs/graph.py | 1 + 5 files changed, 77 insertions(+), 27 deletions(-) diff --git a/masfactory/adapters/retrieval.py b/masfactory/adapters/retrieval.py index f874f4e..5584ff0 100644 --- a/masfactory/adapters/retrieval.py +++ b/masfactory/adapters/retrieval.py @@ -234,6 +234,30 @@ def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float: return 0.0 return float(np.dot(vec1, vec2) / (norm1 * norm2)) + def get_checkpoint_state(self): + state=super().get_checkpoint_state() + state.update({ + "documents":deepcopy(self._documents), + "similarity_threshold":float(self._similarity_threshold), + "doc_embeddings":{ + key:value.tolist() for key,value in self._doc_embeddings.items() + }, + "cache_path":str(self._cache_path) if self._cache_path else None, + "docs_dir": str(self._docs_dir), + "file_extension":self._file_extension, + }) + return state + + def load_checkpoint_state(self, state): + super().load_checkpoint_state(state) + self._documents =deepcopy(state["documents"]) + self._similarity_threshold = float(state["similarity_threshold"]) + self._doc_embeddings={ + key:np.array(value) for key,value in state["doc_embeddings"].items() + } + self._cache_path = Path(state["cache_path"]) if state.get("cache_path") else None + self._docs_dir = Path(state["docs_dir"]) + self._file_extension = state["file_extension"] class SimpleKeywordRetriever(Retrieval): """Lightweight keyword-frequency retriever for small corpora.""" diff --git a/masfactory/checkpoint/collector.py b/masfactory/checkpoint/collector.py index 0e9188a..b42109f 100644 --- a/masfactory/checkpoint/collector.py +++ b/masfactory/checkpoint/collector.py @@ -7,17 +7,7 @@ def collect(self,root_graph): "edges":{}, "components":{}, } - state["graphs"]["root"]=root_graph.get_checkpoint_state() - - for node_name,node in root_graph._nodes.items(): - node_id=f'root.{node_name}' - state["nodes"][node_id]=node.get_checkpoint_state() - self._collect_node_components(node,node_id,state) - - for index,edge in enumerate(root_graph._edges): - edge_id=f'root.edge.{index}' - state["edges"][edge_id]=edge.get_checkpoint_state() - + self._collect_graph(root_graph,"root",state) return state def _collect_node_components(self, node, node_id: str, state: dict) -> None: @@ -38,3 +28,18 @@ def _collect_node_components(self, node, node_id: str, state: dict) -> None: for index, retriever in enumerate(retrievers): component_id = f"{node_id}.retrievers.{index}" state["components"][component_id] = retriever.get_checkpoint_state() + + def _collect_graph(self,graph,graph_id,state): + state["graphs"][graph_id]=graph.get_checkpoint_state() + + for node_name,node in graph._nodes.items(): + node_id=f'{graph_id}.{node_name}' + state["nodes"][node_id]=node.get_checkpoint_state() + self._collect_node_components(node,node_id,state) + if hasattr(node,"_nodes") and hasattr(node,"_edges"): + self._collect_graph(node,node_id,state) + + for index,edge in enumerate(graph._edges): + edge_id=f'{graph_id}.edge.{index}' + state["edges"][edge_id]=edge.get_checkpoint_state() + diff --git a/masfactory/checkpoint/manager.py b/masfactory/checkpoint/manager.py index 3717be9..1d6ac80 100644 --- a/masfactory/checkpoint/manager.py +++ b/masfactory/checkpoint/manager.py @@ -41,7 +41,14 @@ def _save_after_execute(self,node,result,outer_env=None): self.save(trigger=node) def resume(self): - graph=self.root_graph + self._resume_graph(self.root_graph) + + if self.root_graph._exit.is_ready: + self.root_graph._exit.execute(self.root_graph.attributes) + + return self.root_graph._exit.output.copy(),self.root_graph.attributes.copy() + + def _resume_graph(self,graph): max_iterations=10000 for _ in range(max_iterations): if graph._exit.is_ready or graph._gate !=Gate.OPEN: @@ -49,6 +56,17 @@ def resume(self): executed_any = False for node in graph._nodes.values(): + if hasattr(node,"_nodes") and hasattr(node,"_edges"): + before_exit_ready=node._exit.is_ready + self._resume_graph(node) + if node._exit.is_ready and not before_exit_ready: + node._exit.execute(node.attributes) + + if node._exit.output: + node._message_dispatch_out(node._exit.output) + executed_any=True + break + if node.is_ready and graph._gate==Gate.OPEN: node.execute(graph.attributes) executed_any=True @@ -56,6 +74,5 @@ def resume(self): if not executed_any: break - if graph._exit.is_ready: - graph._exit.execute(graph.attributes) - return graph._exit.output.copy(),graph.attributes.copy() \ No newline at end of file + + \ No newline at end of file diff --git a/masfactory/checkpoint/restorer.py b/masfactory/checkpoint/restorer.py index 8969e0f..cf0408a 100644 --- a/masfactory/checkpoint/restorer.py +++ b/masfactory/checkpoint/restorer.py @@ -1,20 +1,9 @@ class CheckpointRestorer: def restore(self,root_graph,state): - - root_graph.load_checkpoint_state(state["graphs"]["root"]) - - for node_name,node in root_graph._nodes.items(): - node_id=f'root.{node_name}' - node.load_checkpoint_state(state["nodes"][node_id]) - self._restore_node_components(node,node_id,state) - - for index,edge in enumerate(root_graph._edges): - edge_id=f'root.edge.{index}' - edge.load_checkpoint_state(state["edges"][edge_id]) + self._restore_graph(root_graph,"root",state) def _restore_node_components(self,node,node_id,state): - memories=getattr(node,"_memories",None) if memories: for index,memory in enumerate(memories): @@ -35,3 +24,17 @@ def _restore_node_components(self,node,node_id,state): component_id=f'{node_id}.retrievers.{index}' if component_id in state["components"]: retriever.load_checkpoint_state(state["components"][component_id]) + + def _restore_graph(self,graph,graph_id,state): + graph.load_checkpoint_state(state["graphs"][graph_id]) + + for node_name,node in graph._nodes.items(): + node_id=f'{graph_id}.{node_name}' + node.load_checkpoint_state(state["nodes"][node_id]) + self._restore_node_components(node,node_id,state) + if hasattr(node,"_nodes") and hasattr(node,"_edges"): + self._restore_graph(node,node_id,state) + + for index,edge in enumerate(graph._edges): + edge_id=f'{graph_id}.edge.{index}' + edge.load_checkpoint_state(state["edges"][edge_id]) \ No newline at end of file diff --git a/masfactory/components/graphs/graph.py b/masfactory/components/graphs/graph.py index 5a3efd2..6e4a146 100644 --- a/masfactory/components/graphs/graph.py +++ b/masfactory/components/graphs/graph.py @@ -161,6 +161,7 @@ def build(self): def check_built(self) -> bool: return super().check_built() and self._entry.check_built() and self._exit.check_built() + @masf_hook(Node.Hook.FORWARD) def _forward(self, input:dict[str,object]) -> dict[str,object]: """Run one graph invocation and return exit output.""" From eaf93e00118430e45e03c028d762b7dc19a02673 Mon Sep 17 00:00:00 2001 From: huangzhenhua111 Date: Tue, 5 May 2026 15:51:04 +0800 Subject: [PATCH 3/3] Add checkpoint save granularity --- masfactory/checkpoint/__init__.py | 3 +++ masfactory/checkpoint/manager.py | 30 ++++++++++++++++++++++++------ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/masfactory/checkpoint/__init__.py b/masfactory/checkpoint/__init__.py index e69de29..c4f4795 100644 --- a/masfactory/checkpoint/__init__.py +++ b/masfactory/checkpoint/__init__.py @@ -0,0 +1,3 @@ +from .checkpointable import Checkpointable + +__all__ = ["Checkpointable"] diff --git a/masfactory/checkpoint/manager.py b/masfactory/checkpoint/manager.py index 1d6ac80..2748fe3 100644 --- a/masfactory/checkpoint/manager.py +++ b/masfactory/checkpoint/manager.py @@ -5,7 +5,10 @@ class CheckpointManager: - def __init__(self,root_graph,storage): + def __init__(self,root_graph,storage,save_granularity="node"): + if save_granularity not in {"node","graph"}: + raise ValueError("save_granularity must be 'node' or 'graph'") + self.save_granularity=save_granularity self.collector=CheckpointCollector() self.restorer=CheckpointRestorer() self.root_graph=root_graph @@ -31,12 +34,28 @@ def load_last(self): return self.load(path_str) def attach_hooks(self): - self.root_graph.hook_register( + if self.save_granularity=="node": + self.root_graph.hook_register( + Node.Hook.EXECUTE.AFTER, + self._save_after_execute, + recursion=True + ) + + elif self.save_granularity=="graph": + self._attach_graph_hooks(self.root_graph) + + else: + raise ValueError("save_granularity must be 'node' or 'graph'") + + def _attach_graph_hooks(self,graph): + graph.hook_register( Node.Hook.EXECUTE.AFTER, self._save_after_execute, - recursion=True ) - + for node in graph._nodes.values(): + if hasattr(node,"_nodes") and hasattr(node,"_edges"): + self._attach_graph_hooks(node) + def _save_after_execute(self,node,result,outer_env=None): self.save(trigger=node) @@ -74,5 +93,4 @@ def _resume_graph(self,graph): if not executed_any: break - - \ No newline at end of file + \ No newline at end of file