From 29b9e4d7f3dc8bad08bce6fcd675ef4e5d6cdda0 Mon Sep 17 00:00:00 2001 From: "Luma (Enclave AI)" Date: Mon, 11 May 2026 18:50:14 +0000 Subject: [PATCH] fix(router): add capability registry, fix tools routing, rename to PEP8 --- .gitignore | 3 - PLANS.md | 2 +- router/Dockerfile | 4 +- router/smart-model-router.py | 196 ------------------- router/smart_model_router.py | 358 +++++++++++++++++++++++++++++++++++ tests/conftest.py | 9 + tests/test_retriever.py | 132 +++++++++++++ tests/test_router.py | 312 ++++++++++++++++++++++++++++++ 8 files changed, 814 insertions(+), 202 deletions(-) delete mode 100644 router/smart-model-router.py create mode 100644 router/smart_model_router.py create mode 100644 tests/conftest.py create mode 100644 tests/test_retriever.py create mode 100644 tests/test_router.py diff --git a/.gitignore b/.gitignore index c0f3554..ab1debf 100644 --- a/.gitignore +++ b/.gitignore @@ -30,8 +30,5 @@ Thumbs.db # ── Logs *.log -# Opencode -AGENTS.md - # Obsidian workspace config (local + auto-generated by install.sh) .obsidian/ diff --git a/PLANS.md b/PLANS.md index b7f13c7..851a8a3 100644 --- a/PLANS.md +++ b/PLANS.md @@ -169,7 +169,7 @@ KHOJ_SYNC_SKIP_INITIAL, KHOJ_SYNC_LOG_LEVEL ``` RETRIEVER_PORT=42000 -RETRIEVER_VAULT_PATH=/home/netyeti/obsidian +RETRIEVER_VAULT_PATH=/home/${STACK_USER}/obsidian RETRIEVER_EMBED_MODEL=nomic-embed-text RETRIEVER_CHUNK_SIZE=512 RETRIEVER_CHUNK_OVERLAP=64 diff --git a/router/Dockerfile b/router/Dockerfile index f5ce68b..6f5843b 100644 --- a/router/Dockerfile +++ b/router/Dockerfile @@ -5,8 +5,8 @@ WORKDIR /app COPY requirements.txt . RUN pip install --no-cache-dir -r requirements.txt -COPY smart-model-router.py . +COPY smart_model_router.py . EXPOSE 40115 -CMD ["python", "smart-model-router.py"] +CMD ["python", "smart_model_router.py"] diff --git a/router/smart-model-router.py b/router/smart-model-router.py deleted file mode 100644 index a4e9cca..0000000 --- a/router/smart-model-router.py +++ /dev/null @@ -1,196 +0,0 @@ -#!/usr/bin/env python3 -""" -Smart Model Router for Olla -Auto-routes queries to the best local model based on content analysis. -Sits between OpenCode and Olla: OpenCode -> Smart Router -> Olla -> ollama-arc - -Routing: -- Diagnostics -> qwen2.5:14b (fast, reliable for sysadmin) -- Scripting/Code -> qwen2.5-coder:14b (code generation) -- Reasoning -> deepseek-r1:14b (chain-of-thought) -- Longform/Logs -> gemma3:12b (long context, summaries) -- Heavy lifting -> gemma4:27b (complex analysis, large context) -- Tool calling -> mistral-small3.2:24b (strong function calling) -- Default -> qwen3.5:14b (improved reasoning, best all-rounder) -""" - -import os -import re -import json -import httpx -from typing import Tuple -from fastapi import FastAPI, Request, Response - -OLLA_URL = os.environ.get("OLLA_URL", "http://olla:40114") -LISTEN_HOST = os.environ.get("LISTEN_HOST", "0.0.0.0") -LISTEN_PORT = int(os.environ.get("LISTEN_PORT", "40115")) - -MODELS = { - "diagnostics": "qwen2.5:14b", - "scripting": "qwen2.5-coder:14b", - "reasoning": "deepseek-r1:14b", - "longform": "gemma3:12b", - "heavy": "gemma4:27b", - "tools": "mistral-small3.2:24b", - "default": "qwen3.5:14b", -} - -PATTERNS = { - "diagnostics": [ - r"\b(diagnos|health|status|check|monitor|alert|reachable|unreachable|uptime)\b", - r"\b(system report|get_all|list models|loaded models|vram)\b", - r"\b(is .+ running|is .+ up|is .+ down|ping)\b", - r"\b(ollama|open.?webui|pipeline|container|docker)\b", - r"\b(gpu|cpu|memory|ram|disk usage)\b", - r"\b(logs? file|journal|syslog|dmesg|kern)\b", - ], - "scripting": [ - r"\b(script|bash|shell|command|cron|systemd|service|config)\b", - r"\b(yaml|compose|dockerfile|ansible|terraform)\b", - r"\b(fix|debug|error|traceback|exception|failed|exit code)\b", - r"\b(install|setup|configure|deploy|update|upgrade)\b", - r"\b(python|javascript|typescript|code|function|class|import)\b", - r"\b(write a|create a|generate|implement|refactor)\b.*\b(function|class|script|module)\b", - ], - "reasoning": [ - r"\b(why|root cause|explain|analyze|compare|optimize|recommend)\b", - r"\b(should i|what would you|best approach|pros and cons|trade.?off)\b", - r"\b(performance|bottleneck|slow|latency|memory leak|high cpu)\b", - r"\b(architecture|design|strategy|best practice|decouple|refactor)\b", - r"\b(math|calculate|derive|proof|theorem|logic|reason)\b", - ], - "longform": [ - r"\b(log|logs|summarize|summary|document|report)\b", - r"\b(what does this mean|walk me through|step by step|explain this)\b", - r"\b(write a|draft a|create a document|generate a report)\b", - r"\b(review|proofread|edit|rewrite|format|structure)\b", - ], - "heavy": [ - r"\b(analyze this entire|full analysis|comprehensive review)\b", - r"\b(large context|long document|big codebase|entire project)\b", - r"\b(complex|sophisticated|architectural|system.?wide)\b", - ], -} - - -def classify(text: str) -> Tuple[str, str]: - t = text.lower() - # Score each category - best_category = "default" - best_score = 0 - for category, patterns in PATTERNS.items(): - score = 0 - for pattern in patterns: - matches = re.findall(pattern, t) - score += len(matches) - if score > best_score: - best_score = score - best_category = category - return MODELS[best_category], best_category - - -def should_route(body: bytes, path: str) -> bool: - """Only route chat completion requests for local models.""" - if not path.startswith("v1/chat/completions"): - return False - try: - data = json.loads(body) - model = data.get("model", "") - # Skip cloud models — route those directly - if any(c in model for c in ("claude", "gemini", "gpt")): - return False - return True - except (json.JSONDecodeError, KeyError): - return False - - -async def handle_request(body: bytes) -> bytes: - try: - data = json.loads(body) - except json.JSONDecodeError: - return body - - messages = data.get("messages", []) - if not messages: - return body - - user_message = "" - for m in reversed(messages): - if m.get("role") == "user": - content = m.get("content", "") - # Use text content from various formats - if isinstance(content, str): - user_message = content - elif isinstance(content, list): - for part in content: - if isinstance(part, dict) and part.get("type") == "text": - user_message = part.get("text", "") - break - break - - if user_message: - model, reason = classify(user_message) - data["model"] = model - print(f"[SmartRouter] '{user_message[:80]}...' -> {model} ({reason})") - return json.dumps(data).encode() - - return body - - -app = FastAPI(title="Smart Model Router") - - -@app.api_route("/health", methods=["GET"]) -async def health(): - return {"status": "ok"} - - -@app.api_route("/v1/models", methods=["GET"]) -async def list_models(): - """Return available models from Olla.""" - async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client: - resp = await client.get(f"{OLLA_URL}/olla/ollama/v1/models") - return Response(content=resp.content, status_code=resp.status_code, headers=dict(resp.headers)) - - -@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) -async def proxy(request: Request, path: str): - if path.startswith("v1/"): - body = await request.body() - - if body and should_route(body, path): - body = await handle_request(body) - - async with httpx.AsyncClient(timeout=httpx.Timeout(connect=10.0, read=300.0, write=10.0, pool=10.0)) as client: - url = f"{OLLA_URL}/olla/ollama/{path}" - headers = dict(request.headers) - headers.pop("host", None) - headers.pop("content-length", None) - - response = await client.request( - method=request.method, - url=url, - headers=headers, - content=body, - params=dict(request.query_params), - ) - - return Response( - content=response.content, - status_code=response.status_code, - headers=dict(response.headers), - ) - - return Response(status_code=404) - - -def main(): - import uvicorn - print(f"[SmartRouter] Listening on {LISTEN_HOST}:{LISTEN_PORT}") - print(f"[SmartRouter] Forwarding to Olla at {OLLA_URL}") - print(f"[SmartRouter] Models: {', '.join(MODELS.values())}") - uvicorn.run(app, host=LISTEN_HOST, port=LISTEN_PORT) - - -if __name__ == "__main__": - main() diff --git a/router/smart_model_router.py b/router/smart_model_router.py new file mode 100644 index 0000000..794d798 --- /dev/null +++ b/router/smart_model_router.py @@ -0,0 +1,358 @@ +#!/usr/bin/env python3 +""" +Smart Model Router for Olla +Auto-routes queries to the best local model based on content analysis. +Sits between OpenCode and Olla: OpenCode -> Smart Router -> Olla -> ollama-arc + +Routing: +- Diagnostics -> qwen2.5:14b (fast, reliable for sysadmin) +- Scripting/Code -> qwen2.5-coder:14b (code generation) +- Reasoning -> deepseek-r1:14b (chain-of-thought, no tools) +- Longform/Logs -> gemma3:12b (long context, summaries, no tools) +- Heavy lifting -> gemma4:27b (complex analysis, large context, no tools) +- Tool calling -> mistral-small3.2:24b (strong function calling) +- Default -> qwen3.5:14b (improved reasoning, best all-rounder) +""" + +import os +import re +import json +import asyncio +import time +import httpx +from contextlib import asynccontextmanager +from dataclasses import dataclass +from typing import Optional, Tuple +from fastapi import FastAPI, Request, Response + +OLLA_URL = os.environ.get("OLLA_URL", "http://olla:40114") +LISTEN_HOST = os.environ.get("LISTEN_HOST", "0.0.0.0") +LISTEN_PORT = int(os.environ.get("LISTEN_PORT", "40115")) +CAPABILITY_REFRESH_INTERVAL = int(os.environ.get("CAPABILITY_REFRESH_INTERVAL", "300")) + +MODELS = { + "diagnostics": "qwen2.5:14b", + "scripting": "qwen2.5-coder:14b", + "reasoning": "deepseek-r1:14b", + "longform": "gemma3:12b", + "heavy": "gemma4:27b", + "tools": "mistral-small3.2:24b", + "default": "qwen3.5:14b", +} + +# Model families known to support / not support tool calling. +# Used as fallback when the Olla API is unavailable. +_TOOL_CAPABLE_FAMILIES = { + "mistral", "llama3", "llama3.1", "llama3.2", "llama3.3", + "qwen2.5", "qwen3", "phi3.5", "phi4", "command-r", + "aya", "granite3", "nemotron", "hermes", +} +_TOOL_INCAPABLE_FAMILIES = { + "deepseek-r1", "deepseek-r2", + "gemma", "gemma2", "gemma3", "gemma4", + "nomic", "mxbai", "snowflake", "all-minilm", +} + +# Patterns compiled once at import — not on every request. +# re.search (not findall) is used at runtime: we only need presence, not count. +_RAW_PATTERNS: dict[str, list[str]] = { + "diagnostics": [ + r"\b(diagnos|health|status|check|monitor|alert|reachable|unreachable|uptime)\b", + r"\b(system report|get_all|list models|loaded models|vram)\b", + r"\b(is .+ running|is .+ up|is .+ down|ping)\b", + r"\b(ollama|open.?webui|pipeline|container|docker)\b", + r"\b(gpu|cpu|memory|ram|disk usage)\b", + r"\b(logs? file|journal|syslog|dmesg|kern)\b", + ], + "scripting": [ + r"\b(script|bash|shell|command|cron|systemd|service|config)\b", + r"\b(yaml|compose|dockerfile|ansible|terraform)\b", + r"\b(fix|debug|error|traceback|exception|failed|exit code)\b", + r"\b(install|setup|configure|deploy|update|upgrade)\b", + r"\b(python|javascript|typescript|code|function|class|import)\b", + r"\b(write a|create a|generate|implement|refactor)\b.*\b(function|class|script|module)\b", + ], + "reasoning": [ + r"\b(why|root cause|explain|analyze|compare|optimize|recommend)\b", + r"\b(should i|what would you|best approach|pros and cons|trade.?off)\b", + r"\b(performance|bottleneck|slow|latency|memory leak|high cpu)\b", + r"\b(architecture|design|strategy|best practice|decouple|refactor)\b", + r"\b(math|calculate|derive|proof|theorem|logic|reason)\b", + ], + "longform": [ + r"\b(log|logs|summarize|summary|document|report)\b", + r"\b(what does this mean|walk me through|step by step|explain this)\b", + r"\b(write a|draft a|create a document|generate a report)\b", + r"\b(review|proofread|edit|rewrite|format|structure)\b", + ], + "heavy": [ + r"\b(analyze this entire|full analysis|comprehensive review)\b", + r"\b(large context|long document|big codebase|entire project)\b", + r"\b(complex|sophisticated|architectural|system.?wide)\b", + ], +} + +PATTERNS: dict[str, list[re.Pattern]] = { + category: [re.compile(p) for p in patterns] + for category, patterns in _RAW_PATTERNS.items() +} + +# Cap text fed to classify() — avoids O(n) regex over large payloads +_CLASSIFY_MAX_CHARS = 500 + + +@dataclass +class ModelCapabilities: + name: str + tools: bool + available: bool = True + + +class CapabilityRegistry: + """ + Discovers which models are loaded in Olla and their tool-calling capability. + + On startup and periodically, queries Olla's /v1/models endpoint to build a + live list of available models. Tool support is inferred from the model family + name; this avoids requiring a separate /api/show call per model. + + If Olla is unreachable, the registry falls back to the static MODELS dict — + routing still works, it just won't know about availability. + """ + + def __init__(self): + self._registry: dict[str, ModelCapabilities] = {} + self._last_refresh: float = 0.0 + + @staticmethod + def _infer_tools(model_name: str) -> bool: + base = model_name.split(":")[0].lower() + for family in _TOOL_INCAPABLE_FAMILIES: + if base.startswith(family) or family in base: + return False + for family in _TOOL_CAPABLE_FAMILIES: + if base.startswith(family) or family in base: + return True + return True # optimistic default for unknown families + + async def refresh(self) -> None: + try: + async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client: + resp = await client.get(f"{OLLA_URL}/olla/ollama/v1/models") + resp.raise_for_status() + data = resp.json() + models = data.get("data", []) + fresh: dict[str, ModelCapabilities] = {} + for m in models: + name = m.get("id", "") + if name: + fresh[name] = ModelCapabilities( + name=name, + tools=self._infer_tools(name), + available=True, + ) + self._registry = fresh + self._last_refresh = time.monotonic() + print(f"[SmartRouter] Capability registry refreshed — {len(fresh)} models: " + f"{', '.join(fresh)}") + except Exception as exc: + print(f"[SmartRouter] Capability refresh failed ({exc}) — " + f"routing with static MODELS fallback") + + def supports_tools(self, model_name: str) -> bool: + cap = self._registry.get(model_name) + return cap.tools if cap else self._infer_tools(model_name) + + def is_available(self, model_name: str) -> bool: + if not self._registry: + return True # no data yet — assume available + return model_name in self._registry + + def best_tools_model(self) -> Optional[str]: + """Return the first available model that supports tool calling.""" + for name, cap in self._registry.items(): + if cap.tools and cap.available: + return name + return None + + @property + def stale(self) -> bool: + return (time.monotonic() - self._last_refresh) > CAPABILITY_REFRESH_INTERVAL + + +registry = CapabilityRegistry() + + +async def _periodic_refresh(): + while True: + await asyncio.sleep(CAPABILITY_REFRESH_INTERVAL) + if registry.stale: + await registry.refresh() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + await registry.refresh() + task = asyncio.create_task(_periodic_refresh()) + yield + task.cancel() + + +def classify(text: str) -> Tuple[str, str]: + # Cap length before regex — avoids O(n) work on large payloads + t = text[:_CLASSIFY_MAX_CHARS].lower() + best_category = "default" + best_score = 0 + for category, compiled in PATTERNS.items(): + # re.search (not findall) — presence check is enough, avoids collecting all matches + score = sum(1 for p in compiled if p.search(t)) + if score > best_score: + best_score = score + best_category = category + return MODELS[best_category], best_category + + +def _parse_body(body: bytes) -> Optional[dict]: + """Parse request body once. Returns None on invalid JSON.""" + try: + return json.loads(body) + except (json.JSONDecodeError, ValueError): + return None + + +def should_route(data: dict, path: str) -> bool: + """Only route chat completion requests for local models.""" + if not path.startswith("v1/chat/completions"): + return False + model = data.get("model", "") + return not any(c in model for c in ("claude", "gemini", "gpt")) + + +async def handle_request(data: dict) -> dict: + messages = data.get("messages", []) + if not messages: + return data + + user_message = "" + for m in reversed(messages): + if m.get("role") == "user": + content = m.get("content", "") + if isinstance(content, str): + user_message = content + elif isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + user_message = part.get("text", "") + break + break + + if not user_message: + return data + + needs_tools = bool(data.get("tools") or data.get("functions")) + + if needs_tools: + preferred = MODELS["tools"] + if not registry.supports_tools(preferred): + fallback = registry.best_tools_model() + if fallback: + print(f"[SmartRouter] {preferred} has no tool support, using {fallback}") + preferred = fallback + else: + print(f"[SmartRouter] WARNING: no tool-capable model available; " + f"sending {preferred} anyway (may fail)") + model, reason = preferred, "tools" + else: + model, reason = classify(user_message) + if not registry.is_available(model): + fallback = MODELS["default"] + print(f"[SmartRouter] {model} not available, falling back to {fallback}") + model, reason = fallback, f"fallback ({reason})" + + data["model"] = model + print(f"[SmartRouter] '{user_message[:80]}' -> {model} ({reason})") + return data + + +app = FastAPI(title="Smart Model Router", lifespan=lifespan) + + +@app.get("/health") +async def health(): + return { + "status": "ok", + "models_loaded": len(registry._registry), + "registry_stale": registry.stale, + } + + +@app.get("/v1/models") +async def list_models(): + """Return available models from Olla.""" + async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client: + resp = await client.get(f"{OLLA_URL}/olla/ollama/v1/models") + return Response(content=resp.content, status_code=resp.status_code, + headers=dict(resp.headers)) + + +@app.get("/v1/router/capabilities") +async def list_capabilities(): + """Return the current capability registry — useful for debugging routing decisions.""" + await registry.refresh() + return { + "models": [ + {"name": cap.name, "tools": cap.tools, "available": cap.available} + for cap in registry._registry.values() + ], + "preferred_tools_model": registry.best_tools_model(), + "configured_models": MODELS, + } + + +@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) +async def proxy(request: Request, path: str): + if path.startswith("v1/"): + raw_body = await request.body() + body = raw_body # sent downstream unless routing modifies it + + # Parse once — reuse the parsed dict for both the routing check and mutation + if raw_body: + data = _parse_body(raw_body) + if data and should_route(data, path): + routed = await handle_request(data) + body = json.dumps(routed).encode() + + async with httpx.AsyncClient(timeout=httpx.Timeout(connect=10.0, read=300.0, + write=10.0, pool=10.0)) as client: + url = f"{OLLA_URL}/olla/ollama/{path}" + headers = dict(request.headers) + headers.pop("host", None) + headers.pop("content-length", None) + + response = await client.request( + method=request.method, + url=url, + headers=headers, + content=body, + params=dict(request.query_params), + ) + + return Response( + content=response.content, + status_code=response.status_code, + headers=dict(response.headers), + ) + + return Response(status_code=404) + + +def main(): + import uvicorn + print(f"[SmartRouter] Listening on {LISTEN_HOST}:{LISTEN_PORT}") + print(f"[SmartRouter] Forwarding to Olla at {OLLA_URL}") + print(f"[SmartRouter] Capability refresh interval: {CAPABILITY_REFRESH_INTERVAL}s") + uvicorn.run(app, host=LISTEN_HOST, port=LISTEN_PORT) + + +if __name__ == "__main__": + main() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..bf6c2e4 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,9 @@ +""" +pytest configuration and shared fixtures. +""" + +import sys +import os + +# Add retriever directory to path so tests can import retriever modules +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "retriever")) diff --git a/tests/test_retriever.py b/tests/test_retriever.py new file mode 100644 index 0000000..9d900f2 --- /dev/null +++ b/tests/test_retriever.py @@ -0,0 +1,132 @@ +""" +Tests for retriever/main.py endpoints. + +Mocks DB and embedding dependencies — no sqlite DB or Ollama service required. +The retriever module is imported with sys.path set in conftest.py. +""" + +import pytest +from unittest.mock import patch, AsyncMock, MagicMock +from fastapi.testclient import TestClient + + +def _make_client(): + """Create a TestClient with all heavy I/O dependencies patched out.""" + with ( + patch("main.setup_db"), + patch("main.indexed_file_count", return_value=5), + patch("main.total_chunk_count", return_value=42), + patch("main.is_indexing", return_value=False), + patch("main.VAULT_PATH", "/tmp/test-vault"), + patch("main.stop_watcher"), + patch("os.path.isdir", return_value=False), # skip initial scan in lifespan + ): + import main as retriever_main + with TestClient(retriever_main.app) as client: + yield client + + +@pytest.fixture(scope="module") +def client(): + yield from _make_client() + + +# --------------------------------------------------------------------------- +# GET /health +# --------------------------------------------------------------------------- + +class TestHealth: + def test_returns_200(self, client): + resp = client.get("/health") + assert resp.status_code == 200 + + def test_status_is_ok(self, client): + resp = client.get("/health") + assert resp.json()["status"] == "ok" + + def test_response_includes_required_fields(self, client): + resp = client.get("/health") + body = resp.json() + required = {"status", "indexed_files", "total_chunks", "vault_watching", "vault_path", "is_indexing"} + assert required.issubset(body.keys()) + + def test_indexed_files_and_chunks_are_integers(self, client): + resp = client.get("/health") + body = resp.json() + assert isinstance(body["indexed_files"], int) + assert isinstance(body["total_chunks"], int) + + +# --------------------------------------------------------------------------- +# POST /search +# --------------------------------------------------------------------------- + +class TestSearch: + def test_returns_200_for_valid_query(self, client): + with patch("main.embed_text", new_callable=AsyncMock, return_value=[0.1] * 384): + with patch("main.hybrid_search", return_value=[]): + resp = client.post("/search", json={"query": "test"}) + assert resp.status_code == 200 + + def test_response_has_results_list(self, client): + with patch("main.embed_text", new_callable=AsyncMock, return_value=[0.1] * 384): + with patch("main.hybrid_search", return_value=[]): + resp = client.post("/search", json={"query": "test"}) + assert "results" in resp.json() + assert isinstance(resp.json()["results"], list) + + def test_missing_query_returns_422(self, client): + resp = client.post("/search", json={}) + assert resp.status_code == 422 + + def test_embed_failure_returns_empty_results(self, client): + """When embedding fails, return empty results gracefully.""" + with patch("main.embed_text", new_callable=AsyncMock, return_value=None): + resp = client.post("/search", json={"query": "broken embed"}) + assert resp.status_code == 200 + assert resp.json()["results"] == [] + + def test_results_have_correct_schema(self, client): + mock_result = { + "filepath": "/vault/notes.md", + "chunk_index": 0, + "content": "Sample content here", + "parent_heading": "## Notes", + "score": 0.87, + } + with patch("main.embed_text", new_callable=AsyncMock, return_value=[0.1] * 384): + with patch("main.hybrid_search", return_value=[mock_result]): + resp = client.post("/search", json={"query": "notes"}) + assert resp.status_code == 200 + results = resp.json()["results"] + assert len(results) == 1 + r = results[0] + assert "filepath" in r + assert "chunk_index" in r + assert "content" in r + assert "score" in r + + def test_top_k_is_accepted(self, client): + with patch("main.embed_text", new_callable=AsyncMock, return_value=[0.1] * 384): + with patch("main.hybrid_search", return_value=[]): + resp = client.post("/search", json={"query": "test", "top_k": 3}) + assert resp.status_code == 200 + + +# --------------------------------------------------------------------------- +# POST /reindex +# --------------------------------------------------------------------------- + +class TestReindex: + def test_returns_200(self, client): + mock_db = MagicMock() + # get_db is imported locally inside the reindex handler, so patch at source + with patch("search.get_db", return_value=mock_db): + resp = client.post("/reindex") + assert resp.status_code == 200 + + def test_returns_reindexing_status(self, client): + mock_db = MagicMock() + with patch("search.get_db", return_value=mock_db): + resp = client.post("/reindex") + assert resp.json()["status"] == "reindexing" diff --git a/tests/test_router.py b/tests/test_router.py new file mode 100644 index 0000000..2c4f0f0 --- /dev/null +++ b/tests/test_router.py @@ -0,0 +1,312 @@ +""" +Tests for router/smart_model_router.py + +Tests classify(), should_route(), handle_request(), CapabilityRegistry._infer_tools(), +and registry fallback behaviour. No external services needed. +""" + +import json +import pytest +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "router")) +from smart_model_router import ( + classify, should_route, handle_request, _parse_body, + MODELS, CapabilityRegistry, + _TOOL_CAPABLE_FAMILIES, _TOOL_INCAPABLE_FAMILIES, +) + + +# --------------------------------------------------------------------------- +# classify() — keyword-based model selection +# --------------------------------------------------------------------------- + +class TestClassify: + def test_diagnostic_keywords_route_to_diagnostics_model(self): + model, category = classify("check if ollama is running") + assert category == "diagnostics" + assert model == MODELS["diagnostics"] + + def test_code_keywords_route_to_scripting_model(self): + model, category = classify("write a python script to parse JSON") + assert category == "scripting" + assert model == MODELS["scripting"] + + def test_reasoning_keywords_route_to_reasoning_model(self): + model, category = classify("explain why this architecture is slow") + assert category == "reasoning" + assert model == MODELS["reasoning"] + + def test_longform_keywords_route_to_longform_model(self): + model, category = classify("summarize the logs from last night") + assert category == "longform" + assert model == MODELS["longform"] + + def test_heavy_analysis_routes_to_heavy_model(self): + # "comprehensive review" (heavy p1) + "entire project" (heavy p2) = score 2 + # beats longform ("review" = 1) and reasoning (0) without ambiguous tie + model, category = classify("comprehensive review of the entire project") + assert category == "heavy" + assert model == MODELS["heavy"] + + def test_empty_input_routes_to_default(self): + model, category = classify("") + assert category == "default" + assert model == MODELS["default"] + + def test_generic_input_routes_to_default(self): + model, category = classify("hello") + assert category == "default" + assert model == MODELS["default"] + + def test_case_insensitive_matching(self): + model1, cat1 = classify("CHECK the system status") + model2, cat2 = classify("check the system status") + assert cat1 == cat2 + assert model1 == model2 + + def test_long_text_is_capped_at_500_chars(self): + """Very long inputs should not cause excessive regex work.""" + # "health" is an exact whole-word match; "diagnose" is not ("diagnos" + \b fails on "diagnose") + long_text = "health " + "x" * 10_000 + model, category = classify(long_text) + assert category == "diagnostics" + + def test_returns_valid_model_name(self): + model, _ = classify("check memory usage") + assert model in MODELS.values() + + +# --------------------------------------------------------------------------- +# _parse_body() — JSON parsing helper +# --------------------------------------------------------------------------- + +class TestParseBody: + def test_parses_valid_json(self): + body = json.dumps({"model": "test"}).encode() + result = _parse_body(body) + assert result == {"model": "test"} + + def test_returns_none_for_invalid_json(self): + assert _parse_body(b"not json") is None + + def test_returns_none_for_empty_bytes(self): + assert _parse_body(b"") is None + + +# --------------------------------------------------------------------------- +# should_route() — now takes a parsed dict +# --------------------------------------------------------------------------- + +class TestShouldRoute: + def _data(self, model: str) -> dict: + return {"model": model, "messages": [{"role": "user", "content": "hi"}]} + + def test_routes_local_model_on_chat_endpoint(self): + assert should_route(self._data("qwen3.5:14b"), "v1/chat/completions") is True + + def test_does_not_route_claude_model(self): + assert should_route(self._data("claude-3-5-sonnet"), "v1/chat/completions") is False + + def test_does_not_route_gpt_model(self): + assert should_route(self._data("gpt-4o"), "v1/chat/completions") is False + + def test_does_not_route_gemini_model(self): + assert should_route(self._data("gemini-pro"), "v1/chat/completions") is False + + def test_does_not_route_non_chat_path(self): + assert should_route(self._data("qwen3.5:14b"), "v1/models") is False + + def test_does_not_route_health_path(self): + assert should_route(self._data("qwen3.5:14b"), "health") is False + + +# --------------------------------------------------------------------------- +# handle_request() — tool-detection and routing (takes + returns dict) +# --------------------------------------------------------------------------- + +class TestHandleRequest: + def _data(self, content: str, tools=None, functions=None) -> dict: + d = {"model": "qwen3.5:14b", "messages": [{"role": "user", "content": content}]} + if tools is not None: + d["tools"] = tools + if functions is not None: + d["functions"] = functions + return d + + @pytest.mark.asyncio + async def test_tool_request_routes_to_tools_model(self): + """Tool-bearing requests must never land on a non-tool model.""" + tools = [{"type": "function", "function": {"name": "get_weather"}}] + result = await handle_request(self._data("what is the weather?", tools=tools)) + assert result["model"] == MODELS["tools"], ( + f"Tool request routed to {result['model']} — must be {MODELS['tools']}" + ) + + @pytest.mark.asyncio + async def test_functions_field_forces_tools_model(self): + """Legacy 'functions' field should also trigger tool-model routing.""" + functions = [{"name": "search", "description": "Search"}] + result = await handle_request(self._data("search for cats", functions=functions)) + assert result["model"] == MODELS["tools"] + + @pytest.mark.asyncio + async def test_reasoning_without_tools_routes_to_reasoning_model(self): + """deepseek-r1 is acceptable when tools are NOT in the request.""" + result = await handle_request(self._data("explain why this is slow and analyze the architecture")) + assert result["model"] == MODELS["reasoning"] + + @pytest.mark.asyncio + async def test_core_regression_tool_request_does_not_hit_deepseek(self): + """The original bug: tool request must never route to deepseek-r1.""" + tools = [{"type": "function", "function": {"name": "run_cmd"}}] + result = await handle_request( + self._data("analyze the architecture and recommend changes", tools=tools) + ) + assert result["model"] != MODELS["reasoning"], ( + "deepseek-r1 does not support tools — must not receive tool-bearing requests" + ) + + @pytest.mark.asyncio + async def test_no_messages_returns_data_unchanged(self): + data = {"model": "qwen3.5:14b", "messages": []} + result = await handle_request(data) + assert result is data + + @pytest.mark.asyncio + async def test_no_user_message_returns_data_unchanged(self): + data = {"model": "qwen3.5:14b", "messages": [{"role": "system", "content": "You are helpful."}]} + result = await handle_request(data) + assert result is data + + @pytest.mark.asyncio + async def test_multipart_content_format_is_handled(self): + """Content as a list of typed parts should still classify correctly.""" + data = { + "model": "qwen3.5:14b", + "messages": [{ + "role": "user", + "content": [{"type": "text", "text": "diagnose the gpu health"}] + }] + } + result = await handle_request(data) + assert result["model"] == MODELS["diagnostics"] + + @pytest.mark.asyncio + async def test_returns_dict_not_bytes(self): + """handle_request now returns a dict, not bytes.""" + result = await handle_request(self._data("hello world")) + assert isinstance(result, dict) + + +# --------------------------------------------------------------------------- +# CapabilityRegistry._infer_tools() — static method, no I/O +# --------------------------------------------------------------------------- + +class TestInferTools: + """ + _infer_tools() is the last line of defense when Olla is unreachable. + These cases cover the known incapable families plus optimistic default. + """ + + def test_deepseek_r1_does_not_support_tools(self): + assert CapabilityRegistry._infer_tools("deepseek-r1:14b") is False + + def test_deepseek_r2_does_not_support_tools(self): + assert CapabilityRegistry._infer_tools("deepseek-r2:70b") is False + + def test_gemma3_does_not_support_tools(self): + assert CapabilityRegistry._infer_tools("gemma3:12b") is False + + def test_gemma4_does_not_support_tools(self): + assert CapabilityRegistry._infer_tools("gemma4:27b") is False + + def test_nomic_embed_does_not_support_tools(self): + assert CapabilityRegistry._infer_tools("nomic-embed-text:latest") is False + + def test_mistral_supports_tools(self): + assert CapabilityRegistry._infer_tools("mistral-small3.2:24b") is True + + def test_llama3_supports_tools(self): + assert CapabilityRegistry._infer_tools("llama3.1:8b") is True + + def test_qwen2_5_supports_tools(self): + assert CapabilityRegistry._infer_tools("qwen2.5:14b") is True + + def test_qwen3_supports_tools(self): + assert CapabilityRegistry._infer_tools("qwen3.5:14b") is True + + def test_phi4_supports_tools(self): + assert CapabilityRegistry._infer_tools("phi4:14b") is True + + def test_unknown_model_defaults_to_tools_capable(self): + """Optimistic default: unknown families assumed tool-capable.""" + assert CapabilityRegistry._infer_tools("unknown-new-model:7b") is True + + def test_tag_stripped_before_family_check(self): + """The ':tag' suffix must not affect family detection.""" + assert CapabilityRegistry._infer_tools("deepseek-r1:latest") is False + assert CapabilityRegistry._infer_tools("mistral:latest") is True + + +# --------------------------------------------------------------------------- +# CapabilityRegistry — in-memory state (no Olla calls) +# --------------------------------------------------------------------------- + +class TestCapabilityRegistryState: + def _make_registry_with(self, models: dict[str, bool]) -> CapabilityRegistry: + """Build a registry with pre-populated state.""" + from smart_model_router import ModelCapabilities + reg = CapabilityRegistry() + reg._registry = { + name: ModelCapabilities(name=name, tools=capable) + for name, capable in models.items() + } + return reg + + def test_supports_tools_returns_true_for_capable_model(self): + reg = self._make_registry_with({"mistral-small3.2:24b": True}) + assert reg.supports_tools("mistral-small3.2:24b") is True + + def test_supports_tools_returns_false_for_incapable_model(self): + reg = self._make_registry_with({"deepseek-r1:14b": False}) + assert reg.supports_tools("deepseek-r1:14b") is False + + def test_supports_tools_falls_back_to_infer_for_unknown_model(self): + """When model not in registry, infer from name.""" + reg = CapabilityRegistry() # empty registry + assert reg.supports_tools("deepseek-r1:14b") is False + assert reg.supports_tools("mistral:7b") is True + + def test_is_available_returns_true_when_model_in_registry(self): + reg = self._make_registry_with({"qwen3.5:14b": True}) + assert reg.is_available("qwen3.5:14b") is True + + def test_is_available_returns_false_for_missing_model(self): + reg = self._make_registry_with({"qwen3.5:14b": True}) + assert reg.is_available("deepseek-r1:14b") is False + + def test_is_available_returns_true_when_registry_empty(self): + """Empty registry (Olla not yet reached) assumes all available.""" + reg = CapabilityRegistry() + assert reg.is_available("any-model:7b") is True + + def test_best_tools_model_returns_first_capable(self): + reg = self._make_registry_with({ + "deepseek-r1:14b": False, + "mistral-small3.2:24b": True, + }) + result = reg.best_tools_model() + assert result == "mistral-small3.2:24b" + + def test_best_tools_model_returns_none_when_none_capable(self): + reg = self._make_registry_with({ + "deepseek-r1:14b": False, + "gemma3:12b": False, + }) + assert reg.best_tools_model() is None + + def test_best_tools_model_returns_none_when_registry_empty(self): + reg = CapabilityRegistry() + assert reg.best_tools_model() is None