diff --git a/align_app/adm/decider/__init__.py b/align_app/adm/decider/__init__.py index c711656..9b51b25 100644 --- a/align_app/adm/decider/__init__.py +++ b/align_app/adm/decider/__init__.py @@ -1,11 +1,12 @@ from align_utils.models import ADMResult, Decision, ChoiceInfo from .decider import MultiprocessDecider -from .client import get_decision +from .client import get_decision, get_model_cache_status from .types import DeciderParams __all__ = [ "MultiprocessDecider", "get_decision", + "get_model_cache_status", "DeciderParams", "ADMResult", "Decision", diff --git a/align_app/adm/decider/client.py b/align_app/adm/decider/client.py index 67746be..6a4ae48 100644 --- a/align_app/adm/decider/client.py +++ b/align_app/adm/decider/client.py @@ -4,8 +4,10 @@ """ import atexit +from typing import Dict, Any from align_utils.models import ADMResult from .decider import MultiprocessDecider +from .worker import CacheQueryResult from .types import DeciderParams _decider = None @@ -32,6 +34,14 @@ async def get_decision(params: DeciderParams) -> ADMResult: return await process_manager.get_decision(params) +async def get_model_cache_status( + resolved_config: Dict[str, Any], +) -> CacheQueryResult | None: + """Get best-effort model cache status (memory + disk).""" + process_manager = _get_process_manager() + return await process_manager.get_model_cache_status(resolved_config) + + def cleanup(): """Clean up resources when the module is unloaded""" if _decider is not None: diff --git a/align_app/adm/decider/decider.py b/align_app/adm/decider/decider.py index d4ccbe6..50ad47a 100644 --- a/align_app/adm/decider/decider.py +++ b/align_app/adm/decider/decider.py @@ -1,6 +1,7 @@ +from typing import Dict, Any from align_utils.models import ADMResult from .types import DeciderParams -from .worker import decider_worker_func +from .worker import decider_worker_func, CacheQuery, CacheQueryResult from .multiprocess_worker import ( WorkerHandle, create_worker, @@ -13,6 +14,14 @@ class MultiprocessDecider: def __init__(self): self.worker: WorkerHandle = create_worker(decider_worker_func) + async def get_model_cache_status( + self, resolved_config: Dict[str, Any] + ) -> CacheQueryResult | None: + self.worker, result = await send(self.worker, CacheQuery(resolved_config)) + if isinstance(result, CacheQueryResult): + return result + return None + async def get_decision(self, params: DeciderParams) -> ADMResult: self.worker, result = await send(self.worker, params) diff --git a/align_app/adm/decider/worker.py b/align_app/adm/decider/worker.py index c80af6b..6e7fe51 100644 --- a/align_app/adm/decider/worker.py +++ b/align_app/adm/decider/worker.py @@ -2,8 +2,10 @@ import hashlib import json import logging +import os import traceback -from typing import Dict, Tuple, Callable, Any +from dataclasses import dataclass +from typing import Dict, Tuple, Callable, Any, Optional from multiprocessing import Queue from align_utils.models import ADMResult from .executor import instantiate_adm @@ -15,15 +17,89 @@ def extract_cache_key(resolved_config: Dict[str, Any]) -> str: return hashlib.md5(cache_str.encode()).hexdigest() +@dataclass +class CacheQuery: + resolved_config: Dict[str, Any] + + +@dataclass +class CacheQueryResult: + is_cached: bool + is_downloaded: Optional[bool] + + +def _extract_model_name(resolved_config: Dict[str, Any]) -> Optional[str]: + if not isinstance(resolved_config, dict): + return None + + if isinstance(resolved_config.get("model_name"), str): + return resolved_config["model_name"] + + structured = resolved_config.get("structured_inference_engine") + if isinstance(structured, dict) and isinstance(structured.get("model_name"), str): + return structured["model_name"] + + for value in resolved_config.values(): + if isinstance(value, dict): + found = _extract_model_name(value) + if found: + return found + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + found = _extract_model_name(item) + if found: + return found + return None + + +def _is_model_downloaded(model_name: Optional[str]) -> Optional[bool]: + if not model_name: + return None + + if os.path.exists(model_name): + return True + + try: + from huggingface_hub import snapshot_download + except Exception: + return None + + try: + snapshot_download(model_name, local_files_only=True) + return True + except Exception: + return False + + def decider_worker_func(task_queue: Queue, result_queue: Queue): root_logger = logging.getLogger() root_logger.setLevel("WARNING") + logger = logging.getLogger(__name__) model_cache: Dict[str, Tuple[Callable, Callable]] = {} try: for task in iter(task_queue.get, None): try: + if isinstance(task, CacheQuery): + cache_key = extract_cache_key(task.resolved_config) + is_cached = cache_key in model_cache + is_downloaded = ( + True + if is_cached + else _is_model_downloaded( + _extract_model_name(task.resolved_config) + ) + ) + result_queue.put( + CacheQueryResult( + is_cached=is_cached, + is_downloaded=is_downloaded, + ) + ) + continue + params: DeciderParams = task cache_key = extract_cache_key(params.resolved_config) @@ -54,7 +130,8 @@ def decider_worker_func(task_queue: Queue, result_queue: Queue): except (KeyboardInterrupt, SystemExit): break except Exception as e: - error_msg = f"{str(e)}\n{traceback.format_exc()}" + logger.error("Worker error:\n%s", traceback.format_exc()) + error_msg = _format_worker_error(e) result_queue.put(Exception(error_msg)) finally: for _, (_, cleanup_func) in model_cache.items(): @@ -62,3 +139,21 @@ def decider_worker_func(task_queue: Queue, result_queue: Queue): cleanup_func() except Exception: pass + + +def _format_worker_error(error: Exception) -> str: + error_text = str(error) + gated_tokens = ( + "GatedRepoError", + "gated repo", + "401 Client Error", + "Access to model", + "restricted", + "Please log in", + ) + if any(token in error_text for token in gated_tokens): + return ( + "Model access denied. Authenticate with Hugging Face or request access " + "to the gated repo." + ) + return f"{error_text}\n{traceback.format_exc()}" diff --git a/align_app/app/runs_state_adapter.py b/align_app/app/runs_state_adapter.py index df0576c..18c514d 100644 --- a/align_app/app/runs_state_adapter.py +++ b/align_app/app/runs_state_adapter.py @@ -1,3 +1,4 @@ +import logging from typing import Optional, Callable from trame.app import asynchronous from trame.app.file_upload import ClientFile @@ -7,6 +8,7 @@ from .runs_registry import RunsRegistry from .runs_table_filter import RunsTableFilter from ..adm.decider.types import DeciderParams +from ..adm.decider import get_model_cache_status from ..adm.system_adm_discovery import discover_system_adms from ..utils.utils import get_id from .runs_presentation import extract_base_scenarios @@ -15,6 +17,8 @@ from .import_experiments import import_experiments_from_zip from align_utils.models import AlignmentTarget +logger = logging.getLogger(__name__) + @TrameApp() class RunsStateAdapter: @@ -351,35 +355,32 @@ def delete_run_alignment_attribute(self, run_id: str, attr_index: int): new_run = self.runs_registry.delete_run_alignment_attribute(run_id, attr_index) self._handle_run_update(run_id, new_run) - @controller.set("update_run_probe_text") + @trigger("update_run_probe_text") def update_run_probe_text(self, run_id: str, text: str): if run_id in self.state.runs: self.state.runs[run_id]["prompt"]["probe"]["display_state"] = text - self.state.dirty("runs") choices = self.state.runs[run_id]["prompt"]["probe"]["choices"] self.state.probe_dirty[run_id] = self._is_probe_edited( run_id, text, choices ) self.state.dirty("probe_dirty") - @controller.set("update_run_choice_text") + @trigger("update_run_choice_text") def update_run_choice_text(self, run_id: str, index: int, text: str): if run_id in self.state.runs: choices = self.state.runs[run_id]["prompt"]["probe"]["choices"] if 0 <= index < len(choices): choices[index]["unstructured"] = text - self.state.dirty("runs") probe_text = self.state.runs[run_id]["prompt"]["probe"]["display_state"] self.state.probe_dirty[run_id] = self._is_probe_edited( run_id, probe_text, choices ) self.state.dirty("probe_dirty") - @controller.set("update_run_config_yaml") + @trigger("update_run_config_yaml") def update_run_config_yaml(self, run_id: str, yaml_text: str): if run_id in self.state.runs: self.state.runs[run_id]["prompt"]["resolved_config_yaml"] = yaml_text - self.state.dirty("runs") is_edited = self._is_config_edited(run_id, yaml_text) self.state.config_dirty[run_id] = is_edited self.state.dirty("config_dirty") @@ -521,7 +522,15 @@ def _create_run_with_edited_config( run = self.runs_registry.get_run(run_id) if not run: return None - new_config = yaml.safe_load(current_yaml) + try: + new_config = yaml.safe_load(current_yaml) + except yaml.YAMLError: + logger.exception("Invalid YAML syntax while saving config edits") + self._alerts.create_info_alert( + title="Invalid YAML syntax. Please fix and try again.", + timeout=8000, + ) + return None decider_options = self.decider_registry.get_decider_options( run.probe_id, run.decider_name @@ -614,13 +623,20 @@ async def _execute_run_decision(self, run_id: str): with self.state: self._add_pending_cache_key(cache_key) - is_cached = self.runs_registry.has_cached_decision(run_id) - if not is_cached: - alert_id = self._alerts.create_info_alert( - title="Loading model and deciding...", timeout=0 - ) + run = self.runs_registry.get_run(run_id) + is_cached_decision = self.runs_registry.has_cached_decision(run_id) + status = None + if run: + status = await get_model_cache_status(run.decider_params.resolved_config) + + if is_cached_decision or (status and status.is_cached): + alert_title = "Deciding..." + elif status and status.is_downloaded is False: + alert_title = "Downloading model and deciding..." else: - alert_id = self._alerts.create_info_alert(title="Deciding...", timeout=0) + alert_title = "Loading model and deciding..." + + alert_id = self._alerts.create_info_alert(title=alert_title, timeout=0) await self.server.network_completion try: @@ -629,7 +645,15 @@ async def _execute_run_decision(self, run_id: str): self._alerts.create_info_alert(title="Decision complete", timeout=3000) except Exception as e: self._alerts.remove_alert(alert_id) - self._alerts.create_info_alert(title=f"Decision failed: {e}", timeout=5000) + error_text = str(e) + if "Model access denied" in error_text: + message = ( + "Decision failed: Model access denied. " + "Authenticate with Hugging Face or request access to the model." + ) + else: + message = f"Decision failed: {e}" + self._alerts.create_info_alert(title=message, timeout=8000) with self.state: self._rebuild_comparison_runs() diff --git a/align_app/app/ui.py b/align_app/app/ui.py index 2fd2c2c..41a0550 100644 --- a/align_app/app/ui.py +++ b/align_app/app/ui.py @@ -396,8 +396,8 @@ def run_content(): vuetify3.VTextarea( model_value=("runs[id].prompt.resolved_config_yaml",), update_modelValue=( - ctrl.update_run_config_yaml, - r"[id, $event]", + r"runs[id].prompt.resolved_config_yaml = $event; " + r"trigger('update_run_config_yaml', [id, $event])" ), auto_grow=True, rows=1, @@ -568,7 +568,10 @@ def __init__(self, server): html.Div("Situation", classes="text-h6 pt-4") vuetify3.VTextarea( model_value=("runs[id].prompt.probe.display_state",), - update_modelValue=(ctrl.update_run_probe_text, "[id, $event]"), + update_modelValue=( + "runs[id].prompt.probe.display_state = $event; " + "trigger('update_run_probe_text', [id, $event])" + ), auto_grow=True, rows=3, hide_details="auto", @@ -588,8 +591,8 @@ def __init__(self, server): vuetify3.VTextarea( model_value=("choice.unstructured",), update_modelValue=( - ctrl.update_run_choice_text, - "[id, index, $event]", + "runs[id].prompt.probe.choices[index].unstructured = $event; " + "trigger('update_run_choice_text', [id, index, $event])" ), auto_grow=True, rows=1, @@ -1561,6 +1564,9 @@ def __init__( ".drop-zone-active { outline: 3px dashed #1976d2 !important; outline-offset: -3px; }" ".alert-popup-container { left: auto; right: 0; transform: none; width: fit-content; }" ".alert-popup-container .v-alert { --v-theme-info: 66, 66, 66; }" + ".alert-popup-container .v-alert__icon { display: none; }" + ".alert-popup-container .v-alert__prepend { display: none; }" + ".alert-popup-container .v-alert__prepend .v-icon { display: none; }" "'" ) ) diff --git a/tests/e2e/test_scenario_edit.py b/tests/e2e/test_scenario_edit.py index 86e98c2..a373139 100644 --- a/tests/e2e/test_scenario_edit.py +++ b/tests/e2e/test_scenario_edit.py @@ -64,3 +64,25 @@ def test_situation_text_revert_restores_original_scene(page, align_app_server): f"Expected scene to revert to original after restoring text. " f"Original: {original_scene}, Reverted: {reverted_scene}" ) + + +def test_situation_textarea_cursor_position_preserved(page, align_app_server): + """Regression test: typing in situation textarea should not jump cursor to end.""" + align_page = AlignPage(page) + align_page.goto(align_app_server) + align_page.expand_scenario_panel() + + textarea = align_page.situation_textarea + expect(textarea).to_be_visible() + textarea.click() + page.keyboard.press("Control+Home") + page.wait_for_timeout(200) + + page.keyboard.type("X", delay=50) + page.wait_for_timeout(1500) + + cursor_position = textarea.evaluate("el => el.selectionStart") + assert cursor_position <= 2, ( + f"Cursor jumped to position {cursor_position} after typing at start. " + f"Expected near position 1." + )