diff --git a/align_app/adm/decider/__init__.py b/align_app/adm/decider/__init__.py index c711656..9fbc8a8 100644 --- a/align_app/adm/decider/__init__.py +++ b/align_app/adm/decider/__init__.py @@ -1,11 +1,12 @@ from align_utils.models import ADMResult, Decision, ChoiceInfo from .decider import MultiprocessDecider -from .client import get_decision +from .client import get_decision, is_model_cached from .types import DeciderParams __all__ = [ "MultiprocessDecider", "get_decision", + "is_model_cached", "DeciderParams", "ADMResult", "Decision", diff --git a/align_app/adm/decider/client.py b/align_app/adm/decider/client.py index 67746be..0f955b5 100644 --- a/align_app/adm/decider/client.py +++ b/align_app/adm/decider/client.py @@ -4,6 +4,7 @@ """ import atexit +from typing import Dict, Any from align_utils.models import ADMResult from .decider import MultiprocessDecider from .types import DeciderParams @@ -19,6 +20,12 @@ def _get_process_manager(): return _decider +async def is_model_cached(resolved_config: Dict[str, Any]) -> bool: + """Check if model for this config is already loaded in worker.""" + process_manager = _get_process_manager() + return await process_manager.is_model_cached(resolved_config) + + async def get_decision(params: DeciderParams) -> ADMResult: """Get a decision using DeciderParams. diff --git a/align_app/adm/decider/decider.py b/align_app/adm/decider/decider.py index d4ccbe6..4540645 100644 --- a/align_app/adm/decider/decider.py +++ b/align_app/adm/decider/decider.py @@ -1,6 +1,7 @@ +from typing import Dict, Any from align_utils.models import ADMResult from .types import DeciderParams -from .worker import decider_worker_func +from .worker import decider_worker_func, CacheQuery, CacheQueryResult from .multiprocess_worker import ( WorkerHandle, create_worker, @@ -13,6 +14,12 @@ class MultiprocessDecider: def __init__(self): self.worker: WorkerHandle = create_worker(decider_worker_func) + async def is_model_cached(self, resolved_config: Dict[str, Any]) -> bool: + self.worker, result = await send(self.worker, CacheQuery(resolved_config)) + if isinstance(result, CacheQueryResult): + return result.is_cached + return False + async def get_decision(self, params: DeciderParams) -> ADMResult: self.worker, result = await send(self.worker, params) diff --git a/align_app/adm/decider/worker.py b/align_app/adm/decider/worker.py index c80af6b..5ed5749 100644 --- a/align_app/adm/decider/worker.py +++ b/align_app/adm/decider/worker.py @@ -3,6 +3,7 @@ import json import logging import traceback +from dataclasses import dataclass from typing import Dict, Tuple, Callable, Any from multiprocessing import Queue from align_utils.models import ADMResult @@ -15,6 +16,16 @@ def extract_cache_key(resolved_config: Dict[str, Any]) -> str: return hashlib.md5(cache_str.encode()).hexdigest() +@dataclass +class CacheQuery: + resolved_config: Dict[str, Any] + + +@dataclass +class CacheQueryResult: + is_cached: bool + + def decider_worker_func(task_queue: Queue, result_queue: Queue): root_logger = logging.getLogger() root_logger.setLevel("WARNING") @@ -24,6 +35,13 @@ def decider_worker_func(task_queue: Queue, result_queue: Queue): try: for task in iter(task_queue.get, None): try: + if isinstance(task, CacheQuery): + cache_key = extract_cache_key(task.resolved_config) + result_queue.put( + CacheQueryResult(is_cached=cache_key in model_cache) + ) + continue + params: DeciderParams = task cache_key = extract_cache_key(params.resolved_config) diff --git a/align_app/app/alerts_controller.py b/align_app/app/alerts_controller.py index 304981c..8e55279 100644 --- a/align_app/app/alerts_controller.py +++ b/align_app/app/alerts_controller.py @@ -14,6 +14,3 @@ def show(self, message: str, timeout: int = -1): self.server.state.alert_message = message self.server.state.alert_timeout = timeout self.server.state.alert_visible = True - - def hide(self): - self.server.state.alert_visible = False diff --git a/align_app/app/runs_registry.py b/align_app/app/runs_registry.py index f29ba55..1264179 100644 --- a/align_app/app/runs_registry.py +++ b/align_app/app/runs_registry.py @@ -103,13 +103,6 @@ async def execute_run_decision(self, run_id: str) -> Optional[Run]: return await self._execute_with_cache(run, probe.choices or []) - def has_cached_decision(self, run_id: str) -> bool: - run = runs_core.get_run(self._runs, run_id) - if not run: - return False - cache_key = run.compute_cache_key() - return runs_core.get_cached_decision(self._runs, cache_key) is not None - def get_run(self, run_id: str) -> Optional[Run]: run = runs_core.get_run(self._runs, run_id) if run: diff --git a/align_app/app/runs_state_adapter.py b/align_app/app/runs_state_adapter.py index 5049d8f..5b2c0ec 100644 --- a/align_app/app/runs_state_adapter.py +++ b/align_app/app/runs_state_adapter.py @@ -6,6 +6,7 @@ from .runs_registry import RunsRegistry from .runs_table_filter import RunsTableFilter from ..adm.decider.types import DeciderParams +from ..adm.decider import is_model_cached from ..adm.system_adm_discovery import discover_system_adms from ..utils.utils import get_id from .runs_presentation import extract_base_scenarios @@ -616,12 +617,11 @@ async def _execute_run_decision(self, run_id: str): with self.state: self._add_pending_cache_key(cache_key) - is_cached = self.runs_registry.has_cached_decision(run_id) - if not is_cached: - self._alerts.show("Loading model...") - await self.server.network_completion - - self._alerts.show("Making decision...") + run = self.runs_registry.get_run(run_id) + if run and await is_model_cached(run.decider_params.resolved_config): + self._alerts.show("Deciding...") + else: + self._alerts.show("Loading model and deciding...") await self.server.network_completion try: diff --git a/align_app/app/ui.py b/align_app/app/ui.py index 742e59c..79a55c6 100644 --- a/align_app/app/ui.py +++ b/align_app/app/ui.py @@ -1544,7 +1544,7 @@ def __init__( with vuetify3.VSnackbar( v_model=("alert_visible", False), text=("alert_message", ""), - location="bottom left", + location="bottom right", color="white", timeout=("alert_timeout", -1), content_class="text-h6 font-weight-medium",