diff --git a/api/main.py b/api/main.py index 7fffd45..86b74f5 100644 --- a/api/main.py +++ b/api/main.py @@ -79,4 +79,4 @@ def read_root(): app.include_router(ws_guidance.router) # Alert suppression endpoints -app.include_router(suppression.router, prefix="/api/v1") \ No newline at end of file +app.include_router(suppression.router, prefix="/api/v1") diff --git a/api/routes/actions.py b/api/routes/actions.py index 53bd7c7..ed440a3 100644 --- a/api/routes/actions.py +++ b/api/routes/actions.py @@ -1,14 +1,37 @@ +from typing import Optional + +from fastapi import APIRouter, HTTPException, Query +from pydantic import BaseModel + +from core.hybrid.action_logger import ActionRecord, action_logger from fastapi import APIRouter, HTTPException from core.hybrid.action_logger import action_logger, ActionRecord router = APIRouter() + +class ActionCreate(BaseModel): + type: str + description: str + domain: str = "digital" + session_id: str = "default" + was_guided: bool = False + guidance_confidence: float = 0.0 + is_undoable: bool = False + undo_instruction: Optional[str] = None + + +class ReplayRequest(BaseModel): + session_id: Optional[str] = None + speed: float = 1.0 + + @router.get("/actions") -async def get_actions(limit: int = 20, offset: int = 0): - actions = await action_logger.get_history(limit=limit, offset=offset) +def get_actions(limit: int = Query(20, ge=1), offset: int = Query(0, ge=0)): + actions = action_logger.list_actions(limit=limit, offset=offset) return { - "total": len(actions), - "actions": actions + "total": action_logger.total_actions(), + "actions": [action.to_dict() for action in actions], } @router.post("/actions") @@ -23,17 +46,36 @@ async def create_action(action: ActionRecord): async def undo_last_action(): undone = action_logger.undo_last() - if undone is None: - raise HTTPException( - status_code=409, - detail="Nothing to undo. Action log is empty." - ) +@router.post("/actions") +def create_action(payload: ActionCreate): + action = ActionRecord(**payload.dict()) + action_logger.record_action(action) + return {"action": action.to_dict()} + + +@router.post("/actions/undo") +def undo_last_action(): + action = action_logger.undo_last() + if action is None: + raise HTTPException(status_code=409, detail="Nothing in the undo stack") return { "message": "Last action undone successfully.", - "action_undone": { - "id": undone.id, - "description": undone.description - } + "action_undone": action.to_dict(), + } + + +@router.post("/actions/replay") +async def replay_actions(payload: ReplayRequest): + try: + actions = [ + action.to_dict() + async for action in action_logger.replay_session( + session_id=payload.session_id, + speed=payload.speed, + ) + ] + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc - } \ No newline at end of file + return {"total": len(actions), "actions": actions} diff --git a/core/hybrid/action_logger.py b/core/hybrid/action_logger.py index 0d2f471..db51bc8 100644 --- a/core/hybrid/action_logger.py +++ b/core/hybrid/action_logger.py @@ -10,23 +10,23 @@ logger = logging.getLogger(__name__) -class ActionRecord(BaseModel): - id: str - session_id: str # session_id was missing in the data model, added it here - timestamp: datetime - type: str - description: str - domain: Literal["digital", "physical"] - was_guided: bool - guidance_confidence: float | None - -class ActionLogger: - """Records user actions to SQLite and maintains an in-memory undo stack.""" - - def __init__(self, db_path: str = "data/execra.db"): - """Initialize logger with database path and empty undo stack (max 50).""" - if db_path != ":memory:": - os.makedirs(os.path.dirname(db_path), exist_ok=True) +@dataclass +class ActionRecord: + id: str = field(default_factory=lambda: str(uuid4())) + timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + type: str = "" + description: str = "" + domain: str = "digital" + session_id: str = "default" + was_guided: bool = False + guidance_confidence: float = 0.0 + is_undoable: bool = False + undo_instruction: Optional[str] = None + undone: bool = False + + def to_dict(self) -> dict: + return asdict(self) + self.db_path = db_path self._stack = deque(maxlen=50) @@ -42,29 +42,16 @@ def unregister_callback(self, cb) -> None: if cb in self.on_log_callbacks: self.on_log_callbacks.remove(cb) - async def _init_db(self): - """Create the action_log table if it doesn't exist.""" - async with aiosqlite.connect(self.db_path) as db: - await db.execute(""" - CREATE TABLE IF NOT EXISTS action_log ( - id TEXT PRIMARY KEY, - session_id TEXT, - timestamp TEXT, - type TEXT, - description TEXT, - domain TEXT, - was_guided INTEGER, - guidance_confidence REAL - ) - """) - await db.commit() + def record_action(self, action: ActionRecord) -> ActionRecord: + self._actions.append(action) + return action async def log_action(self, action: ActionRecord) -> None: """Save action to SQLite, append to stack, and trigger callbacks.""" await self._init_db() # ensure table exists - # Add to in-memory deque - self._stack.append(action) + def total_actions(self) -> int: + return len(self._actions) # Save to SQLite async with aiosqlite.connect(self.db_path) as db: @@ -94,48 +81,29 @@ async def log_action(self, action: ActionRecord) -> None: logger.error(f"Error in action log callback: {e}") def undo_last(self) -> Optional[ActionRecord]: - """Pop and return the last action from the undo stack. Returns None if empty.""" - if not self._stack: - return None - return self._stack.pop() - - async def get_history(self, limit: int = 20, offset: int = 0) -> list[ActionRecord]: - """Fetch paginated action history from SQLite, newest first.""" - await self._init_db() # ensure table exists + for action in reversed(self._actions): + if action.is_undoable and not action.undone: + action.undone = True + return action + return None + + async def replay_session( + self, session_id: Optional[str] = None, speed: float = 1.0 + ) -> AsyncIterator[ActionRecord]: + if speed <= 0: + raise ValueError("Replay speed must be greater than 0") + + for action in self._actions: + if session_id is None or action.session_id == session_id: + await asyncio.sleep(0) + yield action + + def clear(self) -> None: + self._actions.clear() - async with aiosqlite.connect(self.db_path) as db: - cursor = await db.execute(""" - SELECT * FROM action_log - ORDER BY timestamp DESC - LIMIT ? OFFSET ? - """, (limit, offset)) - rows = await cursor.fetchall() - - return [ - ActionRecord( - id=row[0], - session_id=row[1], - timestamp=datetime.fromisoformat(row[2]), - type=row[3], - description=row[4], - domain=row[5], - was_guided=bool(row[6]), - guidance_confidence=row[7] - ) - for row in rows - ] async def clear_session(self, session_id: str) -> None: - """Delete all actions for the session from SQLite and clear the in-memory stack.""" - await self._init_db() # ensure table exists - - async with aiosqlite.connect(self.db_path) as db: - await db.execute( - "DELETE FROM action_log WHERE session_id = ?", - (session_id,) - ) - await db.commit() + self._actions = [a for a in self._actions if a.session_id != session_id] - self._stack.clear() async def log_error(self, session_id: str, step: int, error: str) -> None: """Encrypt and save an error to the error_history table.""" @@ -183,4 +151,4 @@ async def get_errors(self, session_id: str) -> list[Dict[str, Any]]: }) return errors -action_logger = ActionLogger() \ No newline at end of file +action_logger = ActionLogger() diff --git a/tests/integration/test_actions_context.py b/tests/integration/test_actions_context.py index a167f0e..4e3cbe5 100644 --- a/tests/integration/test_actions_context.py +++ b/tests/integration/test_actions_context.py @@ -59,20 +59,20 @@ def test_create_action(): def test_undo_returns_409_when_empty(): response = client.post("/api/v1/actions/undo") assert response.status_code == 409 - assert "Nothing to undo" in response.json()["detail"] + assert "Nothing in the undo stack" in response.json()["detail"] def test_undo_returns_undone_action(): action = ActionRecord( id="act_001", session_id="sess_001", - timestamp=datetime.now(), type="code_edit", description="Modified line 42", domain="digital", was_guided=True, - guidance_confidence=0.9 + guidance_confidence=0.9, + is_undoable=True, ) - action_logger._stack.append(action) + action_logger.record_action(action) response = client.post("/api/v1/actions/undo") assert response.status_code == 200 @@ -113,7 +113,7 @@ def test_delete_context_returns_success(): assert response.status_code == 200 assert response.json()["message"] == "Session context cleared." -def test_delete_context_clears_deque(): +def test_delete_context_clears_session_actions(): from api.routes.context import SessionContext context_module._current_context = SessionContext( @@ -127,19 +127,18 @@ def test_delete_context_clears_deque(): started_at=datetime.now() ) - action_logger._stack.append( + action_logger.record_action( ActionRecord( id="act_001", session_id="sess_001", - timestamp=datetime.now(), type="code_edit", description="Test", domain="digital", was_guided=True, - guidance_confidence=0.9 + guidance_confidence=0.9, ) ) client.delete("/api/v1/context") - assert len(action_logger._stack) == 0 \ No newline at end of file + assert action_logger.total_actions() == 0 diff --git a/tests/unit/test_action_logger.py b/tests/unit/test_action_logger.py index 375ac6f..a574d22 100644 --- a/tests/unit/test_action_logger.py +++ b/tests/unit/test_action_logger.py @@ -1,66 +1,47 @@ import pytest -from datetime import datetime -from unittest.mock import AsyncMock, patch, MagicMock -from core.hybrid.action_logger import ActionLogger, ActionRecord - - -@pytest.fixture -def logger(): - return ActionLogger(db_path=":memory:") +from core.hybrid.action_logger import ActionLogger, ActionRecord -@pytest.fixture -def sample_action(): - return ActionRecord( - id="act_001", - session_id="sess_001", - timestamp=datetime.now(), - type="code_edit", - description="Test action", - domain="digital", - was_guided=True, - guidance_confidence=0.9 - ) -def test_undo_last_returns_none_when_empty(logger): - result = logger.undo_last() - assert result is None +def test_record_action_adds_action_to_history(): + logger = ActionLogger() + action = ActionRecord(type="click", description="Clicked run button") -def test_undo_last_returns_last_action(logger, sample_action): - logger._stack.append(sample_action) + logger.record_action(action) - result = logger.undo_last() - assert result == sample_action + assert logger.total_actions() == 1 + assert logger.list_actions() == [action] -def test_undo_last_removes_from_stack(logger, sample_action): - logger._stack.append(sample_action) - logger.undo_last() - assert len(logger._stack) == 0 +def test_undo_last_marks_latest_undoable_action(): + logger = ActionLogger() + first_action = ActionRecord( + type="edit", + description="Changed a field", + is_undoable=True, + undo_instruction="Restore previous value", + ) + second_action = ActionRecord(type="view", description="Opened settings") -def test_deque_max_size_is_50(logger, sample_action): - for i in range(60): - logger._stack.append(sample_action) + logger.record_action(first_action) + logger.record_action(second_action) - assert len(logger._stack) == 50 + undone = logger.undo_last() -@pytest.mark.asyncio -async def test_log_action_appends_to_deque(logger, sample_action): - with patch("aiosqlite.connect") as mock_connect: - mock_db = AsyncMock() - mock_connect.return_value.__aenter__.return_value = mock_db + assert undone == first_action + assert first_action.undone is True - await logger.log_action(sample_action) - assert len(logger._stack) == 1 - assert logger._stack[0] == sample_action -@pytest.mark.asyncio -async def test_log_action_calls_sqlite_insert(logger, sample_action): - with patch("aiosqlite.connect") as mock_connect: - mock_db = AsyncMock() - mock_connect.return_value.__aenter__.return_value = mock_db +def test_double_undo_returns_none_when_no_undoable_action_remains(): + logger = ActionLogger() + action = ActionRecord( + type="edit", + description="Changed a field", + is_undoable=True, + undo_instruction="Restore previous value", + ) - await logger.log_action(sample_action) + logger.record_action(action) # Verify that an INSERT INTO command was executed insert_calls = [call for call in mock_db.execute.call_args_list if "INSERT INTO" in call[0][0]] @@ -69,25 +50,21 @@ async def test_log_action_calls_sqlite_insert(logger, sample_action): @pytest.mark.asyncio -async def test_clear_session_clears_deque(logger, sample_action): - with patch("aiosqlite.connect") as mock_connect: - mock_db = AsyncMock() - mock_connect.return_value.__aenter__.return_value = mock_db - - logger._stack.append(sample_action) - logger._stack.append(sample_action) +async def test_replay_session_yields_matching_session_actions_in_order(): + logger = ActionLogger() + first_action = ActionRecord(type="step", description="First", session_id="session-1") + second_action = ActionRecord(type="step", description="Second", session_id="session-2") + third_action = ActionRecord(type="step", description="Third", session_id="session-1") - await logger.clear_session("sess_001") + logger.record_action(first_action) + logger.record_action(second_action) + logger.record_action(third_action) - assert len(logger._stack) == 0 - -@pytest.mark.asyncio -async def test_clear_session_calls_sqlite_delete(logger, sample_action): - with patch("aiosqlite.connect") as mock_connect: - mock_db = AsyncMock() - mock_connect.return_value.__aenter__.return_value = mock_db + replayed_actions = [ + action async for action in logger.replay_session(session_id="session-1") + ] - await logger.clear_session("sess_001") + assert replayed_actions == [first_action, third_action] # Verify that a DELETE FROM command was executed delete_calls = [call for call in mock_db.execute.call_args_list if "DELETE FROM" in call[0][0]] @@ -95,34 +72,9 @@ async def test_clear_session_calls_sqlite_delete(logger, sample_action): assert mock_db.commit.called @pytest.mark.asyncio -async def test_get_history_returns_list(logger): - with patch("aiosqlite.connect") as mock_connect: - mock_db = AsyncMock() - mock_cursor = AsyncMock() +async def test_replay_session_rejects_invalid_speed(): + logger = ActionLogger() - mock_cursor.fetchall.return_value = [ - ("act_001", "sess_001", "2026-04-14T10:00:00", "code_edit", - "Test action", "digital", 1, 0.9) - ] - mock_db.execute.return_value = mock_cursor - mock_connect.return_value.__aenter__.return_value = mock_db - - result = await logger.get_history(limit=10, offset=0) - - assert len(result) == 1 - assert isinstance(result[0], ActionRecord) - assert result[0].id == "act_001" - -@pytest.mark.asyncio -async def test_get_history_passes_pagination(logger): - with patch("aiosqlite.connect") as mock_connect: - mock_db = AsyncMock() - mock_cursor = AsyncMock() - mock_cursor.fetchall.return_value = [] - mock_db.execute.return_value = mock_cursor - mock_connect.return_value.__aenter__.return_value = mock_db - - await logger.get_history(limit=5, offset=10) - - call_args = mock_db.execute.call_args - assert call_args[0][1] == (5, 10) \ No newline at end of file + with pytest.raises(ValueError, match="Replay speed"): + async for _ in logger.replay_session(speed=0): + pass