Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backend/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ class Settings(BaseSettings):
shake_min_distance: float = 15 # Min movement distance in pixels (reduced from 20)
shake_check_interval: float = 0.02 # Check interval in seconds (faster sampling)

# Set-of-Mark visual grounding
som_enabled: bool = True
som_max_elements: int = 80 # Max labeled elements per screenshot
som_min_element_size: int = 15 # Skip elements smaller than this (px)

class Config:
env_file = ".env"
env_file_encoding = "utf-8"
Expand Down
47 changes: 47 additions & 0 deletions backend/app/services/action_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class ActionExecutor:
def __init__(self, screen_width: int, screen_height: int):
self._screen_width = screen_width
self._screen_height = screen_height
self._elements: list = [] # Current SoM elements for this step
self._actions: dict[str, callable] = {
"click": self._click,
"double_click": self._double_click,
Expand Down Expand Up @@ -68,9 +69,51 @@ async def execute(self, response: LLMResponse) -> str:
except Exception as e:
return f"Action '{action}' failed: {e}"

def set_elements(self, elements: list) -> None:
"""Set the current step's detected SoM elements for element-based clicking."""
self._elements = elements

# ── Element resolution ───────────────────────────────────────────

async def _click_element(self, params: dict) -> str:
"""Click a labeled SoM element by its number."""
label = params["element"]
elem = next((e for e in self._elements if e.label == label), None)
if not elem:
return f"Element #{label} not found in current screen elements"
cx, cy = elem.center_pixel()
modifiers = params.get("modifiers")
computer_service.click(cx, cy, modifiers=modifiers)
title = f' "{elem.title}"' if elem.title else ""
return f'Clicked element #{label} ({elem.short_role}{title}) at ({cx}, {cy})'

async def _double_click_element(self, params: dict) -> str:
"""Double-click a labeled SoM element by its number."""
label = params["element"]
elem = next((e for e in self._elements if e.label == label), None)
if not elem:
return f"Element #{label} not found in current screen elements"
cx, cy = elem.center_pixel()
computer_service.double_click(cx, cy)
title = f' "{elem.title}"' if elem.title else ""
return f'Double-clicked element #{label} ({elem.short_role}{title}) at ({cx}, {cy})'

async def _right_click_element(self, params: dict) -> str:
"""Right-click a labeled SoM element by its number."""
label = params["element"]
elem = next((e for e in self._elements if e.label == label), None)
if not elem:
return f"Element #{label} not found in current screen elements"
cx, cy = elem.center_pixel()
computer_service.right_click(cx, cy)
title = f' "{elem.title}"' if elem.title else ""
return f'Right-clicked element #{label} ({elem.short_role}{title}) at ({cx}, {cy})'

# ── Individual action handlers ──────────────────────────────────────

async def _click(self, params: dict) -> str:
if "element" in params:
return await self._click_element(params)
x, y = self.normalize_to_screen_coords(params["x"], params["y"])
modifiers = params.get("modifiers")
computer_service.click(x, y, modifiers=modifiers)
Expand All @@ -79,11 +122,15 @@ async def _click(self, params: dict) -> str:
return f"Clicked at ({x}, {y}) [normalized: ({params['x']}, {params['y']})]"

async def _double_click(self, params: dict) -> str:
if "element" in params:
return await self._double_click_element(params)
x, y = self.normalize_to_screen_coords(params["x"], params["y"])
computer_service.double_click(x, y)
return f"Double-clicked at ({x}, {y}) [normalized: ({params['x']}, {params['y']})]"

async def _right_click(self, params: dict) -> str:
if "element" in params:
return await self._right_click_element(params)
x, y = self.normalize_to_screen_coords(params["x"], params["y"])
computer_service.right_click(x, y)
return f"Right-clicked at ({x}, {y}) [normalized: ({params['x']}, {params['y']})]"
Expand Down
51 changes: 43 additions & 8 deletions backend/app/services/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,27 +154,60 @@ async def _wait_if_paused(self, state: AgentState) -> bool:

# ── Observe / Think / Act helpers ───────────────────────────────────

async def _observe(self, state: AgentState) -> str:
"""Capture a screenshot and save it to disk. Returns base64 data."""
async def _observe(self, state: AgentState) -> tuple[str, list]:
"""Capture screenshot, detect SoM elements, annotate. Returns (base64, elements)."""
screenshot_b64 = screen_service.capture_screen_base64()
state.last_screenshot = screenshot_b64

# SoM detection + annotation
elements: list = []
annotated_b64 = screenshot_b64
if settings.som_enabled:
try:
from app.services.som import (
detect_elements, filter_leaf_elements,
annotate_screenshot, format_element_list,
)
screen_w, screen_h = screen_service.get_screen_size()
raw_elements = detect_elements(screen_w, screen_h)
elements = filter_leaf_elements(
raw_elements,
max_elements=settings.som_max_elements,
min_size=settings.som_min_element_size,
)
if elements:
image_data = base64.b64decode(screenshot_b64)
image = Image.open(io.BytesIO(image_data))
annotated = annotate_screenshot(image, elements)
buf = io.BytesIO()
annotated.save(buf, "JPEG", quality=SCREENSHOT_QUALITY)
annotated_b64 = base64.b64encode(buf.getvalue()).decode()
logger.info("SoM: detected %d elements", len(elements))
except Exception as e:
logger.warning("SoM detection failed, using raw screenshot: %s", e)

# Save raw screenshot to disk (unannotated)
screenshots_dir = Path("./screenshots")
screenshots_dir.mkdir(exist_ok=True)
filename = f"task_{state.task_id}_step_{state.current_step}.jpg"
path = screenshots_dir / filename

image_data = base64.b64decode(screenshot_b64)
image = Image.open(io.BytesIO(image_data))
image.save(str(path), "JPEG", quality=SCREENSHOT_QUALITY)

state.screenshot_before_path = f"screenshots/{filename}"
return screenshot_b64

return annotated_b64, elements

async def _think(
self, instruction: str, screenshot_b64: str, state: AgentState
self, instruction: str, screenshot_b64: str, state: AgentState,
elements: list | None = None,
) -> LLMResponse:
"""Ask the LLM to analyse the current screen."""
element_list_text = ""
if elements:
from app.services.som import format_element_list
element_list_text = format_element_list(elements)

return await self.llm.analyze_screen(
instruction=instruction,
screenshot_base64=screenshot_b64,
Expand All @@ -184,6 +217,7 @@ async def _think(
notes=state.notes,
references=state.task_references,
lessons=state.task_lessons,
element_list_text=element_list_text,
)

async def _handle_response(
Expand Down Expand Up @@ -401,7 +435,8 @@ async def run(

# 1. Observe
try:
screenshot_b64 = await self._observe(state)
screenshot_b64, elements = await self._observe(state)
self._action_executor.set_elements(elements)
except Exception as e:
logger.error("Failed to capture screenshot: %s", e)
raise
Expand All @@ -415,7 +450,7 @@ async def run(

# 2. Think
try:
llm_response = await self._think(instruction, screenshot_b64, state)
llm_response = await self._think(instruction, screenshot_b64, state, elements)
except Exception as e:
logger.error("Failed to analyze screen with LLM: %s", e)
raise
Expand Down
3 changes: 3 additions & 0 deletions backend/app/services/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ async def analyze_screen(
notes: Optional[str] = None,
references: Optional[list[dict]] = None,
lessons: Optional[list[dict]] = None,
element_list_text: str = "",
) -> LLMResponse:
"""Analyze the current screen and determine the next action."""
pass
Expand Down Expand Up @@ -168,6 +169,7 @@ def get_system_prompt(
notes: Optional[str] = None,
references: Optional[list[dict]] = None,
lessons: Optional[list[dict]] = None,
element_list_text: str = "",
) -> str:
"""Get the system prompt for this provider."""
from app.services.defaults import get_running_apps, load_system_defaults
Expand All @@ -179,4 +181,5 @@ def get_system_prompt(
references=references, lessons=lessons,
system_defaults=system_defaults,
running_apps=running_apps,
element_list_text=element_list_text or None,
)
77 changes: 77 additions & 0 deletions backend/app/services/llm/gemini_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def _append_full_steps(
class GeminiProvider(BaseLLMProvider):
"""Google Gemini provider for vision-based computer control."""

_MAX_PARSE_RETRIES = 2

def __init__(self, model: str = None):
super().__init__()
self._model = model or settings.gemini_model
Expand Down Expand Up @@ -111,6 +113,7 @@ async def analyze_screen(
notes: Optional[str] = None,
references: Optional[list[dict]] = None,
lessons: Optional[list[dict]] = None,
element_list_text: str = "",
) -> LLMResponse:
"""Analyze screen and return next action."""
contents = self._build_contents(instruction, screenshot_base64, history)
Expand All @@ -122,6 +125,7 @@ async def analyze_screen(
raw_response, token_usage = await self._call_with_retries(
contents, skill, memories, plan=plan, notes=notes,
references=references, lessons=lessons,
element_list_text=element_list_text,
)

if raw_response and not raw_response.strip().endswith("}"):
Expand All @@ -133,6 +137,14 @@ async def analyze_screen(
logger.debug("Raw response: %s", raw_response)

llm_response = self._parse_response(raw_response)

if llm_response.reasoning.startswith("Failed to parse JSON"):
logger.info("Parse failed, attempting JSON correction retry")
retried = await self._retry_parse_with_correction(raw_response)
if retried is not None:
retried.token_usage = token_usage
return retried

llm_response.token_usage = token_usage
return llm_response

Expand Down Expand Up @@ -233,6 +245,7 @@ async def _call_with_retries(
notes: Optional[str] = None,
references: Optional[list[dict]] = None,
lessons: Optional[list[dict]] = None,
element_list_text: str = "",
) -> tuple[str, TokenUsage]:
"""Call Gemini API with exponential-backoff retries. Returns (raw_text, token_usage)."""
last_error: Optional[Exception] = None
Expand All @@ -242,6 +255,7 @@ async def _call_with_retries(
system_prompt = self.get_system_prompt(
skill=skill, memories=memories, plan=plan, notes=notes,
references=references, lessons=lessons,
element_list_text=element_list_text,
)
if skill:
logger.info("Using skill: %s", skill)
Expand Down Expand Up @@ -362,6 +376,69 @@ def _parse_response(self, raw_response: str) -> LLMResponse:
raw_response=raw_response,
)

async def _retry_parse_with_correction(self, malformed_response: str) -> Optional[LLMResponse]:
"""Re-prompt the LLM to fix a malformed JSON response.

Returns a parsed LLMResponse on success, or None after all retries fail.
"""
truncated = malformed_response[:1000]
correction_prompt = (
"Your previous response was not valid JSON. Here is what you returned: "
f"{truncated}. Return ONLY a valid JSON object with the required fields: "
"reasoning, action, params, needs_confirmation. Nothing else."
)

for attempt in range(self._MAX_PARSE_RETRIES):
try:
response = await self.client.aio.models.generate_content(
model=self._model,
contents=[types.Content(
role="user",
parts=[types.Part.from_text(text=correction_prompt)],
)],
config=types.GenerateContentConfig(
max_output_tokens=2048,
temperature=0.0,
),
)

raw_text = response.text if hasattr(response, "text") else None
if raw_text is None:
logger.warning("JSON retry attempt %d/%d: empty response", attempt + 1, self._MAX_PARSE_RETRIES)
continue

content = _strip_markdown_codeblock(raw_text)
data = json.loads(content)

action = data.get("action", "")
is_complete = action == "done"
needs_clarification = data.get("needs_clarification", False)

logger.info("JSON retry attempt %d/%d succeeded", attempt + 1, self._MAX_PARSE_RETRIES)
return LLMResponse(
reasoning=data.get("reasoning", ""),
action=action if not is_complete and not needs_clarification else None,
action_params=data.get("params", {}),
is_complete=is_complete,
needs_confirmation=data.get("needs_confirmation", False),
needs_human_help=data.get("needs_human_help", False),
human_help_reason=data.get("human_help_reason"),
needs_clarification=needs_clarification,
clarification_question=data.get("clarification_question"),
clarification_options=data.get("clarification_options"),
expected_outcome=data.get("expected_outcome"),
notes=data.get("notes"),
raw_response=raw_text,
)
except (json.JSONDecodeError, Exception) as e:
logger.warning(
"JSON retry attempt %d/%d failed: %s",
attempt + 1, self._MAX_PARSE_RETRIES, e,
)

logger.warning("All %d JSON retry attempts failed", self._MAX_PARSE_RETRIES)
return None

async def extract_and_merge_memories(
self,
instruction: str,
Expand Down
Loading