Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions opencane/agent/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,8 @@ async def _run_agent_loop(
final_content = None
tools_used: list[str] = []
tool_call_counts: dict[str, int] = {}
text_only_retried = False
interim_content: str | None = None

while iteration < self.max_iterations:
iteration += 1
Expand Down Expand Up @@ -584,6 +586,25 @@ async def _run_agent_loop(
messages.append({"role": "user", "content": "Reflect on the results and decide next steps."})
else:
final_content = response.content
interim_text = str(final_content or "").strip()
# Some providers return an interim text-only answer before issuing tool calls.
# Give one extra iteration only when tool calling is available.
if interim_text and not tools_used and not text_only_retried and bool(tool_defs):
text_only_retried = True
interim_content = interim_text
logger.debug(
"Interim text response (no tools used yet), retrying once: {}",
_shorten(interim_text, 120),
)
messages = self.context.add_assistant_message(
messages,
response.content,
reasoning_content=response.reasoning_content,
)
final_content = None
continue
if not interim_text and interim_content and not tools_used:
final_content = interim_content
if require_tool_use and not tools_used:
final_content = "NO_TOOL_USED"
break
Expand Down
8 changes: 4 additions & 4 deletions opencane/agent/subagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ async def _run_subagent(
# Build subagent tools (no message tool, no spawn tool)
tools = ToolRegistry()
allowed_dir = self.workspace if self.restrict_to_workspace else None
tools.register(ReadFileTool(allowed_dir=allowed_dir))
tools.register(WriteFileTool(allowed_dir=allowed_dir))
tools.register(EditFileTool(allowed_dir=allowed_dir))
tools.register(ListDirTool(allowed_dir=allowed_dir))
tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
tools.register(WriteFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
tools.register(EditFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir))
tools.register(ExecTool(
working_dir=str(self.workspace),
timeout=self.exec_config.timeout,
Expand Down
20 changes: 17 additions & 3 deletions opencane/agent/tools/cron.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def parameters(self) -> dict[str, Any]:
"type": "string",
"description": "Cron expression like '0 9 * * *' (for scheduled tasks)"
},
"tz": {
"type": "string",
"description": "IANA timezone for cron expressions (e.g. 'America/Vancouver')"
},
"at": {
"type": "string",
"description": "ISO datetime for one-time execution (e.g. '2026-02-12T10:30:00')"
Expand All @@ -69,30 +73,40 @@ async def execute(
message: str = "",
every_seconds: int | None = None,
cron_expr: str | None = None,
tz: str | None = None,
at: str | None = None,
job_id: str | None = None,
**kwargs: Any
) -> str:
if action == "add":
return self._add_job(message, every_seconds, cron_expr, at)
return self._add_job(message, every_seconds, cron_expr, tz, at)
elif action == "list":
return self._list_jobs()
elif action == "remove":
return self._remove_job(job_id)
return f"Unknown action: {action}"

def _add_job(self, message: str, every_seconds: int | None, cron_expr: str | None, at: str | None) -> str:
def _add_job(
self,
message: str,
every_seconds: int | None,
cron_expr: str | None,
tz: str | None,
at: str | None,
) -> str:
if not message:
return "Error: message is required for add"
if not self._channel or not self._chat_id:
return "Error: no session context (channel/chat_id)"
if tz and not cron_expr:
return "Error: tz can only be used with cron_expr"

# Build schedule
delete_after = False
if every_seconds:
schedule = CronSchedule(kind="every", every_ms=every_seconds * 1000)
elif cron_expr:
schedule = CronSchedule(kind="cron", expr=cron_expr)
schedule = CronSchedule(kind="cron", expr=cron_expr, tz=tz)
elif at:
from datetime import datetime
dt = datetime.fromisoformat(at)
Expand Down
17 changes: 14 additions & 3 deletions opencane/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,20 +1165,27 @@ def cron_list(
table.add_column("Next Run")

import time
from datetime import datetime as _dt
from zoneinfo import ZoneInfo

for job in jobs:
# Format schedule
if job.schedule.kind == "every":
sched = f"every {(job.schedule.every_ms or 0) // 1000}s"
elif job.schedule.kind == "cron":
sched = job.schedule.expr or ""
sched = f"{job.schedule.expr or ''} ({job.schedule.tz})" if job.schedule.tz else (job.schedule.expr or "")
else:
sched = "one-time"

# Format next run
next_run = ""
if job.state.next_run_at_ms:
next_time = time.strftime("%Y-%m-%d %H:%M", time.localtime(job.state.next_run_at_ms / 1000))
next_run = next_time
ts = job.state.next_run_at_ms / 1000
try:
tzinfo = ZoneInfo(job.schedule.tz) if job.schedule.tz else None
next_run = _dt.fromtimestamp(ts, tzinfo).strftime("%Y-%m-%d %H:%M")
except Exception:
next_run = time.strftime("%Y-%m-%d %H:%M", time.localtime(ts))

status = "[green]enabled[/green]" if job.enabled else "[dim]disabled[/dim]"

Expand All @@ -1204,6 +1211,10 @@ def cron_add(
from opencane.cron.service import CronService
from opencane.cron.types import CronSchedule

if tz and not cron_expr:
console.print("[red]Error: --tz can only be used with --cron[/red]")
raise typer.Exit(1)

# Determine schedule type
if every:
schedule = CronSchedule(kind="every", every_ms=every * 1000)
Expand Down
3 changes: 2 additions & 1 deletion opencane/cron/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def _compute_next_run(schedule: CronSchedule, now_ms: int) -> int | None:
from zoneinfo import ZoneInfo

from croniter import croniter
base_time = time.time()
# Use caller-provided reference time for deterministic scheduling.
base_time = now_ms / 1000
tz = ZoneInfo(schedule.tz) if schedule.tz else datetime.now().astimezone().tzinfo
base_dt = datetime.fromtimestamp(base_time, tz=tz)
cron = croniter(schedule.expr, base_dt)
Expand Down
10 changes: 6 additions & 4 deletions opencane/session/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _load(self, key: str) -> Session | None:
created_at = None
last_consolidated = 0

with open(path) as f:
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
Expand Down Expand Up @@ -220,9 +220,10 @@ def save(self, session: Session) -> None:
"""Save a session to disk."""
path = self._get_session_path(session.key)

with open(path, "w") as f:
with open(path, "w", encoding="utf-8") as f:
metadata_line = {
"_type": "metadata",
"key": session.key,
"created_at": session.created_at.isoformat(),
"updated_at": session.updated_at.isoformat(),
"metadata": session.metadata,
Expand Down Expand Up @@ -250,13 +251,14 @@ def list_sessions(self) -> list[dict[str, Any]]:
for path in self.sessions_dir.glob("*.jsonl"):
try:
# Read just the metadata line
with open(path) as f:
with open(path, encoding="utf-8") as f:
first_line = f.readline().strip()
if first_line:
data = json.loads(first_line)
if data.get("_type") == "metadata":
key = data.get("key") or path.stem.replace("_", ":", 1)
sessions.append({
"key": path.stem.replace("_", ":"),
"key": key,
"created_at": data.get("created_at"),
"updated_at": data.get("updated_at"),
"path": str(path)
Expand Down
190 changes: 190 additions & 0 deletions tests/test_agent_loop_interim_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
from __future__ import annotations

import asyncio
from pathlib import Path
from typing import Any

import pytest

from opencane.agent.loop import AgentLoop
from opencane.bus.queue import MessageBus
from opencane.providers.base import LLMProvider, LLMResponse, ToolCallRequest


class _InterimThenToolProvider(LLMProvider):
def __init__(self) -> None:
super().__init__(api_key=None, api_base=None)
self.calls = 0

async def chat( # type: ignore[override]
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> LLMResponse:
del messages, tools, model, max_tokens, temperature
self.calls += 1
if self.calls == 1:
return LLMResponse(content="Let me check this first.")
if self.calls == 2:
return LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="msg-1",
name="message",
arguments={"content": "Tool progress update"},
)
],
)
return LLMResponse(content="Done after tool call.")

def get_default_model(self) -> str:
return "fake-model"


class _TextOnlyProvider(LLMProvider):
def __init__(self, outputs: list[str]) -> None:
super().__init__(api_key=None, api_base=None)
self.outputs = outputs
self.calls = 0

async def chat( # type: ignore[override]
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> LLMResponse:
del messages, tools, model, max_tokens, temperature
self.calls += 1
index = min(self.calls - 1, len(self.outputs) - 1)
return LLMResponse(content=self.outputs[index])

def get_default_model(self) -> str:
return "fake-model"


class _InterimToolThenEmptyProvider(LLMProvider):
def __init__(self) -> None:
super().__init__(api_key=None, api_base=None)
self.calls = 0

async def chat( # type: ignore[override]
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> LLMResponse:
del messages, tools, model, max_tokens, temperature
self.calls += 1
if self.calls == 1:
return LLMResponse(content="Interim text before tools.")
if self.calls == 2:
return LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="msg-2",
name="message",
arguments={"content": "progress from tool"},
)
],
)
return LLMResponse(content="")

def get_default_model(self) -> str:
return "fake-model"


@pytest.mark.asyncio
async def test_agent_loop_retries_interim_text_then_executes_tool(tmp_path: Path) -> None:
bus = MessageBus()
provider = _InterimThenToolProvider()
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path)

result = await loop.process_direct(
"run tool flow",
session_key="cli:interim-retry-tools",
channel="cli",
chat_id="chat-1",
)

outbound = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
assert outbound.content == "Tool progress update"
assert result == "Done after tool call."
assert provider.calls == 3


@pytest.mark.asyncio
async def test_agent_loop_retries_interim_text_only_once(tmp_path: Path) -> None:
bus = MessageBus()
provider = _TextOnlyProvider(["first interim", "second final", "third should-not-run"])
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path)

result = await loop.process_direct(
"text only",
session_key="cli:interim-retry-once",
channel="cli",
chat_id="chat-2",
)

assert result == "second final"
assert provider.calls == 2


@pytest.mark.asyncio
async def test_agent_loop_skips_interim_retry_without_available_tools(tmp_path: Path) -> None:
bus = MessageBus()
provider = _TextOnlyProvider(["first response", "second should-not-run"])
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path)

result = await loop.process_direct(
"no tools allowed",
session_key="cli:no-tools-retry",
channel="cli",
chat_id="chat-3",
allowed_tool_names=set(),
)

assert result == "first response"
assert provider.calls == 1


@pytest.mark.asyncio
async def test_agent_loop_falls_back_to_interim_when_retry_returns_empty(tmp_path: Path) -> None:
bus = MessageBus()
provider = _TextOnlyProvider(["first interim answer", ""])
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path)

result = await loop.process_direct(
"fallback please",
session_key="cli:interim-fallback",
channel="cli",
chat_id="chat-4",
)

assert result == "first interim answer"
assert provider.calls == 2


@pytest.mark.asyncio
async def test_agent_loop_does_not_use_interim_fallback_after_tool_usage(tmp_path: Path) -> None:
bus = MessageBus()
provider = _InterimToolThenEmptyProvider()
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path)

result = await loop.process_direct(
"tool then empty",
session_key="cli:interim-no-fallback-after-tools",
channel="cli",
chat_id="chat-5",
)

assert result == ""
assert provider.calls == 3
Loading
Loading