Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 93 additions & 2 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,8 @@ def build_continuation_prompt(
"8. Do not leave placeholder text such as [In progress], [Pending], [TODO], [TBD], or similar unfinished markers.\n"
"9. If the existing stage work is partially correct, keep the correct parts and extend them rather than replacing them blindly."
),
"# Current Stage Context",
_build_inline_context(stage, paths),
"# New Feedback",
revision_feedback.strip()
if revision_feedback
Expand All @@ -457,6 +459,52 @@ def build_continuation_prompt(
return "\n\n".join(sections).strip() + "\n"


def _build_inline_context(stage: StageSpec, paths: RunPaths) -> str:
parts: list[str] = []

user_goal = read_text(paths.user_input).strip() if paths.user_input.exists() else ""
if user_goal:
parts.append(f"**User Goal**: {truncate_text(user_goal, 500)}")

for path in [paths.stage_file(stage), paths.stage_tmp_file(stage)]:
if path.exists():
text = read_text(path)
objective = extract_markdown_section(text, "Objective")
if objective:
parts.append(f"**Current Stage Objective**: {truncate_text(objective, 500)}")
break

memory_text = read_text(paths.memory) if paths.memory.exists() else ""
if memory_text:
recent = _extract_recent_stage_summaries(memory_text, stage, max_stages=2)
if recent:
parts.append(f"**Recent Approved Context**:\n{recent}")

return "\n\n".join(parts) if parts else "No additional context available."


def _extract_recent_stage_summaries(
memory_text: str, current_stage: StageSpec, max_stages: int = 2
) -> str:
pattern = re.compile(r"(### Stage \d{2}: .+?)(?=### Stage \d{2}: |\Z)", re.DOTALL)
matches = pattern.findall(memory_text)
relevant = [m for m in matches if not m.startswith(f"### {current_stage.stage_title}")]
recent = relevant[-max_stages:]
summaries: list[str] = []
for match in recent:
lines = match.strip().split("\n")
title = lines[0] if lines else ""
# Memory entries use #### Objective (h4), not ## Objective (h2)
obj_pattern = re.compile(r"^####\s+Objective\s*$\n?(.*?)(?=^####\s|\Z)", re.MULTILINE | re.DOTALL)
obj_match = obj_pattern.search(match)
obj = obj_match.group(1).strip() if obj_match else ""
if obj:
summaries.append(f"{title}\n{truncate_text(obj, 300)}")
else:
summaries.append(title)
return "\n\n".join(summaries)


def truncate_text(text: str, max_chars: int = 12000) -> str:
stripped = text.strip()
if len(stripped) <= max_chars:
Expand Down Expand Up @@ -656,12 +704,55 @@ def render_approved_stage_entry(stage: StageSpec, stage_markdown: str) -> str:
)


def render_compact_stage_entry(
stage: StageSpec,
stage_markdown: str,
max_section_chars: int = 2000,
) -> str:
objective = extract_markdown_section(stage_markdown, "Objective") or "Not provided."
what_i_did = extract_markdown_section(stage_markdown, "What I Did") or "Not provided."
key_results = extract_markdown_section(stage_markdown, "Key Results") or "Not provided."
files_produced = extract_markdown_section(stage_markdown, "Files Produced") or "Not provided."

what_i_did = _truncate_section(what_i_did, max_section_chars)
key_results = _truncate_section(key_results, max_section_chars)

return (
f"### {stage.stage_title}\n\n"
"#### Objective\n"
f"{objective}\n\n"
"#### What I Did\n"
f"{what_i_did}\n\n"
"#### Key Results\n"
f"{key_results}\n\n"
"#### Files Produced\n"
f"{files_produced}"
)


def _truncate_section(text: str, max_chars: int) -> str:
if len(text) <= max_chars:
return text
cutoff = text.rfind("\n\n", 0, max_chars)
if cutoff < max_chars // 2:
cutoff = max_chars
return text[:cutoff].rstrip() + "\n\n...(see workspace files for full details)"


def append_approved_stage_summary(memory_path: Path, stage: StageSpec, stage_markdown: str) -> None:
current = read_text(memory_path)
entry = render_approved_stage_entry(stage, stage_markdown)
entry = render_compact_stage_entry(stage, stage_markdown)

stage_heading = f"### {stage.stage_title}"
placeholder = "_None yet._"
if placeholder in current:

if stage_heading in current:
pattern = re.compile(
rf"### {re.escape(stage.stage_title)}\n.*?(?=### Stage \d|$)",
re.DOTALL,
)
updated = pattern.sub(entry, current, count=1)
elif placeholder in current:
updated = current.replace(placeholder, entry, 1)
else:
updated = current.rstrip() + "\n\n" + entry + "\n"
Expand Down
222 changes: 222 additions & 0 deletions tests/test_context_compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
from __future__ import annotations

import tempfile
import unittest
from pathlib import Path

from src.utils import (
STAGES,
RunPaths,
StageSpec,
append_approved_stage_summary,
build_continuation_prompt,
build_run_paths,
ensure_run_layout,
initialize_memory,
read_text,
render_approved_stage_entry,
render_compact_stage_entry,
write_text,
_truncate_section,
)


STAGE_01 = next(s for s in STAGES if s.slug == "01_literature_survey")
STAGE_04 = next(s for s in STAGES if s.slug == "04_implementation")
STAGE_05 = next(s for s in STAGES if s.slug == "05_experimentation")
STAGE_06 = next(s for s in STAGES if s.slug == "06_analysis")


def _make_stage_markdown(stage: StageSpec, key_results_lines: int = 5) -> str:
key_results = "\n".join(f"- Result line {i}" for i in range(1, key_results_lines + 1))
return (
f"# Stage {stage.number:02d}: {stage.display_name}\n\n"
"## Objective\n"
f"Test objective for {stage.display_name}.\n\n"
"## Previously Approved Stage Summaries\n"
"_None yet._\n\n"
"## What I Did\n"
f"Implemented {stage.display_name} with care.\n\n"
"## Key Results\n"
f"{key_results}\n\n"
"## Files Produced\n"
"- `workspace/test.txt`\n\n"
"## Suggestions for Refinement\n"
"1. Suggestion A\n"
"2. Suggestion B\n"
"3. Suggestion C\n\n"
"## Your Options\n"
"1. Use suggestion 1\n"
"2. Use suggestion 2\n"
"3. Use suggestion 3\n"
"4. Refine with your own feedback\n"
"5. Approve and continue\n"
"6. Abort\n"
)


class TestTruncateSection(unittest.TestCase):
def test_short_text_unchanged(self) -> None:
text = "Short text."
self.assertEqual(_truncate_section(text, 100), text)

def test_long_text_truncated(self) -> None:
text = "A" * 500
result = _truncate_section(text, 100)
self.assertIn("...(see workspace files for full details)", result)
self.assertLess(len(result), 200)

def test_truncates_at_paragraph_boundary(self) -> None:
text = "Paragraph one.\n\nParagraph two.\n\nParagraph three is very long " + "x" * 200
result = _truncate_section(text, 50)
self.assertIn("...(see workspace files", result)


class TestCompactStageEntry(unittest.TestCase):
def test_short_content_not_truncated(self) -> None:
md = _make_stage_markdown(STAGE_01, key_results_lines=3)
full = render_approved_stage_entry(STAGE_01, md)
compact = render_compact_stage_entry(STAGE_01, md)
self.assertEqual(full, compact)

def test_long_key_results_truncated(self) -> None:
md = _make_stage_markdown(STAGE_04, key_results_lines=200)
compact = render_compact_stage_entry(STAGE_04, md, max_section_chars=500)
self.assertIn("...(see workspace files for full details)", compact)
self.assertLess(len(compact), len(render_approved_stage_entry(STAGE_04, md)))

def test_compact_preserves_objective_and_files(self) -> None:
md = _make_stage_markdown(STAGE_05, key_results_lines=200)
compact = render_compact_stage_entry(STAGE_05, md, max_section_chars=500)
self.assertIn("Test objective for", compact)
self.assertIn("workspace/test.txt", compact)


class TestDuplicatePrevention(unittest.TestCase):
def test_same_stage_not_duplicated(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
paths = build_run_paths(Path(tmp) / "run")
ensure_run_layout(paths)
write_text(paths.user_input, "test goal")
initialize_memory(paths, "test goal")

md = _make_stage_markdown(STAGE_01)
append_approved_stage_summary(paths.memory, STAGE_01, md)
append_approved_stage_summary(paths.memory, STAGE_01, md)

memory = read_text(paths.memory)
count = memory.count(f"### {STAGE_01.stage_title}")
self.assertEqual(count, 1, f"Stage heading appeared {count} times, expected 1")

def test_replace_updates_content(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
paths = build_run_paths(Path(tmp) / "run")
ensure_run_layout(paths)
write_text(paths.user_input, "test goal")
initialize_memory(paths, "test goal")

md_v1 = _make_stage_markdown(STAGE_01, key_results_lines=2)
append_approved_stage_summary(paths.memory, STAGE_01, md_v1)

md_v2 = md_v1.replace("Result line 1", "Updated result")
append_approved_stage_summary(paths.memory, STAGE_01, md_v2)

memory = read_text(paths.memory)
self.assertIn("Updated result", memory)
self.assertEqual(memory.count(f"### {STAGE_01.stage_title}"), 1)

def test_multiple_stages_coexist(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
paths = build_run_paths(Path(tmp) / "run")
ensure_run_layout(paths)
write_text(paths.user_input, "test goal")
initialize_memory(paths, "test goal")

for stage in [STAGE_01, STAGE_04, STAGE_05]:
md = _make_stage_markdown(stage)
append_approved_stage_summary(paths.memory, stage, md)

memory = read_text(paths.memory)
self.assertIn(f"### {STAGE_01.stage_title}", memory)
self.assertIn(f"### {STAGE_04.stage_title}", memory)
self.assertIn(f"### {STAGE_05.stage_title}", memory)


class TestInlineContext(unittest.TestCase):
def test_continuation_prompt_contains_user_goal(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
paths = build_run_paths(Path(tmp) / "run")
ensure_run_layout(paths)
write_text(paths.user_input, "My research goal about GNN")
initialize_memory(paths, "My research goal about GNN")

prompt = build_continuation_prompt(STAGE_01, "template", paths, "fix it")
self.assertIn("My research goal about GNN", prompt)

def test_continuation_prompt_contains_current_objective(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
paths = build_run_paths(Path(tmp) / "run")
ensure_run_layout(paths)
write_text(paths.user_input, "test goal")
initialize_memory(paths, "test goal")

md = _make_stage_markdown(STAGE_05)
write_text(paths.stage_file(STAGE_05), md)

prompt = build_continuation_prompt(STAGE_05, "template", paths, "improve")
self.assertIn("Current Stage Objective", prompt)
self.assertIn("Test objective for", prompt)

def test_continuation_prompt_contains_recent_summaries(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
paths = build_run_paths(Path(tmp) / "run")
ensure_run_layout(paths)
write_text(paths.user_input, "test goal")
initialize_memory(paths, "test goal")

for stage in [STAGE_01, STAGE_04, STAGE_05]:
md = _make_stage_markdown(stage)
append_approved_stage_summary(paths.memory, stage, md)

prompt = build_continuation_prompt(STAGE_06, "template", paths, "analyze")
self.assertIn("Recent Approved Context", prompt)
# Should contain the 2 most recent stages before 06 (04 and 05)
self.assertIn(STAGE_04.stage_title, prompt)
self.assertIn(STAGE_05.stage_title, prompt)


class TestMemorySizeReduction(unittest.TestCase):
def test_compact_memory_smaller_than_full(self) -> None:
"""Using compact entries should produce a smaller memory.md."""
with tempfile.TemporaryDirectory() as tmp:
paths = build_run_paths(Path(tmp) / "run")
ensure_run_layout(paths)
write_text(paths.user_input, "test goal")
initialize_memory(paths, "test goal")

# Append stages with large key_results
for stage in STAGES[:6]:
md = _make_stage_markdown(stage, key_results_lines=150)
append_approved_stage_summary(paths.memory, stage, md)

compact_size = len(read_text(paths.memory))

# Compare to what full entries would produce
initialize_memory(paths, "test goal")
for stage in STAGES[:6]:
md = _make_stage_markdown(stage, key_results_lines=150)
current = read_text(paths.memory)
entry = render_approved_stage_entry(stage, md)
if "_None yet._" in current:
updated = current.replace("_None yet._", entry, 1)
else:
updated = current.rstrip() + "\n\n" + entry + "\n"
write_text(paths.memory, updated)

full_size = len(read_text(paths.memory))

self.assertLess(compact_size, full_size)


if __name__ == "__main__":
unittest.main()