diff --git a/backend/app/config.py b/backend/app/config.py index 9633d46..cb3d9cd 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -49,6 +49,11 @@ class Settings(BaseSettings): shake_min_distance: float = 15 # Min movement distance in pixels (reduced from 20) shake_check_interval: float = 0.02 # Check interval in seconds (faster sampling) + # Set-of-Mark visual grounding + som_enabled: bool = True + som_max_elements: int = 80 # Max labeled elements per screenshot + som_min_element_size: int = 15 # Skip elements smaller than this (px) + class Config: env_file = ".env" env_file_encoding = "utf-8" diff --git a/backend/app/services/action_executor.py b/backend/app/services/action_executor.py index fd92ba5..5bbad29 100644 --- a/backend/app/services/action_executor.py +++ b/backend/app/services/action_executor.py @@ -23,6 +23,7 @@ class ActionExecutor: def __init__(self, screen_width: int, screen_height: int): self._screen_width = screen_width self._screen_height = screen_height + self._elements: list = [] # Current SoM elements for this step self._actions: dict[str, callable] = { "click": self._click, "double_click": self._double_click, @@ -68,9 +69,51 @@ async def execute(self, response: LLMResponse) -> str: except Exception as e: return f"Action '{action}' failed: {e}" + def set_elements(self, elements: list) -> None: + """Set the current step's detected SoM elements for element-based clicking.""" + self._elements = elements + + # ── Element resolution ─────────────────────────────────────────── + + async def _click_element(self, params: dict) -> str: + """Click a labeled SoM element by its number.""" + label = params["element"] + elem = next((e for e in self._elements if e.label == label), None) + if not elem: + return f"Element #{label} not found in current screen elements" + cx, cy = elem.center_pixel() + modifiers = params.get("modifiers") + computer_service.click(cx, cy, modifiers=modifiers) + title = f' "{elem.title}"' if elem.title else "" + return f'Clicked element #{label} ({elem.short_role}{title}) at ({cx}, {cy})' + + async def _double_click_element(self, params: dict) -> str: + """Double-click a labeled SoM element by its number.""" + label = params["element"] + elem = next((e for e in self._elements if e.label == label), None) + if not elem: + return f"Element #{label} not found in current screen elements" + cx, cy = elem.center_pixel() + computer_service.double_click(cx, cy) + title = f' "{elem.title}"' if elem.title else "" + return f'Double-clicked element #{label} ({elem.short_role}{title}) at ({cx}, {cy})' + + async def _right_click_element(self, params: dict) -> str: + """Right-click a labeled SoM element by its number.""" + label = params["element"] + elem = next((e for e in self._elements if e.label == label), None) + if not elem: + return f"Element #{label} not found in current screen elements" + cx, cy = elem.center_pixel() + computer_service.right_click(cx, cy) + title = f' "{elem.title}"' if elem.title else "" + return f'Right-clicked element #{label} ({elem.short_role}{title}) at ({cx}, {cy})' + # ── Individual action handlers ────────────────────────────────────── async def _click(self, params: dict) -> str: + if "element" in params: + return await self._click_element(params) x, y = self.normalize_to_screen_coords(params["x"], params["y"]) modifiers = params.get("modifiers") computer_service.click(x, y, modifiers=modifiers) @@ -79,11 +122,15 @@ async def _click(self, params: dict) -> str: return f"Clicked at ({x}, {y}) [normalized: ({params['x']}, {params['y']})]" async def _double_click(self, params: dict) -> str: + if "element" in params: + return await self._double_click_element(params) x, y = self.normalize_to_screen_coords(params["x"], params["y"]) computer_service.double_click(x, y) return f"Double-clicked at ({x}, {y}) [normalized: ({params['x']}, {params['y']})]" async def _right_click(self, params: dict) -> str: + if "element" in params: + return await self._right_click_element(params) x, y = self.normalize_to_screen_coords(params["x"], params["y"]) computer_service.right_click(x, y) return f"Right-clicked at ({x}, {y}) [normalized: ({params['x']}, {params['y']})]" diff --git a/backend/app/services/agent.py b/backend/app/services/agent.py index 13cd6ad..7f1caaa 100644 --- a/backend/app/services/agent.py +++ b/backend/app/services/agent.py @@ -154,27 +154,60 @@ async def _wait_if_paused(self, state: AgentState) -> bool: # ── Observe / Think / Act helpers ─────────────────────────────────── - async def _observe(self, state: AgentState) -> str: - """Capture a screenshot and save it to disk. Returns base64 data.""" + async def _observe(self, state: AgentState) -> tuple[str, list]: + """Capture screenshot, detect SoM elements, annotate. Returns (base64, elements).""" screenshot_b64 = screen_service.capture_screen_base64() state.last_screenshot = screenshot_b64 + # SoM detection + annotation + elements: list = [] + annotated_b64 = screenshot_b64 + if settings.som_enabled: + try: + from app.services.som import ( + detect_elements, filter_leaf_elements, + annotate_screenshot, format_element_list, + ) + screen_w, screen_h = screen_service.get_screen_size() + raw_elements = detect_elements(screen_w, screen_h) + elements = filter_leaf_elements( + raw_elements, + max_elements=settings.som_max_elements, + min_size=settings.som_min_element_size, + ) + if elements: + image_data = base64.b64decode(screenshot_b64) + image = Image.open(io.BytesIO(image_data)) + annotated = annotate_screenshot(image, elements) + buf = io.BytesIO() + annotated.save(buf, "JPEG", quality=SCREENSHOT_QUALITY) + annotated_b64 = base64.b64encode(buf.getvalue()).decode() + logger.info("SoM: detected %d elements", len(elements)) + except Exception as e: + logger.warning("SoM detection failed, using raw screenshot: %s", e) + + # Save raw screenshot to disk (unannotated) screenshots_dir = Path("./screenshots") screenshots_dir.mkdir(exist_ok=True) filename = f"task_{state.task_id}_step_{state.current_step}.jpg" path = screenshots_dir / filename - image_data = base64.b64decode(screenshot_b64) image = Image.open(io.BytesIO(image_data)) image.save(str(path), "JPEG", quality=SCREENSHOT_QUALITY) - state.screenshot_before_path = f"screenshots/{filename}" - return screenshot_b64 + + return annotated_b64, elements async def _think( - self, instruction: str, screenshot_b64: str, state: AgentState + self, instruction: str, screenshot_b64: str, state: AgentState, + elements: list | None = None, ) -> LLMResponse: """Ask the LLM to analyse the current screen.""" + element_list_text = "" + if elements: + from app.services.som import format_element_list + element_list_text = format_element_list(elements) + return await self.llm.analyze_screen( instruction=instruction, screenshot_base64=screenshot_b64, @@ -184,6 +217,7 @@ async def _think( notes=state.notes, references=state.task_references, lessons=state.task_lessons, + element_list_text=element_list_text, ) async def _handle_response( @@ -401,7 +435,8 @@ async def run( # 1. Observe try: - screenshot_b64 = await self._observe(state) + screenshot_b64, elements = await self._observe(state) + self._action_executor.set_elements(elements) except Exception as e: logger.error("Failed to capture screenshot: %s", e) raise @@ -415,7 +450,7 @@ async def run( # 2. Think try: - llm_response = await self._think(instruction, screenshot_b64, state) + llm_response = await self._think(instruction, screenshot_b64, state, elements) except Exception as e: logger.error("Failed to analyze screen with LLM: %s", e) raise diff --git a/backend/app/services/llm/base.py b/backend/app/services/llm/base.py index 1e86852..a2b48b9 100644 --- a/backend/app/services/llm/base.py +++ b/backend/app/services/llm/base.py @@ -113,6 +113,7 @@ async def analyze_screen( notes: Optional[str] = None, references: Optional[list[dict]] = None, lessons: Optional[list[dict]] = None, + element_list_text: str = "", ) -> LLMResponse: """Analyze the current screen and determine the next action.""" pass @@ -168,6 +169,7 @@ def get_system_prompt( notes: Optional[str] = None, references: Optional[list[dict]] = None, lessons: Optional[list[dict]] = None, + element_list_text: str = "", ) -> str: """Get the system prompt for this provider.""" from app.services.defaults import get_running_apps, load_system_defaults @@ -179,4 +181,5 @@ def get_system_prompt( references=references, lessons=lessons, system_defaults=system_defaults, running_apps=running_apps, + element_list_text=element_list_text or None, ) diff --git a/backend/app/services/llm/gemini_provider.py b/backend/app/services/llm/gemini_provider.py index 9bbf110..2dfe07d 100644 --- a/backend/app/services/llm/gemini_provider.py +++ b/backend/app/services/llm/gemini_provider.py @@ -80,6 +80,8 @@ def _append_full_steps( class GeminiProvider(BaseLLMProvider): """Google Gemini provider for vision-based computer control.""" + _MAX_PARSE_RETRIES = 2 + def __init__(self, model: str = None): super().__init__() self._model = model or settings.gemini_model @@ -111,6 +113,7 @@ async def analyze_screen( notes: Optional[str] = None, references: Optional[list[dict]] = None, lessons: Optional[list[dict]] = None, + element_list_text: str = "", ) -> LLMResponse: """Analyze screen and return next action.""" contents = self._build_contents(instruction, screenshot_base64, history) @@ -122,6 +125,7 @@ async def analyze_screen( raw_response, token_usage = await self._call_with_retries( contents, skill, memories, plan=plan, notes=notes, references=references, lessons=lessons, + element_list_text=element_list_text, ) if raw_response and not raw_response.strip().endswith("}"): @@ -133,6 +137,14 @@ async def analyze_screen( logger.debug("Raw response: %s", raw_response) llm_response = self._parse_response(raw_response) + + if llm_response.reasoning.startswith("Failed to parse JSON"): + logger.info("Parse failed, attempting JSON correction retry") + retried = await self._retry_parse_with_correction(raw_response) + if retried is not None: + retried.token_usage = token_usage + return retried + llm_response.token_usage = token_usage return llm_response @@ -233,6 +245,7 @@ async def _call_with_retries( notes: Optional[str] = None, references: Optional[list[dict]] = None, lessons: Optional[list[dict]] = None, + element_list_text: str = "", ) -> tuple[str, TokenUsage]: """Call Gemini API with exponential-backoff retries. Returns (raw_text, token_usage).""" last_error: Optional[Exception] = None @@ -242,6 +255,7 @@ async def _call_with_retries( system_prompt = self.get_system_prompt( skill=skill, memories=memories, plan=plan, notes=notes, references=references, lessons=lessons, + element_list_text=element_list_text, ) if skill: logger.info("Using skill: %s", skill) @@ -362,6 +376,69 @@ def _parse_response(self, raw_response: str) -> LLMResponse: raw_response=raw_response, ) + async def _retry_parse_with_correction(self, malformed_response: str) -> Optional[LLMResponse]: + """Re-prompt the LLM to fix a malformed JSON response. + + Returns a parsed LLMResponse on success, or None after all retries fail. + """ + truncated = malformed_response[:1000] + correction_prompt = ( + "Your previous response was not valid JSON. Here is what you returned: " + f"{truncated}. Return ONLY a valid JSON object with the required fields: " + "reasoning, action, params, needs_confirmation. Nothing else." + ) + + for attempt in range(self._MAX_PARSE_RETRIES): + try: + response = await self.client.aio.models.generate_content( + model=self._model, + contents=[types.Content( + role="user", + parts=[types.Part.from_text(text=correction_prompt)], + )], + config=types.GenerateContentConfig( + max_output_tokens=2048, + temperature=0.0, + ), + ) + + raw_text = response.text if hasattr(response, "text") else None + if raw_text is None: + logger.warning("JSON retry attempt %d/%d: empty response", attempt + 1, self._MAX_PARSE_RETRIES) + continue + + content = _strip_markdown_codeblock(raw_text) + data = json.loads(content) + + action = data.get("action", "") + is_complete = action == "done" + needs_clarification = data.get("needs_clarification", False) + + logger.info("JSON retry attempt %d/%d succeeded", attempt + 1, self._MAX_PARSE_RETRIES) + return LLMResponse( + reasoning=data.get("reasoning", ""), + action=action if not is_complete and not needs_clarification else None, + action_params=data.get("params", {}), + is_complete=is_complete, + needs_confirmation=data.get("needs_confirmation", False), + needs_human_help=data.get("needs_human_help", False), + human_help_reason=data.get("human_help_reason"), + needs_clarification=needs_clarification, + clarification_question=data.get("clarification_question"), + clarification_options=data.get("clarification_options"), + expected_outcome=data.get("expected_outcome"), + notes=data.get("notes"), + raw_response=raw_text, + ) + except (json.JSONDecodeError, Exception) as e: + logger.warning( + "JSON retry attempt %d/%d failed: %s", + attempt + 1, self._MAX_PARSE_RETRIES, e, + ) + + logger.warning("All %d JSON retry attempts failed", self._MAX_PARSE_RETRIES) + return None + async def extract_and_merge_memories( self, instruction: str, diff --git a/backend/app/services/llm/prompt_builder.py b/backend/app/services/llm/prompt_builder.py index 887b7dc..feebc23 100644 --- a/backend/app/services/llm/prompt_builder.py +++ b/backend/app/services/llm/prompt_builder.py @@ -24,6 +24,7 @@ def build( lessons: Optional[list[dict]] = None, system_defaults: Optional[dict[str, str]] = None, running_apps: Optional[list[str]] = None, + element_list_text: Optional[str] = None, ) -> str: """Build the full system prompt. @@ -70,6 +71,8 @@ def build( prompt += self._references_section(references) if skill: prompt += self._skill_section(skill) + if element_list_text: + prompt += self._labeled_elements_section(element_list_text) return prompt @@ -190,10 +193,14 @@ def _available_actions(self) -> str: ═══════════════════════════════════════════════════════════════════════════════ MOUSE ACTIONS: -- click: {"x": 0-1000, "y": 0-1000} +- click: {"element": N} ← PREFERRED — use the labeled element number from screenshot + OR: {"x": 0-1000, "y": 0-1000} ← only when target is NOT a labeled element Optional modifiers: {"x": N, "y": N, "modifiers": ["shift"|"command"|"option"|"control"]} -- double_click: {"x": 0-1000, "y": 0-1000} - For opening files, selecting words -- right_click: {"x": 0-1000, "y": 0-1000} - For context menus + ⚠ IMPORTANT: When interactive elements are labeled [1], [2], etc. on the screenshot, + ALWAYS use {"element": N} instead of guessing x,y coordinates. Element clicks are + precise; raw coordinates often miss. +- double_click: {"element": N} or {"x": 0-1000, "y": 0-1000} - For opening files, selecting words +- right_click: {"element": N} or {"x": 0-1000, "y": 0-1000} - For context menus - drag: {"start_x": N, "start_y": N, "end_x": N, "end_y": N} - For selecting, moving items KEYBOARD ACTIONS: @@ -275,7 +282,9 @@ def _available_actions(self) -> str: - y=0 is top edge, y=1000 is bottom edge - Example: Center of screen ≈ (500, 500) - Example: Top-right corner ≈ (950, 50) -- Example: Dock (bottom center) ≈ (500, 980)""" +- Example: Dock (bottom center) ≈ (500, 980) +- When elements are labeled on screen: {"element": 5} clicks element [5] precisely + This is MORE ACCURATE than guessing coordinates — always prefer element clicks.""" def _error_recovery(self) -> str: return """ @@ -530,6 +539,8 @@ def _response_format(self) -> str: - ALWAYS start reasoning with verification of the last step's expected_outcome (except on step 0). State: "Expected: X. Actual: Y. [Match/Mismatch]." - Verify coordinates are in 0-1000 range +- For click/double_click/right_click: when a labeled element [N] matches your target, + use "params": {"element": N} — not raw x,y coordinates - When task is complete, action must be "done" with a summary - Never guess when uncertain - ask for clarification - Learn from history - don't repeat failed actions""" @@ -760,7 +771,18 @@ def _skill_section(skill: "ResolvedSkill") -> str: ACTIVE SKILL: {label} ═══════════════════════════════════════════════════════════════════════════════ -The following specialized knowledge is available for this task. Use these +The following specialized knowledge is available for this task. Use these app-specific guidelines to perform actions more effectively. {skill.instructions}""" + + @staticmethod + def _labeled_elements_section(element_list_text: str) -> str: + return ( + "\n\n## LABELED ELEMENTS (IMPORTANT)\n\n" + "Interactive elements on the current screen are labeled [1], [2], etc.\n" + 'You MUST use click({"element": N}) for any target that matches a labeled element.\n' + "Do NOT guess x,y coordinates for labeled elements — element clicks are exact.\n" + "Only use raw x,y for unlabeled content (e.g., web page body, images, unlabeled areas).\n\n" + f"{element_list_text}" + ) diff --git a/backend/app/services/som.py b/backend/app/services/som.py new file mode 100644 index 0000000..90a2fd5 --- /dev/null +++ b/backend/app/services/som.py @@ -0,0 +1,408 @@ +"""Set-of-Mark element detection via macOS Accessibility API.""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass + +from PIL import Image, ImageDraw, ImageFont + +logger = logging.getLogger(__name__) + +# Roles considered interactive / clickable +INTERACTIVE_ROLES = frozenset({ + "AXButton", + "AXTextField", + "AXTextArea", + "AXLink", + "AXMenuItem", + "AXMenuBarItem", + "AXPopUpButton", + "AXCheckBox", + "AXRadioButton", + "AXTab", + "AXComboBox", + "AXSlider", + "AXIncrementor", + "AXSearchField", + "AXSecureTextField", + "AXStaticText", + "AXImage", + "AXRow", + "AXCell", +}) + +MAX_AX_DEPTH = 12 + + +@dataclass +class SoMElement: + """A single interactive UI element detected via Accessibility.""" + + label: int # display number [1], [2], ... + role: str # AX role (e.g. "AXButton") + title: str # AX title/description + x: int # screen x (px) + y: int # screen y (px) + width: int # element width (px) + height: int # element height (px) + screen_width: int # total screen width (px) + screen_height: int # total screen height (px) + + # ------------------------------------------------------------------ + # helpers + # ------------------------------------------------------------------ + + def center_pixel(self) -> tuple[int, int]: + """Return centre in screen pixels.""" + return (self.x + self.width // 2, self.y + self.height // 2) + + def center_normalized(self) -> tuple[int, int]: + """Return centre in 0-1000 normalised coords.""" + cx, cy = self.center_pixel() + nx = int(cx * 1000 / self.screen_width) if self.screen_width else 0 + ny = int(cy * 1000 / self.screen_height) if self.screen_height else 0 + return (nx, ny) + + def is_too_small(self, min_size: int = 15) -> bool: + """True when either dimension is below *min_size*.""" + return self.width < min_size or self.height < min_size + + def contains(self, other: SoMElement) -> bool: + """True if *self* fully encloses *other* and has strictly larger area.""" + if self.width * self.height <= other.width * other.height: + return False + return ( + self.x <= other.x + and self.y <= other.y + and self.x + self.width >= other.x + other.width + and self.y + self.height >= other.y + other.height + ) + + @property + def short_role(self) -> str: + """Role name without the ``AX`` prefix.""" + return self.role[2:] if self.role.startswith("AX") else self.role + + +# ------------------------------------------------------------------ +# AX tree walking +# ------------------------------------------------------------------ + +def _walk_ax_tree( + element, + results: list[dict], + depth: int, + AXUIElementCopyAttributeValue, + AXValueGetValue, + kAXValueTypeCGPoint, + kAXValueTypeCGSize, +) -> None: + """Recursively collect interactive elements from the AX tree.""" + if depth > MAX_AX_DEPTH: + return + + try: + err, role = AXUIElementCopyAttributeValue(element, "AXRole", None) + if err != 0 or role is None: + role = "" + except Exception: + role = "" + + if role in INTERACTIVE_ROLES: + try: + err_pos, pos_val = AXUIElementCopyAttributeValue( + element, "AXPosition", None + ) + err_size, size_val = AXUIElementCopyAttributeValue( + element, "AXSize", None + ) + if err_pos == 0 and err_size == 0 and pos_val and size_val: + _, point = AXValueGetValue(pos_val, kAXValueTypeCGPoint, None) + _, size = AXValueGetValue(size_val, kAXValueTypeCGSize, None) + + # title / description + title = "" + try: + err_t, t = AXUIElementCopyAttributeValue( + element, "AXTitle", None + ) + if err_t == 0 and t: + title = str(t) + except Exception: + pass + if not title: + try: + err_d, d = AXUIElementCopyAttributeValue( + element, "AXDescription", None + ) + if err_d == 0 and d: + title = str(d) + except Exception: + pass + + # Skip untitled AXStaticText — decorative, not interactive + if role == "AXStaticText" and not title: + pass + else: + results.append({ + "role": str(role), + "title": title, + "x": int(point.x), + "y": int(point.y), + "width": int(size.width), + "height": int(size.height), + }) + except Exception: + pass + + # recurse into children + try: + err, children = AXUIElementCopyAttributeValue( + element, "AXChildren", None + ) + if err == 0 and children: + for child in children: + _walk_ax_tree( + child, + results, + depth + 1, + AXUIElementCopyAttributeValue, + AXValueGetValue, + kAXValueTypeCGPoint, + kAXValueTypeCGSize, + ) + except Exception: + pass + + +def _walk_pid( + pid: int, + raw: list[dict], + AXUIElementCreateApplication, + AXUIElementCopyAttributeValue, + AXValueGetValue, + kAXValueTypeCGPoint, + kAXValueTypeCGSize, +) -> None: + """Walk one process's AX tree and append results to *raw*.""" + try: + app_element = AXUIElementCreateApplication(pid) + _walk_ax_tree( + app_element, raw, 0, + AXUIElementCopyAttributeValue, AXValueGetValue, + kAXValueTypeCGPoint, kAXValueTypeCGSize, + ) + except Exception: + pass + + +# System processes whose menu-bar items we always want to include. +_SYSTEM_UI_BUNDLE_IDS = frozenset({ + "com.apple.systemuiserver", # legacy menu extras + "com.apple.controlcenter", # WiFi, Battery, BT, etc. + "com.apple.Spotlight", # Spotlight icon + "com.apple.notificationcenterui", # Notification Center +}) + + +def detect_elements( + screen_width: int, screen_height: int +) -> list[SoMElement]: + """Detect interactive UI elements on screen via macOS Accessibility API. + + Walks the frontmost application AND system-UI processes so that + menu-bar items (WiFi, Battery, Clock, etc.) are always included. + + Returns elements with ``label=0``; the caller should assign labels + after filtering. + """ + try: + from ApplicationServices import ( + AXUIElementCreateApplication, + AXUIElementCopyAttributeValue, + AXValueGetValue, + kAXValueTypeCGPoint, + kAXValueTypeCGSize, + ) + from AppKit import NSWorkspace + except ImportError: + logger.warning("macOS AX frameworks not available") + return [] + + raw: list[dict] = [] + + # 1. Frontmost app + try: + frontmost = NSWorkspace.sharedWorkspace().frontmostApplication() + frontmost_pid = frontmost.processIdentifier() + _walk_pid( + frontmost_pid, raw, + AXUIElementCreateApplication, AXUIElementCopyAttributeValue, + AXValueGetValue, kAXValueTypeCGPoint, kAXValueTypeCGSize, + ) + except Exception: + logger.exception("Failed to get frontmost application AX element") + + # 2. System-UI processes (menu bar extras) + try: + for app in NSWorkspace.sharedWorkspace().runningApplications(): + bid = app.bundleIdentifier() + if bid and bid in _SYSTEM_UI_BUNDLE_IDS: + pid = app.processIdentifier() + if pid != frontmost_pid: + _walk_pid( + pid, raw, + AXUIElementCreateApplication, + AXUIElementCopyAttributeValue, + AXValueGetValue, + kAXValueTypeCGPoint, kAXValueTypeCGSize, + ) + except Exception: + logger.warning("Failed to walk system-UI processes", exc_info=True) + + elements: list[SoMElement] = [] + for r in raw: + elements.append( + SoMElement( + label=0, + role=r["role"], + title=r["title"], + x=r["x"], + y=r["y"], + width=r["width"], + height=r["height"], + screen_width=screen_width, + screen_height=screen_height, + ) + ) + + logger.info("AX tree yielded %d interactive elements", len(elements)) + return elements + + +# ------------------------------------------------------------------ +# filtering +# ------------------------------------------------------------------ + +def filter_leaf_elements( + elements: list[SoMElement], + max_elements: int = 80, + min_size: int = 15, +) -> list[SoMElement]: + """Keep only visible leaf elements, sorted reading-order, capped & labelled.""" + raw_count = len(elements) + + # 1. remove too-small + sized = [e for e in elements if not e.is_too_small(min_size)] + + # 2. remove off-screen (outside viewport) + on_screen = [ + e for e in sized + if (e.x + e.width > 0 and e.y + e.height > 0 + and e.x < e.screen_width and e.y < e.screen_height) + ] + + # 3. deduplicate by (role, title, x, y, width, height) + seen: set[tuple] = set() + deduped: list[SoMElement] = [] + for e in on_screen: + key = (e.role, e.title, e.x, e.y, e.width, e.height) + if key not in seen: + seen.add(key) + deduped.append(e) + + # 4. remove parents (elements that contain another element) + parent_indices: set[int] = set() + for i, a in enumerate(deduped): + for j, b in enumerate(deduped): + if i != j and a.contains(b): + parent_indices.add(i) + break # a is a parent, no need to check more + leaves = [e for i, e in enumerate(deduped) if i not in parent_indices] + + # 5. sort by y then x (reading order) + leaves.sort(key=lambda e: (e.y, e.x)) + + # 6. cap + leaves = leaves[:max_elements] + + # 7. assign labels 1..N + for idx, elem in enumerate(leaves, start=1): + elem.label = idx + + logger.info( + "SoM filter: %d raw → %d sized → %d on-screen → %d deduped → %d leaves → %d capped", + raw_count, len(sized), len(on_screen), len(deduped), len(deduped) - len(parent_indices), len(leaves), + ) + return leaves + + +# ------------------------------------------------------------------ +# annotation +# ------------------------------------------------------------------ + +def annotate_screenshot( + image: Image.Image, elements: list[SoMElement] +) -> Image.Image: + """Return a copy of *image* with numbered badges on each element.""" + img = image.copy().convert("RGBA") + overlay = Image.new("RGBA", img.size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(overlay) + + # font + try: + font = ImageFont.truetype("/System/Library/Fonts/Menlo.ttc", 11) + except Exception: + font = ImageFont.load_default() + + badge_h = 16 + badge_color = (220, 40, 40, 200) + text_color = (255, 255, 255, 255) + rect_color = (220, 40, 40, 180) + + for elem in elements: + x0, y0 = elem.x, elem.y + x1, y1 = x0 + elem.width, y0 + elem.height + + # bounding rectangle + draw.rectangle([x0, y0, x1, y1], outline=rect_color, width=1) + + # badge + label_text = str(elem.label) + bbox = font.getbbox(label_text) + tw = bbox[2] - bbox[0] + badge_w = tw + 6 + bx0, by0 = x0, y0 + bx1, by1 = bx0 + badge_w, by0 + badge_h + draw.rectangle([bx0, by0, bx1, by1], fill=badge_color) + draw.text( + (bx0 + 3, by0 + 1), label_text, fill=text_color, font=font + ) + + img = Image.alpha_composite(img, overlay) + return img.convert("RGB") + + +# ------------------------------------------------------------------ +# prompt formatting +# ------------------------------------------------------------------ + +def format_element_list(elements: list[SoMElement]) -> str: + """Format elements for LLM prompt injection. + + Example line:: + + [1] Button "Settings" (450,30)-(520,55) + """ + lines: list[str] = [] + for e in elements: + nx0 = int(e.x * 1000 / e.screen_width) if e.screen_width else 0 + ny0 = int(e.y * 1000 / e.screen_height) if e.screen_height else 0 + nx1 = int((e.x + e.width) * 1000 / e.screen_width) if e.screen_width else 0 + ny1 = int((e.y + e.height) * 1000 / e.screen_height) if e.screen_height else 0 + + role = e.short_role + title_part = f' "{e.title}"' if e.title else "" + lines.append(f"[{e.label}] {role}{title_part} ({nx0},{ny0})-({nx1},{ny1})") + return "\n".join(lines) diff --git a/backend/requirements.txt b/backend/requirements.txt index 32d9c73..5ef4786 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -11,6 +11,7 @@ pyautogui>=0.9.54 pillow>=10.2.0 pyobjc-framework-Quartz>=10.1 # macOS screen capture pyobjc-framework-Vision>=10.1 # macOS Vision OCR +pyobjc-framework-ApplicationServices>=10.1 # macOS Accessibility API (AXUIElement) # AI/LLM clients google-genai>=1.0.0 diff --git a/backend/tests/unit/services/llm/test_json_retry.py b/backend/tests/unit/services/llm/test_json_retry.py new file mode 100644 index 0000000..47caca4 --- /dev/null +++ b/backend/tests/unit/services/llm/test_json_retry.py @@ -0,0 +1,150 @@ +"""Tests for JSON retry-on-parse-failure logic.""" + +import json + +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from app.services.llm.gemini_provider import GeminiProvider + + +@pytest.fixture +def provider(): + """Create a GeminiProvider with mocked externals.""" + with ( + patch("app.services.llm.gemini_provider.get_gemini_api_key", return_value="fake-key"), + patch("app.services.llm.gemini_provider.genai") as mock_genai, + patch("app.services.llm.base.screen_service") as mock_screen, + ): + mock_genai.Client.return_value = MagicMock() + mock_screen.get_screen_size.return_value = (1728, 1117) + p = GeminiProvider() + return p + + +def _mock_response(text: str): + """Build a mock Gemini response object.""" + resp = MagicMock() + resp.text = text + resp.usage_metadata = None + return resp + + +VALID_JSON = json.dumps({ + "reasoning": "I see the button", + "action": "click", + "params": {"x": 100, "y": 200}, + "needs_confirmation": False, +}) + + +class TestRetryParseWithCorrection: + """Tests for _retry_parse_with_correction.""" + + @pytest.mark.asyncio + async def test_recovers_from_malformed_json(self, provider): + provider.client = MagicMock() + provider.client.aio.models.generate_content = AsyncMock( + return_value=_mock_response(VALID_JSON), + ) + + result = await provider._retry_parse_with_correction('{"reasoning": "trunca') + assert result is not None + assert result.action == "click" + assert result.action_params == {"x": 100, "y": 200} + + @pytest.mark.asyncio + async def test_returns_none_after_max_retries(self, provider): + provider.client = MagicMock() + provider.client.aio.models.generate_content = AsyncMock( + return_value=_mock_response("still not json!!!"), + ) + + result = await provider._retry_parse_with_correction('{"bad": ') + assert result is None + assert provider.client.aio.models.generate_content.call_count == GeminiProvider._MAX_PARSE_RETRIES + + @pytest.mark.asyncio + async def test_returns_none_on_empty_responses(self, provider): + resp = MagicMock() + resp.text = None + resp.usage_metadata = None + + provider.client = MagicMock() + provider.client.aio.models.generate_content = AsyncMock(return_value=resp) + + result = await provider._retry_parse_with_correction('{"bad": ') + assert result is None + + @pytest.mark.asyncio + async def test_succeeds_on_second_attempt(self, provider): + bad_resp = _mock_response("not json") + good_resp = _mock_response(VALID_JSON) + + provider.client = MagicMock() + provider.client.aio.models.generate_content = AsyncMock( + side_effect=[bad_resp, good_resp], + ) + + result = await provider._retry_parse_with_correction('{"bad": ') + assert result is not None + assert result.action == "click" + assert provider.client.aio.models.generate_content.call_count == 2 + + +class TestAnalyzeScreenRetryIntegration: + """Tests that analyze_screen triggers retry on parse failure.""" + + def _mock_token_usage(self): + return MagicMock(prompt_tokens=0, candidates_tokens=0, thoughts_tokens=0, total_tokens=0) + + @pytest.mark.asyncio + async def test_valid_json_does_not_trigger_retry(self, provider): + provider._build_contents = MagicMock(return_value=[]) + provider._call_with_retries = AsyncMock( + return_value=(VALID_JSON, self._mock_token_usage()), + ) + provider._retry_parse_with_correction = AsyncMock() + + with patch("app.services.llm.gemini_provider.get_all_memories_for_prompt", return_value=None): + result = await provider.analyze_screen("click the button", "base64img", []) + + provider._retry_parse_with_correction.assert_not_called() + assert result.action == "click" + + @pytest.mark.asyncio + async def test_parse_failure_triggers_retry_and_uses_result(self, provider): + malformed = '{"reasoning": "trunca' + + provider._build_contents = MagicMock(return_value=[]) + provider._call_with_retries = AsyncMock( + return_value=(malformed, self._mock_token_usage()), + ) + + retried_response = MagicMock() + retried_response.reasoning = "recovered" + retried_response.action = "click" + retried_response.token_usage = None + provider._retry_parse_with_correction = AsyncMock(return_value=retried_response) + + with patch("app.services.llm.gemini_provider.get_all_memories_for_prompt", return_value=None): + result = await provider.analyze_screen("click the button", "base64img", []) + + provider._retry_parse_with_correction.assert_called_once_with(malformed) + assert result.action == "click" + + @pytest.mark.asyncio + async def test_parse_failure_retry_fails_returns_fallback(self, provider): + malformed = '{"reasoning": "trunca' + + provider._build_contents = MagicMock(return_value=[]) + provider._call_with_retries = AsyncMock( + return_value=(malformed, self._mock_token_usage()), + ) + provider._retry_parse_with_correction = AsyncMock(return_value=None) + + with patch("app.services.llm.gemini_provider.get_all_memories_for_prompt", return_value=None): + result = await provider.analyze_screen("click the button", "base64img", []) + + assert result.action == "wait" + assert "Failed to parse" in result.reasoning diff --git a/backend/tests/unit/services/llm/test_prompt_builder.py b/backend/tests/unit/services/llm/test_prompt_builder.py new file mode 100644 index 0000000..e9b4872 --- /dev/null +++ b/backend/tests/unit/services/llm/test_prompt_builder.py @@ -0,0 +1,22 @@ +"""Tests for prompt builder SoM section.""" + +from app.services.llm.prompt_builder import PromptBuilder + + +class TestSoMPromptSection: + def test_element_list_injected_when_provided(self): + pb = PromptBuilder(screen_width=1728, screen_height=1117) + prompt = pb.build(element_list_text='[1] Button "OK" (100,200)-(160,230)') + assert "LABELED ELEMENTS" in prompt + assert '[1] Button "OK"' in prompt + assert 'click({"element": N})' in prompt + + def test_no_element_section_when_none(self): + pb = PromptBuilder(screen_width=1728, screen_height=1117) + prompt = pb.build() + assert "LABELED ELEMENTS" not in prompt + + def test_no_element_section_when_empty_string(self): + pb = PromptBuilder(screen_width=1728, screen_height=1117) + prompt = pb.build(element_list_text="") + assert "LABELED ELEMENTS" not in prompt diff --git a/backend/tests/unit/services/test_action_executor.py b/backend/tests/unit/services/test_action_executor.py index 04110c3..b68a81a 100644 --- a/backend/tests/unit/services/test_action_executor.py +++ b/backend/tests/unit/services/test_action_executor.py @@ -5,6 +5,7 @@ from app.services.action_executor import ActionExecutor from app.services.llm.base import LLMResponse +from app.services.som import SoMElement SCREEN_W = 1728 SCREEN_H = 1117 @@ -224,3 +225,87 @@ async def test_action_exception_caught(self, executor, mock_computer): result = await executor.execute(resp) assert "failed" in result assert "Mouse broken" in result + + +# --------------------------------------------------------------------------- +# SoM element resolution +# --------------------------------------------------------------------------- + +def _make_element(label: int = 1, role: str = "AXButton", title: str = "OK", + x: int = 100, y: int = 200, width: int = 80, height: int = 30) -> SoMElement: + return SoMElement( + label=label, role=role, title=title, + x=x, y=y, width=width, height=height, + screen_width=SCREEN_W, screen_height=SCREEN_H, + ) + + +class TestElementResolution: + """Tests for resolving SoM element labels to coordinates.""" + + @pytest.mark.asyncio + async def test_click_element_resolves_to_center(self, executor, mock_computer): + elem = _make_element(label=3, x=100, y=200, width=80, height=30) + executor.set_elements([elem]) + resp = _make_response(action="click", action_params={"element": 3}) + result = await executor.execute(resp) + # center: (100+40, 200+15) = (140, 215) + mock_computer.click.assert_called_once_with(140, 215, modifiers=None) + assert "element #3" in result + assert "(140, 215)" in result + + @pytest.mark.asyncio + async def test_click_xy_still_works(self, executor, mock_computer): + executor.set_elements([_make_element()]) + resp = _make_response(action="click", action_params={"x": 500, "y": 500}) + result = await executor.execute(resp) + mock_computer.click.assert_called_once() + assert "Clicked at" in result + + @pytest.mark.asyncio + async def test_click_unknown_element_returns_error(self, executor, mock_computer): + executor.set_elements([_make_element(label=1)]) + resp = _make_response(action="click", action_params={"element": 99}) + result = await executor.execute(resp) + assert "not found" in result + mock_computer.click.assert_not_called() + + @pytest.mark.asyncio + async def test_double_click_element(self, executor, mock_computer): + elem = _make_element(label=2, x=50, y=60, width=100, height=40) + executor.set_elements([elem]) + resp = _make_response(action="double_click", action_params={"element": 2}) + result = await executor.execute(resp) + # center: (50+50, 60+20) = (100, 80) + mock_computer.double_click.assert_called_once_with(100, 80) + assert "Double-clicked element #2" in result + + @pytest.mark.asyncio + async def test_right_click_element(self, executor, mock_computer): + elem = _make_element(label=5, x=200, y=300, width=60, height=20) + executor.set_elements([elem]) + resp = _make_response(action="right_click", action_params={"element": 5}) + result = await executor.execute(resp) + # center: (200+30, 300+10) = (230, 310) + mock_computer.right_click.assert_called_once_with(230, 310) + assert "Right-clicked element #5" in result + + @pytest.mark.asyncio + async def test_click_element_with_modifiers(self, executor, mock_computer): + elem = _make_element(label=1, x=100, y=200, width=80, height=30) + executor.set_elements([elem]) + resp = _make_response( + action="click", + action_params={"element": 1, "modifiers": ["command"]}, + ) + result = await executor.execute(resp) + mock_computer.click.assert_called_once_with(140, 215, modifiers=["command"]) + assert "element #1" in result + + @pytest.mark.asyncio + async def test_click_element_without_title(self, executor, mock_computer): + elem = _make_element(label=1, title="") + executor.set_elements([elem]) + resp = _make_response(action="click", action_params={"element": 1}) + result = await executor.execute(resp) + assert '"' not in result.split("(Button")[0] if "Button" in result else True diff --git a/backend/tests/unit/services/test_som.py b/backend/tests/unit/services/test_som.py new file mode 100644 index 0000000..72eebda --- /dev/null +++ b/backend/tests/unit/services/test_som.py @@ -0,0 +1,174 @@ +"""Tests for SoM element detection helpers (pure functions only).""" + +import pytest +from PIL import Image + +from app.services.som import ( + SoMElement, + filter_leaf_elements, + annotate_screenshot, + format_element_list, +) + +SW, SH = 2000, 1000 # screen dims used across tests + + +def _elem( + label=0, role="AXButton", title="", x=0, y=0, w=100, h=50, +) -> SoMElement: + return SoMElement( + label=label, role=role, title=title, + x=x, y=y, width=w, height=h, + screen_width=SW, screen_height=SH, + ) + + +# ------------------------------------------------------------------ +# TestSoMElement +# ------------------------------------------------------------------ + +class TestSoMElement: + + def test_center_pixel(self): + e = _elem(x=100, y=200, w=60, h=40) + assert e.center_pixel() == (130, 220) + + def test_center_normalized(self): + # center pixel = (130, 220); norm = (130*1000/2000, 220*1000/1000) + e = _elem(x=100, y=200, w=60, h=40) + nx, ny = e.center_normalized() + assert nx == 65 + assert ny == 220 + + def test_is_too_small_true(self): + e = _elem(w=10, h=50) + assert e.is_too_small() is True + + def test_is_too_small_false(self): + e = _elem(w=20, h=20) + assert e.is_too_small() is False + + def test_is_too_small_custom(self): + e = _elem(w=20, h=20) + assert e.is_too_small(min_size=25) is True + + def test_contains_true(self): + parent = _elem(x=0, y=0, w=200, h=200) + child = _elem(x=10, y=10, w=50, h=50) + assert parent.contains(child) is True + + def test_contains_false_no_overlap(self): + a = _elem(x=0, y=0, w=50, h=50) + b = _elem(x=100, y=100, w=50, h=50) + assert a.contains(b) is False + + def test_contains_false_same_area(self): + a = _elem(x=0, y=0, w=100, h=100) + b = _elem(x=0, y=0, w=100, h=100) + assert a.contains(b) is False # equal area → False + + def test_short_role(self): + assert _elem(role="AXButton").short_role == "Button" + assert _elem(role="AXTextField").short_role == "TextField" + assert _elem(role="PlainRole").short_role == "PlainRole" + + +# ------------------------------------------------------------------ +# TestFilterLeafElements +# ------------------------------------------------------------------ + +class TestFilterLeafElements: + + def test_parents_removed(self): + parent = _elem(x=0, y=0, w=300, h=300) + child = _elem(x=10, y=10, w=50, h=50) + result = filter_leaf_elements([parent, child]) + assert len(result) == 1 + assert result[0].width == 50 # child kept + + def test_non_overlapping_kept(self): + a = _elem(x=0, y=0, w=80, h=40) + b = _elem(x=200, y=0, w=80, h=40) + result = filter_leaf_elements([a, b]) + assert len(result) == 2 + + def test_labels_assigned_reading_order(self): + # b is above a → b should be label 1 + a = _elem(x=0, y=200, w=80, h=40) + b = _elem(x=0, y=50, w=80, h=40) + result = filter_leaf_elements([a, b]) + assert result[0].label == 1 + assert result[0].y == 50 # b first + assert result[1].label == 2 + assert result[1].y == 200 + + def test_too_small_removed(self): + small = _elem(w=5, h=5) + big = _elem(x=200, y=200, w=80, h=40) + result = filter_leaf_elements([small, big]) + assert len(result) == 1 + + def test_max_elements_cap(self): + elems = [_elem(x=i * 100, y=0, w=80, h=40) for i in range(10)] + result = filter_leaf_elements(elems, max_elements=3) + assert len(result) == 3 + assert result[-1].label == 3 + + +# ------------------------------------------------------------------ +# TestAnnotateScreenshot +# ------------------------------------------------------------------ + +class TestAnnotateScreenshot: + + def test_returns_same_size(self): + img = Image.new("RGB", (800, 600), (255, 255, 255)) + elem = _elem(label=1, x=10, y=10, w=100, h=50) + result = annotate_screenshot(img, [elem]) + assert result.size == img.size + + def test_returns_rgb(self): + img = Image.new("RGB", (800, 600)) + result = annotate_screenshot(img, [_elem(label=1, x=0, y=0, w=50, h=50)]) + assert result.mode == "RGB" + + def test_empty_elements(self): + img = Image.new("RGB", (400, 300)) + result = annotate_screenshot(img, []) + assert result.size == (400, 300) + + +# ------------------------------------------------------------------ +# TestFormatElementList +# ------------------------------------------------------------------ + +class TestFormatElementList: + + def test_basic_format(self): + e = _elem(label=1, role="AXButton", title="Settings", x=900, y=30, w=140, h=25) + out = format_element_list([e]) + assert out.startswith("[1] Button") + assert '"Settings"' in out + + def test_no_title(self): + e = _elem(label=2, role="AXLink", title="", x=0, y=0, w=100, h=50) + out = format_element_list([e]) + assert "[2] Link " in out + assert '""' not in out + + def test_coords_normalized(self): + # x=1000 on 2000px screen → 500 norm, y=500 on 1000px → 500 norm + e = _elem(label=1, role="AXButton", title="", x=1000, y=500, w=200, h=100) + out = format_element_list([e]) + assert "(500,500)" in out + + def test_multi_elements(self): + elems = [ + _elem(label=1, role="AXButton", title="A", x=0, y=0, w=100, h=50), + _elem(label=2, role="AXTextField", title="B", x=200, y=100, w=400, h=50), + ] + out = format_element_list(elems) + lines = out.strip().split("\n") + assert len(lines) == 2 + assert lines[0].startswith("[1]") + assert lines[1].startswith("[2]") diff --git a/backend/tests/unit/test_config.py b/backend/tests/unit/test_config.py new file mode 100644 index 0000000..f858271 --- /dev/null +++ b/backend/tests/unit/test_config.py @@ -0,0 +1,11 @@ +"""Tests for config defaults.""" + +from app.config import Settings + + +class TestSoMConfig: + def test_som_defaults(self): + s = Settings(gemini_api_key="fake") + assert s.som_enabled is True + assert s.som_max_elements == 80 + assert s.som_min_element_size == 15