Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 159 additions & 6 deletions src/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import subprocess
import sys
import uuid
from datetime import datetime
from pathlib import Path
from typing import TextIO

Expand Down Expand Up @@ -65,6 +66,19 @@ def _run_real(
write_text(prompt_path, prompt)
session_id = self._resolve_stage_session_id(paths, stage, continue_session)
command = self._build_cli_command(prompt_path, session_id, resume=continue_session)
self._write_attempt_state(
paths,
stage,
attempt_no,
{
"status": "starting",
"mode": "resume" if continue_session else "start",
"session_id": session_id,
"prompt_path": str(prompt_path),
"command": command,
"started_at": self._now(),
},
)

append_jsonl(
paths.logs_raw,
Expand All @@ -80,7 +94,7 @@ def _run_real(
},
)

exit_code, stdout_text, stderr_text, observed_session_id = self._run_streaming_command(
exit_code, stdout_text, stderr_text, observed_session_id, stream_meta = self._run_streaming_command(
command=command,
cwd=paths.run_root,
stage=stage,
Expand Down Expand Up @@ -112,7 +126,8 @@ def _run_real(
}
},
)
exit_code, stdout_text, stderr_text, observed_session_id = self._run_streaming_command(
self._mark_session_broken(paths, stage, session_id, reason="resume_failure")
exit_code, stdout_text, stderr_text, observed_session_id, stream_meta = self._run_streaming_command(
command=fallback_command,
cwd=paths.run_root,
stage=stage,
Expand All @@ -125,6 +140,34 @@ def _run_real(
effective_session_id = observed_session_id or session_id
self._persist_stage_session_id(paths, stage, effective_session_id)
success = exit_code == 0 and stage_file.exists()
self._update_session_state(
paths,
stage,
effective_session_id,
{
"broken": not success and continue_session,
"last_exit_code": exit_code,
"last_mode": "resume" if continue_session else "start",
"updated_at": self._now(),
},
)
self._write_attempt_state(
paths,
stage,
attempt_no,
{
"status": "completed" if success else "failed",
"mode": "resume" if continue_session else "start",
"session_id": effective_session_id,
"prompt_path": str(prompt_path),
"command": command,
"exit_code": exit_code,
"stdout_excerpt": stdout_text[-2000:] if stdout_text else "",
"stderr_excerpt": stderr_text[-1000:] if stderr_text else "",
"stream_meta": stream_meta,
"finished_at": self._now(),
},
)

return OperatorResult(
success=success,
Expand Down Expand Up @@ -242,7 +285,21 @@ def repair_stage_summary(
},
)

exit_code, stdout_text, stderr_text, observed_session_id = self._run_streaming_command(
self._write_attempt_state(
paths,
stage,
attempt_no,
{
"status": "repair_starting",
"mode": "repair",
"session_id": session_id,
"prompt_path": str(recovery_prompt_path),
"command": command,
"started_at": self._now(),
},
)

exit_code, stdout_text, stderr_text, observed_session_id, stream_meta = self._run_streaming_command(
command=command,
cwd=paths.run_root,
stage=stage,
Expand Down Expand Up @@ -277,7 +334,8 @@ def repair_stage_summary(
}
},
)
exit_code, stdout_text, stderr_text, observed_session_id = self._run_streaming_command(
self._mark_session_broken(paths, stage, session_id, reason="repair_resume_failure")
exit_code, stdout_text, stderr_text, observed_session_id, stream_meta = self._run_streaming_command(
command=fallback_command,
cwd=paths.run_root,
stage=stage,
Expand All @@ -289,6 +347,34 @@ def repair_stage_summary(

effective_session_id = observed_session_id or session_id
self._persist_stage_session_id(paths, stage, effective_session_id)
self._update_session_state(
paths,
stage,
effective_session_id,
{
"broken": exit_code != 0 and not stage_file.exists(),
"last_exit_code": exit_code,
"last_mode": "repair",
"updated_at": self._now(),
},
)
self._write_attempt_state(
paths,
stage,
attempt_no,
{
"status": "repair_completed" if exit_code == 0 and stage_file.exists() else "repair_failed",
"mode": "repair",
"session_id": effective_session_id,
"prompt_path": str(recovery_prompt_path),
"command": command,
"exit_code": exit_code,
"stdout_excerpt": stdout_text[-2000:] if stdout_text else "",
"stderr_excerpt": stderr_text[-1000:] if stderr_text else "",
"stream_meta": stream_meta,
"finished_at": self._now(),
},
)

return OperatorResult(
success=exit_code == 0 and stage_file.exists(),
Expand All @@ -307,7 +393,7 @@ def _run_streaming_command(
attempt_no: int,
paths: RunPaths,
mode: str,
) -> tuple[int, str, str, str | None]:
) -> tuple[int, str, str, str | None, dict[str, object]]:
process = subprocess.Popen(
command,
cwd=str(cwd),
Expand All @@ -325,6 +411,7 @@ def _run_streaming_command(
non_json_lines: list[str] = []
ended_with_newline = True
observed_session_id: str | None = None
malformed_json_count = 0

try:
for raw_line in process.stdout:
Expand All @@ -341,6 +428,7 @@ def _run_streaming_command(
try:
payload = json.loads(stripped)
except json.JSONDecodeError:
malformed_json_count += 1
append_jsonl(
paths.logs_raw,
{
Expand Down Expand Up @@ -380,7 +468,12 @@ def _run_streaming_command(
non_json_lines=non_json_lines,
raw_lines=raw_lines,
)
return exit_code, stdout_text, "", observed_session_id
return exit_code, stdout_text, "", observed_session_id, {
"raw_line_count": len(raw_lines),
"non_json_line_count": len(non_json_lines),
"malformed_json_count": malformed_json_count,
"observed_session_id": observed_session_id,
}

def _compose_stdout_text(
self,
Expand Down Expand Up @@ -484,6 +577,14 @@ def _resolve_stage_session_id(
continue_session: bool,
allow_create: bool = True,
) -> str | None:
session_state_path = paths.stage_session_state_file(stage)
if session_state_path.exists():
payload = json.loads(read_text(session_state_path))
session_id = str(payload.get("session_id") or "").strip()
broken = bool(payload.get("broken", False))
if session_id and not broken:
return session_id

session_file = paths.stage_session_file(stage)
if session_file.exists():
session_id = read_text(session_file).strip()
Expand All @@ -499,6 +600,15 @@ def _persist_stage_session_id(self, paths: RunPaths, stage: StageSpec, session_i
if not session_id:
return
write_text(paths.stage_session_file(stage), session_id)
self._update_session_state(
paths,
stage,
session_id,
{
"broken": False,
"updated_at": self._now(),
},
)

def _extract_session_id(self, payload: dict[str, object]) -> str | None:
value = payload.get("session_id")
Expand Down Expand Up @@ -542,3 +652,46 @@ def _build_cli_command(
def _looks_like_resume_failure(self, stdout_text: str, stderr_text: str) -> bool:
combined = "\n".join(part for part in [stdout_text, stderr_text] if part).lower()
return "no conversation found with session id" in combined or "resume" in combined and "not found" in combined

def _write_attempt_state(
self,
paths: RunPaths,
stage: StageSpec,
attempt_no: int,
payload: dict[str, object],
) -> None:
write_text(paths.stage_attempt_state_file(stage, attempt_no), json.dumps(payload, indent=2, ensure_ascii=True))

def _update_session_state(
self,
paths: RunPaths,
stage: StageSpec,
session_id: str | None,
changes: dict[str, object],
) -> None:
path = paths.stage_session_state_file(stage)
payload: dict[str, object] = {}
if path.exists():
try:
payload = json.loads(read_text(path))
except json.JSONDecodeError:
payload = {}
payload.update(changes)
if session_id:
payload["session_id"] = session_id
write_text(path, json.dumps(payload, indent=2, ensure_ascii=True))

def _mark_session_broken(self, paths: RunPaths, stage: StageSpec, session_id: str | None, reason: str) -> None:
self._update_session_state(
paths,
stage,
session_id,
{
"broken": True,
"broken_reason": reason,
"updated_at": self._now(),
},
)

def _now(self) -> str:
return datetime.now().isoformat(timespec="seconds")
6 changes: 6 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def stage_tmp_file(self, stage: StageSpec) -> Path:
def stage_session_file(self, stage: StageSpec) -> Path:
return self.operator_state_dir / f"{stage.slug}.session_id.txt"

def stage_session_state_file(self, stage: StageSpec) -> Path:
return self.operator_state_dir / f"{stage.slug}.session.json"

def stage_attempt_state_file(self, stage: StageSpec, attempt_no: int) -> Path:
return self.operator_state_dir / f"{stage.slug}.attempt_{attempt_no:02d}.json"


@dataclass(frozen=True)
class OperatorResult:
Expand Down
115 changes: 115 additions & 0 deletions tests/test_operator_recovery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from __future__ import annotations

import io
import json
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch

from src.operator import ClaudeOperator
from src.utils import STAGES, build_run_paths, ensure_run_layout, initialize_memory, write_text


class OperatorRecoveryTests(unittest.TestCase):
def test_resume_failure_falls_back_to_new_session_and_records_attempt_state(self) -> None:
with tempfile.TemporaryDirectory() as tmp_dir:
run_root = Path(tmp_dir) / "run"
paths = build_run_paths(run_root)
ensure_run_layout(paths)
write_text(paths.user_input, "Operator recovery goal")
initialize_memory(paths, "Operator recovery goal")

operator = ClaudeOperator(fake_mode=False, output_stream=io.StringIO())
stage = STAGES[0]
old_session_id = "old-session-id"
operator._persist_stage_session_id(paths, stage, old_session_id)

call_count = {"value": 0}

def fake_stream(*args, **kwargs):
call_count["value"] += 1
if call_count["value"] == 1:
return (
1,
"No conversation found with session id old-session-id",
"",
None,
{"raw_line_count": 1, "non_json_line_count": 1, "malformed_json_count": 1},
)

stage_tmp_path = paths.stage_tmp_file(stage)
write_text(
stage_tmp_path,
(
"# Stage 01: Literature Survey\n\n"
"## Objective\nRecovered.\n\n"
"## Previously Approved Stage Summaries\n_None yet._\n\n"
"## What I Did\nRecovered session.\n\n"
"## Key Results\nRecovered stage summary.\n\n"
"## Files Produced\n- `stages/01_literature_survey.tmp.md`\n\n"
"## Suggestions for Refinement\n"
"1. Refine one.\n2. Refine two.\n3. Refine three.\n\n"
"## Your Options\n"
"1. Use suggestion 1\n2. Use suggestion 2\n3. Use suggestion 3\n4. Refine with your own feedback\n5. Approve and continue\n6. Abort\n"
),
)
return (
0,
"Recovered successfully.",
"",
"new-session-id",
{"raw_line_count": 2, "non_json_line_count": 0, "malformed_json_count": 0},
)

with patch("src.operator.shutil.which", return_value="/usr/bin/claude"), patch.object(
operator,
"_run_streaming_command",
side_effect=fake_stream,
):
result = operator._run_real(
stage=stage,
prompt="prompt",
paths=paths,
attempt_no=1,
continue_session=True,
)

self.assertTrue(result.success)
self.assertEqual(result.session_id, "new-session-id")
self.assertEqual(call_count["value"], 2)
self.assertEqual(paths.stage_session_file(stage).read_text(encoding="utf-8").strip(), "new-session-id")

attempt_state = json.loads(paths.stage_attempt_state_file(stage, 1).read_text(encoding="utf-8"))
self.assertEqual(attempt_state["status"], "completed")
self.assertEqual(attempt_state["mode"], "resume")
self.assertEqual(attempt_state["session_id"], "new-session-id")

def test_broken_session_is_not_reused(self) -> None:
with tempfile.TemporaryDirectory() as tmp_dir:
run_root = Path(tmp_dir) / "run"
paths = build_run_paths(run_root)
ensure_run_layout(paths)
write_text(paths.user_input, "Broken session test")
initialize_memory(paths, "Broken session test")

operator = ClaudeOperator(fake_mode=False, output_stream=io.StringIO())
stage = STAGES[0]
write_text(
paths.stage_session_state_file(stage),
json.dumps(
{
"session_id": "broken-session-id",
"broken": True,
},
indent=2,
),
)

resolved = operator._resolve_stage_session_id(paths, stage, continue_session=False)
self.assertIsNotNone(resolved)
self.assertNotEqual(resolved, "broken-session-id")


if __name__ == "__main__":
unittest.main()