Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 79 additions & 18 deletions opencane/agent/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __init__(
self._mcp_connected = False
self._mcp_connecting = False
self._consolidating: set[str] = set() # Session keys with consolidation in progress
self._consolidation_locks: dict[str, asyncio.Lock] = {}
self._background_tasks: list[asyncio.Task[Any]] = []
self._register_default_tools()

Expand Down Expand Up @@ -407,15 +408,35 @@ def _schedule_consolidation(self, session: Session, *, archive_all: bool = False
if session.key in self._consolidating:
return
self._consolidating.add(session.key)
lock = self._get_consolidation_lock(session.key)

async def _run() -> None:
try:
await self._consolidate_memory(session, archive_all=archive_all)
async with lock:
await self._consolidate_memory(session, archive_all=archive_all)
finally:
self._consolidating.discard(session.key)
self._prune_consolidation_lock(session.key)

self._track_background_task(asyncio.create_task(_run()))

def _get_consolidation_lock(self, session_key: str) -> asyncio.Lock:
lock = self._consolidation_locks.get(session_key)
if lock is None:
lock = asyncio.Lock()
self._consolidation_locks[session_key] = lock
return lock

def _prune_consolidation_lock(self, session_key: str) -> None:
"""Drop unused per-session lock entries to avoid unbounded growth."""
lock = self._consolidation_locks.get(session_key)
if lock is None:
return
waiters = getattr(lock, "_waiters", None)
if lock.locked() or bool(waiters):
return
self._consolidation_locks.pop(session_key, None)

def _track_background_task(self, task: asyncio.Task[Any]) -> None:
"""Track a background task so shutdown can drain in-flight work."""
self._background_tasks.append(task)
Expand Down Expand Up @@ -693,17 +714,39 @@ async def _process_message(
# Handle slash commands
cmd = msg.content.strip().lower()
if cmd == "/new":
# Capture messages before clearing (avoid race condition with background task)
messages_to_archive = session.messages.copy()
session.clear()
self.sessions.save(session)
self.sessions.invalidate(session.key)

temp_session = Session(key=session.key)
temp_session.messages = messages_to_archive
self._schedule_consolidation(temp_session, archive_all=True)
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
content="New session started. Memory consolidation in progress.")
lock = self._get_consolidation_lock(session.key)
messages_to_archive: list[dict[str, Any]] = []
archive_succeeded = True
try:
async with lock:
messages_to_archive = session.messages[session.last_consolidated:].copy()
if messages_to_archive:
temp_session = Session(key=session.key)
temp_session.messages = messages_to_archive
archive_succeeded = await self._consolidate_memory(
temp_session,
archive_all=True,
)
if not archive_succeeded:
return OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content=(
"Could not start a new session because memory archival failed. "
"Please try again."
),
)
session.clear()
self.sessions.save(session)
self.sessions.invalidate(session.key)
finally:
self._prune_consolidation_lock(session.key)

return OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content="New session started. Memory archived.",
)
if cmd == "/help":
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
content="🦯 OpenCane commands:\n/new — Start a new conversation\n/help — Show available commands")
Expand All @@ -712,6 +755,9 @@ async def _process_message(
self._schedule_consolidation(session)

self._set_tool_context(msg.channel, msg.chat_id)
if message_tool := self.tools.get("message"):
if isinstance(message_tool, MessageTool):
message_tool.start_turn()
memory_context = await self._build_prompt_memory_context(
query=msg.content,
session_key=key,
Expand Down Expand Up @@ -769,6 +815,19 @@ async def _process_message(
logger.debug(f"layered memory record_turn failed: {e}")
self.sessions.save(session)

suppress_final_reply = False
if message_tool := self.tools.get("message"):
if isinstance(message_tool, MessageTool):
sent_targets = set(message_tool.get_turn_sends())
suppress_final_reply = (msg.channel, msg.chat_id) in sent_targets

if suppress_final_reply:
logger.info(
"Skipping final auto-reply because message tool already sent to "
f"{msg.channel}:{msg.chat_id} in this turn"
)
return None

return OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
Expand Down Expand Up @@ -850,7 +909,7 @@ async def _process_system_message(self, msg: InboundMessage) -> OutboundMessage
content=final_content
)

async def _consolidate_memory(self, session, archive_all: bool = False) -> None:
async def _consolidate_memory(self, session, archive_all: bool = False) -> bool:
"""Consolidate old messages into MEMORY.md + HISTORY.md.

Args:
Expand All @@ -867,16 +926,16 @@ async def _consolidate_memory(self, session, archive_all: bool = False) -> None:
keep_count = self.memory_window // 2
if len(session.messages) <= keep_count:
logger.debug(f"Session {session.key}: No consolidation needed (messages={len(session.messages)}, keep={keep_count})")
return
return True

messages_to_process = len(session.messages) - session.last_consolidated
if messages_to_process <= 0:
logger.debug(f"Session {session.key}: No new messages to consolidate (last_consolidated={session.last_consolidated}, total={len(session.messages)})")
return
return True

old_messages = session.messages[session.last_consolidated:-keep_count]
if not old_messages:
return
return True
logger.info(f"Memory consolidation started: {len(session.messages)} total, {len(old_messages)} new to consolidate, {keep_count} keep")

lines = []
Expand Down Expand Up @@ -921,13 +980,13 @@ async def _consolidate_memory(self, session, archive_all: bool = False) -> None:
text = (response.content or "").strip()
if not text:
logger.warning("Memory consolidation: LLM returned empty response, skipping")
return
return False
if text.startswith("```"):
text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
result = json_repair.loads(text)
if not isinstance(result, dict):
logger.warning(f"Memory consolidation: unexpected response type, skipping. Response: {text[:200]}")
return
return False

if entry := result.get("history_entry"):
if not isinstance(entry, str):
Expand All @@ -944,8 +1003,10 @@ async def _consolidate_memory(self, session, archive_all: bool = False) -> None:
else:
session.last_consolidated = len(session.messages) - keep_count
logger.info(f"Memory consolidation done: {len(session.messages)} messages, last_consolidated={session.last_consolidated}")
return True
except Exception as e:
logger.error(f"Memory consolidation failed: {e}")
return False

async def process_direct(
self,
Expand Down
10 changes: 10 additions & 0 deletions opencane/agent/tools/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
self._default_channel = default_channel
self._default_chat_id = default_chat_id
self._pre_send_filter = pre_send_filter
self._turn_sends: list[tuple[str, str]] = []

def set_context(self, channel: str, chat_id: str) -> None:
"""Set the current message context."""
Expand All @@ -34,6 +35,14 @@ def set_pre_send_filter(self, callback: Callable[[str, str, str], str] | None) -
"""Set optional content filter before outbound send."""
self._pre_send_filter = callback

def start_turn(self) -> None:
"""Reset per-turn send tracking."""
self._turn_sends.clear()

def get_turn_sends(self) -> list[tuple[str, str]]:
"""Get (channel, chat_id) targets sent in the current turn."""
return list(self._turn_sends)

@property
def name(self) -> str:
return "message"
Expand Down Expand Up @@ -94,6 +103,7 @@ async def execute(

try:
await self._send_callback(msg)
self._turn_sends.append((channel, chat_id))
return f"Message sent to {channel}:{chat_id}"
except Exception as e:
return f"Error sending message: {str(e)}"
91 changes: 89 additions & 2 deletions tests/test_agent_loop_background_archival.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,104 @@ async def _fake_consolidate(target_session, archive_all: bool = False): # type:
del target_session, archive_all
await asyncio.sleep(0.05)
archived.set()
return True

monkeypatch.setattr(loop, "_consolidate_memory", _fake_consolidate)

response = await loop._process_message(
InboundMessage(channel="cli", sender_id="u1", chat_id="chat-archive", content="/new")
InboundMessage(channel="cli", sender_id="u1", chat_id="chat-archive", content="hello")
)
assert response is not None
assert "new session started" in response.content.lower()
assert response.content == "ok"
assert not archived.is_set()

await loop.close_mcp()
assert archived.is_set()
assert loop._background_tasks == []


@pytest.mark.asyncio
async def test_new_keeps_session_when_archival_fails(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
loop = AgentLoop(
bus=MessageBus(),
provider=_Provider(),
workspace=tmp_path,
memory_window=2,
)

session = loop.sessions.get_or_create("cli:chat-archive")
session.add_message("user", "u1")
session.add_message("assistant", "a1")
loop.sessions.save(session)

async def _fake_consolidate(target_session, archive_all: bool = False): # type: ignore[no-untyped-def]
del target_session, archive_all
return False

monkeypatch.setattr(loop, "_consolidate_memory", _fake_consolidate)

response = await loop._process_message(
InboundMessage(channel="cli", sender_id="u1", chat_id="chat-archive", content="/new")
)
assert response is not None
assert "could not start a new session" in response.content.lower()

session_after = loop.sessions.get_or_create("cli:chat-archive")
assert len(session_after.messages) == 2


@pytest.mark.asyncio
async def test_new_waits_for_inflight_consolidation_and_archives_tail_only(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
loop = AgentLoop(
bus=MessageBus(),
provider=_Provider(),
workspace=tmp_path,
memory_window=4,
)

session = loop.sessions.get_or_create("cli:chat-archive")
for i in range(6):
session.add_message("user", f"user-{i}")
session.add_message("assistant", f"assistant-{i}")
loop.sessions.save(session)

started = asyncio.Event()
release = asyncio.Event()
archived_count = -1

async def _fake_consolidate(target_session, archive_all: bool = False): # type: ignore[no-untyped-def]
nonlocal archived_count
if archive_all:
archived_count = len(target_session.messages)
return True
started.set()
await release.wait()
target_session.last_consolidated = len(target_session.messages) - 2
return True

monkeypatch.setattr(loop, "_consolidate_memory", _fake_consolidate)

first = await loop._process_message(
InboundMessage(channel="cli", sender_id="u1", chat_id="chat-archive", content="hello")
)
assert first is not None
assert first.content == "ok"
await started.wait()

pending_new = asyncio.create_task(
loop._process_message(
InboundMessage(channel="cli", sender_id="u1", chat_id="chat-archive", content="/new")
)
)
await asyncio.sleep(0.02)
assert not pending_new.done()

release.set()
response = await pending_new
assert response is not None
assert "new session started" in response.content.lower()
assert archived_count == 2
2 changes: 1 addition & 1 deletion tests/test_agent_loop_interim_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async def chat( # type: ignore[override]
ToolCallRequest(
id="msg-1",
name="message",
arguments={"content": "Tool progress update"},
arguments={"content": "Tool progress update", "chat_id": "chat-progress"},
)
],
)
Expand Down
Loading
Loading