Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 55 additions & 3 deletions masfactory/adapters/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from .context.provider import ContextProvider, HistoryProvider
from .context.types import ContextBlock, ContextQuery
from masfactory.core.multimodal import MediaMessageBlock, TextMessageBlock, iter_media_message_blocks
from masfactory.checkpoint.checkpointable import Checkpointable
from copy import deepcopy


class Memory(ContextProvider, ABC):
class Memory(ContextProvider,Checkpointable,ABC):
"""Base interface for memory backends.

Memory is a long-lived stateful adapter that can both:
Expand Down Expand Up @@ -50,7 +51,19 @@ def reset(self):
def get_blocks(self, query: ContextQuery, *, top_k: int = 8) -> list[ContextBlock]:
"""Return context blocks relevant to the query."""
raise NotImplementedError


def get_checkpoint_state(self):
return {
"type":self.__class__.__name__,
"context_label":self._context_label,
"passive":self.passive,
"active":self.active,
}

def load_checkpoint_state(self, state):
self._context_label = state["context_label"]
self.passive = state["passive"]
self.active = state["active"]

class HistoryMemory(Memory, HistoryProvider):
"""Conversation history memory (list-of-dict message format)."""
Expand Down Expand Up @@ -146,6 +159,22 @@ def delete(self, key: str, index: int = -1):
def reset(self):
self._memory = []

def get_checkpoint_state(self):
state=super().get_checkpoint_state()
state.update({
"memory":deepcopy(self._memory),
"memory_size":self._memory_size ,
"top_k":self._top_k,
"merge_historical_media":self._merge_historical_media,
})
return state

def load_checkpoint_state(self, state):
super().load_checkpoint_state(state)
self._memory=deepcopy(state["memory"])
self._memory_size = int(state["memory_size"])
self._top_k = int(state["top_k"])
self._merge_historical_media = bool(state["merge_historical_media"])

class VectorMemory(Memory):
"""Semantic memory backed by embeddings and cosine similarity."""
Expand Down Expand Up @@ -238,3 +267,26 @@ def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
if norm1 == 0 or norm2 == 0:
return 0.0
return float(np.dot(vec1, vec2) / (norm1 * norm2))

def get_checkpoint_state(self)->dict:
state=super().get_checkpoint_state()
state.update({
"memory_size":self._memory_size,
"top_k":self._top_k,
"query_threshold":self._query_threshold,
"memory":deepcopy(self._memory),
"embeddings":{
key:value.tolist() for key,value in self._embeddings.items()
}
})
return state

def load_checkpoint_state(self, state:dict)->None:
super().load_checkpoint_state(state)
self._memory_size = int(state["memory_size"])
self._top_k = int(state["top_k"])
self._query_threshold =float(state["query_threshold"])
self._memory=deepcopy(state["memory"])
self._embeddings={
key:np.array(value) for key,value in state["embeddings"].items()
}
15 changes: 8 additions & 7 deletions masfactory/adapters/model/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,15 @@ def __init__(
"tool_choice": {"name": "tool_choice", "type": (str, dict)},
}

def _encode_responses_content(self, content: object) -> list[dict]:
def _encode_responses_content(self, content: object, role:str|None=None) -> list[dict]:
encoded: list[dict] = []
text_type="output_text" if role=="assistant" else "input_text"
for block in content_blocks(content):
if isinstance(block, str):
encoded.append({"type": "input_text", "text": block})
encoded.append({"type": text_type, "text": block})
continue
if isinstance(block, TextMessageBlock):
encoded.append({"type": "input_text", "text": block.text})
encoded.append({"type": text_type, "text": block.text})
continue
if isinstance(block, MediaMessageBlock):
validate_media_capability(
Expand All @@ -109,9 +110,9 @@ def _encode_responses_content(self, content: object) -> list[dict]:
item["file_data"] = asset_to_base64(asset)
encoded.append(item)
continue
encoded.append({"type": "input_text", "text": str(block)})
encoded.append({"type": text_type, "text": str(block)})
if not encoded:
encoded.append({"type": "input_text", "text": ""})
encoded.append({"type": text_type, "text": ""})
return encoded

def _encode_responses_input(self, messages: list[dict]) -> list[dict]:
Expand All @@ -135,7 +136,7 @@ def _encode_responses_input(self, messages: list[dict]) -> list[dict]:
items.append(
{
"role": "assistant",
"content": [{"type": "input_text", "text": assistant_text}],
"content": [{"type": "output_text", "text": assistant_text}],
}
)
for tool_call in tool_calls:
Expand All @@ -154,7 +155,7 @@ def _encode_responses_input(self, messages: list[dict]) -> list[dict]:
items.append(
{
"role": role,
"content": self._encode_responses_content(message.get("content")),
"content": self._encode_responses_content(message.get("content"),role=role),
}
)
return items
Expand Down
74 changes: 70 additions & 4 deletions masfactory/adapters/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@

from .context.provider import ContextProvider
from .context.types import ContextBlock, ContextQuery
from masfactory.checkpoint.checkpointable import Checkpointable
from copy import deepcopy


class Retrieval(ContextProvider, ABC):
class Retrieval(ContextProvider, Checkpointable,ABC):
"""Read-only retrieval interface for external context (RAG)."""

supports_passive: bool = True
Expand All @@ -32,7 +34,19 @@ def context_label(self) -> str:
def get_blocks(self, query: ContextQuery, *, top_k: int = 8) -> list[ContextBlock]:
"""Return structured context blocks relevant to the query."""
raise NotImplementedError


def get_checkpoint_state(self):
return {
"type":self.__class__.__name__,
"context_label":self._context_label,
"passive":self.passive,
"active":self.active,
}

def load_checkpoint_state(self, state):
self._context_label =state["context_label"]
self.passive = state["passive"]
self.active = state["active"]

class VectorRetriever(Retrieval):
"""In-memory semantic retriever based on embedding cosine similarity."""
Expand Down Expand Up @@ -94,8 +108,26 @@ def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
if norm1 == 0 or norm2 == 0:
return 0.0
return float(np.dot(vec1, vec2) / (norm1 * norm2))



def get_checkpoint_state(self):
state=super().get_checkpoint_state()
state.update({
"documents":deepcopy(self._documents),
"similarity_threshold":float(self._similarity_threshold),
"doc_embeddings":{
key:value.tolist() for key,value in self._doc_embeddings.items()
},
})
return state

def load_checkpoint_state(self, state):
super().load_checkpoint_state(state)
self._documents =deepcopy(state["documents"])
self._similarity_threshold = float(state["similarity_threshold"])
self._doc_embeddings={
key:np.array(value) for key,value in state["doc_embeddings"].items()
}

class FileSystemRetriever(Retrieval):
"""File-system retriever that indexes files in a directory and retrieves by embeddings."""

Expand Down Expand Up @@ -202,6 +234,30 @@ def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
return 0.0
return float(np.dot(vec1, vec2) / (norm1 * norm2))

def get_checkpoint_state(self):
state=super().get_checkpoint_state()
state.update({
"documents":deepcopy(self._documents),
"similarity_threshold":float(self._similarity_threshold),
"doc_embeddings":{
key:value.tolist() for key,value in self._doc_embeddings.items()
},
"cache_path":str(self._cache_path) if self._cache_path else None,
"docs_dir": str(self._docs_dir),
"file_extension":self._file_extension,
})
return state

def load_checkpoint_state(self, state):
super().load_checkpoint_state(state)
self._documents =deepcopy(state["documents"])
self._similarity_threshold = float(state["similarity_threshold"])
self._doc_embeddings={
key:np.array(value) for key,value in state["doc_embeddings"].items()
}
self._cache_path = Path(state["cache_path"]) if state.get("cache_path") else None
self._docs_dir = Path(state["docs_dir"])
self._file_extension = state["file_extension"]

class SimpleKeywordRetriever(Retrieval):
"""Lightweight keyword-frequency retriever for small corpora."""
Expand Down Expand Up @@ -255,3 +311,13 @@ def _compute_relevance(self, query: str, document: str) -> float:
count += len(words) * 2
return count / (len(document_lower.split()) + 1)

def get_checkpoint_state(self):
state=super().get_checkpoint_state()
state.update({
"documents":deepcopy(self._documents)
})
return state

def load_checkpoint_state(self, state):
super().load_checkpoint_state(state)
self._documents=deepcopy(state["documents"])
3 changes: 3 additions & 0 deletions masfactory/checkpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .checkpointable import Checkpointable

__all__ = ["Checkpointable"]
12 changes: 12 additions & 0 deletions masfactory/checkpoint/checkpointable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from abc import ABC,abstractmethod

class Checkpointable(ABC):
@abstractmethod
def get_checkpoint_state(self) -> dict:
"""Export this object's runtime checkpoint state."""
raise NotImplementedError

@abstractmethod
def load_checkpoint_state(self,state:dict) -> None:
"""Restore this object's runtime checkpoint state."""
raise NotImplementedError
45 changes: 45 additions & 0 deletions masfactory/checkpoint/collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
class CheckpointCollector:

def collect(self,root_graph):
state={
"graphs":{},
"nodes":{},
"edges":{},
"components":{},
}
self._collect_graph(root_graph,"root",state)
return state

def _collect_node_components(self, node, node_id: str, state: dict) -> None:
memories = getattr(node, "_memories", None)
if memories:
for index, memory in enumerate(memories):
component_id = f"{node_id}.memories.{index}"
state["components"][component_id] = memory.get_checkpoint_state()

history_memories = getattr(node, "_history_memories", None)
if history_memories:
for index, memory in enumerate(history_memories):
component_id = f"{node_id}.history_memories.{index}"
state["components"][component_id] = memory.get_checkpoint_state()

retrievers = getattr(node, "_retrievers", None)
if retrievers:
for index, retriever in enumerate(retrievers):
component_id = f"{node_id}.retrievers.{index}"
state["components"][component_id] = retriever.get_checkpoint_state()

def _collect_graph(self,graph,graph_id,state):
state["graphs"][graph_id]=graph.get_checkpoint_state()

for node_name,node in graph._nodes.items():
node_id=f'{graph_id}.{node_name}'
state["nodes"][node_id]=node.get_checkpoint_state()
self._collect_node_components(node,node_id,state)
if hasattr(node,"_nodes") and hasattr(node,"_edges"):
self._collect_graph(node,node_id,state)

for index,edge in enumerate(graph._edges):
edge_id=f'{graph_id}.edge.{index}'
state["edges"][edge_id]=edge.get_checkpoint_state()

Loading