Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ def parse_args() -> argparse.Namespace:
"--redo-stage",
help="When resuming a run, restart from this stage slug or stage number (for example '06_analysis' or '6').",
)
parser.add_argument(
"--rollback-stage",
help="When resuming a run, roll back to this stage and mark downstream stages stale before continuing.",
)
return parser.parse_args()


Expand Down Expand Up @@ -120,6 +124,9 @@ def main() -> int:

if args.resume_run:
start_stage = resolve_stage(args.redo_stage)
rollback_stage = resolve_stage(args.rollback_stage)
if start_stage is not None and rollback_stage is not None:
raise ValueError("--redo-stage and --rollback-stage are mutually exclusive.")
run_root = resolve_resume_run(runs_dir, args.resume_run)
paths = build_run_paths(run_root)
existing_config = load_run_config(paths)
Expand All @@ -133,7 +140,8 @@ def main() -> int:
operator=operator,
ui=ui,
)
return 0 if manager.resume_run(run_root, start_stage=start_stage, venue=venue) else 1
manager.resume_run(run_root, start_stage=start_stage or rollback_stage, venue=venue, rollback_stage=rollback_stage)
return 0

model = args.model or "sonnet"
venue = resolve_venue_key(args.venue or DEFAULT_VENUE)
Expand Down
27 changes: 27 additions & 0 deletions src/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
mark_stage_failed_manifest,
mark_stage_human_review_manifest,
mark_stage_running_manifest,
rebuild_memory_from_manifest,
rollback_to_stage,
sync_stage_session_id,
update_manifest_run_status,
)
Expand Down Expand Up @@ -77,6 +79,7 @@ def resume_run(
self,
run_root: Path,
start_stage: StageSpec | None = None,
rollback_stage: StageSpec | None = None,
venue: str | None = None,
) -> bool:
paths = build_run_paths(run_root)
Expand All @@ -88,11 +91,17 @@ def resume_run(
if not paths.memory.exists():
raise FileNotFoundError(f"Missing memory.md in run: {run_root}")

if rollback_stage is not None:
self._print(self._format_rollback_preview(paths, rollback_stage))
rollback_to_stage(paths, rollback_stage)
start_stage = rollback_stage

append_log_entry(
paths.logs,
"run_resume",
f"Resumed run at: {paths.run_root}"
+ (f"\nRequested start stage: {start_stage.stage_title}" if start_stage else "")
+ (f"\nRequested rollback stage: {rollback_stage.stage_title}" if rollback_stage else "")
+ f"\nVenue: {config['venue']}",
)
self.ui.show_run_started(
Expand Down Expand Up @@ -526,6 +535,24 @@ def _materialize_missing_stage_draft(
)
return type("FallbackResult", (), {"stage_file_path": draft_path, "stdout": fallback_text, "stderr": ""})()

def _format_rollback_preview(self, paths: RunPaths, rollback_stage: StageSpec) -> str:
manifest = ensure_run_manifest(paths)
stale_candidates = [
entry.slug
for entry in manifest.stages
if entry.number > rollback_stage.number and (entry.approved or entry.status != "pending")
]
lines = [
f"Rolling back to {rollback_stage.stage_title}.",
f"Stage {rollback_stage.slug} will be marked pending/dirty.",
]
if stale_candidates:
lines.append("Downstream stages that will be marked stale:")
lines.extend(f"- {slug}" for slug in stale_candidates)
else:
lines.append("No downstream stages currently need invalidation.")
return "\n".join(lines)

def describe_run_status(self, run_root: Path) -> str:
paths = build_run_paths(run_root)
ensure_run_layout(paths)
Expand Down
101 changes: 100 additions & 1 deletion src/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@ class StageManifestEntry:
status: str = "pending"
approved: bool = False
dirty: bool = False
stale: bool = False
attempt_count: int = 0
session_id: str | None = None
final_stage_path: str = ""
draft_stage_path: str = ""
artifact_paths: list[str] = field(default_factory=list)
last_error: str | None = None
invalidated_reason: str | None = None
invalidated_by_stage: str | None = None
updated_at: str = ""
approved_at: str | None = None

Expand All @@ -32,12 +35,15 @@ def to_dict(self) -> dict[str, object]:
"status": self.status,
"approved": self.approved,
"dirty": self.dirty,
"stale": self.stale,
"attempt_count": self.attempt_count,
"session_id": self.session_id,
"final_stage_path": self.final_stage_path,
"draft_stage_path": self.draft_stage_path,
"artifact_paths": list(self.artifact_paths),
"last_error": self.last_error,
"invalidated_reason": self.invalidated_reason,
"invalidated_by_stage": self.invalidated_by_stage,
"updated_at": self.updated_at,
"approved_at": self.approved_at,
}
Expand All @@ -51,12 +57,15 @@ def from_dict(cls, payload: dict[str, object]) -> "StageManifestEntry":
status=str(payload.get("status") or "pending"),
approved=bool(payload.get("approved", False)),
dirty=bool(payload.get("dirty", False)),
stale=bool(payload.get("stale", False)),
attempt_count=int(payload.get("attempt_count") or 0),
session_id=str(payload["session_id"]) if payload.get("session_id") is not None else None,
final_stage_path=str(payload.get("final_stage_path") or ""),
draft_stage_path=str(payload.get("draft_stage_path") or ""),
artifact_paths=[str(item) for item in payload.get("artifact_paths", []) if str(item).strip()],
last_error=str(payload["last_error"]) if payload.get("last_error") is not None else None,
invalidated_reason=str(payload["invalidated_reason"]) if payload.get("invalidated_reason") is not None else None,
invalidated_by_stage=str(payload["invalidated_by_stage"]) if payload.get("invalidated_by_stage") is not None else None,
updated_at=str(payload.get("updated_at") or ""),
approved_at=str(payload["approved_at"]) if payload.get("approved_at") is not None else None,
)
Expand Down Expand Up @@ -165,9 +174,17 @@ def format_manifest_status(manifest: RunManifest) -> str:
"Stages:",
]
for entry in manifest.stages:
flags = []
if entry.approved:
flags.append("approved")
if entry.dirty:
flags.append("dirty")
if entry.stale:
flags.append("stale")
suffix = f" [{' '.join(flags)}]" if flags else ""
lines.append(
f"- {entry.slug}: status={entry.status}, approved={entry.approved}, attempts={entry.attempt_count}, "
f"session_id={entry.session_id or 'none'}"
f"session_id={entry.session_id or 'none'}{suffix}"
)
return "\n".join(lines)

Expand Down Expand Up @@ -237,6 +254,7 @@ def mark_stage_running_manifest(paths: RunPaths, stage: StageSpec, attempt_no: i
status="running",
approved=False,
dirty=False,
stale=False,
attempt_count=attempt_no,
last_error=None,
)
Expand All @@ -260,6 +278,7 @@ def mark_stage_human_review_manifest(
status="human_review",
approved=False,
dirty=False,
stale=False,
attempt_count=attempt_no,
artifact_paths=artifact_paths,
)
Expand All @@ -283,6 +302,7 @@ def mark_stage_approved_manifest(
status="approved",
approved=True,
dirty=False,
stale=False,
attempt_count=attempt_no,
artifact_paths=artifact_paths,
approved_at=_now(),
Expand All @@ -303,9 +323,88 @@ def mark_stage_failed_manifest(paths: RunPaths, stage: StageSpec, error: str) ->
status="failed",
approved=False,
dirty=True,
stale=False,
last_error=error,
)


def sync_stage_session_id(paths: RunPaths, stage: StageSpec, session_id: str | None) -> RunManifest:
return update_stage_entry(paths, stage, session_id=session_id)


def rollback_to_stage(paths: RunPaths, rollback_stage: StageSpec, reason: str | None = None) -> RunManifest:
manifest = ensure_run_manifest(paths)
invalidated_reason = reason or f"Rolled back to {rollback_stage.stage_title}"
updated_stages: list[StageManifestEntry] = []

for entry in manifest.stages:
payload = entry.to_dict()
if entry.number < rollback_stage.number:
updated_stages.append(entry)
continue
if entry.number == rollback_stage.number:
payload.update(
{
"status": "pending",
"approved": False,
"dirty": True,
"stale": False,
"approved_at": None,
"invalidated_reason": invalidated_reason,
"invalidated_by_stage": rollback_stage.slug,
}
)
else:
payload.update(
{
"status": "stale",
"approved": False,
"dirty": True,
"stale": True,
"approved_at": None,
"invalidated_reason": invalidated_reason,
"invalidated_by_stage": rollback_stage.slug,
}
)
payload["updated_at"] = _now()
updated_stages.append(StageManifestEntry.from_dict(payload))

updated = RunManifest(
run_id=manifest.run_id,
created_at=manifest.created_at,
updated_at=_now(),
run_status="pending",
last_event="run.rolled_back",
current_stage_slug=rollback_stage.slug,
last_error=None,
completed_at=None,
stages=updated_stages,
)
save_run_manifest(paths.run_manifest, updated)
rebuild_memory_from_manifest(paths, updated)
return updated


def rebuild_memory_from_manifest(paths: RunPaths, manifest: RunManifest | None = None) -> None:
manifest = manifest or ensure_run_manifest(paths)
goal_text = paths.user_input.read_text(encoding="utf-8").strip()
entries: list[str] = []
from .utils import read_text, render_approved_stage_entry, write_text

for stage in STAGES:
entry = next(item for item in manifest.stages if item.slug == stage.slug)
if not entry.approved:
continue
stage_path = paths.stage_file(stage)
if not stage_path.exists():
continue
entries.append(render_approved_stage_entry(stage, read_text(stage_path)))

body = (
"# Approved Run Memory\n\n"
"## Original User Goal\n"
f"{goal_text}\n\n"
"## Approved Stage Summaries\n\n"
)
body += "\n\n".join(entries) + "\n" if entries else "_None yet._\n"
write_text(paths.memory, body)
59 changes: 59 additions & 0 deletions tests/test_stage_rollback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations

import tempfile
import unittest
from pathlib import Path

from src.manifest import (
initialize_run_manifest,
load_run_manifest,
mark_stage_approved_manifest,
rollback_to_stage,
)
from src.utils import STAGES, build_run_paths, ensure_run_config, ensure_run_layout, initialize_memory, write_text


class StageRollbackTests(unittest.TestCase):
def test_rollback_marks_downstream_stale_and_rebuilds_memory(self) -> None:
with tempfile.TemporaryDirectory() as tmp_dir:
runs_dir = Path(tmp_dir) / "runs"
run_root = runs_dir / "run"
paths = build_run_paths(run_root)
ensure_run_layout(paths)
write_text(paths.user_input, "Rollback validation workflow.")
initialize_memory(paths, "Rollback validation workflow.")
ensure_run_config(paths, model="sonnet", venue="neurips_2025")

initialize_run_manifest(paths)
for stage in STAGES[:5]:
write_text(
paths.stage_file(stage),
(
f"# Stage {stage.number:02d}: {stage.display_name}\n\n"
"## Objective\nDone.\n\n"
"## Previously Approved Stage Summaries\n_None yet._\n\n"
"## What I Did\nDid work.\n\n"
"## Key Results\nKey result.\n\n"
"## Files Produced\n- `stages/example.md`\n\n"
"## Suggestions for Refinement\n"
"1. Refine one.\n2. Refine two.\n3. Refine three.\n\n"
"## Your Options\n"
"1. Use suggestion 1\n2. Use suggestion 2\n3. Use suggestion 3\n4. Refine with your own feedback\n5. Approve and continue\n6. Abort\n"
),
)
mark_stage_approved_manifest(paths, stage, attempt_no=1, artifact_paths=[str(paths.stage_file(stage))])

rollback_to_stage(paths, STAGES[2], reason="Redo study design")
manifest = load_run_manifest(paths.run_manifest)
assert manifest is not None

by_slug = {entry.slug: entry for entry in manifest.stages}
self.assertEqual(by_slug["03_study_design"].status, "pending")
self.assertTrue(by_slug["03_study_design"].dirty)
self.assertEqual(by_slug["04_implementation"].status, "stale")
self.assertTrue(by_slug["04_implementation"].stale)
self.assertEqual(by_slug["05_experimentation"].status, "stale")


if __name__ == "__main__":
unittest.main()