diff --git a/api/__init__.py b/api/__init__.py new file mode 100644 index 0000000..0f73110 --- /dev/null +++ b/api/__init__.py @@ -0,0 +1 @@ +"""HTTP API package for serving GMemory as an external memory backend.""" diff --git a/api/prompt_renderer.py b/api/prompt_renderer.py new file mode 100644 index 0000000..c2ef8e1 --- /dev/null +++ b/api/prompt_renderer.py @@ -0,0 +1,61 @@ +from mas.memory.common import MASMessage + + +TASK_SOLVE_WITH_INSIGHTS = """ +## Successful Examples (Reference Cases) +Below are some examples of similar tasks that were successfully completed. +Please use these as references to guide your thinking and approach to the current task: + +{few_shots} +--- + +## Your Own Past Successes (Execution Patterns) +Here are examples of successful execution processes you've previously used on similar tasks. +Pay special attention to the step-by-step procedures and strategies, especially when encountering obstacles: + +{memory_few_shots} +--- + +## Key Insights from Related Tasks +The following are insights gathered during the execution of similar tasks. You may refer to them during your task execution to improve problem-solving accuracy. + +{insights} +--- + +## Your Turn: Take Action! +Use the above examples and insights as a foundation, and now work on the following task: +{task_description} +""" + +TASK_CONTEXT = """ +### Task description: +{task_description} + +### Key steps: +{key_steps} + +### Detailed trajectory: +{trajectory} +""" + + +def render_memory_prompt(successful: list[MASMessage], insights: list[str], task_description: str) -> str: + if not successful and not insights: + return "" + + memory_few_shots = "\n\n".join( + f"Task {idx + 1}:\n" + + TASK_CONTEXT.format( + task_description=item.task_description, + key_steps=item.get_extra_field("key_steps"), + trajectory=item.task_trajectory, + ) + for idx, item in enumerate(successful) + ) + insight_text = "\n".join(f"{idx}. {insight}" for idx, insight in enumerate(insights, 1)) + return TASK_SOLVE_WITH_INSIGHTS.format( + few_shots="", + memory_few_shots=memory_few_shots, + insights=insight_text, + task_description=task_description, + ) diff --git a/api/schemas.py b/api/schemas.py new file mode 100644 index 0000000..9044035 --- /dev/null +++ b/api/schemas.py @@ -0,0 +1,66 @@ +from typing import Any, Optional + +from pydantic import BaseModel, ConfigDict, Field, StrictBool + + +class EpisodeStep(BaseModel): + subgoal: Optional[str] = None + action: str + observation: str + reward: float = 0.0 + + +class RetrieveRequest(BaseModel): + task_type: str + goal: str + initial_observation: str + max_chars: int = Field(default=4000, gt=0) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class EpisodeRequest(BaseModel): + task_type: str + goal: str + initial_observation: str + success: StrictBool + progress_rate: Optional[float] = None + steps: list[EpisodeStep] = Field(default_factory=list) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class MemoryStats(BaseModel): + memory_size: int + successful_count: int + failed_count: int + insight_count: int + + +class RetrieveResponse(BaseModel): + memory_prompt: str + stats: MemoryStats + trace_id: str + error: Optional[str] = None + + +class EpisodeResponse(BaseModel): + stored: bool + episode_id: Optional[str] = None + trace_id: str + error: Optional[str] = None + + +class HealthResponse(BaseModel): + ok: bool + backend: str + namespace: str + memory_size: int + error: Optional[str] = None + + +class TraceArtifact(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + request: dict[str, Any] + derived: dict[str, Any] + response: dict[str, Any] + error: Optional[str] = None diff --git a/api/server.py b/api/server.py new file mode 100644 index 0000000..f0621c9 --- /dev/null +++ b/api/server.py @@ -0,0 +1,54 @@ +import os + +from dotenv import load_dotenv +from fastapi import FastAPI, Request +from fastapi.exceptions import RequestValidationError +from fastapi.encoders import jsonable_encoder +from fastapi.responses import JSONResponse + +load_dotenv() +os.environ.setdefault("OPENAI_API_BASE", "") +os.environ.setdefault("OPENAI_API_KEY", "") + +from .schemas import EpisodeRequest, HealthResponse, RetrieveRequest +from .service import GMemoryApiService + + +app = FastAPI(title="GMemory API", version="0.1.0") +service = GMemoryApiService() + + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request: Request, exc: RequestValidationError): + trace_id = service.tracer.new_trace_id() + errors = jsonable_encoder(exc.errors()) + error = f"RequestValidationError: {errors}" + response = {"detail": errors, "trace_id": trace_id, "error": error} + try: + body = await request.json() + except Exception: + body = {} + service.tracer.record( + trace_id, + request.url.path, + body, + {"validation_error": True}, + response, + error, + ) + return JSONResponse(status_code=422, content=response) + + +@app.get("/api/v1/memory/health", response_model=HealthResponse) +def health(): + return service.health() + + +@app.post("/api/v1/memory/retrieve") +def retrieve_memory(request: RetrieveRequest): + return service.retrieve(request) + + +@app.post("/api/v1/memory/episodes") +def save_episode(request: EpisodeRequest): + return service.save_episode(request) diff --git a/api/service.py b/api/service.py new file mode 100644 index 0000000..8ed9b9c --- /dev/null +++ b/api/service.py @@ -0,0 +1,263 @@ +import os +from dataclasses import dataclass +from typing import Optional +from uuid import uuid4 + +from dotenv import load_dotenv + +from mas.memory.common import MASMessage + +from .prompt_renderer import render_memory_prompt +from .schemas import ( + EpisodeRequest, + EpisodeResponse, + MemoryStats, + RetrieveRequest, + RetrieveResponse, +) +from .tracing import ApiTracer + + +@dataclass +class GMemoryApiConfig: + namespace: str = "hiagent-cross-task" + working_dir: str = "./.db/hiagent_gmemory_api" + embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2" + llm_model: str = "gpt-3.5-turbo-0125" + successful_topk: int = 1 + failed_topk: int = 0 + insights_topk: int = 3 + threshold: float = 0.0 + hop: int = 1 + strip_alfworld_prefix_for_retrieval: bool = False + + +class GMemoryApiService: + def __init__(self, config: Optional[GMemoryApiConfig] = None, tracer: Optional[ApiTracer] = None): + self.config = config or GMemoryApiConfig() + self._load_env_config() + self.tracer = tracer or ApiTracer() + self._memory = None + self._init_error = None + + @property + def namespace(self) -> str: + return self.config.namespace + + def health(self) -> dict: + try: + memory_size = self.memory_size + return { + "ok": self._init_error is None, + "backend": "g-memory", + "namespace": self.namespace, + "memory_size": memory_size, + "error": self._init_error, + } + except Exception as exc: + return { + "ok": False, + "backend": "g-memory", + "namespace": self.namespace, + "memory_size": 0, + "error": self._summarize_error(exc), + } + + def retrieve(self, request: RetrieveRequest) -> RetrieveResponse: + trace_id = self.tracer.new_trace_id() + request_dict = request.model_dump() + task_main, task_description, task_main_rule, raw_task_main = self._derive_task_fields( + request.task_type, + request.goal, + request.initial_observation, + request.metadata, + ) + derived = { + "query_task": task_main, + "raw_query_task": raw_task_main, + "task_main_rule": task_main_rule, + "task_description": task_description, + } + + error = None + memory_prompt = "" + stats = self._empty_stats() + + try: + memory_size = self.memory_size + stats.memory_size = memory_size + if memory_size == 0: + error = "empty memory" + else: + success, failed, insights = self._memory.retrieve_memory( + query_task=task_main, + successful_topk=self.config.successful_topk, + failed_topk=self.config.failed_topk, + insight_topk=self.config.insights_topk, + threshold=self.config.threshold, + ) + retrieval_debug = getattr(self._memory, "last_retrieval_debug", None) + if retrieval_debug: + derived["retrieval_debug"] = retrieval_debug + memory_prompt = self._render_memory_prompt(success, insights, task_description) + memory_prompt = memory_prompt[: request.max_chars] + stats = MemoryStats( + memory_size=memory_size, + successful_count=len(success), + failed_count=len(failed), + insight_count=len(insights), + ) + if not memory_prompt: + error = "no retrieval result" + except Exception as exc: + error = self._summarize_error(exc) + stats = self._safe_stats() + memory_prompt = "" + + response = RetrieveResponse( + memory_prompt=memory_prompt, + stats=stats, + trace_id=trace_id, + error=error, + ) + self.tracer.record(trace_id, "/retrieve", request_dict, derived, response.model_dump(), error) + return response + + def save_episode(self, request: EpisodeRequest) -> EpisodeResponse: + trace_id = self.tracer.new_trace_id() + request_dict = request.model_dump() + task_main, task_description, task_main_rule, raw_task_main = self._derive_task_fields( + request.task_type, + request.goal, + request.initial_observation, + request.metadata, + ) + label = request.success + mas_message = MASMessage(task_main=task_main, task_description=task_description, label=label) + mas_message.add_extra_field("task_type", request.task_type) + metadata = dict(request.metadata) + if raw_task_main != task_main: + metadata["raw_task_main"] = raw_task_main + mas_message.add_extra_field("metadata", metadata) + if request.progress_rate is not None: + mas_message.add_extra_field("progress_rate", request.progress_rate) + + for step in request.steps: + if step.subgoal is not None: + mas_message.add_extra_field("last_subgoal", step.subgoal) + mas_message.move_state(step.action, step.observation, reward=step.reward) + + derived = { + "task_main": task_main, + "raw_task_main": raw_task_main, + "task_description": task_description, + "task_main_rule": task_main_rule, + "label": label, + "step_count": len(request.steps), + } + + try: + _ = self.memory_size + self._memory.add_memory(mas_message) + response = EpisodeResponse(stored=True, episode_id=uuid4().hex, trace_id=trace_id) + error = None + except Exception as exc: + error = self._summarize_error(exc) + response = EpisodeResponse(stored=False, episode_id=None, trace_id=trace_id, error=error) + + self.tracer.record(trace_id, "/episodes", request_dict, derived, response.model_dump(), error) + return response + + @property + def memory_size(self) -> int: + try: + return int(self._memory.memory_size) + except Exception: + if self._memory is None: + self._build_memory() + return int(self._memory.memory_size) + raise + + def _build_memory(self) -> None: + load_dotenv() + + from mas.llm import GPTChat + from mas.memory.mas_memory.GMemory import GMemory + from mas.utils import EmbeddingFunc + + self._load_env_config() + + try: + os.makedirs(self.config.working_dir, exist_ok=True) + self._memory = GMemory( + namespace=self.config.namespace, + global_config={"working_dir": self.config.working_dir, "hop": self.config.hop}, + llm_model=GPTChat(model_name=self.config.llm_model), + embedding_func=EmbeddingFunc(self.config.embedding_model), + ) + self._init_error = None + except Exception as exc: + self._init_error = self._summarize_error(exc) + raise + + def _derive_task_fields( + self, + task_type: str, + goal: str, + initial_observation: str, + metadata: dict, + ) -> tuple[str, str, str, str]: + normalized_type = task_type.lower() + metadata_env = str(metadata.get("env", "")).lower() + if normalized_type.startswith("alfworld") or metadata_env == "alfworld": + raw_task_main = f"alfworld-{goal}" + if self.config.strip_alfworld_prefix_for_retrieval: + task_main = goal + rule = "alfworld-prefix-stripped" + else: + task_main = raw_task_main + rule = "alfworld-prefix-goal" + else: + task_main = goal + raw_task_main = task_main + rule = "pddl-goal" + task_description = f"Here is your initial observation: {initial_observation}\n**Here is your task: {goal}" + return task_main, task_description, rule, raw_task_main + + def _render_memory_prompt(self, successful: list[MASMessage], insights: list[str], task_description: str) -> str: + return render_memory_prompt(successful, insights, task_description) + + def _empty_stats(self) -> MemoryStats: + return MemoryStats(memory_size=0, successful_count=0, failed_count=0, insight_count=0) + + def _safe_stats(self) -> MemoryStats: + try: + memory_size = self.memory_size + except Exception: + memory_size = 0 + return MemoryStats(memory_size=memory_size, successful_count=0, failed_count=0, insight_count=0) + + def _summarize_error(self, exc: Exception) -> str: + return f"{exc.__class__.__name__}: {str(exc)[:500]}" + + def _load_env_config(self) -> None: + load_dotenv() + self.config.llm_model = os.getenv("GMEMORY_API_MODEL", self.config.llm_model) + self.config.working_dir = os.getenv("GMEMORY_API_WORKING_DIR", self.config.working_dir) + self.config.namespace = os.getenv("GMEMORY_API_NAMESPACE", self.config.namespace) + self.config.embedding_model = os.getenv("GMEMORY_API_EMBEDDING_MODEL", self.config.embedding_model) + self.config.successful_topk = int(os.getenv("GMEMORY_API_SUCCESSFUL_TOPK", self.config.successful_topk)) + self.config.failed_topk = int(os.getenv("GMEMORY_API_FAILED_TOPK", self.config.failed_topk)) + self.config.insights_topk = int(os.getenv("GMEMORY_API_INSIGHTS_TOPK", self.config.insights_topk)) + self.config.threshold = float(os.getenv("GMEMORY_API_THRESHOLD", self.config.threshold)) + self.config.hop = int(os.getenv("GMEMORY_API_HOP", self.config.hop)) + self.config.strip_alfworld_prefix_for_retrieval = self._env_bool( + "GMEMORY_API_STRIP_ALFWORLD_PREFIX_FOR_RETRIEVAL", + self.config.strip_alfworld_prefix_for_retrieval, + ) + + def _env_bool(self, name: str, default: bool) -> bool: + value = os.getenv(name) + if value is None: + return default + return value.strip().lower() in {"1", "true", "yes", "on"} diff --git a/api/tracing.py b/api/tracing.py new file mode 100644 index 0000000..bf737b7 --- /dev/null +++ b/api/tracing.py @@ -0,0 +1,66 @@ +import json +import os +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Optional +from uuid import uuid4 + + +class ApiTracer: + def __init__( + self, + trace_dir: str = "./.logs/hiagent_gmemory_api", + enabled: bool = True, + full_payload: bool = True, + max_artifact_chars: int = 20000, + ): + self.trace_dir = Path(trace_dir) + self.artifact_dir = self.trace_dir / "artifacts" + self.enabled = enabled + self.full_payload = full_payload + self.max_artifact_chars = max_artifact_chars + + def new_trace_id(self) -> str: + return uuid4().hex + + def record( + self, + trace_id: str, + endpoint: str, + request: dict[str, Any], + derived: dict[str, Any], + response: dict[str, Any], + error: Optional[str] = None, + ) -> None: + if not self.enabled: + return + + self.artifact_dir.mkdir(parents=True, exist_ok=True) + artifact_name = f"{trace_id}.{endpoint.strip('/').split('/')[-1]}.json" + artifact_path = self.artifact_dir / artifact_name + artifact = { + "request": request if self.full_payload else self._truncate_obj(request), + "derived": derived, + "response": response if self.full_payload else self._truncate_obj(response), + "error": error, + } + artifact_text = json.dumps(artifact, ensure_ascii=False, indent=2, default=str) + if len(artifact_text) > self.max_artifact_chars: + artifact_text = artifact_text[: self.max_artifact_chars] + "\n..." + artifact_path.write_text(artifact_text, encoding="utf-8") + + summary = { + "trace_id": trace_id, + "timestamp": datetime.now(timezone.utc).isoformat(), + "endpoint": endpoint, + "artifact": os.fspath(artifact_path), + "error": error, + } + with (self.trace_dir / "traces.jsonl").open("a", encoding="utf-8") as handle: + handle.write(json.dumps(summary, ensure_ascii=False, default=str) + "\n") + + def _truncate_obj(self, obj: Any) -> Any: + text = json.dumps(obj, ensure_ascii=False, default=str) + if len(text) <= self.max_artifact_chars: + return obj + return {"truncated": True, "preview": text[: self.max_artifact_chars]} diff --git a/mas/memory/mas_memory/GMemory.py b/mas/memory/mas_memory/GMemory.py index 75e0251..328a8be 100644 --- a/mas/memory/mas_memory/GMemory.py +++ b/mas/memory/mas_memory/GMemory.py @@ -3,6 +3,7 @@ from langchain.docstore.document import Document import os import copy +import json import re from typing import Iterable import random @@ -60,6 +61,7 @@ def __post_init__(self): ) self.insights_cache: list[str] = [] + self.last_retrieval_debug: dict = {} print(self._get_hyperparams_dict()) @@ -119,6 +121,27 @@ def _retrieve_memory_raw( threshold: float = 0.3 ) -> tuple[list, list, list]: + def get_raw_task_main(doc: Document) -> str | None: + extra_fields = doc.metadata.get("extra_fields", "{}") + try: + parsed_extra_fields = json.loads(extra_fields) + except Exception: + return None + metadata = parsed_extra_fields.get("metadata", {}) + if isinstance(metadata, dict): + return metadata.get("raw_task_main") + return None + + def summarize_doc(doc: Document, similarity: float, passed_threshold: bool = True) -> dict: + return { + "task_main": doc.metadata.get("task_main"), + "raw_task_main": get_raw_task_main(doc), + "comparison_text": doc.page_content, + "similarity": float(similarity), + "passed_threshold": passed_threshold, + "label": doc.metadata.get("label"), + } + def sort_and_filter_by_similarity(docs: list[Document], threshold: float = 0.3) -> list[tuple[Document, float]]: result = [] for doc in docs: @@ -167,6 +190,16 @@ def sort_and_filter_by_similarity(docs: list[Document], threshold: float = 0.3) origin_embedding: list[float] = self.embedding_func.embed_query(query_task) true_tasks_doc_with_score = sort_and_filter_by_similarity(true_tasks_doc, threshold)[:successful_topk] false_tasks_doc_with_score = sort_and_filter_by_similarity(false_tasks_doc, threshold)[:failed_topk] + self.last_retrieval_debug = { + "query_task": query_task, + "threshold": threshold, + "successful_candidates": [ + summarize_doc(doc, score) for doc, score in true_tasks_doc_with_score + ], + "failed_candidates": [ + summarize_doc(doc, score) for doc, score in false_tasks_doc_with_score + ], + } true_task_messages: list[MASMessage] = [] false_task_messages: list[MASMessage] = [] @@ -183,6 +216,9 @@ def sort_and_filter_by_similarity(docs: list[Document], threshold: float = 0.3) # get insights and order by relelvance insights_with_score = self.insights_layer.query_insights_with_score(query_task, top_k=insight_windows) insights = [insight for insight, _ in insights_with_score][:insight_windows] + self.last_retrieval_debug["insights"] = [ + {"text": insight, "score": float(score)} for insight, score in insights_with_score[:insight_windows] + ] return true_task_messages, false_task_messages, insights @@ -237,6 +273,28 @@ def retrieve_memory( # directlt get insights top_k_insights = insights[:insight_topk] self.insights_cache = top_k_insights + debug = getattr(self, "last_retrieval_debug", {}) + if debug: + debug["llm_importance_scores"] = [ + {"task_main": task.task_main, "score": float(score)} + for task, score in zip(successful_task_trajectories, importance_score) + ] + selected_successful = {task.task_main for task in top_success_task_trajectories} + selected_failed = {task.task_main for task in top_fail_task_trajectories} + selected_insights = set(top_k_insights) + debug["selected_successful"] = [ + item for item in debug.get("successful_candidates", []) + if item.get("task_main") in selected_successful + ] + debug["selected_failed"] = [ + item for item in debug.get("failed_candidates", []) + if item.get("task_main") in selected_failed + ] + debug["selected_insights"] = [ + item for item in debug.get("insights", []) + if item.get("text") in selected_insights + ] + self.last_retrieval_debug = debug return top_success_task_trajectories, top_fail_task_trajectories, top_k_insights @@ -889,4 +947,4 @@ def _retrieve_rule_index(self, operation_rule_text: str) -> int: for idx, insight in enumerate(self.insights_memory): if insight['rule'] in operation_rule_text: return idx - return -1 \ No newline at end of file + return -1 diff --git a/requirements.txt b/requirements.txt index ff7e002..cc1946c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ alfworld==0.3.5 attr==0.3.2 camel==0.1.2 datasets==3.5.0 +fastapi==0.135.1 finch_clust==0.2.0 finchpy==0.0.1 graphviz==0.20.3 @@ -26,4 +27,5 @@ seaborn==0.13.2 sentence_transformers==3.4.1 skimage==0.0 tqdm==4.66.5 +uvicorn==0.41.0 wikipedia==1.4.0 diff --git a/tools/analyze_retrieve_similarity.py b/tools/analyze_retrieve_similarity.py new file mode 100644 index 0000000..d47b4e7 --- /dev/null +++ b/tools/analyze_retrieve_similarity.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python3 +"""Recompute similarity for GMemory API retrieve artifacts. + +The script reads .logs/hiagent_gmemory_api/artifacts/*.retrieve.json, +extracts the query_task and the historical tasks rendered into memory_prompt, +then computes cosine similarity using the same EmbeddingFunc used by GMemory. + +Run this on a server where the configured sentence-transformers model is +available locally or can be loaded by SentenceTransformer. +""" + +from __future__ import annotations + +import argparse +import csv +import json +import os +import re +import sys +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any + +import numpy as np +from dotenv import load_dotenv +from sentence_transformers import SentenceTransformer + + +DEFAULT_ARTIFACT_DIR = Path(".logs/hiagent_gmemory_api/artifacts") +DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" + +TASK_BLOCK_RE = re.compile( + r"Task\s+(?P\d+):\s*" + r"### Task description:\s*" + r"(?P.*?)(?=\n### Key steps:|\nTask\s+\d+:|\Z)", + re.DOTALL, +) +TASK_GOAL_RE = re.compile( + r"\*\*Here is your task:\s*(?P.*?)(?=\n|$)", + re.DOTALL, +) +GOAL_PREFIX_RE = re.compile( + r"^\s*The goal is to satisfy the following conditions:\s*", + re.IGNORECASE, +) +ALFWORLD_PREFIX_RE = re.compile(r"^\s*alfworld-", re.IGNORECASE) + + +@dataclass +class SimilarityRow: + trace_id: str + artifact: str + memory_size: int + successful_count: int + failed_count: int + insight_count: int + returned_task_index: int + similarity: float | None + status: str + query_task: str + returned_task: str + query_embedding_text: str + returned_embedding_text: str + error: str + + +class EmbeddingFunc: + def __init__(self, model_type: str): + self.model = SentenceTransformer(model_type) + + def embed_query(self, query: str) -> list[float]: + return self.model.encode(query).tolist() + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Compute embedding cosine similarity for GMemory retrieve artifacts." + ) + parser.add_argument( + "--artifact-dir", + default=str(DEFAULT_ARTIFACT_DIR), + help="Directory containing *.retrieve.json artifacts.", + ) + parser.add_argument( + "--embedding-model", + default=None, + help=( + "SentenceTransformer model/path. Defaults to GMEMORY_API_EMBEDDING_MODEL " + "from .env, then sentence-transformers/all-MiniLM-L6-v2." + ), + ) + parser.add_argument( + "--format", + choices=("table", "json", "csv"), + default="table", + help="Output format.", + ) + parser.add_argument( + "--include-empty", + action="store_true", + help="Include retrieve artifacts that returned no historical task.", + ) + parser.add_argument( + "--strip-goal-prefix", + action="store_true", + help=( + "Remove the fixed PDDL prefix 'The goal is to satisfy the following " + "conditions:' before embedding query and returned tasks." + ), + ) + parser.add_argument( + "--strip-alfworld-prefix", + action="store_true", + help=( + "Remove the fixed ALFWorld namespace prefix 'alfworld-' before " + "embedding query and returned tasks." + ), + ) + parser.add_argument( + "--fail-fast", + action="store_true", + help="Stop on the first malformed artifact instead of reporting a row with status=error.", + ) + return parser.parse_args() + + +def normalize_space(text: str) -> str: + return re.sub(r"\s+", " ", text or "").strip() + + +def normalize_embedding_text( + text: str, + strip_goal_prefix: bool, + strip_alfworld_prefix: bool, +) -> str: + text = normalize_space(text) + if strip_alfworld_prefix: + text = ALFWORLD_PREFIX_RE.sub("", text) + if strip_goal_prefix: + text = GOAL_PREFIX_RE.sub("", text) + return normalize_space(text) + + +def cosine_similarity(vec1: list[float], vec2: list[float]) -> float: + left = np.array(vec1) + right = np.array(vec2) + left_norm = np.linalg.norm(left) + right_norm = np.linalg.norm(right) + if left_norm == 0 or right_norm == 0: + return 0.0 + return float(np.dot(left, right) / (left_norm * right_norm)) + + +def load_artifact(path: Path) -> dict[str, Any]: + return json.loads(path.read_text(encoding="utf-8")) + + +def normalize_task_for_comparison(text: str, query_task: str) -> str: + text = normalize_space(text) + if ALFWORLD_PREFIX_RE.match(query_task) and not ALFWORLD_PREFIX_RE.match(text): + return f"alfworld-{text}" + return text + + +def extract_returned_tasks(memory_prompt: str, query_task: str) -> list[tuple[int, str]]: + tasks: list[tuple[int, str]] = [] + for match in TASK_BLOCK_RE.finditer(memory_prompt or ""): + index = int(match.group("index")) + description = match.group("description").strip() + goal_match = TASK_GOAL_RE.search(description) + task = goal_match.group("goal") if goal_match else description + tasks.append((index, normalize_task_for_comparison(task, query_task))) + return tasks + + +def get_trace_id(path: Path, data: dict[str, Any]) -> str: + response = data.get("response", {}) + return response.get("trace_id") or data.get("trace_id") or path.name.replace(".retrieve.json", "") + + +def get_stats(data: dict[str, Any]) -> dict[str, int]: + stats = data.get("response", {}).get("stats", {}) or {} + return { + "memory_size": int(stats.get("memory_size", 0) or 0), + "successful_count": int(stats.get("successful_count", 0) or 0), + "failed_count": int(stats.get("failed_count", 0) or 0), + "insight_count": int(stats.get("insight_count", 0) or 0), + } + + +def empty_row(path: Path, data: dict[str, Any], status: str, error: str = "") -> SimilarityRow: + stats = get_stats(data) + return SimilarityRow( + trace_id=get_trace_id(path, data), + artifact=str(path), + returned_task_index=0, + similarity=None, + status=status, + query_task=normalize_space(data.get("derived", {}).get("query_task", "")), + returned_task="", + query_embedding_text="", + returned_embedding_text="", + error=error, + **stats, + ) + + +def analyze_artifact( + path: Path, + embedder: EmbeddingFunc, + include_empty: bool, + strip_goal_prefix: bool, + strip_alfworld_prefix: bool, +) -> list[SimilarityRow]: + data = load_artifact(path) + query_task = normalize_space(data.get("derived", {}).get("query_task", "")) + memory_prompt = data.get("response", {}).get("memory_prompt", "") + returned_tasks = extract_returned_tasks(memory_prompt, query_task) + + if not query_task: + return [empty_row(path, data, "parse_failed", "missing derived.query_task")] + + if not returned_tasks: + return [empty_row(path, data, "empty", "no returned task parsed")] if include_empty else [] + + query_embedding_text = normalize_embedding_text( + query_task, + strip_goal_prefix, + strip_alfworld_prefix, + ) + query_embedding = embedder.embed_query(query_embedding_text) + stats = get_stats(data) + rows: list[SimilarityRow] = [] + for index, returned_task in returned_tasks: + returned_embedding_text = normalize_embedding_text( + returned_task, + strip_goal_prefix, + strip_alfworld_prefix, + ) + returned_embedding = embedder.embed_query(returned_embedding_text) + similarity = cosine_similarity(query_embedding, returned_embedding) + rows.append( + SimilarityRow( + trace_id=get_trace_id(path, data), + artifact=str(path), + returned_task_index=index, + similarity=similarity, + status="ok", + query_task=query_embedding_text, + returned_task=returned_embedding_text, + query_embedding_text=query_embedding_text, + returned_embedding_text=returned_embedding_text, + error="", + **stats, + ) + ) + return rows + + +def collect_rows( + artifact_dir: Path, + embedder: EmbeddingFunc, + include_empty: bool, + fail_fast: bool, + strip_goal_prefix: bool, + strip_alfworld_prefix: bool, +) -> list[SimilarityRow]: + rows: list[SimilarityRow] = [] + paths = sorted(artifact_dir.glob("*.retrieve.json"), key=lambda item: item.stat().st_mtime) + for path in paths: + try: + rows.extend( + analyze_artifact( + path, + embedder, + include_empty, + strip_goal_prefix, + strip_alfworld_prefix, + ) + ) + except Exception as exc: + if fail_fast: + raise + rows.append( + SimilarityRow( + trace_id=path.name.replace(".retrieve.json", ""), + artifact=str(path), + memory_size=0, + successful_count=0, + failed_count=0, + insight_count=0, + returned_task_index=0, + similarity=None, + status="error", + query_task="", + returned_task="", + query_embedding_text="", + returned_embedding_text="", + error=f"{exc.__class__.__name__}: {exc}", + ) + ) + return rows + + +def print_table(rows: list[SimilarityRow]) -> None: + if not rows: + print("No rows found.") + return + print("\t".join(["trace_id", "mem", "succ", "task", "similarity", "status", "error"])) + for row in rows: + similarity = "" if row.similarity is None else f"{row.similarity:.6f}" + print( + "\t".join( + [ + row.trace_id, + str(row.memory_size), + str(row.successful_count), + str(row.returned_task_index), + similarity, + row.status, + row.error, + ] + ) + ) + + +def print_csv(rows: list[SimilarityRow]) -> None: + fieldnames = list(SimilarityRow.__dataclass_fields__.keys()) + writer = csv.DictWriter(sys.stdout, fieldnames=fieldnames) + writer.writeheader() + for row in rows: + writer.writerow(asdict(row)) + + +def main() -> None: + args = parse_args() + load_dotenv() + + model = ( + args.embedding_model + or os.getenv("GMEMORY_API_EMBEDDING_MODEL") + or DEFAULT_EMBEDDING_MODEL + ) + embedder = EmbeddingFunc(model) + rows = collect_rows( + Path(args.artifact_dir), + embedder, + args.include_empty, + args.fail_fast, + args.strip_goal_prefix, + args.strip_alfworld_prefix, + ) + + if args.format == "json": + print(json.dumps([asdict(row) for row in rows], indent=2, ensure_ascii=False)) + elif args.format == "csv": + print_csv(rows) + else: + print_table(rows) + + +if __name__ == "__main__": + main()