From ad6f7ffc073db43212146832179f4791a827c936 Mon Sep 17 00:00:00 2001 From: Evan Takahashi Date: Tue, 3 Mar 2026 12:17:47 -0800 Subject: [PATCH 01/11] chore: add pyobjc-framework-ApplicationServices for AX element detection --- backend/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/backend/requirements.txt b/backend/requirements.txt index 32d9c73..5ef4786 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -11,6 +11,7 @@ pyautogui>=0.9.54 pillow>=10.2.0 pyobjc-framework-Quartz>=10.1 # macOS screen capture pyobjc-framework-Vision>=10.1 # macOS Vision OCR +pyobjc-framework-ApplicationServices>=10.1 # macOS Accessibility API (AXUIElement) # AI/LLM clients google-genai>=1.0.0 From 6074e40ced3f6d30dcbf19df1af1d00f68e9e41b Mon Sep 17 00:00:00 2001 From: Evan Takahashi Date: Tue, 3 Mar 2026 12:17:50 -0800 Subject: [PATCH 02/11] feat: add SoM config settings --- backend/app/config.py | 5 +++++ backend/tests/unit/test_config.py | 11 +++++++++++ 2 files changed, 16 insertions(+) create mode 100644 backend/tests/unit/test_config.py diff --git a/backend/app/config.py b/backend/app/config.py index 9633d46..2484e36 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -49,6 +49,11 @@ class Settings(BaseSettings): shake_min_distance: float = 15 # Min movement distance in pixels (reduced from 20) shake_check_interval: float = 0.02 # Check interval in seconds (faster sampling) + # Set-of-Mark visual grounding + som_enabled: bool = True + som_max_elements: int = 40 # Max labeled elements per screenshot + som_min_element_size: int = 15 # Skip elements smaller than this (px) + class Config: env_file = ".env" env_file_encoding = "utf-8" diff --git a/backend/tests/unit/test_config.py b/backend/tests/unit/test_config.py new file mode 100644 index 0000000..13b8f5b --- /dev/null +++ b/backend/tests/unit/test_config.py @@ -0,0 +1,11 @@ +"""Tests for config defaults.""" + +from app.config import Settings + + +class TestSoMConfig: + def test_som_defaults(self): + s = Settings(gemini_api_key="fake") + assert s.som_enabled is True + assert s.som_max_elements == 40 + assert s.som_min_element_size == 15 From d1c0642b7d62ff16b0aefcef8ddcaca231acf59c Mon Sep 17 00:00:00 2001 From: Evan Takahashi Date: Tue, 3 Mar 2026 12:18:56 -0800 Subject: [PATCH 03/11] feat: retry LLM on JSON parse failure instead of defaulting to wait --- backend/app/services/llm/gemini_provider.py | 73 +++++++++ .../unit/services/llm/test_json_retry.py | 150 ++++++++++++++++++ 2 files changed, 223 insertions(+) create mode 100644 backend/tests/unit/services/llm/test_json_retry.py diff --git a/backend/app/services/llm/gemini_provider.py b/backend/app/services/llm/gemini_provider.py index 9bbf110..460ffcd 100644 --- a/backend/app/services/llm/gemini_provider.py +++ b/backend/app/services/llm/gemini_provider.py @@ -80,6 +80,8 @@ def _append_full_steps( class GeminiProvider(BaseLLMProvider): """Google Gemini provider for vision-based computer control.""" + _MAX_PARSE_RETRIES = 2 + def __init__(self, model: str = None): super().__init__() self._model = model or settings.gemini_model @@ -133,6 +135,14 @@ async def analyze_screen( logger.debug("Raw response: %s", raw_response) llm_response = self._parse_response(raw_response) + + if llm_response.reasoning.startswith("Failed to parse JSON"): + logger.info("Parse failed, attempting JSON correction retry") + retried = await self._retry_parse_with_correction(raw_response) + if retried is not None: + retried.token_usage = token_usage + return retried + llm_response.token_usage = token_usage return llm_response @@ -362,6 +372,69 @@ def _parse_response(self, raw_response: str) -> LLMResponse: raw_response=raw_response, ) + async def _retry_parse_with_correction(self, malformed_response: str) -> Optional[LLMResponse]: + """Re-prompt the LLM to fix a malformed JSON response. + + Returns a parsed LLMResponse on success, or None after all retries fail. + """ + truncated = malformed_response[:1000] + correction_prompt = ( + "Your previous response was not valid JSON. Here is what you returned: " + f"{truncated}. Return ONLY a valid JSON object with the required fields: " + "reasoning, action, params, needs_confirmation. Nothing else." + ) + + for attempt in range(self._MAX_PARSE_RETRIES): + try: + response = await self.client.aio.models.generate_content( + model=self._model, + contents=[types.Content( + role="user", + parts=[types.Part.from_text(text=correction_prompt)], + )], + config=types.GenerateContentConfig( + max_output_tokens=2048, + temperature=0.0, + ), + ) + + raw_text = response.text if hasattr(response, "text") else None + if raw_text is None: + logger.warning("JSON retry attempt %d/%d: empty response", attempt + 1, self._MAX_PARSE_RETRIES) + continue + + content = _strip_markdown_codeblock(raw_text) + data = json.loads(content) + + action = data.get("action", "") + is_complete = action == "done" + needs_clarification = data.get("needs_clarification", False) + + logger.info("JSON retry attempt %d/%d succeeded", attempt + 1, self._MAX_PARSE_RETRIES) + return LLMResponse( + reasoning=data.get("reasoning", ""), + action=action if not is_complete and not needs_clarification else None, + action_params=data.get("params", {}), + is_complete=is_complete, + needs_confirmation=data.get("needs_confirmation", False), + needs_human_help=data.get("needs_human_help", False), + human_help_reason=data.get("human_help_reason"), + needs_clarification=needs_clarification, + clarification_question=data.get("clarification_question"), + clarification_options=data.get("clarification_options"), + expected_outcome=data.get("expected_outcome"), + notes=data.get("notes"), + raw_response=raw_text, + ) + except (json.JSONDecodeError, Exception) as e: + logger.warning( + "JSON retry attempt %d/%d failed: %s", + attempt + 1, self._MAX_PARSE_RETRIES, e, + ) + + logger.warning("All %d JSON retry attempts failed", self._MAX_PARSE_RETRIES) + return None + async def extract_and_merge_memories( self, instruction: str, diff --git a/backend/tests/unit/services/llm/test_json_retry.py b/backend/tests/unit/services/llm/test_json_retry.py new file mode 100644 index 0000000..47caca4 --- /dev/null +++ b/backend/tests/unit/services/llm/test_json_retry.py @@ -0,0 +1,150 @@ +"""Tests for JSON retry-on-parse-failure logic.""" + +import json + +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from app.services.llm.gemini_provider import GeminiProvider + + +@pytest.fixture +def provider(): + """Create a GeminiProvider with mocked externals.""" + with ( + patch("app.services.llm.gemini_provider.get_gemini_api_key", return_value="fake-key"), + patch("app.services.llm.gemini_provider.genai") as mock_genai, + patch("app.services.llm.base.screen_service") as mock_screen, + ): + mock_genai.Client.return_value = MagicMock() + mock_screen.get_screen_size.return_value = (1728, 1117) + p = GeminiProvider() + return p + + +def _mock_response(text: str): + """Build a mock Gemini response object.""" + resp = MagicMock() + resp.text = text + resp.usage_metadata = None + return resp + + +VALID_JSON = json.dumps({ + "reasoning": "I see the button", + "action": "click", + "params": {"x": 100, "y": 200}, + "needs_confirmation": False, +}) + + +class TestRetryParseWithCorrection: + """Tests for _retry_parse_with_correction.""" + + @pytest.mark.asyncio + async def test_recovers_from_malformed_json(self, provider): + provider.client = MagicMock() + provider.client.aio.models.generate_content = AsyncMock( + return_value=_mock_response(VALID_JSON), + ) + + result = await provider._retry_parse_with_correction('{"reasoning": "trunca') + assert result is not None + assert result.action == "click" + assert result.action_params == {"x": 100, "y": 200} + + @pytest.mark.asyncio + async def test_returns_none_after_max_retries(self, provider): + provider.client = MagicMock() + provider.client.aio.models.generate_content = AsyncMock( + return_value=_mock_response("still not json!!!"), + ) + + result = await provider._retry_parse_with_correction('{"bad": ') + assert result is None + assert provider.client.aio.models.generate_content.call_count == GeminiProvider._MAX_PARSE_RETRIES + + @pytest.mark.asyncio + async def test_returns_none_on_empty_responses(self, provider): + resp = MagicMock() + resp.text = None + resp.usage_metadata = None + + provider.client = MagicMock() + provider.client.aio.models.generate_content = AsyncMock(return_value=resp) + + result = await provider._retry_parse_with_correction('{"bad": ') + assert result is None + + @pytest.mark.asyncio + async def test_succeeds_on_second_attempt(self, provider): + bad_resp = _mock_response("not json") + good_resp = _mock_response(VALID_JSON) + + provider.client = MagicMock() + provider.client.aio.models.generate_content = AsyncMock( + side_effect=[bad_resp, good_resp], + ) + + result = await provider._retry_parse_with_correction('{"bad": ') + assert result is not None + assert result.action == "click" + assert provider.client.aio.models.generate_content.call_count == 2 + + +class TestAnalyzeScreenRetryIntegration: + """Tests that analyze_screen triggers retry on parse failure.""" + + def _mock_token_usage(self): + return MagicMock(prompt_tokens=0, candidates_tokens=0, thoughts_tokens=0, total_tokens=0) + + @pytest.mark.asyncio + async def test_valid_json_does_not_trigger_retry(self, provider): + provider._build_contents = MagicMock(return_value=[]) + provider._call_with_retries = AsyncMock( + return_value=(VALID_JSON, self._mock_token_usage()), + ) + provider._retry_parse_with_correction = AsyncMock() + + with patch("app.services.llm.gemini_provider.get_all_memories_for_prompt", return_value=None): + result = await provider.analyze_screen("click the button", "base64img", []) + + provider._retry_parse_with_correction.assert_not_called() + assert result.action == "click" + + @pytest.mark.asyncio + async def test_parse_failure_triggers_retry_and_uses_result(self, provider): + malformed = '{"reasoning": "trunca' + + provider._build_contents = MagicMock(return_value=[]) + provider._call_with_retries = AsyncMock( + return_value=(malformed, self._mock_token_usage()), + ) + + retried_response = MagicMock() + retried_response.reasoning = "recovered" + retried_response.action = "click" + retried_response.token_usage = None + provider._retry_parse_with_correction = AsyncMock(return_value=retried_response) + + with patch("app.services.llm.gemini_provider.get_all_memories_for_prompt", return_value=None): + result = await provider.analyze_screen("click the button", "base64img", []) + + provider._retry_parse_with_correction.assert_called_once_with(malformed) + assert result.action == "click" + + @pytest.mark.asyncio + async def test_parse_failure_retry_fails_returns_fallback(self, provider): + malformed = '{"reasoning": "trunca' + + provider._build_contents = MagicMock(return_value=[]) + provider._call_with_retries = AsyncMock( + return_value=(malformed, self._mock_token_usage()), + ) + provider._retry_parse_with_correction = AsyncMock(return_value=None) + + with patch("app.services.llm.gemini_provider.get_all_memories_for_prompt", return_value=None): + result = await provider.analyze_screen("click the button", "base64img", []) + + assert result.action == "wait" + assert "Failed to parse" in result.reasoning From 9fd4d85648d8f7779ce76235805e1946f7a1a5f8 Mon Sep 17 00:00:00 2001 From: Evan Takahashi Date: Tue, 3 Mar 2026 12:21:09 -0800 Subject: [PATCH 04/11] feat: SoM element detection, leaf filtering, screenshot annotation --- backend/app/services/som.py | 338 ++++++++++++++++++++++++ backend/tests/unit/services/test_som.py | 174 ++++++++++++ 2 files changed, 512 insertions(+) create mode 100644 backend/app/services/som.py create mode 100644 backend/tests/unit/services/test_som.py diff --git a/backend/app/services/som.py b/backend/app/services/som.py new file mode 100644 index 0000000..05d2d05 --- /dev/null +++ b/backend/app/services/som.py @@ -0,0 +1,338 @@ +"""Set-of-Mark element detection via macOS Accessibility API.""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass + +from PIL import Image, ImageDraw, ImageFont + +logger = logging.getLogger(__name__) + +# Roles considered interactive / clickable +INTERACTIVE_ROLES = frozenset({ + "AXButton", + "AXTextField", + "AXTextArea", + "AXLink", + "AXMenuItem", + "AXPopUpButton", + "AXCheckBox", + "AXRadioButton", + "AXTab", + "AXComboBox", + "AXSlider", + "AXIncrementor", + "AXSearchField", + "AXSecureTextField", +}) + +MAX_AX_DEPTH = 12 + + +@dataclass +class SoMElement: + """A single interactive UI element detected via Accessibility.""" + + label: int # display number [1], [2], ... + role: str # AX role (e.g. "AXButton") + title: str # AX title/description + x: int # screen x (px) + y: int # screen y (px) + width: int # element width (px) + height: int # element height (px) + screen_width: int # total screen width (px) + screen_height: int # total screen height (px) + + # ------------------------------------------------------------------ + # helpers + # ------------------------------------------------------------------ + + def center_pixel(self) -> tuple[int, int]: + """Return centre in screen pixels.""" + return (self.x + self.width // 2, self.y + self.height // 2) + + def center_normalized(self) -> tuple[int, int]: + """Return centre in 0-1000 normalised coords.""" + cx, cy = self.center_pixel() + nx = int(cx * 1000 / self.screen_width) if self.screen_width else 0 + ny = int(cy * 1000 / self.screen_height) if self.screen_height else 0 + return (nx, ny) + + def is_too_small(self, min_size: int = 15) -> bool: + """True when either dimension is below *min_size*.""" + return self.width < min_size or self.height < min_size + + def contains(self, other: SoMElement) -> bool: + """True if *self* fully encloses *other* and has strictly larger area.""" + if self.width * self.height <= other.width * other.height: + return False + return ( + self.x <= other.x + and self.y <= other.y + and self.x + self.width >= other.x + other.width + and self.y + self.height >= other.y + other.height + ) + + @property + def short_role(self) -> str: + """Role name without the ``AX`` prefix.""" + return self.role[2:] if self.role.startswith("AX") else self.role + + +# ------------------------------------------------------------------ +# AX tree walking +# ------------------------------------------------------------------ + +def _walk_ax_tree( + element, + results: list[dict], + depth: int, + AXUIElementCopyAttributeValue, + AXValueGetValue, + kAXValueTypeCGPoint, + kAXValueTypeCGSize, +) -> None: + """Recursively collect interactive elements from the AX tree.""" + if depth > MAX_AX_DEPTH: + return + + try: + err, role = AXUIElementCopyAttributeValue(element, "AXRole", None) + if err != 0 or role is None: + role = "" + except Exception: + role = "" + + if role in INTERACTIVE_ROLES: + try: + err_pos, pos_val = AXUIElementCopyAttributeValue( + element, "AXPosition", None + ) + err_size, size_val = AXUIElementCopyAttributeValue( + element, "AXSize", None + ) + if err_pos == 0 and err_size == 0 and pos_val and size_val: + from Quartz import CGPoint, CGSize + + point = CGPoint() + size = CGSize() + AXValueGetValue(pos_val, kAXValueTypeCGPoint, point) + AXValueGetValue(size_val, kAXValueTypeCGSize, size) + + # title / description + title = "" + try: + err_t, t = AXUIElementCopyAttributeValue( + element, "AXTitle", None + ) + if err_t == 0 and t: + title = str(t) + except Exception: + pass + if not title: + try: + err_d, d = AXUIElementCopyAttributeValue( + element, "AXDescription", None + ) + if err_d == 0 and d: + title = str(d) + except Exception: + pass + + results.append({ + "role": str(role), + "title": title, + "x": int(point.x), + "y": int(point.y), + "width": int(size.width), + "height": int(size.height), + }) + except Exception: + pass + + # recurse into children + try: + err, children = AXUIElementCopyAttributeValue( + element, "AXChildren", None + ) + if err == 0 and children: + for child in children: + _walk_ax_tree( + child, + results, + depth + 1, + AXUIElementCopyAttributeValue, + AXValueGetValue, + kAXValueTypeCGPoint, + kAXValueTypeCGSize, + ) + except Exception: + pass + + +def detect_elements( + screen_width: int, screen_height: int +) -> list[SoMElement]: + """Detect interactive UI elements on screen via macOS Accessibility API. + + Returns elements with ``label=0``; the caller should assign labels + after filtering. + """ + try: + from ApplicationServices import ( + AXUIElementCreateApplication, + AXUIElementCopyAttributeValue, + ) + from Quartz import ( + AXValueGetValue, + kAXValueTypeCGPoint, + kAXValueTypeCGSize, + ) + from AppKit import NSWorkspace + except ImportError: + logger.warning("macOS AX frameworks not available") + return [] + + try: + frontmost = NSWorkspace.sharedWorkspace().frontmostApplication() + pid = frontmost.processIdentifier() + app_element = AXUIElementCreateApplication(pid) + except Exception: + logger.exception("Failed to get frontmost application AX element") + return [] + + raw: list[dict] = [] + _walk_ax_tree( + app_element, + raw, + 0, + AXUIElementCopyAttributeValue, + AXValueGetValue, + kAXValueTypeCGPoint, + kAXValueTypeCGSize, + ) + + elements: list[SoMElement] = [] + for r in raw: + elements.append( + SoMElement( + label=0, + role=r["role"], + title=r["title"], + x=r["x"], + y=r["y"], + width=r["width"], + height=r["height"], + screen_width=screen_width, + screen_height=screen_height, + ) + ) + + logger.info("AX tree yielded %d interactive elements", len(elements)) + return elements + + +# ------------------------------------------------------------------ +# filtering +# ------------------------------------------------------------------ + +def filter_leaf_elements( + elements: list[SoMElement], + max_elements: int = 40, + min_size: int = 15, +) -> list[SoMElement]: + """Keep only visible leaf elements, sorted reading-order, capped & labelled.""" + # 1. remove too-small + sized = [e for e in elements if not e.is_too_small(min_size)] + + # 2. remove parents (elements that contain another element) + parent_indices: set[int] = set() + for i, a in enumerate(sized): + for j, b in enumerate(sized): + if i != j and a.contains(b): + parent_indices.add(i) + break # a is a parent, no need to check more + leaves = [e for i, e in enumerate(sized) if i not in parent_indices] + + # 3. sort by y then x (reading order) + leaves.sort(key=lambda e: (e.y, e.x)) + + # 4. cap + leaves = leaves[:max_elements] + + # 5. assign labels 1..N + for idx, elem in enumerate(leaves, start=1): + elem.label = idx + + return leaves + + +# ------------------------------------------------------------------ +# annotation +# ------------------------------------------------------------------ + +def annotate_screenshot( + image: Image.Image, elements: list[SoMElement] +) -> Image.Image: + """Return a copy of *image* with numbered badges on each element.""" + img = image.copy().convert("RGBA") + overlay = Image.new("RGBA", img.size, (0, 0, 0, 0)) + draw = ImageDraw.Draw(overlay) + + # font + try: + font = ImageFont.truetype("/System/Library/Fonts/Menlo.ttc", 11) + except Exception: + font = ImageFont.load_default() + + badge_h = 16 + badge_color = (220, 40, 40, 200) + text_color = (255, 255, 255, 255) + rect_color = (220, 40, 40, 180) + + for elem in elements: + x0, y0 = elem.x, elem.y + x1, y1 = x0 + elem.width, y0 + elem.height + + # bounding rectangle + draw.rectangle([x0, y0, x1, y1], outline=rect_color, width=1) + + # badge + label_text = str(elem.label) + bbox = font.getbbox(label_text) + tw = bbox[2] - bbox[0] + badge_w = tw + 6 + bx0, by0 = x0, y0 + bx1, by1 = bx0 + badge_w, by0 + badge_h + draw.rectangle([bx0, by0, bx1, by1], fill=badge_color) + draw.text( + (bx0 + 3, by0 + 1), label_text, fill=text_color, font=font + ) + + img = Image.alpha_composite(img, overlay) + return img.convert("RGB") + + +# ------------------------------------------------------------------ +# prompt formatting +# ------------------------------------------------------------------ + +def format_element_list(elements: list[SoMElement]) -> str: + """Format elements for LLM prompt injection. + + Example line:: + + [1] Button "Settings" (450,30)-(520,55) + """ + lines: list[str] = [] + for e in elements: + nx0 = int(e.x * 1000 / e.screen_width) if e.screen_width else 0 + ny0 = int(e.y * 1000 / e.screen_height) if e.screen_height else 0 + nx1 = int((e.x + e.width) * 1000 / e.screen_width) if e.screen_width else 0 + ny1 = int((e.y + e.height) * 1000 / e.screen_height) if e.screen_height else 0 + + role = e.short_role + title_part = f' "{e.title}"' if e.title else "" + lines.append(f"[{e.label}] {role}{title_part} ({nx0},{ny0})-({nx1},{ny1})") + return "\n".join(lines) diff --git a/backend/tests/unit/services/test_som.py b/backend/tests/unit/services/test_som.py new file mode 100644 index 0000000..72eebda --- /dev/null +++ b/backend/tests/unit/services/test_som.py @@ -0,0 +1,174 @@ +"""Tests for SoM element detection helpers (pure functions only).""" + +import pytest +from PIL import Image + +from app.services.som import ( + SoMElement, + filter_leaf_elements, + annotate_screenshot, + format_element_list, +) + +SW, SH = 2000, 1000 # screen dims used across tests + + +def _elem( + label=0, role="AXButton", title="", x=0, y=0, w=100, h=50, +) -> SoMElement: + return SoMElement( + label=label, role=role, title=title, + x=x, y=y, width=w, height=h, + screen_width=SW, screen_height=SH, + ) + + +# ------------------------------------------------------------------ +# TestSoMElement +# ------------------------------------------------------------------ + +class TestSoMElement: + + def test_center_pixel(self): + e = _elem(x=100, y=200, w=60, h=40) + assert e.center_pixel() == (130, 220) + + def test_center_normalized(self): + # center pixel = (130, 220); norm = (130*1000/2000, 220*1000/1000) + e = _elem(x=100, y=200, w=60, h=40) + nx, ny = e.center_normalized() + assert nx == 65 + assert ny == 220 + + def test_is_too_small_true(self): + e = _elem(w=10, h=50) + assert e.is_too_small() is True + + def test_is_too_small_false(self): + e = _elem(w=20, h=20) + assert e.is_too_small() is False + + def test_is_too_small_custom(self): + e = _elem(w=20, h=20) + assert e.is_too_small(min_size=25) is True + + def test_contains_true(self): + parent = _elem(x=0, y=0, w=200, h=200) + child = _elem(x=10, y=10, w=50, h=50) + assert parent.contains(child) is True + + def test_contains_false_no_overlap(self): + a = _elem(x=0, y=0, w=50, h=50) + b = _elem(x=100, y=100, w=50, h=50) + assert a.contains(b) is False + + def test_contains_false_same_area(self): + a = _elem(x=0, y=0, w=100, h=100) + b = _elem(x=0, y=0, w=100, h=100) + assert a.contains(b) is False # equal area → False + + def test_short_role(self): + assert _elem(role="AXButton").short_role == "Button" + assert _elem(role="AXTextField").short_role == "TextField" + assert _elem(role="PlainRole").short_role == "PlainRole" + + +# ------------------------------------------------------------------ +# TestFilterLeafElements +# ------------------------------------------------------------------ + +class TestFilterLeafElements: + + def test_parents_removed(self): + parent = _elem(x=0, y=0, w=300, h=300) + child = _elem(x=10, y=10, w=50, h=50) + result = filter_leaf_elements([parent, child]) + assert len(result) == 1 + assert result[0].width == 50 # child kept + + def test_non_overlapping_kept(self): + a = _elem(x=0, y=0, w=80, h=40) + b = _elem(x=200, y=0, w=80, h=40) + result = filter_leaf_elements([a, b]) + assert len(result) == 2 + + def test_labels_assigned_reading_order(self): + # b is above a → b should be label 1 + a = _elem(x=0, y=200, w=80, h=40) + b = _elem(x=0, y=50, w=80, h=40) + result = filter_leaf_elements([a, b]) + assert result[0].label == 1 + assert result[0].y == 50 # b first + assert result[1].label == 2 + assert result[1].y == 200 + + def test_too_small_removed(self): + small = _elem(w=5, h=5) + big = _elem(x=200, y=200, w=80, h=40) + result = filter_leaf_elements([small, big]) + assert len(result) == 1 + + def test_max_elements_cap(self): + elems = [_elem(x=i * 100, y=0, w=80, h=40) for i in range(10)] + result = filter_leaf_elements(elems, max_elements=3) + assert len(result) == 3 + assert result[-1].label == 3 + + +# ------------------------------------------------------------------ +# TestAnnotateScreenshot +# ------------------------------------------------------------------ + +class TestAnnotateScreenshot: + + def test_returns_same_size(self): + img = Image.new("RGB", (800, 600), (255, 255, 255)) + elem = _elem(label=1, x=10, y=10, w=100, h=50) + result = annotate_screenshot(img, [elem]) + assert result.size == img.size + + def test_returns_rgb(self): + img = Image.new("RGB", (800, 600)) + result = annotate_screenshot(img, [_elem(label=1, x=0, y=0, w=50, h=50)]) + assert result.mode == "RGB" + + def test_empty_elements(self): + img = Image.new("RGB", (400, 300)) + result = annotate_screenshot(img, []) + assert result.size == (400, 300) + + +# ------------------------------------------------------------------ +# TestFormatElementList +# ------------------------------------------------------------------ + +class TestFormatElementList: + + def test_basic_format(self): + e = _elem(label=1, role="AXButton", title="Settings", x=900, y=30, w=140, h=25) + out = format_element_list([e]) + assert out.startswith("[1] Button") + assert '"Settings"' in out + + def test_no_title(self): + e = _elem(label=2, role="AXLink", title="", x=0, y=0, w=100, h=50) + out = format_element_list([e]) + assert "[2] Link " in out + assert '""' not in out + + def test_coords_normalized(self): + # x=1000 on 2000px screen → 500 norm, y=500 on 1000px → 500 norm + e = _elem(label=1, role="AXButton", title="", x=1000, y=500, w=200, h=100) + out = format_element_list([e]) + assert "(500,500)" in out + + def test_multi_elements(self): + elems = [ + _elem(label=1, role="AXButton", title="A", x=0, y=0, w=100, h=50), + _elem(label=2, role="AXTextField", title="B", x=200, y=100, w=400, h=50), + ] + out = format_element_list(elems) + lines = out.strip().split("\n") + assert len(lines) == 2 + assert lines[0].startswith("[1]") + assert lines[1].startswith("[2]") From 2aa2a17fa213edd6f03bb08c8b154a4716dd56d8 Mon Sep 17 00:00:00 2001 From: Evan Takahashi Date: Tue, 3 Mar 2026 12:22:32 -0800 Subject: [PATCH 05/11] feat: add LABELED ELEMENTS section to system prompt for SoM --- backend/app/services/llm/prompt_builder.py | 18 ++++++++++++++- .../unit/services/llm/test_prompt_builder.py | 22 +++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 backend/tests/unit/services/llm/test_prompt_builder.py diff --git a/backend/app/services/llm/prompt_builder.py b/backend/app/services/llm/prompt_builder.py index 887b7dc..b225ff3 100644 --- a/backend/app/services/llm/prompt_builder.py +++ b/backend/app/services/llm/prompt_builder.py @@ -24,6 +24,7 @@ def build( lessons: Optional[list[dict]] = None, system_defaults: Optional[dict[str, str]] = None, running_apps: Optional[list[str]] = None, + element_list_text: Optional[str] = None, ) -> str: """Build the full system prompt. @@ -70,6 +71,8 @@ def build( prompt += self._references_section(references) if skill: prompt += self._skill_section(skill) + if element_list_text: + prompt += self._labeled_elements_section(element_list_text) return prompt @@ -760,7 +763,20 @@ def _skill_section(skill: "ResolvedSkill") -> str: ACTIVE SKILL: {label} ═══════════════════════════════════════════════════════════════════════════════ -The following specialized knowledge is available for this task. Use these +The following specialized knowledge is available for this task. Use these app-specific guidelines to perform actions more effectively. {skill.instructions}""" + + @staticmethod + def _labeled_elements_section(element_list_text: str) -> str: + return ( + "\n\n## LABELED ELEMENTS\n\n" + "Interactive elements on the current screen are labeled with numbers " + "[1], [2], etc. on the screenshot.\n" + 'When clicking a labeled element, use: click({"element": 1}) with ' + "the element's label number.\n" + "Only use raw x,y coordinates when your target is NOT labeled " + "(e.g., content inside a web page, unlabeled areas).\n\n" + f"{element_list_text}" + ) diff --git a/backend/tests/unit/services/llm/test_prompt_builder.py b/backend/tests/unit/services/llm/test_prompt_builder.py new file mode 100644 index 0000000..8117fab --- /dev/null +++ b/backend/tests/unit/services/llm/test_prompt_builder.py @@ -0,0 +1,22 @@ +"""Tests for prompt builder SoM section.""" + +from app.services.llm.prompt_builder import PromptBuilder + + +class TestSoMPromptSection: + def test_element_list_injected_when_provided(self): + pb = PromptBuilder(screen_width=1728, screen_height=1117) + prompt = pb.build(element_list_text='[1] Button "OK" (100,200)-(160,230)') + assert "LABELED ELEMENTS" in prompt + assert '[1] Button "OK"' in prompt + assert 'click({"element": 1})' in prompt + + def test_no_element_section_when_none(self): + pb = PromptBuilder(screen_width=1728, screen_height=1117) + prompt = pb.build() + assert "LABELED ELEMENTS" not in prompt + + def test_no_element_section_when_empty_string(self): + pb = PromptBuilder(screen_width=1728, screen_height=1117) + prompt = pb.build(element_list_text="") + assert "LABELED ELEMENTS" not in prompt From 84668a841f93da16389dae1ad9fcd69c4bc09c9b Mon Sep 17 00:00:00 2001 From: Evan Takahashi Date: Tue, 3 Mar 2026 12:24:20 -0800 Subject: [PATCH 06/11] feat: resolve SoM element labels to coordinates in action executor --- backend/app/services/action_executor.py | 47 ++++++++++ .../unit/services/test_action_executor.py | 85 +++++++++++++++++++ 2 files changed, 132 insertions(+) diff --git a/backend/app/services/action_executor.py b/backend/app/services/action_executor.py index fd92ba5..5bbad29 100644 --- a/backend/app/services/action_executor.py +++ b/backend/app/services/action_executor.py @@ -23,6 +23,7 @@ class ActionExecutor: def __init__(self, screen_width: int, screen_height: int): self._screen_width = screen_width self._screen_height = screen_height + self._elements: list = [] # Current SoM elements for this step self._actions: dict[str, callable] = { "click": self._click, "double_click": self._double_click, @@ -68,9 +69,51 @@ async def execute(self, response: LLMResponse) -> str: except Exception as e: return f"Action '{action}' failed: {e}" + def set_elements(self, elements: list) -> None: + """Set the current step's detected SoM elements for element-based clicking.""" + self._elements = elements + + # ── Element resolution ─────────────────────────────────────────── + + async def _click_element(self, params: dict) -> str: + """Click a labeled SoM element by its number.""" + label = params["element"] + elem = next((e for e in self._elements if e.label == label), None) + if not elem: + return f"Element #{label} not found in current screen elements" + cx, cy = elem.center_pixel() + modifiers = params.get("modifiers") + computer_service.click(cx, cy, modifiers=modifiers) + title = f' "{elem.title}"' if elem.title else "" + return f'Clicked element #{label} ({elem.short_role}{title}) at ({cx}, {cy})' + + async def _double_click_element(self, params: dict) -> str: + """Double-click a labeled SoM element by its number.""" + label = params["element"] + elem = next((e for e in self._elements if e.label == label), None) + if not elem: + return f"Element #{label} not found in current screen elements" + cx, cy = elem.center_pixel() + computer_service.double_click(cx, cy) + title = f' "{elem.title}"' if elem.title else "" + return f'Double-clicked element #{label} ({elem.short_role}{title}) at ({cx}, {cy})' + + async def _right_click_element(self, params: dict) -> str: + """Right-click a labeled SoM element by its number.""" + label = params["element"] + elem = next((e for e in self._elements if e.label == label), None) + if not elem: + return f"Element #{label} not found in current screen elements" + cx, cy = elem.center_pixel() + computer_service.right_click(cx, cy) + title = f' "{elem.title}"' if elem.title else "" + return f'Right-clicked element #{label} ({elem.short_role}{title}) at ({cx}, {cy})' + # ── Individual action handlers ────────────────────────────────────── async def _click(self, params: dict) -> str: + if "element" in params: + return await self._click_element(params) x, y = self.normalize_to_screen_coords(params["x"], params["y"]) modifiers = params.get("modifiers") computer_service.click(x, y, modifiers=modifiers) @@ -79,11 +122,15 @@ async def _click(self, params: dict) -> str: return f"Clicked at ({x}, {y}) [normalized: ({params['x']}, {params['y']})]" async def _double_click(self, params: dict) -> str: + if "element" in params: + return await self._double_click_element(params) x, y = self.normalize_to_screen_coords(params["x"], params["y"]) computer_service.double_click(x, y) return f"Double-clicked at ({x}, {y}) [normalized: ({params['x']}, {params['y']})]" async def _right_click(self, params: dict) -> str: + if "element" in params: + return await self._right_click_element(params) x, y = self.normalize_to_screen_coords(params["x"], params["y"]) computer_service.right_click(x, y) return f"Right-clicked at ({x}, {y}) [normalized: ({params['x']}, {params['y']})]" diff --git a/backend/tests/unit/services/test_action_executor.py b/backend/tests/unit/services/test_action_executor.py index 04110c3..b68a81a 100644 --- a/backend/tests/unit/services/test_action_executor.py +++ b/backend/tests/unit/services/test_action_executor.py @@ -5,6 +5,7 @@ from app.services.action_executor import ActionExecutor from app.services.llm.base import LLMResponse +from app.services.som import SoMElement SCREEN_W = 1728 SCREEN_H = 1117 @@ -224,3 +225,87 @@ async def test_action_exception_caught(self, executor, mock_computer): result = await executor.execute(resp) assert "failed" in result assert "Mouse broken" in result + + +# --------------------------------------------------------------------------- +# SoM element resolution +# --------------------------------------------------------------------------- + +def _make_element(label: int = 1, role: str = "AXButton", title: str = "OK", + x: int = 100, y: int = 200, width: int = 80, height: int = 30) -> SoMElement: + return SoMElement( + label=label, role=role, title=title, + x=x, y=y, width=width, height=height, + screen_width=SCREEN_W, screen_height=SCREEN_H, + ) + + +class TestElementResolution: + """Tests for resolving SoM element labels to coordinates.""" + + @pytest.mark.asyncio + async def test_click_element_resolves_to_center(self, executor, mock_computer): + elem = _make_element(label=3, x=100, y=200, width=80, height=30) + executor.set_elements([elem]) + resp = _make_response(action="click", action_params={"element": 3}) + result = await executor.execute(resp) + # center: (100+40, 200+15) = (140, 215) + mock_computer.click.assert_called_once_with(140, 215, modifiers=None) + assert "element #3" in result + assert "(140, 215)" in result + + @pytest.mark.asyncio + async def test_click_xy_still_works(self, executor, mock_computer): + executor.set_elements([_make_element()]) + resp = _make_response(action="click", action_params={"x": 500, "y": 500}) + result = await executor.execute(resp) + mock_computer.click.assert_called_once() + assert "Clicked at" in result + + @pytest.mark.asyncio + async def test_click_unknown_element_returns_error(self, executor, mock_computer): + executor.set_elements([_make_element(label=1)]) + resp = _make_response(action="click", action_params={"element": 99}) + result = await executor.execute(resp) + assert "not found" in result + mock_computer.click.assert_not_called() + + @pytest.mark.asyncio + async def test_double_click_element(self, executor, mock_computer): + elem = _make_element(label=2, x=50, y=60, width=100, height=40) + executor.set_elements([elem]) + resp = _make_response(action="double_click", action_params={"element": 2}) + result = await executor.execute(resp) + # center: (50+50, 60+20) = (100, 80) + mock_computer.double_click.assert_called_once_with(100, 80) + assert "Double-clicked element #2" in result + + @pytest.mark.asyncio + async def test_right_click_element(self, executor, mock_computer): + elem = _make_element(label=5, x=200, y=300, width=60, height=20) + executor.set_elements([elem]) + resp = _make_response(action="right_click", action_params={"element": 5}) + result = await executor.execute(resp) + # center: (200+30, 300+10) = (230, 310) + mock_computer.right_click.assert_called_once_with(230, 310) + assert "Right-clicked element #5" in result + + @pytest.mark.asyncio + async def test_click_element_with_modifiers(self, executor, mock_computer): + elem = _make_element(label=1, x=100, y=200, width=80, height=30) + executor.set_elements([elem]) + resp = _make_response( + action="click", + action_params={"element": 1, "modifiers": ["command"]}, + ) + result = await executor.execute(resp) + mock_computer.click.assert_called_once_with(140, 215, modifiers=["command"]) + assert "element #1" in result + + @pytest.mark.asyncio + async def test_click_element_without_title(self, executor, mock_computer): + elem = _make_element(label=1, title="") + executor.set_elements([elem]) + resp = _make_response(action="click", action_params={"element": 1}) + result = await executor.execute(resp) + assert '"' not in result.split("(Button")[0] if "Button" in result else True From 380afc093af8afbab71eed2e85901d0d6cd55cfb Mon Sep 17 00:00:00 2001 From: Evan Takahashi Date: Tue, 3 Mar 2026 12:27:04 -0800 Subject: [PATCH 07/11] =?UTF-8?q?feat:=20wire=20SoM=20into=20agent=20loop?= =?UTF-8?q?=20=E2=80=94=20detect,=20annotate,=20pass=20to=20LLM?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/services/agent.py | 51 +++++++++++++++++---- backend/app/services/llm/base.py | 3 ++ backend/app/services/llm/gemini_provider.py | 4 ++ 3 files changed, 50 insertions(+), 8 deletions(-) diff --git a/backend/app/services/agent.py b/backend/app/services/agent.py index 13cd6ad..7f1caaa 100644 --- a/backend/app/services/agent.py +++ b/backend/app/services/agent.py @@ -154,27 +154,60 @@ async def _wait_if_paused(self, state: AgentState) -> bool: # ── Observe / Think / Act helpers ─────────────────────────────────── - async def _observe(self, state: AgentState) -> str: - """Capture a screenshot and save it to disk. Returns base64 data.""" + async def _observe(self, state: AgentState) -> tuple[str, list]: + """Capture screenshot, detect SoM elements, annotate. Returns (base64, elements).""" screenshot_b64 = screen_service.capture_screen_base64() state.last_screenshot = screenshot_b64 + # SoM detection + annotation + elements: list = [] + annotated_b64 = screenshot_b64 + if settings.som_enabled: + try: + from app.services.som import ( + detect_elements, filter_leaf_elements, + annotate_screenshot, format_element_list, + ) + screen_w, screen_h = screen_service.get_screen_size() + raw_elements = detect_elements(screen_w, screen_h) + elements = filter_leaf_elements( + raw_elements, + max_elements=settings.som_max_elements, + min_size=settings.som_min_element_size, + ) + if elements: + image_data = base64.b64decode(screenshot_b64) + image = Image.open(io.BytesIO(image_data)) + annotated = annotate_screenshot(image, elements) + buf = io.BytesIO() + annotated.save(buf, "JPEG", quality=SCREENSHOT_QUALITY) + annotated_b64 = base64.b64encode(buf.getvalue()).decode() + logger.info("SoM: detected %d elements", len(elements)) + except Exception as e: + logger.warning("SoM detection failed, using raw screenshot: %s", e) + + # Save raw screenshot to disk (unannotated) screenshots_dir = Path("./screenshots") screenshots_dir.mkdir(exist_ok=True) filename = f"task_{state.task_id}_step_{state.current_step}.jpg" path = screenshots_dir / filename - image_data = base64.b64decode(screenshot_b64) image = Image.open(io.BytesIO(image_data)) image.save(str(path), "JPEG", quality=SCREENSHOT_QUALITY) - state.screenshot_before_path = f"screenshots/{filename}" - return screenshot_b64 + + return annotated_b64, elements async def _think( - self, instruction: str, screenshot_b64: str, state: AgentState + self, instruction: str, screenshot_b64: str, state: AgentState, + elements: list | None = None, ) -> LLMResponse: """Ask the LLM to analyse the current screen.""" + element_list_text = "" + if elements: + from app.services.som import format_element_list + element_list_text = format_element_list(elements) + return await self.llm.analyze_screen( instruction=instruction, screenshot_base64=screenshot_b64, @@ -184,6 +217,7 @@ async def _think( notes=state.notes, references=state.task_references, lessons=state.task_lessons, + element_list_text=element_list_text, ) async def _handle_response( @@ -401,7 +435,8 @@ async def run( # 1. Observe try: - screenshot_b64 = await self._observe(state) + screenshot_b64, elements = await self._observe(state) + self._action_executor.set_elements(elements) except Exception as e: logger.error("Failed to capture screenshot: %s", e) raise @@ -415,7 +450,7 @@ async def run( # 2. Think try: - llm_response = await self._think(instruction, screenshot_b64, state) + llm_response = await self._think(instruction, screenshot_b64, state, elements) except Exception as e: logger.error("Failed to analyze screen with LLM: %s", e) raise diff --git a/backend/app/services/llm/base.py b/backend/app/services/llm/base.py index 1e86852..a2b48b9 100644 --- a/backend/app/services/llm/base.py +++ b/backend/app/services/llm/base.py @@ -113,6 +113,7 @@ async def analyze_screen( notes: Optional[str] = None, references: Optional[list[dict]] = None, lessons: Optional[list[dict]] = None, + element_list_text: str = "", ) -> LLMResponse: """Analyze the current screen and determine the next action.""" pass @@ -168,6 +169,7 @@ def get_system_prompt( notes: Optional[str] = None, references: Optional[list[dict]] = None, lessons: Optional[list[dict]] = None, + element_list_text: str = "", ) -> str: """Get the system prompt for this provider.""" from app.services.defaults import get_running_apps, load_system_defaults @@ -179,4 +181,5 @@ def get_system_prompt( references=references, lessons=lessons, system_defaults=system_defaults, running_apps=running_apps, + element_list_text=element_list_text or None, ) diff --git a/backend/app/services/llm/gemini_provider.py b/backend/app/services/llm/gemini_provider.py index 460ffcd..2dfe07d 100644 --- a/backend/app/services/llm/gemini_provider.py +++ b/backend/app/services/llm/gemini_provider.py @@ -113,6 +113,7 @@ async def analyze_screen( notes: Optional[str] = None, references: Optional[list[dict]] = None, lessons: Optional[list[dict]] = None, + element_list_text: str = "", ) -> LLMResponse: """Analyze screen and return next action.""" contents = self._build_contents(instruction, screenshot_base64, history) @@ -124,6 +125,7 @@ async def analyze_screen( raw_response, token_usage = await self._call_with_retries( contents, skill, memories, plan=plan, notes=notes, references=references, lessons=lessons, + element_list_text=element_list_text, ) if raw_response and not raw_response.strip().endswith("}"): @@ -243,6 +245,7 @@ async def _call_with_retries( notes: Optional[str] = None, references: Optional[list[dict]] = None, lessons: Optional[list[dict]] = None, + element_list_text: str = "", ) -> tuple[str, TokenUsage]: """Call Gemini API with exponential-backoff retries. Returns (raw_text, token_usage).""" last_error: Optional[Exception] = None @@ -252,6 +255,7 @@ async def _call_with_retries( system_prompt = self.get_system_prompt( skill=skill, memories=memories, plan=plan, notes=notes, references=references, lessons=lessons, + element_list_text=element_list_text, ) if skill: logger.info("Using skill: %s", skill) From ffbdfec1b3146ee277619f6fc110cea7805aad95 Mon Sep 17 00:00:00 2001 From: Evan Takahashi Date: Tue, 3 Mar 2026 13:28:47 -0800 Subject: [PATCH 08/11] =?UTF-8?q?fix:=20SoM=20AX=20detection=20=E2=80=94?= =?UTF-8?q?=20correct=20AXValueGetValue=20calling=20convention=20and=20imp?= =?UTF-8?q?orts?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - AXValueGetValue returns (bool, value) tuple with None as 3rd arg, not pre-allocated struct (was silently failing, yielding 0 elements) - Import AX symbols from ApplicationServices instead of Quartz --- backend/app/services/som.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/backend/app/services/som.py b/backend/app/services/som.py index 05d2d05..1910ef3 100644 --- a/backend/app/services/som.py +++ b/backend/app/services/som.py @@ -113,12 +113,8 @@ def _walk_ax_tree( element, "AXSize", None ) if err_pos == 0 and err_size == 0 and pos_val and size_val: - from Quartz import CGPoint, CGSize - - point = CGPoint() - size = CGSize() - AXValueGetValue(pos_val, kAXValueTypeCGPoint, point) - AXValueGetValue(size_val, kAXValueTypeCGSize, size) + _, point = AXValueGetValue(pos_val, kAXValueTypeCGPoint, None) + _, size = AXValueGetValue(size_val, kAXValueTypeCGSize, None) # title / description title = "" @@ -183,8 +179,6 @@ def detect_elements( from ApplicationServices import ( AXUIElementCreateApplication, AXUIElementCopyAttributeValue, - ) - from Quartz import ( AXValueGetValue, kAXValueTypeCGPoint, kAXValueTypeCGSize, From e9ebf53757eeb895cb1f1d877d1b647d418edfaf Mon Sep 17 00:00:00 2001 From: Evan Takahashi Date: Tue, 3 Mar 2026 14:40:53 -0800 Subject: [PATCH 09/11] feat: SoM prompt rework + multi-process AX detection + benchmark - Prompt: element-first click definition, element reminder in CRITICAL REMINDERS and COORDINATE SYSTEM, stronger LABELED ELEMENTS section - AX detection: walk system-UI processes (Control Center, SystemUIServer) for menu bar items (WiFi, Battery, Clock); add AXMenuBarItem role - Benchmark: add coord_accuracy/bench.py using real GeminiProvider (parse retries, full agent prompt) for raw vs SoM comparison --- backend/app/services/llm/prompt_builder.py | 28 ++- backend/app/services/som.py | 72 +++++- backend/benchmark/__init__.py | 0 backend/benchmark/coord_accuracy/__init__.py | 0 backend/benchmark/coord_accuracy/bench.py | 227 ++++++++++++++++++ .../unit/services/llm/test_prompt_builder.py | 2 +- .../plans/2026-03-03-som-prompt-fix-design.md | 38 +++ 7 files changed, 342 insertions(+), 25 deletions(-) create mode 100644 backend/benchmark/__init__.py create mode 100644 backend/benchmark/coord_accuracy/__init__.py create mode 100644 backend/benchmark/coord_accuracy/bench.py create mode 100644 docs/plans/2026-03-03-som-prompt-fix-design.md diff --git a/backend/app/services/llm/prompt_builder.py b/backend/app/services/llm/prompt_builder.py index b225ff3..feebc23 100644 --- a/backend/app/services/llm/prompt_builder.py +++ b/backend/app/services/llm/prompt_builder.py @@ -193,10 +193,14 @@ def _available_actions(self) -> str: ═══════════════════════════════════════════════════════════════════════════════ MOUSE ACTIONS: -- click: {"x": 0-1000, "y": 0-1000} +- click: {"element": N} ← PREFERRED — use the labeled element number from screenshot + OR: {"x": 0-1000, "y": 0-1000} ← only when target is NOT a labeled element Optional modifiers: {"x": N, "y": N, "modifiers": ["shift"|"command"|"option"|"control"]} -- double_click: {"x": 0-1000, "y": 0-1000} - For opening files, selecting words -- right_click: {"x": 0-1000, "y": 0-1000} - For context menus + ⚠ IMPORTANT: When interactive elements are labeled [1], [2], etc. on the screenshot, + ALWAYS use {"element": N} instead of guessing x,y coordinates. Element clicks are + precise; raw coordinates often miss. +- double_click: {"element": N} or {"x": 0-1000, "y": 0-1000} - For opening files, selecting words +- right_click: {"element": N} or {"x": 0-1000, "y": 0-1000} - For context menus - drag: {"start_x": N, "start_y": N, "end_x": N, "end_y": N} - For selecting, moving items KEYBOARD ACTIONS: @@ -278,7 +282,9 @@ def _available_actions(self) -> str: - y=0 is top edge, y=1000 is bottom edge - Example: Center of screen ≈ (500, 500) - Example: Top-right corner ≈ (950, 50) -- Example: Dock (bottom center) ≈ (500, 980)""" +- Example: Dock (bottom center) ≈ (500, 980) +- When elements are labeled on screen: {"element": 5} clicks element [5] precisely + This is MORE ACCURATE than guessing coordinates — always prefer element clicks.""" def _error_recovery(self) -> str: return """ @@ -533,6 +539,8 @@ def _response_format(self) -> str: - ALWAYS start reasoning with verification of the last step's expected_outcome (except on step 0). State: "Expected: X. Actual: Y. [Match/Mismatch]." - Verify coordinates are in 0-1000 range +- For click/double_click/right_click: when a labeled element [N] matches your target, + use "params": {"element": N} — not raw x,y coordinates - When task is complete, action must be "done" with a summary - Never guess when uncertain - ask for clarification - Learn from history - don't repeat failed actions""" @@ -771,12 +779,10 @@ def _skill_section(skill: "ResolvedSkill") -> str: @staticmethod def _labeled_elements_section(element_list_text: str) -> str: return ( - "\n\n## LABELED ELEMENTS\n\n" - "Interactive elements on the current screen are labeled with numbers " - "[1], [2], etc. on the screenshot.\n" - 'When clicking a labeled element, use: click({"element": 1}) with ' - "the element's label number.\n" - "Only use raw x,y coordinates when your target is NOT labeled " - "(e.g., content inside a web page, unlabeled areas).\n\n" + "\n\n## LABELED ELEMENTS (IMPORTANT)\n\n" + "Interactive elements on the current screen are labeled [1], [2], etc.\n" + 'You MUST use click({"element": N}) for any target that matches a labeled element.\n' + "Do NOT guess x,y coordinates for labeled elements — element clicks are exact.\n" + "Only use raw x,y for unlabeled content (e.g., web page body, images, unlabeled areas).\n\n" f"{element_list_text}" ) diff --git a/backend/app/services/som.py b/backend/app/services/som.py index 1910ef3..c874e16 100644 --- a/backend/app/services/som.py +++ b/backend/app/services/som.py @@ -16,6 +16,7 @@ "AXTextArea", "AXLink", "AXMenuItem", + "AXMenuBarItem", "AXPopUpButton", "AXCheckBox", "AXRadioButton", @@ -167,11 +168,44 @@ def _walk_ax_tree( pass +def _walk_pid( + pid: int, + raw: list[dict], + AXUIElementCreateApplication, + AXUIElementCopyAttributeValue, + AXValueGetValue, + kAXValueTypeCGPoint, + kAXValueTypeCGSize, +) -> None: + """Walk one process's AX tree and append results to *raw*.""" + try: + app_element = AXUIElementCreateApplication(pid) + _walk_ax_tree( + app_element, raw, 0, + AXUIElementCopyAttributeValue, AXValueGetValue, + kAXValueTypeCGPoint, kAXValueTypeCGSize, + ) + except Exception: + pass + + +# System processes whose menu-bar items we always want to include. +_SYSTEM_UI_BUNDLE_IDS = frozenset({ + "com.apple.systemuiserver", # legacy menu extras + "com.apple.controlcenter", # WiFi, Battery, BT, etc. + "com.apple.Spotlight", # Spotlight icon + "com.apple.notificationcenterui", # Notification Center +}) + + def detect_elements( screen_width: int, screen_height: int ) -> list[SoMElement]: """Detect interactive UI elements on screen via macOS Accessibility API. + Walks the frontmost application AND system-UI processes so that + menu-bar items (WiFi, Battery, Clock, etc.) are always included. + Returns elements with ``label=0``; the caller should assign labels after filtering. """ @@ -188,24 +222,36 @@ def detect_elements( logger.warning("macOS AX frameworks not available") return [] + raw: list[dict] = [] + + # 1. Frontmost app try: frontmost = NSWorkspace.sharedWorkspace().frontmostApplication() - pid = frontmost.processIdentifier() - app_element = AXUIElementCreateApplication(pid) + frontmost_pid = frontmost.processIdentifier() + _walk_pid( + frontmost_pid, raw, + AXUIElementCreateApplication, AXUIElementCopyAttributeValue, + AXValueGetValue, kAXValueTypeCGPoint, kAXValueTypeCGSize, + ) except Exception: logger.exception("Failed to get frontmost application AX element") - return [] - raw: list[dict] = [] - _walk_ax_tree( - app_element, - raw, - 0, - AXUIElementCopyAttributeValue, - AXValueGetValue, - kAXValueTypeCGPoint, - kAXValueTypeCGSize, - ) + # 2. System-UI processes (menu bar extras) + try: + for app in NSWorkspace.sharedWorkspace().runningApplications(): + bid = app.bundleIdentifier() + if bid and bid in _SYSTEM_UI_BUNDLE_IDS: + pid = app.processIdentifier() + if pid != frontmost_pid: + _walk_pid( + pid, raw, + AXUIElementCreateApplication, + AXUIElementCopyAttributeValue, + AXValueGetValue, + kAXValueTypeCGPoint, kAXValueTypeCGSize, + ) + except Exception: + logger.warning("Failed to walk system-UI processes", exc_info=True) elements: list[SoMElement] = [] for r in raw: diff --git a/backend/benchmark/__init__.py b/backend/benchmark/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/benchmark/coord_accuracy/__init__.py b/backend/benchmark/coord_accuracy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/benchmark/coord_accuracy/bench.py b/backend/benchmark/coord_accuracy/bench.py new file mode 100644 index 0000000..0586e3b --- /dev/null +++ b/backend/benchmark/coord_accuracy/bench.py @@ -0,0 +1,227 @@ +"""Benchmark runner: raw coords vs SoM element clicks. + +Loads test cases from a directory, sends each through the real +GeminiProvider.analyze_screen() pipeline (including parse retries), +and compares click accuracy against ground-truth bounding boxes. + +Called by run_all.py — not intended to be run standalone. +""" + +import asyncio +import base64 +import io +import json +import math +from pathlib import Path + +from PIL import Image + +from app.services.llm.gemini_provider import GeminiProvider +from app.services.som import SoMElement, annotate_screenshot, format_element_list + + +def _build_som_elements(raw_elements: list[dict], screen_w: int, screen_h: int) -> list[SoMElement]: + """Convert test-case element dicts to SoMElement objects.""" + elements = [] + for e in raw_elements: + bbox = e["bbox"] + elements.append(SoMElement( + label=e["label"], + role=e.get("role", "AXButton"), + title=e.get("title", ""), + x=bbox["x"], + y=bbox["y"], + width=bbox["w"], + height=bbox["h"], + screen_width=screen_w, + screen_height=screen_h, + )) + return elements + + +def _hit_test( + params: dict, + target_bbox: dict, + elements: list[SoMElement], + screen_w: int, + screen_h: int, +) -> tuple[bool, float, dict]: + """Check whether the LLM's click lands inside the target bbox. + + Returns (hit, distance_to_center, parsed_response). + """ + # Resolve click point + if "element" in params: + label = params["element"] + elem = next((e for e in elements if e.label == label), None) + if elem: + px, py = elem.center_pixel() + else: + px, py = 0, 0 + parsed = {"action": "click", "raw": json.dumps(params), "element": label} + else: + nx = params.get("x", 0) + ny = params.get("y", 0) + px = int((nx / 1000.0) * screen_w) + py = int((ny / 1000.0) * screen_h) + parsed = {"action": "click", "raw": json.dumps(params), "x": nx, "y": ny} + + tx, ty, tw, th = target_bbox["x"], target_bbox["y"], target_bbox["w"], target_bbox["h"] + hit = tx <= px <= tx + tw and ty <= py <= ty + th + + cx = tx + tw / 2 + cy = ty + th / 2 + dist = math.sqrt((px - cx) ** 2 + (py - cy) ** 2) + + return hit, dist, parsed + + +async def _run_case( + provider: GeminiProvider, + case: dict, + screenshot_bytes: bytes, + screen_w: int, + screen_h: int, +) -> dict: + """Run one test case in both raw and SoM modes.""" + instruction = case["instruction"] + target_bbox = case["target_bbox"] + raw_element_dicts = case.get("elements", []) + + elements = _build_som_elements(raw_element_dicts, screen_w, screen_h) + + raw_b64 = base64.b64encode(screenshot_bytes).decode() + + # Annotated screenshot + image = Image.open(io.BytesIO(screenshot_bytes)) + annotated_img = annotate_screenshot(image, elements) + buf = io.BytesIO() + annotated_img.save(buf, "JPEG", quality=85) + som_b64 = base64.b64encode(buf.getvalue()).decode() + + element_list_text = format_element_list(elements) + + # --- Raw mode: real provider, no element list --- + try: + raw_llm = await provider.analyze_screen( + instruction=instruction, + screenshot_base64=raw_b64, + history=[], + element_list_text="", + ) + params = raw_llm.action_params or {} + raw_hit, raw_dist, raw_parsed = _hit_test( + params, target_bbox, elements, screen_w, screen_h, + ) + except Exception as e: + raw_hit, raw_dist = False, 9999.0 + raw_parsed = {"action": "error", "raw": str(e)} + + # --- SoM mode: annotated screenshot + element list --- + try: + som_llm = await provider.analyze_screen( + instruction=instruction, + screenshot_base64=som_b64, + history=[], + element_list_text=element_list_text, + ) + params = som_llm.action_params or {} + som_hit, som_dist, som_parsed = _hit_test( + params, target_bbox, elements, screen_w, screen_h, + ) + except Exception as e: + som_hit, som_dist = False, 9999.0 + som_parsed = {"action": "error", "raw": str(e)} + + return { + "name": case["name"], + "instruction": instruction, + "raw": {"hit": raw_hit, "distance": raw_dist, "response": raw_parsed}, + "som": {"hit": som_hit, "distance": som_dist, "response": som_parsed}, + } + + +async def run_benchmark( + test_dir: Path, + api_key: str | None = None, + model: str | None = None, +) -> dict: + """Run the full benchmark across all test cases in test_dir. + + Uses the real GeminiProvider (with parse retries, prompt builder, etc.) + so results reflect actual agent performance. + + Args: + test_dir: Directory containing test case JSON files and screenshot. + api_key: Unused, kept for backward compat with run_all.py. + model: Gemini model name (None = use default from settings). + + Returns: + Dict with "cases" list and "summary" stats. + """ + _ = api_key # GeminiProvider reads from env/settings + provider = GeminiProvider(model=model) + + case_files = sorted(test_dir.glob("case_*.json")) + if not case_files: + print(" No test cases found.") + return {"cases": [], "summary": {}} + + # Load screenshot + screenshot_path = None + for cf in case_files: + case_data = json.loads(cf.read_text()) + ss_name = case_data.get("screenshot", "") + if ss_name: + screenshot_path = test_dir / ss_name + break + if not screenshot_path or not screenshot_path.exists(): + print(" Screenshot not found.") + return {"cases": [], "summary": {}} + + screenshot_bytes = screenshot_path.read_bytes() + + results = [] + for cf in case_files: + case = json.loads(cf.read_text()) + screen_w = case["screen_width"] + screen_h = case["screen_height"] + print(f" Running: {case['name']}...", end=" ", flush=True) + + result = await _run_case( + provider, case, screenshot_bytes, screen_w, screen_h, + ) + raw_mark = "HIT" if result["raw"]["hit"] else "miss" + som_mark = "HIT" if result["som"]["hit"] else "miss" + print(f"raw={raw_mark} som={som_mark}") + results.append(result) + + await asyncio.sleep(0.5) + + # Summary + total = len(results) + raw_hits = sum(1 for r in results if r["raw"]["hit"]) + som_hits = sum(1 for r in results if r["som"]["hit"]) + raw_avg_dist = sum(r["raw"]["distance"] for r in results) / max(total, 1) + som_avg_dist = sum(r["som"]["distance"] for r in results) / max(total, 1) + + summary = { + "total": total, + "raw_hits": raw_hits, + "som_hits": som_hits, + "raw_accuracy": raw_hits / max(total, 1), + "som_accuracy": som_hits / max(total, 1), + "raw_avg_distance": raw_avg_dist, + "som_avg_distance": som_avg_dist, + } + + print(f"\n{'=' * 50}") + print("RESULTS SUMMARY") + print(f"{'=' * 50}") + print(f" Total cases: {total}") + print(f" Raw accuracy: {raw_hits}/{total} ({summary['raw_accuracy']:.1%})") + print(f" SoM accuracy: {som_hits}/{total} ({summary['som_accuracy']:.1%})") + print(f" Raw avg dist: {raw_avg_dist:.1f}px") + print(f" SoM avg dist: {som_avg_dist:.1f}px") + + return {"cases": results, "summary": summary} diff --git a/backend/tests/unit/services/llm/test_prompt_builder.py b/backend/tests/unit/services/llm/test_prompt_builder.py index 8117fab..e9b4872 100644 --- a/backend/tests/unit/services/llm/test_prompt_builder.py +++ b/backend/tests/unit/services/llm/test_prompt_builder.py @@ -9,7 +9,7 @@ def test_element_list_injected_when_provided(self): prompt = pb.build(element_list_text='[1] Button "OK" (100,200)-(160,230)') assert "LABELED ELEMENTS" in prompt assert '[1] Button "OK"' in prompt - assert 'click({"element": 1})' in prompt + assert 'click({"element": N})' in prompt def test_no_element_section_when_none(self): pb = PromptBuilder(screen_width=1728, screen_height=1117) diff --git a/docs/plans/2026-03-03-som-prompt-fix-design.md b/docs/plans/2026-03-03-som-prompt-fix-design.md new file mode 100644 index 0000000..b1350f2 --- /dev/null +++ b/docs/plans/2026-03-03-som-prompt-fix-design.md @@ -0,0 +1,38 @@ +# SoM Prompt Fix + Benchmark Design + +## Problem +SoM infrastructure works (elements detected, annotated screenshot sent to LLM, element list in prompt), but Gemini ignores element labels and outputs raw x,y coordinates. Root cause: the RESPONSE FORMAT section — which defines the output contract Gemini pays most attention to — never mentions `{"element": N}` in params. + +## Approach A: Minimal prompt tweak + benchmark + +### Part 1: Prompt change + +Add element reminder to CRITICAL REMINDERS in RESPONSE FORMAT section of `prompt_builder.py`: + +``` +- For click/double_click/right_click: when a labeled element [N] matches your target, + use "params": {"element": N} — not raw x,y coordinates +``` + +No other prompt changes needed — AVAILABLE ACTIONS and LABELED ELEMENTS sections already describe element-based clicking correctly. + +### Part 2: Benchmark runner + +Create `backend/benchmark/coord_accuracy/bench.py` — the missing module imported by `run_all.py`. + +For each test case: +1. **Raw mode**: raw screenshot + standard prompt (no element list, no annotation) +2. **SoM mode**: annotated screenshot + element list in prompt +3. Parse LLM response → extract click target (x,y or element N) +4. Compute hit/miss + distance to target center +5. Print summary table + save `results.json` + +Uses real `PromptBuilder` and Gemini API so it tests the actual prompt. + +## Files +- `backend/app/services/llm/prompt_builder.py` — add element reminder to RESPONSE FORMAT +- `backend/benchmark/coord_accuracy/bench.py` — new file + +## Verification +1. `pytest -v` passes +2. Run `python -m benchmark.coord_accuracy.run_all` — SoM hit rate > raw hit rate From 3236408b4a18f96585258d6b379900f32caa4e7e Mon Sep 17 00:00:00 2001 From: Evan Takahashi Date: Tue, 3 Mar 2026 14:42:49 -0800 Subject: [PATCH 10/11] chore: remove benchmark and design doc from PR Keep PR focused on agent changes only. --- backend/benchmark/__init__.py | 0 backend/benchmark/coord_accuracy/__init__.py | 0 backend/benchmark/coord_accuracy/bench.py | 227 ------------------ .../plans/2026-03-03-som-prompt-fix-design.md | 38 --- 4 files changed, 265 deletions(-) delete mode 100644 backend/benchmark/__init__.py delete mode 100644 backend/benchmark/coord_accuracy/__init__.py delete mode 100644 backend/benchmark/coord_accuracy/bench.py delete mode 100644 docs/plans/2026-03-03-som-prompt-fix-design.md diff --git a/backend/benchmark/__init__.py b/backend/benchmark/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/benchmark/coord_accuracy/__init__.py b/backend/benchmark/coord_accuracy/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/benchmark/coord_accuracy/bench.py b/backend/benchmark/coord_accuracy/bench.py deleted file mode 100644 index 0586e3b..0000000 --- a/backend/benchmark/coord_accuracy/bench.py +++ /dev/null @@ -1,227 +0,0 @@ -"""Benchmark runner: raw coords vs SoM element clicks. - -Loads test cases from a directory, sends each through the real -GeminiProvider.analyze_screen() pipeline (including parse retries), -and compares click accuracy against ground-truth bounding boxes. - -Called by run_all.py — not intended to be run standalone. -""" - -import asyncio -import base64 -import io -import json -import math -from pathlib import Path - -from PIL import Image - -from app.services.llm.gemini_provider import GeminiProvider -from app.services.som import SoMElement, annotate_screenshot, format_element_list - - -def _build_som_elements(raw_elements: list[dict], screen_w: int, screen_h: int) -> list[SoMElement]: - """Convert test-case element dicts to SoMElement objects.""" - elements = [] - for e in raw_elements: - bbox = e["bbox"] - elements.append(SoMElement( - label=e["label"], - role=e.get("role", "AXButton"), - title=e.get("title", ""), - x=bbox["x"], - y=bbox["y"], - width=bbox["w"], - height=bbox["h"], - screen_width=screen_w, - screen_height=screen_h, - )) - return elements - - -def _hit_test( - params: dict, - target_bbox: dict, - elements: list[SoMElement], - screen_w: int, - screen_h: int, -) -> tuple[bool, float, dict]: - """Check whether the LLM's click lands inside the target bbox. - - Returns (hit, distance_to_center, parsed_response). - """ - # Resolve click point - if "element" in params: - label = params["element"] - elem = next((e for e in elements if e.label == label), None) - if elem: - px, py = elem.center_pixel() - else: - px, py = 0, 0 - parsed = {"action": "click", "raw": json.dumps(params), "element": label} - else: - nx = params.get("x", 0) - ny = params.get("y", 0) - px = int((nx / 1000.0) * screen_w) - py = int((ny / 1000.0) * screen_h) - parsed = {"action": "click", "raw": json.dumps(params), "x": nx, "y": ny} - - tx, ty, tw, th = target_bbox["x"], target_bbox["y"], target_bbox["w"], target_bbox["h"] - hit = tx <= px <= tx + tw and ty <= py <= ty + th - - cx = tx + tw / 2 - cy = ty + th / 2 - dist = math.sqrt((px - cx) ** 2 + (py - cy) ** 2) - - return hit, dist, parsed - - -async def _run_case( - provider: GeminiProvider, - case: dict, - screenshot_bytes: bytes, - screen_w: int, - screen_h: int, -) -> dict: - """Run one test case in both raw and SoM modes.""" - instruction = case["instruction"] - target_bbox = case["target_bbox"] - raw_element_dicts = case.get("elements", []) - - elements = _build_som_elements(raw_element_dicts, screen_w, screen_h) - - raw_b64 = base64.b64encode(screenshot_bytes).decode() - - # Annotated screenshot - image = Image.open(io.BytesIO(screenshot_bytes)) - annotated_img = annotate_screenshot(image, elements) - buf = io.BytesIO() - annotated_img.save(buf, "JPEG", quality=85) - som_b64 = base64.b64encode(buf.getvalue()).decode() - - element_list_text = format_element_list(elements) - - # --- Raw mode: real provider, no element list --- - try: - raw_llm = await provider.analyze_screen( - instruction=instruction, - screenshot_base64=raw_b64, - history=[], - element_list_text="", - ) - params = raw_llm.action_params or {} - raw_hit, raw_dist, raw_parsed = _hit_test( - params, target_bbox, elements, screen_w, screen_h, - ) - except Exception as e: - raw_hit, raw_dist = False, 9999.0 - raw_parsed = {"action": "error", "raw": str(e)} - - # --- SoM mode: annotated screenshot + element list --- - try: - som_llm = await provider.analyze_screen( - instruction=instruction, - screenshot_base64=som_b64, - history=[], - element_list_text=element_list_text, - ) - params = som_llm.action_params or {} - som_hit, som_dist, som_parsed = _hit_test( - params, target_bbox, elements, screen_w, screen_h, - ) - except Exception as e: - som_hit, som_dist = False, 9999.0 - som_parsed = {"action": "error", "raw": str(e)} - - return { - "name": case["name"], - "instruction": instruction, - "raw": {"hit": raw_hit, "distance": raw_dist, "response": raw_parsed}, - "som": {"hit": som_hit, "distance": som_dist, "response": som_parsed}, - } - - -async def run_benchmark( - test_dir: Path, - api_key: str | None = None, - model: str | None = None, -) -> dict: - """Run the full benchmark across all test cases in test_dir. - - Uses the real GeminiProvider (with parse retries, prompt builder, etc.) - so results reflect actual agent performance. - - Args: - test_dir: Directory containing test case JSON files and screenshot. - api_key: Unused, kept for backward compat with run_all.py. - model: Gemini model name (None = use default from settings). - - Returns: - Dict with "cases" list and "summary" stats. - """ - _ = api_key # GeminiProvider reads from env/settings - provider = GeminiProvider(model=model) - - case_files = sorted(test_dir.glob("case_*.json")) - if not case_files: - print(" No test cases found.") - return {"cases": [], "summary": {}} - - # Load screenshot - screenshot_path = None - for cf in case_files: - case_data = json.loads(cf.read_text()) - ss_name = case_data.get("screenshot", "") - if ss_name: - screenshot_path = test_dir / ss_name - break - if not screenshot_path or not screenshot_path.exists(): - print(" Screenshot not found.") - return {"cases": [], "summary": {}} - - screenshot_bytes = screenshot_path.read_bytes() - - results = [] - for cf in case_files: - case = json.loads(cf.read_text()) - screen_w = case["screen_width"] - screen_h = case["screen_height"] - print(f" Running: {case['name']}...", end=" ", flush=True) - - result = await _run_case( - provider, case, screenshot_bytes, screen_w, screen_h, - ) - raw_mark = "HIT" if result["raw"]["hit"] else "miss" - som_mark = "HIT" if result["som"]["hit"] else "miss" - print(f"raw={raw_mark} som={som_mark}") - results.append(result) - - await asyncio.sleep(0.5) - - # Summary - total = len(results) - raw_hits = sum(1 for r in results if r["raw"]["hit"]) - som_hits = sum(1 for r in results if r["som"]["hit"]) - raw_avg_dist = sum(r["raw"]["distance"] for r in results) / max(total, 1) - som_avg_dist = sum(r["som"]["distance"] for r in results) / max(total, 1) - - summary = { - "total": total, - "raw_hits": raw_hits, - "som_hits": som_hits, - "raw_accuracy": raw_hits / max(total, 1), - "som_accuracy": som_hits / max(total, 1), - "raw_avg_distance": raw_avg_dist, - "som_avg_distance": som_avg_dist, - } - - print(f"\n{'=' * 50}") - print("RESULTS SUMMARY") - print(f"{'=' * 50}") - print(f" Total cases: {total}") - print(f" Raw accuracy: {raw_hits}/{total} ({summary['raw_accuracy']:.1%})") - print(f" SoM accuracy: {som_hits}/{total} ({summary['som_accuracy']:.1%})") - print(f" Raw avg dist: {raw_avg_dist:.1f}px") - print(f" SoM avg dist: {som_avg_dist:.1f}px") - - return {"cases": results, "summary": summary} diff --git a/docs/plans/2026-03-03-som-prompt-fix-design.md b/docs/plans/2026-03-03-som-prompt-fix-design.md deleted file mode 100644 index b1350f2..0000000 --- a/docs/plans/2026-03-03-som-prompt-fix-design.md +++ /dev/null @@ -1,38 +0,0 @@ -# SoM Prompt Fix + Benchmark Design - -## Problem -SoM infrastructure works (elements detected, annotated screenshot sent to LLM, element list in prompt), but Gemini ignores element labels and outputs raw x,y coordinates. Root cause: the RESPONSE FORMAT section — which defines the output contract Gemini pays most attention to — never mentions `{"element": N}` in params. - -## Approach A: Minimal prompt tweak + benchmark - -### Part 1: Prompt change - -Add element reminder to CRITICAL REMINDERS in RESPONSE FORMAT section of `prompt_builder.py`: - -``` -- For click/double_click/right_click: when a labeled element [N] matches your target, - use "params": {"element": N} — not raw x,y coordinates -``` - -No other prompt changes needed — AVAILABLE ACTIONS and LABELED ELEMENTS sections already describe element-based clicking correctly. - -### Part 2: Benchmark runner - -Create `backend/benchmark/coord_accuracy/bench.py` — the missing module imported by `run_all.py`. - -For each test case: -1. **Raw mode**: raw screenshot + standard prompt (no element list, no annotation) -2. **SoM mode**: annotated screenshot + element list in prompt -3. Parse LLM response → extract click target (x,y or element N) -4. Compute hit/miss + distance to target center -5. Print summary table + save `results.json` - -Uses real `PromptBuilder` and Gemini API so it tests the actual prompt. - -## Files -- `backend/app/services/llm/prompt_builder.py` — add element reminder to RESPONSE FORMAT -- `backend/benchmark/coord_accuracy/bench.py` — new file - -## Verification -1. `pytest -v` passes -2. Run `python -m benchmark.coord_accuracy.run_all` — SoM hit rate > raw hit rate From 63cb0c14fb4cc4e3891c02e68d34e217f65d25da Mon Sep 17 00:00:00 2001 From: Evan Takahashi Date: Tue, 3 Mar 2026 15:24:42 -0800 Subject: [PATCH 11/11] =?UTF-8?q?feat:=20SoM=20dedup,=20off-screen=20filte?= =?UTF-8?q?r,=20expanded=20roles,=20cap=2040=E2=86=9280?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add AXStaticText/AXImage/AXRow/AXCell to INTERACTIVE_ROLES - Skip untitled AXStaticText (decorative) - Filter off-screen elements (collapsed menu items) - Deduplicate by (role, title, x, y, w, h) - Raise default max_elements 40→80 - Add stage-by-stage filter logging --- backend/app/config.py | 2 +- backend/app/services/som.py | 62 +++++++++++++++++++++++-------- backend/tests/unit/test_config.py | 2 +- 3 files changed, 48 insertions(+), 18 deletions(-) diff --git a/backend/app/config.py b/backend/app/config.py index 2484e36..cb3d9cd 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -51,7 +51,7 @@ class Settings(BaseSettings): # Set-of-Mark visual grounding som_enabled: bool = True - som_max_elements: int = 40 # Max labeled elements per screenshot + som_max_elements: int = 80 # Max labeled elements per screenshot som_min_element_size: int = 15 # Skip elements smaller than this (px) class Config: diff --git a/backend/app/services/som.py b/backend/app/services/som.py index c874e16..90a2fd5 100644 --- a/backend/app/services/som.py +++ b/backend/app/services/som.py @@ -26,6 +26,10 @@ "AXIncrementor", "AXSearchField", "AXSecureTextField", + "AXStaticText", + "AXImage", + "AXRow", + "AXCell", }) MAX_AX_DEPTH = 12 @@ -137,14 +141,18 @@ def _walk_ax_tree( except Exception: pass - results.append({ - "role": str(role), - "title": title, - "x": int(point.x), - "y": int(point.y), - "width": int(size.width), - "height": int(size.height), - }) + # Skip untitled AXStaticText — decorative, not interactive + if role == "AXStaticText" and not title: + pass + else: + results.append({ + "role": str(role), + "title": title, + "x": int(point.x), + "y": int(point.y), + "width": int(size.width), + "height": int(size.height), + }) except Exception: pass @@ -279,32 +287,54 @@ def detect_elements( def filter_leaf_elements( elements: list[SoMElement], - max_elements: int = 40, + max_elements: int = 80, min_size: int = 15, ) -> list[SoMElement]: """Keep only visible leaf elements, sorted reading-order, capped & labelled.""" + raw_count = len(elements) + # 1. remove too-small sized = [e for e in elements if not e.is_too_small(min_size)] - # 2. remove parents (elements that contain another element) + # 2. remove off-screen (outside viewport) + on_screen = [ + e for e in sized + if (e.x + e.width > 0 and e.y + e.height > 0 + and e.x < e.screen_width and e.y < e.screen_height) + ] + + # 3. deduplicate by (role, title, x, y, width, height) + seen: set[tuple] = set() + deduped: list[SoMElement] = [] + for e in on_screen: + key = (e.role, e.title, e.x, e.y, e.width, e.height) + if key not in seen: + seen.add(key) + deduped.append(e) + + # 4. remove parents (elements that contain another element) parent_indices: set[int] = set() - for i, a in enumerate(sized): - for j, b in enumerate(sized): + for i, a in enumerate(deduped): + for j, b in enumerate(deduped): if i != j and a.contains(b): parent_indices.add(i) break # a is a parent, no need to check more - leaves = [e for i, e in enumerate(sized) if i not in parent_indices] + leaves = [e for i, e in enumerate(deduped) if i not in parent_indices] - # 3. sort by y then x (reading order) + # 5. sort by y then x (reading order) leaves.sort(key=lambda e: (e.y, e.x)) - # 4. cap + # 6. cap leaves = leaves[:max_elements] - # 5. assign labels 1..N + # 7. assign labels 1..N for idx, elem in enumerate(leaves, start=1): elem.label = idx + logger.info( + "SoM filter: %d raw → %d sized → %d on-screen → %d deduped → %d leaves → %d capped", + raw_count, len(sized), len(on_screen), len(deduped), len(deduped) - len(parent_indices), len(leaves), + ) return leaves diff --git a/backend/tests/unit/test_config.py b/backend/tests/unit/test_config.py index 13b8f5b..f858271 100644 --- a/backend/tests/unit/test_config.py +++ b/backend/tests/unit/test_config.py @@ -7,5 +7,5 @@ class TestSoMConfig: def test_som_defaults(self): s = Settings(gemini_api_key="fake") assert s.som_enabled is True - assert s.som_max_elements == 40 + assert s.som_max_elements == 80 assert s.som_min_element_size == 15