Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion align_app/adm/decider/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from align_utils.models import ADMResult, Decision, ChoiceInfo
from .decider import MultiprocessDecider
from .client import get_decision
from .client import get_decision, get_model_cache_status
from .types import DeciderParams

__all__ = [
"MultiprocessDecider",
"get_decision",
"get_model_cache_status",
"DeciderParams",
"ADMResult",
"Decision",
Expand Down
10 changes: 10 additions & 0 deletions align_app/adm/decider/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
"""

import atexit
from typing import Dict, Any
from align_utils.models import ADMResult
from .decider import MultiprocessDecider
from .worker import CacheQueryResult
from .types import DeciderParams

_decider = None
Expand All @@ -32,6 +34,14 @@ async def get_decision(params: DeciderParams) -> ADMResult:
return await process_manager.get_decision(params)


async def get_model_cache_status(
resolved_config: Dict[str, Any],
) -> CacheQueryResult | None:
"""Get best-effort model cache status (memory + disk)."""
process_manager = _get_process_manager()
return await process_manager.get_model_cache_status(resolved_config)


def cleanup():
"""Clean up resources when the module is unloaded"""
if _decider is not None:
Expand Down
11 changes: 10 additions & 1 deletion align_app/adm/decider/decider.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict, Any
from align_utils.models import ADMResult
from .types import DeciderParams
from .worker import decider_worker_func
from .worker import decider_worker_func, CacheQuery, CacheQueryResult
from .multiprocess_worker import (
WorkerHandle,
create_worker,
Expand All @@ -13,6 +14,14 @@ class MultiprocessDecider:
def __init__(self):
self.worker: WorkerHandle = create_worker(decider_worker_func)

async def get_model_cache_status(
self, resolved_config: Dict[str, Any]
) -> CacheQueryResult | None:
self.worker, result = await send(self.worker, CacheQuery(resolved_config))
if isinstance(result, CacheQueryResult):
return result
return None

async def get_decision(self, params: DeciderParams) -> ADMResult:
self.worker, result = await send(self.worker, params)

Expand Down
99 changes: 97 additions & 2 deletions align_app/adm/decider/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import hashlib
import json
import logging
import os
import traceback
from typing import Dict, Tuple, Callable, Any
from dataclasses import dataclass
from typing import Dict, Tuple, Callable, Any, Optional
from multiprocessing import Queue
from align_utils.models import ADMResult
from .executor import instantiate_adm
Expand All @@ -15,15 +17,89 @@ def extract_cache_key(resolved_config: Dict[str, Any]) -> str:
return hashlib.md5(cache_str.encode()).hexdigest()


@dataclass
class CacheQuery:
resolved_config: Dict[str, Any]


@dataclass
class CacheQueryResult:
is_cached: bool
is_downloaded: Optional[bool]


def _extract_model_name(resolved_config: Dict[str, Any]) -> Optional[str]:
if not isinstance(resolved_config, dict):
return None

if isinstance(resolved_config.get("model_name"), str):
return resolved_config["model_name"]

structured = resolved_config.get("structured_inference_engine")
if isinstance(structured, dict) and isinstance(structured.get("model_name"), str):
return structured["model_name"]

for value in resolved_config.values():
if isinstance(value, dict):
found = _extract_model_name(value)
if found:
return found
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
found = _extract_model_name(item)
if found:
return found
return None


def _is_model_downloaded(model_name: Optional[str]) -> Optional[bool]:
if not model_name:
return None

if os.path.exists(model_name):
return True

try:
from huggingface_hub import snapshot_download
except Exception:
return None

try:
snapshot_download(model_name, local_files_only=True)
return True
except Exception:
return False


def decider_worker_func(task_queue: Queue, result_queue: Queue):
root_logger = logging.getLogger()
root_logger.setLevel("WARNING")
logger = logging.getLogger(__name__)

model_cache: Dict[str, Tuple[Callable, Callable]] = {}

try:
for task in iter(task_queue.get, None):
try:
if isinstance(task, CacheQuery):
cache_key = extract_cache_key(task.resolved_config)
is_cached = cache_key in model_cache
is_downloaded = (
True
if is_cached
else _is_model_downloaded(
_extract_model_name(task.resolved_config)
)
)
result_queue.put(
CacheQueryResult(
is_cached=is_cached,
is_downloaded=is_downloaded,
)
)
continue

params: DeciderParams = task
cache_key = extract_cache_key(params.resolved_config)

Expand Down Expand Up @@ -54,11 +130,30 @@ def decider_worker_func(task_queue: Queue, result_queue: Queue):
except (KeyboardInterrupt, SystemExit):
break
except Exception as e:
error_msg = f"{str(e)}\n{traceback.format_exc()}"
logger.error("Worker error:\n%s", traceback.format_exc())
error_msg = _format_worker_error(e)
result_queue.put(Exception(error_msg))
finally:
for _, (_, cleanup_func) in model_cache.items():
try:
cleanup_func()
except Exception:
pass


def _format_worker_error(error: Exception) -> str:
error_text = str(error)
gated_tokens = (
"GatedRepoError",
"gated repo",
"401 Client Error",
"Access to model",
"restricted",
"Please log in",
)
if any(token in error_text for token in gated_tokens):
return (
"Model access denied. Authenticate with Hugging Face or request access "
"to the gated repo."
)
return f"{error_text}\n{traceback.format_exc()}"
52 changes: 38 additions & 14 deletions align_app/app/runs_state_adapter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Optional, Callable
from trame.app import asynchronous
from trame.app.file_upload import ClientFile
Expand All @@ -7,6 +8,7 @@
from .runs_registry import RunsRegistry
from .runs_table_filter import RunsTableFilter
from ..adm.decider.types import DeciderParams
from ..adm.decider import get_model_cache_status
from ..adm.system_adm_discovery import discover_system_adms
from ..utils.utils import get_id
from .runs_presentation import extract_base_scenarios
Expand All @@ -15,6 +17,8 @@
from .import_experiments import import_experiments_from_zip
from align_utils.models import AlignmentTarget

logger = logging.getLogger(__name__)


@TrameApp()
class RunsStateAdapter:
Expand Down Expand Up @@ -351,35 +355,32 @@ def delete_run_alignment_attribute(self, run_id: str, attr_index: int):
new_run = self.runs_registry.delete_run_alignment_attribute(run_id, attr_index)
self._handle_run_update(run_id, new_run)

@controller.set("update_run_probe_text")
@trigger("update_run_probe_text")
def update_run_probe_text(self, run_id: str, text: str):
if run_id in self.state.runs:
self.state.runs[run_id]["prompt"]["probe"]["display_state"] = text
self.state.dirty("runs")
choices = self.state.runs[run_id]["prompt"]["probe"]["choices"]
self.state.probe_dirty[run_id] = self._is_probe_edited(
run_id, text, choices
)
self.state.dirty("probe_dirty")

@controller.set("update_run_choice_text")
@trigger("update_run_choice_text")
def update_run_choice_text(self, run_id: str, index: int, text: str):
if run_id in self.state.runs:
choices = self.state.runs[run_id]["prompt"]["probe"]["choices"]
if 0 <= index < len(choices):
choices[index]["unstructured"] = text
self.state.dirty("runs")
probe_text = self.state.runs[run_id]["prompt"]["probe"]["display_state"]
self.state.probe_dirty[run_id] = self._is_probe_edited(
run_id, probe_text, choices
)
self.state.dirty("probe_dirty")

@controller.set("update_run_config_yaml")
@trigger("update_run_config_yaml")
def update_run_config_yaml(self, run_id: str, yaml_text: str):
if run_id in self.state.runs:
self.state.runs[run_id]["prompt"]["resolved_config_yaml"] = yaml_text
self.state.dirty("runs")
is_edited = self._is_config_edited(run_id, yaml_text)
self.state.config_dirty[run_id] = is_edited
self.state.dirty("config_dirty")
Expand Down Expand Up @@ -521,7 +522,15 @@ def _create_run_with_edited_config(
run = self.runs_registry.get_run(run_id)
if not run:
return None
new_config = yaml.safe_load(current_yaml)
try:
new_config = yaml.safe_load(current_yaml)
except yaml.YAMLError:
logger.exception("Invalid YAML syntax while saving config edits")
self._alerts.create_info_alert(
title="Invalid YAML syntax. Please fix and try again.",
timeout=8000,
)
return None

decider_options = self.decider_registry.get_decider_options(
run.probe_id, run.decider_name
Expand Down Expand Up @@ -614,13 +623,20 @@ async def _execute_run_decision(self, run_id: str):
with self.state:
self._add_pending_cache_key(cache_key)

is_cached = self.runs_registry.has_cached_decision(run_id)
if not is_cached:
alert_id = self._alerts.create_info_alert(
title="Loading model and deciding...", timeout=0
)
run = self.runs_registry.get_run(run_id)
is_cached_decision = self.runs_registry.has_cached_decision(run_id)
status = None
if run:
status = await get_model_cache_status(run.decider_params.resolved_config)

if is_cached_decision or (status and status.is_cached):
alert_title = "Deciding..."
elif status and status.is_downloaded is False:
alert_title = "Downloading model and deciding..."
else:
alert_id = self._alerts.create_info_alert(title="Deciding...", timeout=0)
alert_title = "Loading model and deciding..."

alert_id = self._alerts.create_info_alert(title=alert_title, timeout=0)
await self.server.network_completion

try:
Expand All @@ -629,7 +645,15 @@ async def _execute_run_decision(self, run_id: str):
self._alerts.create_info_alert(title="Decision complete", timeout=3000)
except Exception as e:
self._alerts.remove_alert(alert_id)
self._alerts.create_info_alert(title=f"Decision failed: {e}", timeout=5000)
error_text = str(e)
if "Model access denied" in error_text:
message = (
"Decision failed: Model access denied. "
"Authenticate with Hugging Face or request access to the model."
)
else:
message = f"Decision failed: {e}"
self._alerts.create_info_alert(title=message, timeout=8000)

with self.state:
self._rebuild_comparison_runs()
Expand Down
16 changes: 11 additions & 5 deletions align_app/app/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,8 @@ def run_content():
vuetify3.VTextarea(
model_value=("runs[id].prompt.resolved_config_yaml",),
update_modelValue=(
ctrl.update_run_config_yaml,
r"[id, $event]",
r"runs[id].prompt.resolved_config_yaml = $event; "
r"trigger('update_run_config_yaml', [id, $event])"
),
auto_grow=True,
rows=1,
Expand Down Expand Up @@ -568,7 +568,10 @@ def __init__(self, server):
html.Div("Situation", classes="text-h6 pt-4")
vuetify3.VTextarea(
model_value=("runs[id].prompt.probe.display_state",),
update_modelValue=(ctrl.update_run_probe_text, "[id, $event]"),
update_modelValue=(
"runs[id].prompt.probe.display_state = $event; "
"trigger('update_run_probe_text', [id, $event])"
),
auto_grow=True,
rows=3,
hide_details="auto",
Expand All @@ -588,8 +591,8 @@ def __init__(self, server):
vuetify3.VTextarea(
model_value=("choice.unstructured",),
update_modelValue=(
ctrl.update_run_choice_text,
"[id, index, $event]",
"runs[id].prompt.probe.choices[index].unstructured = $event; "
"trigger('update_run_choice_text', [id, index, $event])"
),
auto_grow=True,
rows=1,
Expand Down Expand Up @@ -1561,6 +1564,9 @@ def __init__(
".drop-zone-active { outline: 3px dashed #1976d2 !important; outline-offset: -3px; }"
".alert-popup-container { left: auto; right: 0; transform: none; width: fit-content; }"
".alert-popup-container .v-alert { --v-theme-info: 66, 66, 66; }"
".alert-popup-container .v-alert__icon { display: none; }"
".alert-popup-container .v-alert__prepend { display: none; }"
".alert-popup-container .v-alert__prepend .v-icon { display: none; }"
"</style>'"
)
)
Expand Down
22 changes: 22 additions & 0 deletions tests/e2e/test_scenario_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,25 @@ def test_situation_text_revert_restores_original_scene(page, align_app_server):
f"Expected scene to revert to original after restoring text. "
f"Original: {original_scene}, Reverted: {reverted_scene}"
)


def test_situation_textarea_cursor_position_preserved(page, align_app_server):
"""Regression test: typing in situation textarea should not jump cursor to end."""
align_page = AlignPage(page)
align_page.goto(align_app_server)
align_page.expand_scenario_panel()

textarea = align_page.situation_textarea
expect(textarea).to_be_visible()
textarea.click()
page.keyboard.press("Control+Home")
page.wait_for_timeout(200)

page.keyboard.type("X", delay=50)
page.wait_for_timeout(1500)

cursor_position = textarea.evaluate("el => el.selectionStart")
assert cursor_position <= 2, (
f"Cursor jumped to position {cursor_position} after typing at start. "
f"Expected near position 1."
)