From 6c333467353552a64cbb7299017794a8efd9e6b8 Mon Sep 17 00:00:00 2001 From: bitifirefly Date: Sun, 22 Mar 2026 02:56:05 +0000 Subject: [PATCH 1/2] fix(agent): suppress duplicate final reply after message tool send --- opencane/agent/loop.py | 16 ++++ opencane/agent/tools/message.py | 10 +++ tests/test_agent_loop_interim_retry.py | 2 +- tests/test_agent_loop_message_dedup.py | 103 +++++++++++++++++++++++++ tests/test_agent_loop_safety.py | 2 +- 5 files changed, 131 insertions(+), 2 deletions(-) create mode 100644 tests/test_agent_loop_message_dedup.py diff --git a/opencane/agent/loop.py b/opencane/agent/loop.py index da4d6fe541..e3961e8952 100644 --- a/opencane/agent/loop.py +++ b/opencane/agent/loop.py @@ -712,6 +712,9 @@ async def _process_message( self._schedule_consolidation(session) self._set_tool_context(msg.channel, msg.chat_id) + if message_tool := self.tools.get("message"): + if isinstance(message_tool, MessageTool): + message_tool.start_turn() memory_context = await self._build_prompt_memory_context( query=msg.content, session_key=key, @@ -769,6 +772,19 @@ async def _process_message( logger.debug(f"layered memory record_turn failed: {e}") self.sessions.save(session) + suppress_final_reply = False + if message_tool := self.tools.get("message"): + if isinstance(message_tool, MessageTool): + sent_targets = set(message_tool.get_turn_sends()) + suppress_final_reply = (msg.channel, msg.chat_id) in sent_targets + + if suppress_final_reply: + logger.info( + "Skipping final auto-reply because message tool already sent to " + f"{msg.channel}:{msg.chat_id} in this turn" + ) + return None + return OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, diff --git a/opencane/agent/tools/message.py b/opencane/agent/tools/message.py index c27a7fb711..6ac77df7c9 100644 --- a/opencane/agent/tools/message.py +++ b/opencane/agent/tools/message.py @@ -20,6 +20,7 @@ def __init__( self._default_channel = default_channel self._default_chat_id = default_chat_id self._pre_send_filter = pre_send_filter + self._turn_sends: list[tuple[str, str]] = [] def set_context(self, channel: str, chat_id: str) -> None: """Set the current message context.""" @@ -34,6 +35,14 @@ def set_pre_send_filter(self, callback: Callable[[str, str, str], str] | None) - """Set optional content filter before outbound send.""" self._pre_send_filter = callback + def start_turn(self) -> None: + """Reset per-turn send tracking.""" + self._turn_sends.clear() + + def get_turn_sends(self) -> list[tuple[str, str]]: + """Get (channel, chat_id) targets sent in the current turn.""" + return list(self._turn_sends) + @property def name(self) -> str: return "message" @@ -94,6 +103,7 @@ async def execute( try: await self._send_callback(msg) + self._turn_sends.append((channel, chat_id)) return f"Message sent to {channel}:{chat_id}" except Exception as e: return f"Error sending message: {str(e)}" diff --git a/tests/test_agent_loop_interim_retry.py b/tests/test_agent_loop_interim_retry.py index 3d7af55dd3..35f476b2c1 100644 --- a/tests/test_agent_loop_interim_retry.py +++ b/tests/test_agent_loop_interim_retry.py @@ -35,7 +35,7 @@ async def chat( # type: ignore[override] ToolCallRequest( id="msg-1", name="message", - arguments={"content": "Tool progress update"}, + arguments={"content": "Tool progress update", "chat_id": "chat-progress"}, ) ], ) diff --git a/tests/test_agent_loop_message_dedup.py b/tests/test_agent_loop_message_dedup.py new file mode 100644 index 0000000000..cf5129f010 --- /dev/null +++ b/tests/test_agent_loop_message_dedup.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path +from typing import Any + +import pytest + +from opencane.agent.loop import AgentLoop +from opencane.bus.events import InboundMessage +from opencane.bus.queue import MessageBus +from opencane.providers.base import LLMProvider, LLMResponse, ToolCallRequest + + +class _MessageThenFinalProvider(LLMProvider): + def __init__( + self, + *, + tool_channel: str | None = None, + tool_chat_id: str | None = None, + ) -> None: + super().__init__(api_key=None, api_base=None) + self._turn = 0 + self._tool_channel = tool_channel + self._tool_chat_id = tool_chat_id + + async def chat( # type: ignore[override] + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + ) -> LLMResponse: + del messages, tools, model, max_tokens, temperature + self._turn += 1 + if self._turn == 1: + args: dict[str, Any] = {"content": "tool-message"} + if self._tool_channel: + args["channel"] = self._tool_channel + if self._tool_chat_id: + args["chat_id"] = self._tool_chat_id + return LLMResponse( + content="", + tool_calls=[ + ToolCallRequest( + id="tool-msg-1", + name="message", + arguments=args, + ) + ], + ) + return LLMResponse(content="final-message") + + def get_default_model(self) -> str: + return "fake-model" + + +@pytest.mark.asyncio +async def test_process_message_suppresses_final_when_message_tool_sent_same_target( + tmp_path: Path, +) -> None: + bus = MessageBus() + loop = AgentLoop( + bus=bus, + provider=_MessageThenFinalProvider(), + workspace=tmp_path, + ) + + response = await loop._process_message( + InboundMessage(channel="cli", sender_id="u1", chat_id="chat-a", content="hello") + ) + assert response is None + + outbound = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0) + assert outbound.channel == "cli" + assert outbound.chat_id == "chat-a" + assert outbound.content == "tool-message" + + +@pytest.mark.asyncio +async def test_process_message_keeps_final_when_message_tool_sent_other_target( + tmp_path: Path, +) -> None: + bus = MessageBus() + loop = AgentLoop( + bus=bus, + provider=_MessageThenFinalProvider(tool_channel="cli", tool_chat_id="chat-b"), + workspace=tmp_path, + ) + + response = await loop._process_message( + InboundMessage(channel="cli", sender_id="u1", chat_id="chat-a", content="hello") + ) + assert response is not None + assert response.channel == "cli" + assert response.chat_id == "chat-a" + assert response.content == "final-message" + + outbound = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0) + assert outbound.channel == "cli" + assert outbound.chat_id == "chat-b" + assert outbound.content == "tool-message" diff --git a/tests/test_agent_loop_safety.py b/tests/test_agent_loop_safety.py index a53a4a3427..5860a44099 100644 --- a/tests/test_agent_loop_safety.py +++ b/tests/test_agent_loop_safety.py @@ -206,7 +206,7 @@ async def test_agent_loop_message_tool_outbound_is_safety_filtered(tmp_path: Pat channel="cli", chat_id="chat-99", ) - assert result.startswith("safe:") + assert result == "" outbound = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0) assert outbound.channel == "cli" assert outbound.chat_id == "chat-99" From b2b371b290520fa217bd0ce41c28770f1cdf4df2 Mon Sep 17 00:00:00 2001 From: bitifirefly Date: Sun, 22 Mar 2026 03:05:35 +0000 Subject: [PATCH 2/2] fix(loop): guard /new archival with per-session consolidation lock --- opencane/agent/loop.py | 81 +++++++++++++---- tests/test_agent_loop_background_archival.py | 91 +++++++++++++++++++- 2 files changed, 152 insertions(+), 20 deletions(-) diff --git a/opencane/agent/loop.py b/opencane/agent/loop.py index e3961e8952..2a74fc5931 100644 --- a/opencane/agent/loop.py +++ b/opencane/agent/loop.py @@ -124,6 +124,7 @@ def __init__( self._mcp_connected = False self._mcp_connecting = False self._consolidating: set[str] = set() # Session keys with consolidation in progress + self._consolidation_locks: dict[str, asyncio.Lock] = {} self._background_tasks: list[asyncio.Task[Any]] = [] self._register_default_tools() @@ -407,15 +408,35 @@ def _schedule_consolidation(self, session: Session, *, archive_all: bool = False if session.key in self._consolidating: return self._consolidating.add(session.key) + lock = self._get_consolidation_lock(session.key) async def _run() -> None: try: - await self._consolidate_memory(session, archive_all=archive_all) + async with lock: + await self._consolidate_memory(session, archive_all=archive_all) finally: self._consolidating.discard(session.key) + self._prune_consolidation_lock(session.key) self._track_background_task(asyncio.create_task(_run())) + def _get_consolidation_lock(self, session_key: str) -> asyncio.Lock: + lock = self._consolidation_locks.get(session_key) + if lock is None: + lock = asyncio.Lock() + self._consolidation_locks[session_key] = lock + return lock + + def _prune_consolidation_lock(self, session_key: str) -> None: + """Drop unused per-session lock entries to avoid unbounded growth.""" + lock = self._consolidation_locks.get(session_key) + if lock is None: + return + waiters = getattr(lock, "_waiters", None) + if lock.locked() or bool(waiters): + return + self._consolidation_locks.pop(session_key, None) + def _track_background_task(self, task: asyncio.Task[Any]) -> None: """Track a background task so shutdown can drain in-flight work.""" self._background_tasks.append(task) @@ -693,17 +714,39 @@ async def _process_message( # Handle slash commands cmd = msg.content.strip().lower() if cmd == "/new": - # Capture messages before clearing (avoid race condition with background task) - messages_to_archive = session.messages.copy() - session.clear() - self.sessions.save(session) - self.sessions.invalidate(session.key) - - temp_session = Session(key=session.key) - temp_session.messages = messages_to_archive - self._schedule_consolidation(temp_session, archive_all=True) - return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, - content="New session started. Memory consolidation in progress.") + lock = self._get_consolidation_lock(session.key) + messages_to_archive: list[dict[str, Any]] = [] + archive_succeeded = True + try: + async with lock: + messages_to_archive = session.messages[session.last_consolidated:].copy() + if messages_to_archive: + temp_session = Session(key=session.key) + temp_session.messages = messages_to_archive + archive_succeeded = await self._consolidate_memory( + temp_session, + archive_all=True, + ) + if not archive_succeeded: + return OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content=( + "Could not start a new session because memory archival failed. " + "Please try again." + ), + ) + session.clear() + self.sessions.save(session) + self.sessions.invalidate(session.key) + finally: + self._prune_consolidation_lock(session.key) + + return OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content="New session started. Memory archived.", + ) if cmd == "/help": return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, content="🦯 OpenCane commands:\n/new — Start a new conversation\n/help — Show available commands") @@ -866,7 +909,7 @@ async def _process_system_message(self, msg: InboundMessage) -> OutboundMessage content=final_content ) - async def _consolidate_memory(self, session, archive_all: bool = False) -> None: + async def _consolidate_memory(self, session, archive_all: bool = False) -> bool: """Consolidate old messages into MEMORY.md + HISTORY.md. Args: @@ -883,16 +926,16 @@ async def _consolidate_memory(self, session, archive_all: bool = False) -> None: keep_count = self.memory_window // 2 if len(session.messages) <= keep_count: logger.debug(f"Session {session.key}: No consolidation needed (messages={len(session.messages)}, keep={keep_count})") - return + return True messages_to_process = len(session.messages) - session.last_consolidated if messages_to_process <= 0: logger.debug(f"Session {session.key}: No new messages to consolidate (last_consolidated={session.last_consolidated}, total={len(session.messages)})") - return + return True old_messages = session.messages[session.last_consolidated:-keep_count] if not old_messages: - return + return True logger.info(f"Memory consolidation started: {len(session.messages)} total, {len(old_messages)} new to consolidate, {keep_count} keep") lines = [] @@ -937,13 +980,13 @@ async def _consolidate_memory(self, session, archive_all: bool = False) -> None: text = (response.content or "").strip() if not text: logger.warning("Memory consolidation: LLM returned empty response, skipping") - return + return False if text.startswith("```"): text = text.split("\n", 1)[-1].rsplit("```", 1)[0].strip() result = json_repair.loads(text) if not isinstance(result, dict): logger.warning(f"Memory consolidation: unexpected response type, skipping. Response: {text[:200]}") - return + return False if entry := result.get("history_entry"): if not isinstance(entry, str): @@ -960,8 +1003,10 @@ async def _consolidate_memory(self, session, archive_all: bool = False) -> None: else: session.last_consolidated = len(session.messages) - keep_count logger.info(f"Memory consolidation done: {len(session.messages)} messages, last_consolidated={session.last_consolidated}") + return True except Exception as e: logger.error(f"Memory consolidation failed: {e}") + return False async def process_direct( self, diff --git a/tests/test_agent_loop_background_archival.py b/tests/test_agent_loop_background_archival.py index 19c480ac4c..a68b01f933 100644 --- a/tests/test_agent_loop_background_archival.py +++ b/tests/test_agent_loop_background_archival.py @@ -50,17 +50,104 @@ async def _fake_consolidate(target_session, archive_all: bool = False): # type: del target_session, archive_all await asyncio.sleep(0.05) archived.set() + return True monkeypatch.setattr(loop, "_consolidate_memory", _fake_consolidate) response = await loop._process_message( - InboundMessage(channel="cli", sender_id="u1", chat_id="chat-archive", content="/new") + InboundMessage(channel="cli", sender_id="u1", chat_id="chat-archive", content="hello") ) assert response is not None - assert "new session started" in response.content.lower() + assert response.content == "ok" assert not archived.is_set() await loop.close_mcp() assert archived.is_set() assert loop._background_tasks == [] + +@pytest.mark.asyncio +async def test_new_keeps_session_when_archival_fails( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + loop = AgentLoop( + bus=MessageBus(), + provider=_Provider(), + workspace=tmp_path, + memory_window=2, + ) + + session = loop.sessions.get_or_create("cli:chat-archive") + session.add_message("user", "u1") + session.add_message("assistant", "a1") + loop.sessions.save(session) + + async def _fake_consolidate(target_session, archive_all: bool = False): # type: ignore[no-untyped-def] + del target_session, archive_all + return False + + monkeypatch.setattr(loop, "_consolidate_memory", _fake_consolidate) + + response = await loop._process_message( + InboundMessage(channel="cli", sender_id="u1", chat_id="chat-archive", content="/new") + ) + assert response is not None + assert "could not start a new session" in response.content.lower() + + session_after = loop.sessions.get_or_create("cli:chat-archive") + assert len(session_after.messages) == 2 + + +@pytest.mark.asyncio +async def test_new_waits_for_inflight_consolidation_and_archives_tail_only( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + loop = AgentLoop( + bus=MessageBus(), + provider=_Provider(), + workspace=tmp_path, + memory_window=4, + ) + + session = loop.sessions.get_or_create("cli:chat-archive") + for i in range(6): + session.add_message("user", f"user-{i}") + session.add_message("assistant", f"assistant-{i}") + loop.sessions.save(session) + + started = asyncio.Event() + release = asyncio.Event() + archived_count = -1 + + async def _fake_consolidate(target_session, archive_all: bool = False): # type: ignore[no-untyped-def] + nonlocal archived_count + if archive_all: + archived_count = len(target_session.messages) + return True + started.set() + await release.wait() + target_session.last_consolidated = len(target_session.messages) - 2 + return True + + monkeypatch.setattr(loop, "_consolidate_memory", _fake_consolidate) + + first = await loop._process_message( + InboundMessage(channel="cli", sender_id="u1", chat_id="chat-archive", content="hello") + ) + assert first is not None + assert first.content == "ok" + await started.wait() + + pending_new = asyncio.create_task( + loop._process_message( + InboundMessage(channel="cli", sender_id="u1", chat_id="chat-archive", content="/new") + ) + ) + await asyncio.sleep(0.02) + assert not pending_new.done() + + release.set() + response = await pending_new + assert response is not None + assert "new session started" in response.content.lower() + assert archived_count == 2