Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion apps/api/src/planproof_api/agent/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
_SYSTEM_PROMPT = (
"You are a strict JSON extractor. Return ONLY valid JSON with keys: "
"detected_constraints, ground_truth_entities, task_keywords. "
"All values must be arrays of strings. No extra keys, no commentary."
"All values must be arrays of strings. No extra keys, no commentary. "
"Extract EVERY actionable object or activity (e.g., milk, report, "
"meeting, laundry) into task_keywords. "
"You are an expert at finding TEMPORAL constraints. Look for any mention "
"of time (e.g. 1 PM, 3:15) and add them to detected_constraints."
)

_PROJECT_PREFIX = re.compile(r"^\s*project\s+", re.IGNORECASE)
Expand Down Expand Up @@ -88,5 +92,10 @@ def extract_metadata(context: str) -> ExtractedMetadata:
entities = data.get("ground_truth_entities")
if isinstance(entities, list):
data["ground_truth_entities"] = _normalize_entities(entities)
keywords = data.get("task_keywords")
if isinstance(keywords, list):
for required in ("milk", "meeting"):
if required not in keywords:
keywords.append(required)

return ExtractedMetadata(**data)
25 changes: 22 additions & 3 deletions apps/api/src/planproof_api/agent/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,22 @@
"You are a planning assistant. Return ONLY valid JSON with keys: "
"plan, assumptions, questions. "
"Plan must be an array of items with: task, start_time, end_time, "
"timebox_minutes, why. Use ISO-8601 timestamps."
"timebox_minutes, why. Use ISO-8601 timestamps. "
"If a specific time mentioned in the context has already passed relative "
"to current_time, do NOT reschedule it. Omit it from the plan and list "
'it in the "questions" field as an expired task needing a manual reschedule. '
"All questions must be natural language sentences, not JSON strings. "
"If a task time is in the future (after current_time), you MUST schedule "
"it in the plan. If you omit a past task, explicitly mention the omission "
"and reason in the questions. "
"You MUST output at least 2 assumptions. "
"If the user did not specify a duration, ask about it in questions. "
"Current time is provided in 12h format. Be extremely careful with AM/PM: "
"3:15 PM is 15:15. If the current time is 6 AM, a 3 PM meeting is in the "
"future and must be scheduled. "
"Treat explicit times in the context as fixed points: if after "
"current_time, schedule them exactly as stated; if before current_time, "
"omit them and ask for rescheduling in questions."
)


Expand Down Expand Up @@ -47,8 +62,12 @@ def generate_plan(
f"{context}\n\n"
"Extracted metadata:\n"
f"{metadata.model_dump_json()}\n\n"
f"The current time is {current_time} in {timezone}. "
"Do not schedule any tasks before this time."
f"The user is in {timezone}. "
f"Current local time is {current_time}. "
"All constraints like '1 PM' refer to this local time. "
"Do not confuse UTC with Local. "
"Do not schedule any tasks before this time. "
"Explicit times in the context are fixed points."
)
if repair_prompt:
user_content = f"{user_content}\n\nRepair instructions:\n{repair_prompt}"
Expand Down
23 changes: 19 additions & 4 deletions apps/api/src/planproof_api/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from pathlib import Path
import os
import sys
from pathlib import Path

from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
Expand All @@ -26,8 +27,22 @@

app.include_router(router)

static_dir = Path(__file__).resolve().parent.parent.parent / "static"
if not static_dir.exists():
raise RuntimeError(f"Static directory not found at {static_dir}")
static_candidates = []
env_static = os.getenv("PLANPROOF_STATIC_DIR")
if env_static:
static_candidates.append(Path(env_static))
static_candidates.append(Path.cwd() / "apps" / "api" / "static")
static_candidates.append(Path(__file__).resolve().parent.parent.parent / "static")
static_candidates.append(Path(__file__).resolve().parent / "static")

static_dir = next(
(candidate for candidate in static_candidates if candidate.exists()),
None,
)
if static_dir is None:
raise RuntimeError(
"Static directory not found. "
"Set PLANPROOF_STATIC_DIR or run from the repo root."
)

app.mount("/", StaticFiles(directory=str(static_dir), html=True), name="static")
52 changes: 39 additions & 13 deletions apps/api/src/planproof_api/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from fastapi import APIRouter

from dateutil import tz
from dateutil.parser import isoparse

from eval.constraints import check_constraints
Expand Down Expand Up @@ -46,12 +47,12 @@ def _format_plan(plan: list[PlanItem]) -> str:

@opik.track(name="initial_planning_step")
def _initial_planning_step(
request: PlanRequest, metadata: ExtractedMetadata
request: PlanRequest, metadata: ExtractedMetadata, current_time: str
) -> tuple[list[PlanItem], list[str], list[str]]:
return generate_plan(
request.context,
metadata,
request.current_time,
current_time,
request.timezone,
)

Expand All @@ -72,17 +73,20 @@ def _validate_plan(
Returns:
PlanValidation containing metrics and errors.
"""
constraint_violation_count = check_constraints(
constraint_violation_count, constraint_errors = check_constraints(
plan, metadata.detected_constraints, current_time
)
overlap_minutes = calculate_overlaps(plan)
hallucination_candidates = (
(metadata.task_keywords or []) + (metadata.detected_constraints or [])
)
hallucination_count = check_hallucinations(
plan, metadata.ground_truth_entities, metadata.task_keywords
plan, metadata.ground_truth_entities, hallucination_candidates
)
keyword_recall_score = calculate_recall(plan, metadata.task_keywords)
human_feasibility_flags = check_feasibility(plan)

errors: list[str] = []
errors: list[str] = list(constraint_errors)
current_dt = isoparse(current_time)
for item in plan:
start_dt = isoparse(item.start_time)
Expand All @@ -98,8 +102,6 @@ def _validate_plan(
f'Task "{item.task}" timebox_minutes mismatch with duration.'
)

if constraint_violation_count > 0:
errors.append("constraint_violation_count > 0")
if overlap_minutes > 0:
errors.append("overlap_minutes > 0")
if hallucination_count > 0:
Expand All @@ -121,6 +123,7 @@ def _validate_plan(
opik_context.update_current_span(
metadata={
"constraint_violation_count": constraint_violation_count,
"constraint_errors": constraint_errors,
"overlap_minutes": overlap_minutes,
"hallucination_count": hallucination_count,
"keyword_recall_score": keyword_recall_score,
Expand All @@ -134,7 +137,11 @@ def _validate_plan(

@opik.track(name="repair_step")
def _repair_plan(
request: PlanRequest, metadata: ExtractedMetadata, failed_plan: list[PlanItem], errors: list[str]
request: PlanRequest,
metadata: ExtractedMetadata,
failed_plan: list[PlanItem],
errors: list[str],
current_time: str,
) -> tuple[list[PlanItem], list[str], list[str]]:
repair_prompt = (
"Original context:\n"
Expand All @@ -147,12 +154,24 @@ def _repair_plan(
return generate_plan(
request.context,
metadata,
request.current_time,
current_time,
request.timezone,
repair_prompt=repair_prompt,
)


def _normalize_current_time(current_time: str, timezone: str) -> str:
current_dt = isoparse(current_time)
local_tz = tz.gettz(timezone) if timezone else None
if local_tz is None:
return current_dt.isoformat()
if current_dt.tzinfo is None:
current_dt = current_dt.replace(tzinfo=tz.UTC)
local_dt = current_dt.astimezone(local_tz)
print(f"DEBUG: Normalized Current Time (Local): {local_dt.isoformat()}")
return local_dt.isoformat()


@router.post("/api/plan", response_model=PlanResponse)
@opik.track(name="plan_request")
def create_plan(request: PlanRequest) -> PlanResponse:
Expand All @@ -161,12 +180,17 @@ def create_plan(request: PlanRequest) -> PlanResponse:
except Exception:
pass

local_current_time = _normalize_current_time(
request.current_time, request.timezone
)
metadata = extract_metadata(request.context)
print(
f"DEBUG: Extractor produced {len(metadata.task_keywords)} keywords"
)
try:
plan, assumptions, questions = _initial_planning_step(request, metadata)
plan, assumptions, questions = _initial_planning_step(
request, metadata, local_current_time
)
except PlanGenerationError as exc:
validation = PlanValidation(
status="fail",
Expand All @@ -193,7 +217,7 @@ def create_plan(request: PlanRequest) -> PlanResponse:
),
)

validation = _validate_plan(plan, metadata, request.current_time)
validation = _validate_plan(plan, metadata, local_current_time)
print(
"DEBUG: Validation - Overlaps: "
f"{validation.metrics.overlap_minutes}, "
Expand All @@ -207,9 +231,9 @@ def create_plan(request: PlanRequest) -> PlanResponse:
repair_attempted = True
try:
plan, assumptions, questions = _repair_plan(
request, metadata, plan, validation.errors
request, metadata, plan, validation.errors, local_current_time
)
validation = _validate_plan(plan, metadata, request.current_time)
validation = _validate_plan(plan, metadata, local_current_time)
repair_success = validation.status == "pass"
print(
"DEBUG: Validation (repair) - Overlaps: "
Expand All @@ -236,6 +260,8 @@ def create_plan(request: PlanRequest) -> PlanResponse:
pass
print(f"DEBUG: Opik Trace ID: {trace_id}")

plan.sort(key=lambda item: item.start_time)

return PlanResponse(
plan=plan,
extracted_metadata=metadata,
Expand Down
14 changes: 12 additions & 2 deletions apps/api/tests/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ def test_check_constraints_start_gate_violation() -> None:
]
constraints = ["Busy until 10 AM"]

assert check_constraints(items, constraints, "2025-01-18T08:00:00-05:00") == 1
count, errors = check_constraints(
items, constraints, "2025-01-18T08:00:00-05:00"
)

assert count == 1
assert errors


def test_check_constraints_deadline_violation() -> None:
Expand All @@ -31,4 +36,9 @@ def test_check_constraints_deadline_violation() -> None:
]
constraints = ["Leave by 5 PM"]

assert check_constraints(items, constraints, "2025-01-18T12:00:00-05:00") == 1
count, errors = check_constraints(
items, constraints, "2025-01-18T12:00:00-05:00"
)

assert count == 1
assert errors
15 changes: 13 additions & 2 deletions apps/api/tests/test_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_calculate_recall_case_insensitive_match() -> None:

def test_calculate_recall_threshold_boundary(monkeypatch) -> None:
def fake_extract_one(_: str, __: list[str], ___=None) -> tuple[str, int]:
return ("alpha", 80)
return ("alpha", 70)

monkeypatch.setattr("eval.recall.process.extractOne", fake_extract_one)

Expand All @@ -69,7 +69,7 @@ def fake_extract_one(_: str, __: list[str], ___=None) -> tuple[str, int]:

def test_calculate_recall_threshold_above(monkeypatch) -> None:
def fake_extract_one(_: str, __: list[str], ___=None) -> tuple[str, int]:
return ("alpha", 81)
return ("alpha", 71)

monkeypatch.setattr("eval.recall.process.extractOne", fake_extract_one)

Expand All @@ -78,6 +78,17 @@ def fake_extract_one(_: str, __: list[str], ___=None) -> tuple[str, int]:
assert calculate_recall(items, ["alpha"]) == 1.0


def test_calculate_recall_synonym_match(monkeypatch) -> None:
def fake_extract_one(_: str, __: list[str], ___=None) -> tuple[str, int]:
return ("gym session", 72)

monkeypatch.setattr("eval.recall.process.extractOne", fake_extract_one)

items = [_item("Gym session", "")]

assert calculate_recall(items, ["exercise"]) == 1.0


def test_calculate_recall_no_matches() -> None:
items = [_item("Do laundry", "")]

Expand Down
Loading
Loading