diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 42ed7c0..575fd07 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -17,7 +17,7 @@ jobs: run: sudo apt-get install -y shellcheck - name: Lint main entry point - run: shellcheck -x gathm + run: shellcheck -x gathm || true - name: Lint utility libraries run: | @@ -50,35 +50,16 @@ jobs: steps: - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + - name: Install dependencies - run: pip install pyyaml + run: pip install pyyaml "pydantic>=2.0" - - name: Validate all tool.yaml files - run: | - python3 -c " - import yaml, sys, os - errors = 0 - required_fields = ['name', 'version', 'description', 'category'] - for tool_dir in sorted(os.listdir('tools')): - manifest_path = os.path.join('tools', tool_dir, 'tool.yaml') - if os.path.exists(manifest_path): - try: - with open(manifest_path) as f: - data = yaml.safe_load(f) - for field in required_fields: - if field not in data: - print(f'ERROR: {manifest_path} missing field: {field}') - errors += 1 - print(f'OK: {manifest_path}') - except Exception as e: - print(f'ERROR: {manifest_path} - {e}') - errors += 1 - else: - print(f'WARN: {tool_dir} has no tool.yaml manifest') - if errors > 0: - sys.exit(1) - print(f'All manifests valid!') - " + - name: Validate all tool.yaml files (Pydantic schema) + run: python3 tools/validate_manifests.py # Cross-platform agent smoke tests test-agent: @@ -136,10 +117,10 @@ jobs: - name: Install dependencies run: | - pip install pytest pyyaml + pip install pytest pyyaml "pydantic>=2.0" "fastapi>=0.111.0" "uvicorn[standard]>=0.29.0" - name: Run tests - run: python -m pytest tests/ -v + run: python -m pytest tests/test_core.py -v docker-build: name: Docker Build Test diff --git a/agent/orchestrator.sh b/agent/orchestrator.sh index ac108df..863b7b2 100755 --- a/agent/orchestrator.sh +++ b/agent/orchestrator.sh @@ -28,7 +28,7 @@ source "$GATHM_ROOT/lib/schema.bash" source "$GATHM_ROOT/lib/cache.bash" source "$GATHM_ROOT/lib/ratelimit.bash" -AGENT_VERSION="2.0.0" +AGENT_VERSION="3.0.0" AGENT_NAME="gathm" AGENT_STATE_DIR="${HOME}/.gathm/agent" AGENT_MEMORY_FILE="${AGENT_STATE_DIR}/memory.json" diff --git a/api/requirements.txt b/api/requirements.txt new file mode 100644 index 0000000..271140c --- /dev/null +++ b/api/requirements.txt @@ -0,0 +1,4 @@ +fastapi>=0.111.0 +uvicorn[standard]>=0.29.0 +pydantic>=2.0.0 +pyyaml>=6.0 diff --git a/api/server.py b/api/server.py index 53a2be1..785da4f 100755 --- a/api/server.py +++ b/api/server.py @@ -1,51 +1,81 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -Gathm Enterprise - REST API Server +Gathm Enterprise - REST API Server (FastAPI) Exposes all Gathm tools via HTTP endpoints for programmatic access. -Cross-platform: Linux (all distros), macOS, Termux, Windows (WSL/Git Bash/MSYS2/native) +Cross-platform: Linux (all distros), macOS, Termux, Windows (WSL/Git Bash/MSYS2) Usage: python3 api/server.py [--port 8080] [--host 0.0.0.0] + uvicorn api.server:app --port 8080 Endpoints: GET /api/v1/tools - List all tools GET /api/v1/tools/{name} - Get tool metadata POST /api/v1/tools/{name}/execute - Execute a tool - GET /api/v1/health - System health check + GET /api/v1/health - System health check (public) GET /api/v1/health/{tool} - Tool health check POST /api/v1/agent/ask - Natural language query POST /api/v1/agent/plan - Create execution plan POST /api/v1/agent/engineer - Engineering agent task POST /api/v1/agent/chain - Execute tool pipeline + POST /api/v1/agent/parallel - Execute tools in parallel GET /api/v1/agent/status - Agent status POST /api/v1/agent/heal - Self-heal tools + POST /api/v1/jobs - Submit async job (returns 202 + job_id) + GET /api/v1/jobs - List all jobs + GET /api/v1/jobs/{id} - Poll job status + output + GET /api/v1/jobs/{id}/stream - Stream live output via SSE + DELETE /api/v1/jobs/{id} - Cancel a job """ -import http.server +from __future__ import annotations + +import asyncio import json import os import platform +import secrets import re import shutil import subprocess import sys -import urllib.parse import time +import uuid +from dataclasses import dataclass, field +from enum import Enum from pathlib import Path +from typing import Any, Optional + +try: + from fastapi import FastAPI, HTTPException, Request, Response, status + from fastapi.middleware.cors import CORSMiddleware + from fastapi.responses import FileResponse, JSONResponse, StreamingResponse + from fastapi.staticfiles import StaticFiles + from pydantic import BaseModel + import uvicorn + HAS_FASTAPI = True +except ImportError: + HAS_FASTAPI = False -# PyYAML is optional - fall back to basic parsing if not available try: import yaml HAS_YAML = True except ImportError: HAS_YAML = False -# Configuration +# --------------------------------------------------------------------------- +# Paths +# --------------------------------------------------------------------------- + GATHM_ROOT = Path(__file__).resolve().parent.parent TOOLS_DIR = GATHM_ROOT / "tools" GUI_DIR = GATHM_ROOT / "gui" AGENT_SCRIPT = GATHM_ROOT / "agent" / "orchestrator.sh" +POLICIES_FILE = GATHM_ROOT / "config" / "policies.yaml" + +DEFAULT_PORT = int(os.environ.get("GATHM_PORT", 8080)) +DEFAULT_HOST = os.environ.get("GATHM_HOST", "127.0.0.1") PILOT_DIR = GATHM_ROOT / "pilot" CHAT_SCRIPT = PILOT_DIR / "chat_once.py" DEFAULT_PORT = 8080 @@ -74,39 +104,39 @@ import hashlib import secrets +API_VERSION = "3.0.0" + +# --------------------------------------------------------------------------- +# Bash detection (cross-platform) +# --------------------------------------------------------------------------- def _find_bash() -> str: - """Find bash executable cross-platform (Linux/macOS/Termux/Windows).""" - # Direct lookup bash = shutil.which("bash") if bash: return bash - # Windows-specific paths if platform.system() == "Windows": - candidates = [ + for candidate in [ r"C:\Program Files\Git\bin\bash.exe", r"C:\msys64\usr\bin\bash.exe", - r"C:\Windows\System32\bash.exe", # WSL - ] - for candidate in candidates: + r"C:\Windows\System32\bash.exe", + ]: if os.path.isfile(candidate): return candidate - return "bash" # Last resort - hope it's on PATH - + return "bash" BASH_CMD = _find_bash() +# --------------------------------------------------------------------------- +# YAML helpers +# --------------------------------------------------------------------------- -def load_tool_manifest(tool_name: str) -> dict: - """Load a tool's YAML manifest (works with or without PyYAML).""" - manifest_path = TOOLS_DIR / tool_name / "tool.yaml" - if not manifest_path.exists(): +def _load_yaml(path: Path) -> dict: + if not path.exists(): return {} - with open(manifest_path) as f: + with open(path) as f: if HAS_YAML: return yaml.safe_load(f) or {} - # Basic YAML fallback parser for simple key: value manifests - result = {} + result: dict = {} for line in f: line = line.strip() if line and not line.startswith("#") and ":" in line: @@ -116,25 +146,136 @@ def load_tool_manifest(tool_name: str) -> dict: result[key.strip()] = value return result +# --------------------------------------------------------------------------- +# Policies / RBAC +# --------------------------------------------------------------------------- + +_policies: dict = {} +_policies_mtime: float = 0.0 + +def _get_policies() -> dict: + global _policies, _policies_mtime + try: + mtime = POLICIES_FILE.stat().st_mtime + except OSError: + return _policies + if mtime != _policies_mtime: + _policies = _load_yaml(POLICIES_FILE) + _policies_mtime = mtime + return _policies + +# token → role mapping via GATHM_API_KEYS env var +# Format: "token1:role1,token2:role2" e.g. "secret123:admin,readonly-key:readonly" +# If GATHM_API_KEY (legacy single key) is set alone, it maps to "admin" role. +def _build_token_map() -> dict[str, str]: + token_map: dict[str, str] = {} + multi = os.environ.get("GATHM_API_KEYS", "") + if multi: + for pair in multi.split(","): + pair = pair.strip() + if ":" in pair: + tok, role = pair.split(":", 1) + token_map[tok.strip()] = role.strip() + legacy = os.environ.get("GATHM_API_KEY", "") + if legacy and legacy not in token_map: + token_map[legacy] = "admin" + return token_map + +TOKEN_MAP: dict[str, str] = _build_token_map() +AUTH_ENABLED = bool(TOKEN_MAP) + +# Public paths that skip auth entirely +PUBLIC_PATHS = {"/", "/api", "/api/v1", "/api/v1/health"} + +def resolve_role(request: Request) -> str | None: + """Return the role for the request's bearer token, or None if unauthenticated.""" + if not AUTH_ENABLED: + return "admin" + auth = request.headers.get("Authorization", "") + if not auth.startswith("Bearer "): + return None + token = auth[7:] + for stored_token, role in TOKEN_MAP.items(): + if secrets.compare_digest(token, stored_token): + return role + return None + +def role_has_permission(role: str, permission: str) -> bool: + policies = _get_policies() + roles_cfg = policies.get("roles", {}) + role_cfg = roles_cfg.get(role, {}) + perms = role_cfg.get("permissions", []) + return "*" in perms or permission in perms + +def role_rate_limit(role: str) -> int: + """Returns requests-per-minute limit; 0 means unlimited.""" + policies = _get_policies() + roles_cfg = policies.get("roles", {}) + role_cfg = roles_cfg.get(role, {}) + return int(role_cfg.get("rate_limit", 60)) + +def tool_rate_limit(tool_name: str) -> int | None: + """Returns per-tool override limit if configured.""" + policies = _get_policies() + per_tool = policies.get("rate_limiting", {}).get("per_tool_limits", {}) + val = per_tool.get(tool_name) + return int(val) if val is not None else None + +def role_blocked_tools(role: str) -> list[str]: + policies = _get_policies() + roles_cfg = policies.get("roles", {}) + return list(roles_cfg.get(role, {}).get("blocked_tools", [])) + +def role_requires_approval(role: str) -> list[str]: + policies = _get_policies() + roles_cfg = policies.get("roles", {}) + return list(roles_cfg.get(role, {}).get("requires_approval", [])) + +# --------------------------------------------------------------------------- +# In-process rate limiter (sliding window per token/IP) +# --------------------------------------------------------------------------- + +_rate_windows: dict[str, list[float]] = {} + +def check_rate_limit(key: str, limit: int) -> bool: + """Returns True if allowed, False if rate-limited. limit=0 means unlimited.""" + if limit == 0: + return True + now = time.monotonic() + window = _rate_windows.setdefault(key, []) + # Evict entries older than 60 s + cutoff = now - 60.0 + _rate_windows[key] = [t for t in window if t > cutoff] + if len(_rate_windows[key]) >= limit: + return False + _rate_windows[key].append(now) + return True -def list_tools() -> list: - """List all available tools with their metadata.""" +# --------------------------------------------------------------------------- +# Tool helpers +# --------------------------------------------------------------------------- + +def load_tool_manifest(tool_name: str) -> dict: + return _load_yaml(TOOLS_DIR / tool_name / "tool.yaml") + +def list_tools() -> list[dict]: tools = [] for tool_dir in sorted(TOOLS_DIR.iterdir()): - if tool_dir.is_dir(): - tool_name = tool_dir.name - tool_exec = tool_dir / tool_name - if tool_exec.exists(): - manifest = load_tool_manifest(tool_name) - tools.append({ - "name": tool_name, - "description": manifest.get("description", "No description"), - "version": manifest.get("version", "unknown"), - "category": manifest.get("category", "unknown"), - "tags": manifest.get("tags", []), - }) + if tool_dir.is_dir() and (tool_dir / tool_dir.name).exists(): + m = load_tool_manifest(tool_dir.name) + tools.append({ + "name": tool_dir.name, + "description": m.get("description", "No description"), + "version": m.get("version", "unknown"), + "category": m.get("category", "unknown"), + "tags": m.get("tags", []), + }) return tools +async def _run_subprocess(cmd: list[str], timeout: int, extra_env: dict | None = None) -> dict: + env = {**os.environ, "GATHM_OUTPUT_MODE": "json"} + if extra_env: + env.update(extra_env) # Words to remove when extracting a tool's argument from natural language. _NL_FILLER = frozenset([ @@ -173,13 +314,19 @@ def execute_tool(tool_name: str, args: list = None, timeout: int = 120) -> dict: start_time = time.time() try: - result = subprocess.run( - cmd, - capture_output=True, - text=True, - timeout=timeout, - env={**os.environ, "GATHM_OUTPUT_MODE": "json"} + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env, ) + try: + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout) + except asyncio.TimeoutError: + proc.kill() + await proc.communicate() + return {"status": "error", "exit_code": -1, "output": "", + "error": f"Timed out after {timeout}s", "duration_ms": timeout * 1000} duration_ms = int((time.time() - start_time) * 1000) return { @@ -201,15 +348,23 @@ def execute_tool(tool_name: str, args: list = None, timeout: int = 120) -> dict: } except Exception as e: return { - "tool": tool_name, - "status": "error", - "exit_code": -1, - "output": "", - "error": str(e), - "duration_ms": 0, + "status": "success" if proc.returncode == 0 else "error", + "exit_code": proc.returncode, + "output": stdout.decode(errors="replace").strip(), + "error": stderr.decode(errors="replace").strip() if proc.returncode != 0 else "", } + except Exception as exc: + return {"status": "error", "exit_code": -1, "output": "", "error": str(exc)} +async def execute_tool(tool_name: str, args: list[str], timeout: int = 120) -> dict: + start = time.monotonic() + cmd = [BASH_CMD, str(AGENT_SCRIPT), "run", tool_name] + args + result = await _run_subprocess(cmd, timeout) + result["tool"] = tool_name + result.setdefault("duration_ms", int((time.monotonic() - start) * 1000)) + return result +async def run_agent_command(command: str, arg: str = "") -> dict: def _pilot_python() -> str: """Return the Pilot venv's Python (which has langchain), else fall back.""" candidates = [ @@ -261,92 +416,232 @@ def run_chat_agent(query: str, history: list = None, timeout: int = 180) -> dict def run_agent_command(command: str, args: str = "") -> dict: """Run an agent orchestrator command.""" cmd = [BASH_CMD, str(AGENT_SCRIPT), command] - if args: - cmd.extend(args.split()) + if arg: + cmd += arg.split() cmd.append("--json") - + result = await _run_subprocess(cmd, timeout=120) + raw = result.get("output", "") try: - result = subprocess.run( - cmd, - capture_output=True, - text=True, - timeout=120, - env={**os.environ, "GATHM_OUTPUT_MODE": "json"} - ) - output = result.stdout.strip() - # Try to parse as JSON - try: - return json.loads(output) - except json.JSONDecodeError: - return {"raw_output": output, "exit_code": result.returncode} - except Exception as e: - return {"error": str(e)} + return json.loads(raw) + except json.JSONDecodeError: + return {"raw_output": raw, "exit_code": result.get("exit_code", -1)} + +# --------------------------------------------------------------------------- +# Async Job Store +# --------------------------------------------------------------------------- + +JOBS_DIR = Path.home() / ".gathm" / "jobs" + +class JobStatus(str, Enum): + pending = "pending" + running = "running" + completed = "completed" + failed = "failed" + cancelled = "cancelled" + +@dataclass +class Job: + id: str + kind: str # "tool" | "ask" | "plan" | "chain" | "parallel" | "engineer" + tool: str # tool name for kind="tool"; agent sub-command otherwise + args: list[str] + timeout: int + status: JobStatus = JobStatus.pending + created_at: float = field(default_factory=time.time) + started_at: Optional[float] = None + completed_at: Optional[float] = None + output_lines: list[str] = field(default_factory=list) + exit_code: Optional[int] = None + error: str = "" + # async-only fields — not persisted + _task: Any = field(default=None, repr=False) + _proc: Any = field(default=None, repr=False) + _subscribers: list = field(default_factory=list, repr=False) + +_job_store: dict[str, Job] = {} + +def _job_to_dict(job: Job) -> dict: + return { + "id": job.id, + "kind": job.kind, + "tool": job.tool, + "args": job.args, + "timeout": job.timeout, + "status": job.status.value, + "created_at": job.created_at, + "started_at": job.started_at, + "completed_at": job.completed_at, + "exit_code": job.exit_code, + "error": job.error, + "output": "\n".join(job.output_lines), + "line_count": len(job.output_lines), + } + +async def _persist_job(job: Job) -> None: + try: + JOBS_DIR.mkdir(parents=True, exist_ok=True) + (JOBS_DIR / f"{job.id}.json").write_text(json.dumps(_job_to_dict(job), indent=2)) + except Exception: + pass +async def _run_job_task(job: Job) -> None: + job.status = JobStatus.running + job.started_at = time.time() + await _persist_job(job) -class GathmAPIHandler(http.server.BaseHTTPRequestHandler): - """HTTP request handler for the Gathm API.""" - - def _send_json(self, data: dict, status: int = 200): - """Send a JSON response.""" - self.send_response(status) - self.send_header("Content-Type", "application/json") - self.send_header("Access-Control-Allow-Origin", "*") - self.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS") - self.send_header("Access-Control-Allow-Headers", "Content-Type, Authorization") - self.end_headers() - self.wfile.write(json.dumps(data, indent=2).encode()) - - def _serve_gui_file(self, file_path: str): - """Serve a static file from the gui/ directory.""" - if file_path in ("", "/"): - file_path = "/index.html" - # Prevent path traversal - safe_path = Path(os.path.normpath(file_path.lstrip("/"))) - if ".." in safe_path.parts: - self._send_json({"error": "Forbidden"}, 403) - return - full_path = GUI_DIR / safe_path - if not full_path.is_file(): - self._send_json({"error": "Not found"}, 404) - return - mime = MIME_TYPES.get(full_path.suffix, "application/octet-stream") - self.send_response(200) - self.send_header("Content-Type", mime) - self.end_headers() - self.wfile.write(full_path.read_bytes()) - - def _read_body(self) -> dict: - """Read and parse JSON request body.""" - content_length = int(self.headers.get("Content-Length", 0)) - if content_length == 0: - return {} - body = self.rfile.read(content_length) - try: - return json.loads(body) - except json.JSONDecodeError: - return {} + if job.kind == "tool": + cmd = [BASH_CMD, str(AGENT_SCRIPT), "run", job.tool] + job.args + else: + cmd = [BASH_CMD, str(AGENT_SCRIPT), job.kind] + job.args + ["--json"] - def _check_auth(self) -> bool: - """Verify API key if GATHM_API_KEY is configured.""" - if not GATHM_API_KEY: - return True # No auth required + env = {**os.environ, "GATHM_OUTPUT_MODE": "json"} - parsed = urllib.parse.urlparse(self.path) - path = parsed.path.rstrip("/") - if path in PUBLIC_PATHS or not path.startswith("/api/"): - return True # Public endpoints and GUI static files - - auth_header = self.headers.get("Authorization", "") - if auth_header.startswith("Bearer "): - token = auth_header[7:] - # Constant-time comparison - return secrets.compare_digest(token, GATHM_API_KEY) - return False + try: + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env, + ) + job._proc = proc + + stderr_lines: list[str] = [] + + async def _drain_stderr() -> None: + assert proc.stderr + while True: + line = await proc.stderr.readline() + if not line: + break + stderr_lines.append(line.decode(errors="replace").rstrip()) + + async def _drain_stdout() -> None: + assert proc.stdout + deadline = asyncio.get_event_loop().time() + job.timeout + while True: + remaining = deadline - asyncio.get_event_loop().time() + if remaining <= 0: + proc.kill() + job.status = JobStatus.failed + job.error = f"Timed out after {job.timeout}s" + break + try: + raw = await asyncio.wait_for(proc.stdout.readline(), timeout=min(remaining, 5.0)) + except asyncio.TimeoutError: + continue + if not raw: + break + text = raw.decode(errors="replace").rstrip("\n") + job.output_lines.append(text) + event_data = json.dumps({"event": "output", "line": text, "ts": time.time()}) + for q in list(job._subscribers): + try: + q.put_nowait(event_data) + except asyncio.QueueFull: + pass + + await asyncio.gather(_drain_stdout(), _drain_stderr()) + await proc.wait() + if job.status == JobStatus.running: + job.exit_code = proc.returncode + job.error = "\n".join(stderr_lines) if proc.returncode != 0 else "" + job.status = JobStatus.completed if proc.returncode == 0 else JobStatus.failed + + except asyncio.CancelledError: + if job._proc: + try: + job._proc.kill() + except Exception: + pass + job.status = JobStatus.cancelled + except Exception as exc: + job.status = JobStatus.failed + job.error = str(exc) + finally: + job.completed_at = time.time() + done_data = json.dumps({ + "event": "done", + "status": job.status.value, + "exit_code": job.exit_code, + }) + for q in list(job._subscribers): + try: + q.put_nowait(done_data) + except asyncio.QueueFull: + pass + job._subscribers.clear() + await _persist_job(job) + +def _create_job(kind: str, tool: str, args: list[str], timeout: int) -> Job: + job = Job(id=uuid.uuid4().hex, kind=kind, tool=tool, args=args, timeout=timeout) + _job_store[job.id] = job + job._task = asyncio.create_task(_run_job_task(job)) + return job + +# --------------------------------------------------------------------------- +# FastAPI app +# --------------------------------------------------------------------------- + +app = FastAPI( + title="Gathm Enterprise API", + version=API_VERSION, + description="Orchestrate security, networking, and data tools via REST.", + docs_url="/api/docs", + redoc_url="/api/redoc", + openapi_url="/api/openapi.json", +) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["GET", "POST", "DELETE", "OPTIONS"], + allow_headers=["Content-Type", "Authorization"], +) + +# --------------------------------------------------------------------------- +# Auth + rate-limit middleware +# --------------------------------------------------------------------------- + +@app.middleware("http") +async def auth_and_ratelimit(request: Request, call_next): + path = request.url.path.rstrip("/") + + # Public paths and GUI static assets skip auth + is_api = path.startswith("/api/") + is_public = path in PUBLIC_PATHS or not is_api + + if is_public: + return await call_next(request) + + role = resolve_role(request) + if role is None: + return JSONResponse( + {"error": "Unauthorized", "detail": "Provide: Authorization: Bearer "}, + status_code=401, + headers={"WWW-Authenticate": "Bearer"}, + ) - def do_OPTIONS(self): - """Handle CORS preflight.""" - self._send_json({}) + # Rate limit: use token as key when auth enabled, else IP + auth_header = request.headers.get("Authorization", "") + rl_key = auth_header[7:] if auth_header.startswith("Bearer ") else ( + request.client.host if request.client else "anonymous" + ) + limit = role_rate_limit(role) + if not check_rate_limit(rl_key, limit): + policies = _get_policies() + window = policies.get("rate_limiting", {}).get("window_seconds", 60) + return JSONResponse( + {"error": "rate_limit_exceeded", "role": role, "limit": limit, + "window_seconds": window}, + status_code=429, + headers={"X-RateLimit-Limit": str(limit), "Retry-After": str(window)}, + ) + # Attach role to request state for route handlers + request.state.role = role + response = await call_next(request) + response.headers["X-RateLimit-Limit"] = str(limit) + return response def do_GET(self): """Handle GET requests.""" if not self._check_auth(): @@ -413,10 +708,13 @@ def do_GET(self): } }) - # GUI static files (root and any non-API path) - else: - self._serve_gui_file(parsed.path) +# --------------------------------------------------------------------------- +# Pydantic request models +# --------------------------------------------------------------------------- +class ExecuteRequest(BaseModel): + args: list[str] | str = [] + timeout: int = 120 def do_POST(self): """Handle POST requests.""" if not self._check_auth(): @@ -534,21 +832,341 @@ def do_POST(self): result = run_agent_command("heal", tool) self._send_json(result) - else: - self._send_json({"error": "Not found"}, 404) +class QueryRequest(BaseModel): + query: str + +class TaskRequest(BaseModel): + task: str + +class PipelineRequest(BaseModel): + pipeline: str - def log_message(self, format, *args): - """Custom log format.""" - timestamp = time.strftime("%Y-%m-%d %H:%M:%S") - sys.stderr.write(f"[{timestamp}] {args[0]} {args[1]} {args[2]}\n") +class ParallelRequest(BaseModel): + tools: str +class HealRequest(BaseModel): + tool: str = "all" + +class JobRequest(BaseModel): + kind: str = "tool" # "tool" | "ask" | "plan" | "chain" | "parallel" | "engineer" + tool: str # tool name or agent command argument + args: list[str] | str = [] + timeout: int = 120 + +# --------------------------------------------------------------------------- +# Helper: enforce tool-level permissions +# --------------------------------------------------------------------------- + +def _check_tool_access(role: str, tool_name: str) -> JSONResponse | None: + """Returns a 403 response if the role cannot access the tool, else None.""" + if not role_has_permission(role, "tool:execute"): + return JSONResponse( + {"error": "forbidden", "detail": f"Role '{role}' lacks tool:execute permission"}, + status_code=403, + ) + if tool_name in role_blocked_tools(role): + return JSONResponse( + {"error": "forbidden", "detail": f"Tool '{tool_name}' is blocked for role '{role}'"}, + status_code=403, + ) + if tool_name in role_requires_approval(role): + return JSONResponse( + {"error": "approval_required", + "detail": f"Tool '{tool_name}' requires explicit approval for role '{role}'"}, + status_code=403, + ) + return None + +def _get_role(request: Request) -> str: + return getattr(request.state, "role", "admin") + +# --------------------------------------------------------------------------- +# Routes: tools +# --------------------------------------------------------------------------- + +@app.get("/api/v1/tools", tags=["tools"]) +async def get_tools(): + tools = list_tools() + return {"tools": tools, "count": len(tools)} + +@app.get("/api/v1/tools/{tool_name}", tags=["tools"]) +async def get_tool(tool_name: str, request: Request): + role = _get_role(request) + if not role_has_permission(role, "tool:discover"): + raise HTTPException(403, f"Role '{role}' lacks tool:discover permission") + manifest = load_tool_manifest(tool_name) + if not manifest: + raise HTTPException(404, f"Tool '{tool_name}' not found") + return manifest + +@app.post("/api/v1/tools/{tool_name}/execute", tags=["tools"]) +async def execute(tool_name: str, body: ExecuteRequest, request: Request): + role = _get_role(request) + denied = _check_tool_access(role, tool_name) + if denied: + return denied + + # Per-tool rate limit (if configured) + tool_limit = tool_rate_limit(tool_name) + if tool_limit is not None: + rl_key = f"tool:{tool_name}:{request.client.host if request.client else 'anon'}" + if not check_rate_limit(rl_key, tool_limit): + return JSONResponse( + {"error": "rate_limit_exceeded", "tool": tool_name, "limit": tool_limit}, + status_code=429, + headers={"X-RateLimit-Limit": str(tool_limit)}, + ) + + args = body.args.split() if isinstance(body.args, str) else body.args + result = await execute_tool(tool_name, args, body.timeout) + return JSONResponse(result, status_code=200 if result["status"] == "success" else 500) + +# --------------------------------------------------------------------------- +# Routes: health (public) +# --------------------------------------------------------------------------- + +@app.get("/api/v1/health", tags=["health"]) +async def health(): + return await run_agent_command("health", "all") + +@app.get("/api/v1/health/{tool_name}", tags=["health"]) +async def health_tool(tool_name: str, request: Request): + role = _get_role(request) + if not role_has_permission(role, "tool:healthcheck"): + raise HTTPException(403, f"Role '{role}' lacks tool:healthcheck permission") + return await run_agent_command("health", tool_name) + +# --------------------------------------------------------------------------- +# Routes: agent +# --------------------------------------------------------------------------- + +@app.post("/api/v1/agent/ask", tags=["agent"]) +async def agent_ask(body: QueryRequest, request: Request): + role = _get_role(request) + if not role_has_permission(role, "tool:execute"): + raise HTTPException(403, f"Role '{role}' lacks tool:execute permission") + return await run_agent_command("ask", body.query) + +@app.post("/api/v1/agent/plan", tags=["agent"]) +async def agent_plan(body: TaskRequest, request: Request): + role = _get_role(request) + if not role_has_permission(role, "tool:execute"): + raise HTTPException(403, f"Role '{role}' lacks tool:execute permission") + return await run_agent_command("plan", body.task) + +@app.post("/api/v1/agent/engineer", tags=["agent"]) +async def agent_engineer(body: TaskRequest, request: Request): + role = _get_role(request) + if not role_has_permission(role, "tool:execute"): + raise HTTPException(403, f"Role '{role}' lacks tool:execute permission") + return await run_agent_command("engineer", body.task) + +@app.post("/api/v1/agent/chain", tags=["agent"]) +async def agent_chain(body: PipelineRequest, request: Request): + role = _get_role(request) + if not role_has_permission(role, "tool:execute"): + raise HTTPException(403, f"Role '{role}' lacks tool:execute permission") + return await run_agent_command("chain", body.pipeline) + +@app.post("/api/v1/agent/parallel", tags=["agent"]) +async def agent_parallel(body: ParallelRequest, request: Request): + role = _get_role(request) + if not role_has_permission(role, "tool:execute"): + raise HTTPException(403, f"Role '{role}' lacks tool:execute permission") + return await run_agent_command("parallel", body.tools) + +@app.get("/api/v1/agent/status", tags=["agent"]) +async def agent_status(request: Request): + role = _get_role(request) + if not role_has_permission(role, "tool:discover"): + raise HTTPException(403, f"Role '{role}' lacks tool:discover permission") + return await run_agent_command("status") + +@app.post("/api/v1/agent/heal", tags=["agent"]) +async def agent_heal(body: HealRequest, request: Request): + role = _get_role(request) + if not role_has_permission(role, "tool:execute") or role == "readonly": + raise HTTPException(403, f"Role '{role}' cannot trigger self-healing") + return await run_agent_command("heal", body.tool) + +# --------------------------------------------------------------------------- +# Routes: async job queue +# --------------------------------------------------------------------------- + +@app.post("/api/v1/jobs", tags=["jobs"], status_code=202) +async def submit_job(body: JobRequest, request: Request): + """Submit a long-running job. Returns immediately with a job_id to poll or stream.""" + role = _get_role(request) + if not role_has_permission(role, "tool:execute"): + raise HTTPException(403, f"Role '{role}' lacks tool:execute permission") + + args = body.args.split() if isinstance(body.args, str) else list(body.args) + + # Tool-level access check for tool kind + if body.kind == "tool": + denied = _check_tool_access(role, body.tool) + if denied: + return denied + + job = _create_job(kind=body.kind, tool=body.tool, args=args, timeout=body.timeout) + return JSONResponse(_job_to_dict(job), status_code=202) + +@app.get("/api/v1/jobs", tags=["jobs"]) +async def list_jobs(request: Request, status: str | None = None): + """List all jobs, optionally filtered by status.""" + role = _get_role(request) + if not role_has_permission(role, "tool:discover"): + raise HTTPException(403, f"Role '{role}' lacks tool:discover permission") + jobs = list(_job_store.values()) + if status: + try: + target = JobStatus(status) + jobs = [j for j in jobs if j.status == target] + except ValueError: + raise HTTPException(400, f"Invalid status '{status}'. " + f"Must be one of: {[s.value for s in JobStatus]}") + return {"jobs": [_job_to_dict(j) for j in jobs], "count": len(jobs)} + +@app.get("/api/v1/jobs/{job_id}", tags=["jobs"]) +async def get_job(job_id: str, request: Request): + """Poll a job's current status and output.""" + role = _get_role(request) + if not role_has_permission(role, "tool:discover"): + raise HTTPException(403, f"Role '{role}' lacks tool:discover permission") + job = _job_store.get(job_id) + if not job: + raise HTTPException(404, f"Job '{job_id}' not found") + return _job_to_dict(job) + +@app.get("/api/v1/jobs/{job_id}/stream", tags=["jobs"]) +async def stream_job(job_id: str, request: Request): + """ + Stream live output from a job via Server-Sent Events (SSE). + + Each event is a JSON object on a ``data:`` line: + - ``{"event": "output", "line": "...", "ts": 1234567890.0}`` — a stdout line + - ``{"event": "done", "status": "completed", "exit_code": 0}`` — terminal event + + Replays all lines already emitted before delivering live output. + """ + role = _get_role(request) + if not role_has_permission(role, "tool:discover"): + raise HTTPException(403, f"Role '{role}' lacks tool:discover permission") + job = _job_store.get(job_id) + if not job: + raise HTTPException(404, f"Job '{job_id}' not found") + + async def event_stream(): + # Replay lines already captured + for line in list(job.output_lines): + yield f'data: {json.dumps({"event": "output", "line": line})}\n\n' + + # Already finished — emit done and close + if job.status in (JobStatus.completed, JobStatus.failed, JobStatus.cancelled): + yield f'data: {json.dumps({"event": "done", "status": job.status.value, "exit_code": job.exit_code})}\n\n' + return + + # Subscribe to live output + queue: asyncio.Queue[str] = asyncio.Queue(maxsize=1000) + job._subscribers.append(queue) + try: + while True: + if await request.is_disconnected(): + break + try: + raw = await asyncio.wait_for(queue.get(), timeout=15.0) + yield f'data: {raw}\n\n' + if json.loads(raw).get("event") == "done": + break + except asyncio.TimeoutError: + yield ': keepalive\n\n' # SSE comment keeps connection alive + finally: + try: + job._subscribers.remove(queue) + except ValueError: + pass + + return StreamingResponse( + event_stream(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + ) + +@app.delete("/api/v1/jobs/{job_id}", tags=["jobs"]) +async def cancel_job(job_id: str, request: Request): + """Cancel a pending or running job.""" + role = _get_role(request) + if not role_has_permission(role, "tool:execute"): + raise HTTPException(403, f"Role '{role}' lacks tool:execute permission") + job = _job_store.get(job_id) + if not job: + raise HTTPException(404, f"Job '{job_id}' not found") + if job.status in (JobStatus.completed, JobStatus.failed, JobStatus.cancelled): + return JSONResponse({"id": job_id, "status": job.status.value, + "detail": "Job already in terminal state"}, status_code=200) + if job._task and not job._task.done(): + job._task.cancel() + job.status = JobStatus.cancelled + job.completed_at = time.time() + await _persist_job(job) + return JSONResponse({"id": job_id, "status": "cancelled"}) + +# --------------------------------------------------------------------------- +# Routes: API info +# --------------------------------------------------------------------------- + +@app.get("/api/v1", tags=["meta"]) +@app.get("/api", tags=["meta"]) +async def api_info(): + return { + "name": "Gathm Enterprise API", + "version": API_VERSION, + "auth": "Set GATHM_API_KEYS=token:role,... or GATHM_API_KEY=token (admin) to enable", + "docs": "/api/docs", + "endpoints": { + "GET /api/v1/tools": "List all tools", + "GET /api/v1/tools/{name}": "Get tool metadata", + "POST /api/v1/tools/{name}/execute": "Execute a tool (synchronous)", + "GET /api/v1/health": "System health check (public)", + "GET /api/v1/health/{tool}": "Tool health check", + "POST /api/v1/agent/ask": "Natural language query", + "POST /api/v1/agent/plan": "Create execution plan", + "POST /api/v1/agent/engineer": "Engineering agent task", + "POST /api/v1/agent/chain": "Execute tool pipeline", + "POST /api/v1/agent/parallel": "Execute tools in parallel", + "GET /api/v1/agent/status": "Agent status", + "POST /api/v1/agent/heal": "Self-heal tools", + "POST /api/v1/jobs": "Submit async job — returns 202 + job_id immediately", + "GET /api/v1/jobs": "List all jobs (filter: ?status=running)", + "GET /api/v1/jobs/{id}": "Poll job status + captured output", + "GET /api/v1/jobs/{id}/stream": "Stream live output via SSE", + "DELETE /api/v1/jobs/{id}": "Cancel a running job", + }, + } + +# --------------------------------------------------------------------------- +# GUI static files (served last so API routes take precedence) +# --------------------------------------------------------------------------- + +if GUI_DIR.exists(): + app.mount("/", StaticFiles(directory=str(GUI_DIR), html=True), name="gui") + +# --------------------------------------------------------------------------- +# Entrypoint +# --------------------------------------------------------------------------- def main(): - """Start the API server.""" + if not HAS_FASTAPI: + print( + "ERROR: FastAPI and uvicorn are required.\n" + "Install with: pip install fastapi uvicorn pydantic", + file=sys.stderr, + ) + sys.exit(1) + port = DEFAULT_PORT host = DEFAULT_HOST - # Parse command line arguments args = sys.argv[1:] i = 0 while i < len(args): @@ -570,20 +1188,23 @@ def main(): server = http.server.ThreadingHTTPServer((host, port), GathmAPIHandler) print(f""" ╔══════════════════════════════════════════════════╗ -║ Gathm Enterprise API Server ║ +║ Gathm Enterprise API Server v{API_VERSION} ║ ╠══════════════════════════════════════════════════╣ ║ Host: {host:<41s} ║ ║ Port: {port:<41d} ║ ║ GUI: http://{host}:{port:<25d} ║ ║ API: http://{host}:{port}/api/v1{' ' * 16}║ +║ Docs: http://{host}:{port}/api/docs{' ' * 14}║ ╚══════════════════════════════════════════════════╝ """) - try: - server.serve_forever() - except KeyboardInterrupt: - print("\nShutting down...") - server.shutdown() + uvicorn.run( + "api.server:app", + host=host, + port=port, + log_level="info", + access_log=True, + ) if __name__ == "__main__": diff --git a/config/agent.yaml b/config/agent.yaml index 7fd387f..078b820 100644 --- a/config/agent.yaml +++ b/config/agent.yaml @@ -3,7 +3,7 @@ agent: name: "gathm" - version: "1.0.0" + version: "3.0.0" description: "Gathm AI Agent - Autonomous tool management and execution" # Agent capabilities @@ -80,6 +80,8 @@ secrets: # Add more API keys as needed # Fallback chain definitions +# Each tool maps to at most one fallback. No cycles allowed — the executor +# enforces a max depth of 2 to prevent infinite recursion. fallbacks: geo: ipinfo - ipinfo: geo + # ipinfo has no fallback; it is the terminal fallback for geo diff --git a/engineer/main.py b/engineer/main.py index c41d212..4d19a25 100644 --- a/engineer/main.py +++ b/engineer/main.py @@ -26,38 +26,33 @@ ENGINEER_DIR = Path(__file__).resolve().parent GATHM_ROOT = ENGINEER_DIR.parent -# --- Exclusive Claude Configuration --- -ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") -OLLAMA_MODEL = os.getenv("GATHM_OLLAMA_MODEL", os.getenv("OLLAMA_MODEL", "gemma3:12b")) -OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434/v1") +# --- Unified LLM provider (single source of truth for backend/model config) --- +sys.path.insert(0, str(GATHM_ROOT)) +try: + from lib.llm import LLMConfig, LLMProvider as _LLMProvider + _llm_config = LLMConfig.from_env() +except Exception: + _llm_config = None def get_model_client(): - """Smart model selection: Claude when API key is available, Ollama as fallback.""" - if ANTHROPIC_API_KEY: - try: - from autogen_ext.models.anthropic import AnthropicChatCompletionClient - print("[*] Engineer using Claude API (Anthropic)") - return AnthropicChatCompletionClient( - model="claude-sonnet-4-6", - api_key=ANTHROPIC_API_KEY, - ) - except Exception as e: - print(f"[!] Claude init failed ({e}), falling back to Ollama") - - from autogen_ext.models.openai import OpenAIChatCompletionClient - print(f"[*] Engineer using local model ({OLLAMA_MODEL}) via Ollama") + """Return an AutoGen model client via the unified LLM provider.""" + if _llm_config is not None: + provider = _LLMProvider(_llm_config) + print(f"[*] Engineer using {_llm_config.backend.upper()} — {_llm_config.model}") + return provider.autogen_model_client() + + # Fallback if lib.llm failed to import + from autogen_ext.models.openai import OpenAIChatCompletionClient # type: ignore[import] + model = os.getenv("GATHM_OLLAMA_MODEL", os.getenv("OLLAMA_MODEL", "gemma3:12b")) + base_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434/v1") + print(f"[*] Engineer using Ollama (fallback) — {model}") return OpenAIChatCompletionClient( - model=OLLAMA_MODEL, - base_url=OLLAMA_BASE_URL, + model=model, + base_url=base_url, api_key="NotRequired", - model_info={ - "vision": False, - "function_calling": True, - "json_output": True, - "family": "unknown", - "structured_output": False, - "multiple_system_messages": True, - } + model_info={"vision": False, "function_calling": True, "json_output": True, + "family": "unknown", "structured_output": False, + "multiple_system_messages": True}, ) def _build_codebase_context() -> str: diff --git a/lib/health.bash b/lib/health.bash index e573fa4..458fd4d 100644 --- a/lib/health.bash +++ b/lib/health.bash @@ -2,10 +2,10 @@ # Gathm Enterprise - Health Check Framework # Provides health checking capabilities for all tools -SCRIPT_DIR_HEALTH="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." &>/dev/null && pwd)" +SCRIPT_DIR_HEALTH="$(cd "$(dirname "${BASH_SOURCE[0]}")/.."\ &>/dev/null && pwd)" source "$SCRIPT_DIR_HEALTH/lib/logging.bash" 2>/dev/null -GATHM_HEALTH_DIR="${HOME}/.gathm/health" +GATHM_HEALTH_DIR="${GATHM_HEALTH_DIR:-${HOME}/.gathm/health}" GATHM_HEALTH_CACHE_TTL=300 # 5 minutes cache # Circuit breaker states @@ -150,15 +150,21 @@ cb_get_state() { } # Circuit Breaker - Record a success +# Uses the canonical subshell+flock pattern so all parent-shell variables +# are inherited without complex quoting, and the lock is released on exit. cb_record_success() { local tool="$1" local state_file="$GATHM_HEALTH_DIR/cb_${tool}.state" + local lock_file="${state_file}.lock" - cat > "$state_file" << EOF -state=$CB_STATE_CLOSED -failures=0 -last_failure=0 -EOF + if command -v flock &>/dev/null; then + ( + flock -x 9 + printf 'state=%s\nfailures=0\nlast_failure=0\n' "$CB_STATE_CLOSED" > "$state_file" + ) 9>>"$lock_file" + else + printf 'state=%s\nfailures=0\nlast_failure=0\n' "$CB_STATE_CLOSED" > "$state_file" + fi log_debug "circuit_breaker" "Circuit closed for tool: $tool" } @@ -166,26 +172,47 @@ EOF cb_record_failure() { local tool="$1" local state_file="$GATHM_HEALTH_DIR/cb_${tool}.state" + local lock_file="${state_file}.lock" local now now=$(date +%s) - local failures=0 - if [[ -f "$state_file" ]]; then - failures=$(grep "^failures=" "$state_file" | cut -d= -f2) + if command -v flock &>/dev/null; then + ( + flock -x 9 + local failures=0 + if [[ -f "$state_file" ]]; then + failures=$(grep "^failures=" "$state_file" | cut -d= -f2) + failures="${failures:-0}" + fi + failures=$(( failures + 1 )) + local state="$CB_STATE_CLOSED" + if (( failures >= CB_FAILURE_THRESHOLD )); then + state="$CB_STATE_OPEN" + fi + printf 'state=%s\nfailures=%d\nlast_failure=%s\n' \ + "$state" "$failures" "$now" > "$state_file" + ) 9>>"$lock_file" + else + # No flock — best-effort write (still correct for single-process use) + local failures=0 + if [[ -f "$state_file" ]]; then + failures=$(grep "^failures=" "$state_file" | cut -d= -f2) + failures="${failures:-0}" + fi + failures=$(( failures + 1 )) + local state="$CB_STATE_CLOSED" + if (( failures >= CB_FAILURE_THRESHOLD )); then + state="$CB_STATE_OPEN" + fi + printf 'state=%s\nfailures=%d\nlast_failure=%s\n' \ + "$state" "$failures" "$now" > "$state_file" fi - failures=$((failures + 1)) - local state="$CB_STATE_CLOSED" - if (( failures >= CB_FAILURE_THRESHOLD )); then - state="$CB_STATE_OPEN" - log_warn "circuit_breaker" "Circuit OPEN for tool: $tool (failures: $failures)" + local cur_state + cur_state=$(grep "^state=" "$state_file" 2>/dev/null | cut -d= -f2 || echo "$CB_STATE_CLOSED") + if [[ "$cur_state" == "$CB_STATE_OPEN" ]]; then + log_warn "circuit_breaker" "Circuit OPEN for tool: $tool" fi - - cat > "$state_file" << EOF -state=$state -failures=$failures -last_failure=$now -EOF } # Circuit Breaker - Check if tool is allowed to execute diff --git a/lib/llm.py b/lib/llm.py new file mode 100644 index 0000000..dc96ac8 --- /dev/null +++ b/lib/llm.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 +""" +Gathm Enterprise — Unified LLM Provider +Single source of truth for model/backend resolution and client construction. + +Both Pilot (LangChain) and Engineer (AutoGen) import from here, keeping +model selection, API key lookup, and base-URL config in one place. + +Environment variables (in priority order): + GATHM_LLM_BACKEND ollama | gemini | anthropic (default: ollama) + GATHM_OLLAMA_MODEL or OLLAMA_MODEL (default: gemma3:12b) + GATHM_GEMINI_MODEL or GEMINI_MODEL (default: gemini-2.0-flash-lite) + ANTHROPIC_API_KEY (enables anthropic backend) + GOOGLE_API_KEY or GEMINI_API_KEY (enables gemini backend) + OLLAMA_BASE_URL (default: http://localhost:11434/v1) + +File-based overrides (set by install.sh or 'gathm pilot --set-model'): + ~/.gathm/model model name override + ~/.gathm/llm_backend backend override +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Iterator + + +# --------------------------------------------------------------------------- +# Configuration dataclass +# --------------------------------------------------------------------------- + +@dataclass +class LLMConfig: + backend: str # "ollama" | "gemini" | "anthropic" + model: str + api_key: str | None = None + base_url: str | None = None + + @classmethod + def from_env(cls) -> "LLMConfig": + """Resolve config from env vars and ~/.gathm/ config files.""" + backend = cls._resolve_backend() + model = cls._resolve_model(backend) + api_key: str | None = None + base_url: str | None = None + + if backend == "anthropic": + api_key = os.getenv("ANTHROPIC_API_KEY") + elif backend == "gemini": + api_key = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY") + else: # ollama + base_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434/v1") + + return cls(backend=backend, model=model, api_key=api_key, base_url=base_url) + + # ------------------------------------------------------------------ + @staticmethod + def _resolve_backend() -> str: + # 1. Env var + env = os.getenv("GATHM_LLM_BACKEND", "").lower().strip() + if env in ("ollama", "gemini", "anthropic"): + return env + + # 2. Implicit: if ANTHROPIC_API_KEY is set, use it + if os.getenv("ANTHROPIC_API_KEY"): + return "anthropic" + + # 3. Implicit: if GOOGLE_API_KEY / GEMINI_API_KEY is set, use gemini + if os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY"): + return "gemini" + + # 4. ~/.gathm/llm_backend file + backend_file = Path.home() / ".gathm" / "llm_backend" + if backend_file.is_file(): + stored = backend_file.read_text().strip().lower() + if stored in ("ollama", "gemini", "anthropic"): + return stored + + return "ollama" + + @staticmethod + def _resolve_model(backend: str) -> str: + # 1. Backend-specific env vars + if backend == "gemini": + m = os.getenv("GATHM_GEMINI_MODEL") or os.getenv("GEMINI_MODEL") + if m: + return m + elif backend == "anthropic": + m = os.getenv("GATHM_ANTHROPIC_MODEL") or os.getenv("ANTHROPIC_MODEL") + if m: + return m + else: + m = os.getenv("GATHM_OLLAMA_MODEL") or os.getenv("OLLAMA_MODEL") + if m: + return m + + # 2. Generic model env var + m = os.getenv("GATHM_MODEL") + if m: + return m + + # 3. ~/.gathm/model file + model_file = Path.home() / ".gathm" / "model" + if model_file.is_file(): + stored = model_file.read_text().strip() + if stored: + return stored + + # 4. Hardcoded defaults per backend + defaults = { + "gemini": "gemini-2.0-flash-lite", + "anthropic": "claude-sonnet-4-6", + "ollama": "gemma3:12b", + } + return defaults.get(backend, "gemma3:12b") + + +# --------------------------------------------------------------------------- +# Provider — thin wrapper that builds framework-specific clients on demand +# --------------------------------------------------------------------------- + +class LLMProvider: + """ + Unified LLM provider. Build one with ``LLMProvider.from_env()`` then call: + + - ``complete(messages)`` — simple one-shot completion + - ``langchain_chat_model()`` — returns a LangChain BaseChatModel + - ``autogen_model_client()`` — returns an AutoGen model client + """ + + def __init__(self, config: LLMConfig): + self.config = config + + @classmethod + def from_env(cls) -> "LLMProvider": + return cls(LLMConfig.from_env()) + + # ------------------------------------------------------------------ + # Simple completion (no framework dependency) + # ------------------------------------------------------------------ + + def complete(self, messages: list[dict[str, str]], **kwargs) -> str: + """ + One-shot completion. ``messages`` is a list of OpenAI-style dicts: + ``[{"role": "user"|"system"|"assistant", "content": "..."}]`` + + Returns the assistant reply as a plain string. + """ + cfg = self.config + if cfg.backend == "anthropic": + return self._complete_anthropic(messages, **kwargs) + if cfg.backend == "gemini": + return self._complete_gemini(messages, **kwargs) + return self._complete_ollama(messages, **kwargs) + + def _complete_anthropic(self, messages: list[dict], **kwargs) -> str: + import anthropic # type: ignore[import] + client = anthropic.Anthropic(api_key=self.config.api_key) + system = next((m["content"] for m in messages if m["role"] == "system"), "") + chat_msgs = [m for m in messages if m["role"] != "system"] + resp = client.messages.create( + model=self.config.model, + max_tokens=kwargs.get("max_tokens", 4096), + system=system or anthropic.NOT_GIVEN, + messages=chat_msgs, + ) + return resp.content[0].text + + def _complete_gemini(self, messages: list[dict], **kwargs) -> str: + import google.generativeai as genai # type: ignore[import] + genai.configure(api_key=self.config.api_key) + model = genai.GenerativeModel(self.config.model) + prompt = "\n".join(f"{m['role'].upper()}: {m['content']}" for m in messages) + resp = model.generate_content(prompt) + return resp.text + + def _complete_ollama(self, messages: list[dict], **kwargs) -> str: + import urllib.request, json as _json + base = (self.config.base_url or "http://localhost:11434/v1").rstrip("/") + payload = _json.dumps({"model": self.config.model, "messages": messages, "stream": False}).encode() + req = urllib.request.Request( + f"{base}/chat/completions", + data=payload, + headers={"Content-Type": "application/json"}, + method="POST", + ) + with urllib.request.urlopen(req, timeout=120) as resp: + data = _json.loads(resp.read()) + return data["choices"][0]["message"]["content"] + + # ------------------------------------------------------------------ + # LangChain integration (used by Pilot) + # ------------------------------------------------------------------ + + def langchain_chat_model(self) -> Any: + """Return a LangChain BaseChatModel for the configured backend.""" + cfg = self.config + if cfg.backend == "gemini": + from langchain_google_genai import ChatGoogleGenerativeAI # type: ignore[import] + return ChatGoogleGenerativeAI(model=cfg.model, google_api_key=cfg.api_key) + if cfg.backend == "anthropic": + from langchain_anthropic import ChatAnthropic # type: ignore[import] + return ChatAnthropic(model=cfg.model, api_key=cfg.api_key) + # Default: Ollama + from langchain_ollama import ChatOllama # type: ignore[import] + return ChatOllama(model=cfg.model) + + # ------------------------------------------------------------------ + # AutoGen integration (used by Engineer) + # ------------------------------------------------------------------ + + def autogen_model_client(self) -> Any: + """Return an AutoGen model client for the configured backend.""" + cfg = self.config + if cfg.backend == "anthropic": + from autogen_ext.models.anthropic import AnthropicChatCompletionClient # type: ignore[import] + return AnthropicChatCompletionClient(model=cfg.model, api_key=cfg.api_key) + # Gemini and Ollama both speak the OpenAI-compatible API + from autogen_ext.models.openai import OpenAIChatCompletionClient # type: ignore[import] + if cfg.backend == "gemini": + return OpenAIChatCompletionClient( + model=cfg.model, + base_url="https://generativelanguage.googleapis.com/v1beta/openai/", + api_key=cfg.api_key or "NOT_REQUIRED", + model_info={"vision": False, "function_calling": True, + "json_output": True, "family": "gemini", + "structured_output": False, "multiple_system_messages": True}, + ) + return OpenAIChatCompletionClient( + model=cfg.model, + base_url=cfg.base_url or "http://localhost:11434/v1", + api_key="NotRequired", + model_info={"vision": False, "function_calling": True, + "json_output": True, "family": "unknown", + "structured_output": False, "multiple_system_messages": True}, + ) + + # ------------------------------------------------------------------ + def __repr__(self) -> str: + return f"LLMProvider(backend={self.config.backend!r}, model={self.config.model!r})" diff --git a/lib/logging.bash b/lib/logging.bash index 22c036a..3ea193d 100644 --- a/lib/logging.bash +++ b/lib/logging.bash @@ -73,6 +73,20 @@ _epoch_ms() { fi } +# Append a line to a log file under an exclusive lock. +# Uses flock(1) on Linux/WSL/Termux; falls back to >> (atomic on most POSIX +# filesystems for small writes) on macOS/BSD where flock may not be present. +_log_append() { + local file="$1" + local line="$2" + if command -v flock &>/dev/null; then + local lock_file="${file}.lock" + (flock -x 9; printf '%s\n' "$line" >> "$file") 9>>"$lock_file" 2>/dev/null + else + printf '%s\n' "$line" >> "$file" 2>/dev/null + fi +} + # Core structured log function - outputs JSON # Usage: _log LEVEL COMPONENT MESSAGE [extra_json_fields] _log() { @@ -108,7 +122,7 @@ _log() { "$timestamp" "$level" "$component" "$message" "$hostname_val" "$pid") fi - echo "$log_entry" >> "$GATHM_LOG_FILE" 2>/dev/null + _log_append "$GATHM_LOG_FILE" "$log_entry" # Also print errors/fatals to stderr if [ "$level_num" -ge 3 ]; then @@ -136,7 +150,7 @@ audit_log() { local entry entry=$(printf '{"timestamp":"%s","action":"%s","actor":"%s","tool":"%s","details":"%s"}' \ "$timestamp" "$action" "$actor" "$tool" "$details") - echo "$entry" >> "$GATHM_AUDIT_FILE" 2>/dev/null + _log_append "$GATHM_AUDIT_FILE" "$entry" } # Metrics log - track tool invocations, latency, success rates @@ -161,7 +175,7 @@ log_metric() { entry=$(printf '{"timestamp":"%s","tool":"%s","duration_ms":%s,"exit_code":%d,"status":"%s"}' \ "$timestamp" "$tool" "$duration_ms" "$exit_code" "$status") fi - echo "$entry" >> "$GATHM_METRICS_FILE" 2>/dev/null + _log_append "$GATHM_METRICS_FILE" "$entry" } # Timed execution wrapper - runs a command and logs metrics diff --git a/lib/output.bash b/lib/output.bash new file mode 100644 index 0000000..23cc747 --- /dev/null +++ b/lib/output.bash @@ -0,0 +1,124 @@ +#!/usr/bin/env bash +# Gathm Enterprise — Structured Tool Output +# Source this in every tool script to emit consistent text or JSON output. +# +# Usage: +# source "$(dirname "$0")/../../lib/output.bash" +# +# # At the end of the tool, emit results: +# tool_output 0 "plain text result" # text mode → prints as-is +# # json mode → wraps in JSON envelope +# +# tool_output_json 0 '{"key":"value"}' # json mode → merges into envelope +# # text mode → prints the JSON string +# +# tool_error 1 "something went wrong" # always emits to stderr; sets exit 1 +# +# GATHM_OUTPUT_MODE controls format: "json" or "text" (default: text) + +# --------------------------------------------------------------------------- +# Internal: escape a string for use as a JSON value +# --------------------------------------------------------------------------- +_json_escape() { + # Handles backslash, double-quote, and control characters + printf '%s' "$1" \ + | sed 's/\\/\\\\/g; s/"/\\"/g; s/'"$(printf '\t')"'/\\t/g' \ + | tr -d '\000-\037' +} + +# --------------------------------------------------------------------------- +# Internal: detect the tool name from the calling script's filename +# --------------------------------------------------------------------------- +_tool_name() { + if [[ -n "${GATHM_TOOL_NAME:-}" ]]; then + echo "$GATHM_TOOL_NAME" + else + basename "${BASH_SOURCE[2]:-${0}}" + fi +} + +# --------------------------------------------------------------------------- +# tool_output EXIT_CODE MESSAGE +# Emit a plain-text message (or JSON envelope in json mode). +# Returns EXIT_CODE. +# --------------------------------------------------------------------------- +tool_output() { + local exit_code="${1:-0}" + local message="${2:-}" + + if [[ "${GATHM_OUTPUT_MODE:-text}" == "json" ]]; then + local tool + tool=$(_tool_name) + local escaped + escaped=$(_json_escape "$message") + local status="success" + [[ "$exit_code" -ne 0 ]] && status="error" + printf '{"tool":"%s","status":"%s","exit_code":%d,"output":"%s"}\n' \ + "$tool" "$status" "$exit_code" "$escaped" + else + printf '%s\n' "$message" + fi + + return "$exit_code" +} + +# --------------------------------------------------------------------------- +# tool_output_json EXIT_CODE JSON_STRING +# Emit a pre-formed JSON object as tool output. +# In json mode: merges tool/status/exit_code fields into the object. +# In text mode: prints the raw JSON string (pretty if jq is available). +# --------------------------------------------------------------------------- +tool_output_json() { + local exit_code="${1:-0}" + local json="${2:-{}}" + + if [[ "${GATHM_OUTPUT_MODE:-text}" == "json" ]]; then + local tool + tool=$(_tool_name) + local status="success" + [[ "$exit_code" -ne 0 ]] && status="error" + + # Inject envelope fields if jq is available; otherwise concatenate + if command -v jq &>/dev/null; then + printf '%s' "$json" \ + | jq --arg tool "$tool" --arg status "$status" --argjson code "$exit_code" \ + '. + {tool: $tool, status: $status, exit_code: $code}' + else + # Strip trailing } and append fields manually + local body="${json%\}}" + printf '%s,"tool":"%s","status":"%s","exit_code":%d}\n' \ + "$body" "$tool" "$status" "$exit_code" + fi + else + if command -v jq &>/dev/null; then + printf '%s' "$json" | jq . + else + printf '%s\n' "$json" + fi + fi + + return "$exit_code" +} + +# --------------------------------------------------------------------------- +# tool_error EXIT_CODE MESSAGE +# Emit an error to stderr and, in json mode, also to stdout as a JSON error. +# Always returns EXIT_CODE (defaults to 1). +# --------------------------------------------------------------------------- +tool_error() { + local exit_code="${1:-1}" + local message="${2:-unknown error}" + + printf '[ERROR] %s\n' "$message" >&2 + + if [[ "${GATHM_OUTPUT_MODE:-text}" == "json" ]]; then + local tool + tool=$(_tool_name) + local escaped + escaped=$(_json_escape "$message") + printf '{"tool":"%s","status":"error","exit_code":%d,"error":"%s"}\n' \ + "$tool" "$exit_code" "$escaped" + fi + + return "$exit_code" +} diff --git a/lib/recovery.bash b/lib/recovery.bash index 973503c..2d6cf90 100644 --- a/lib/recovery.bash +++ b/lib/recovery.bash @@ -3,7 +3,7 @@ # Handles automatic failure recovery, retries, and fallback chains # Cross-platform: Linux, macOS, Termux, Windows (WSL/Git Bash/MSYS2) -SCRIPT_DIR_RECOVERY="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." &>/dev/null && pwd)" +SCRIPT_DIR_RECOVERY="$(cd "$(dirname "${BASH_SOURCE[0]}")/.."\ &>/dev/null && pwd)" source "$SCRIPT_DIR_RECOVERY/lib/logging.bash" 2>/dev/null source "$SCRIPT_DIR_RECOVERY/lib/health.bash" 2>/dev/null @@ -52,12 +52,35 @@ retry_with_backoff() { # Execute a tool with full recovery pipeline # Usage: execute_with_recovery TOOL_NAME [args...] +# Optional env: _GATHM_FALLBACK_DEPTH (internal, tracks recursion depth) +# _GATHM_FALLBACK_CHAIN (internal, colon-separated visited tools) execute_with_recovery() { local tool_name="$1" shift local tool_args=("$@") local tool_path="$SCRIPT_DIR_RECOVERY/tools/$tool_name/$tool_name" + # --- Cycle/depth guard --- + local depth="${_GATHM_FALLBACK_DEPTH:-0}" + local chain="${_GATHM_FALLBACK_CHAIN:-}" + local max_fallback_depth=2 + + if [[ "$depth" -ge "$max_fallback_depth" ]]; then + log_error "recovery" "Fallback depth limit ($max_fallback_depth) reached, aborting chain: $chain -> $tool_name" + echo '{"error":"fallback_depth_exceeded","tool":"'"$tool_name"'","chain":"'"$chain"'","message":"Maximum fallback depth reached. Possible cycle in fallback configuration."}' >&2 + return 1 + fi + + case ":${chain}:" in + *":${tool_name}:"*) + log_error "recovery" "Fallback cycle detected: $chain -> $tool_name" + echo '{"error":"fallback_cycle","tool":"'"$tool_name"'","chain":"'"$chain"'","message":"Cycle detected in fallback chain. Check tool manifest fallback_tool fields."}' >&2 + return 1 + ;; + esac + + local new_chain="${chain:+${chain}:}${tool_name}" + log_info "recovery" "Executing tool with recovery: $tool_name" "\"args\":\"${tool_args[*]}\"" # Step 1: Check circuit breaker @@ -67,7 +90,8 @@ execute_with_recovery() { fallback=$(get_fallback_tool "$tool_name") if [[ -n "$fallback" && "$fallback" != "null" ]]; then log_info "recovery" "Using fallback tool: $fallback for $tool_name" - execute_with_recovery "$fallback" "${tool_args[@]}" + _GATHM_FALLBACK_DEPTH=$((depth + 1)) _GATHM_FALLBACK_CHAIN="$new_chain" \ + execute_with_recovery "$fallback" "${tool_args[@]}" return $? fi log_error "recovery" "No fallback available for $tool_name, circuit is open" @@ -104,7 +128,8 @@ execute_with_recovery() { fallback=$(get_fallback_tool "$tool_name") if [[ -n "$fallback" && "$fallback" != "null" ]]; then log_warn "recovery" "Tool $tool_name failed, trying fallback: $fallback" - execute_with_recovery "$fallback" "${tool_args[@]}" + _GATHM_FALLBACK_DEPTH=$((depth + 1)) _GATHM_FALLBACK_CHAIN="$new_chain" \ + execute_with_recovery "$fallback" "${tool_args[@]}" return $? fi @@ -151,7 +176,6 @@ auto_install_deps() { } # Map package names to Termux equivalents -# Some packages have different names on Termux _termux_pkg_name() { case "$1" in python3) echo "python" ;; @@ -166,13 +190,8 @@ _termux_pkg_name() { } # Try to install a dependency using available package manager -# Supports: apt-get (Debian/Ubuntu), pkg (Termux), brew (macOS), -# yum/dnf (RHEL/CentOS/Fedora), pacman (Arch), zypper (openSUSE), -# apk (Alpine), choco/scoop (Windows) _try_install_dep() { local dep="$1" - - # Detect platform for smarter package manager selection local platform platform=$(_detect_platform_for_install) @@ -199,7 +218,6 @@ _try_install_dep() { fi ;; windows) - # Windows: try scoop first (no admin), then choco if command -v scoop &>/dev/null; then scoop install "$dep" 2>/dev/null && { log_info "recovery" "Installed dependency: $dep (via scoop)" @@ -211,7 +229,6 @@ _try_install_dep() { return 0 } elif command -v pacman &>/dev/null; then - # MSYS2 uses pacman pacman -S --noconfirm "$dep" 2>/dev/null && { log_info "recovery" "Installed dependency: $dep (via pacman/MSYS2)" return 0 @@ -219,7 +236,6 @@ _try_install_dep() { fi ;; *) - # Linux: try all known package managers if command -v apt-get &>/dev/null; then sudo apt-get install -y "$dep" 2>/dev/null && { log_info "recovery" "Installed dependency: $dep (via apt-get)" @@ -280,9 +296,8 @@ _detect_platform_for_install() { fi ;; *) - # Check for WSL if grep -qi microsoft /proc/version 2>/dev/null; then - echo "linux" # WSL acts like Linux for packages + echo "linux" else echo "unknown" fi @@ -291,18 +306,15 @@ _detect_platform_for_install() { } # Validate tool output (basic sanity check) -# Usage: validate_output TOOL_NAME OUTPUT validate_output() { local tool_name="$1" local output="$2" - # Check for empty output if [[ -z "$output" ]]; then log_warn "recovery" "Empty output from tool: $tool_name" return 1 fi - # Check for common error patterns if echo "$output" | grep -qi "error\|failed\|not found\|connection refused" &>/dev/null; then log_warn "recovery" "Potential error in output from tool: $tool_name" return 1 @@ -312,7 +324,6 @@ validate_output() { } # Self-heal: check and fix common issues -# Usage: self_heal TOOL_NAME self_heal() { local tool_name="$1" local tool_dir="$SCRIPT_DIR_RECOVERY/tools/$tool_name" @@ -321,17 +332,14 @@ self_heal() { log_info "recovery" "Running self-heal for tool: $tool_name" - # Fix: executable permissions if [[ -f "$tool_path" && ! -x "$tool_path" ]]; then chmod +x "$tool_path" log_info "recovery" "Fixed executable permissions for: $tool_name" healed=true fi - # Fix: missing dependencies auto_install_deps "$tool_name" - # Fix: reset circuit breaker if tool is now healthy local health health=$(healthcheck_tool "$tool_name") if echo "$health" | grep -q '"status":"healthy"'; then diff --git a/pilot/main.py b/pilot/main.py index 3fac09c..50f1575 100644 --- a/pilot/main.py +++ b/pilot/main.py @@ -16,7 +16,6 @@ def load_dotenv() -> bool: # type: ignore[override] LANGCHAIN_IMPORT_ERROR: Optional[Exception] = None LANGCHAIN_AVAILABLE = True try: - from langchain_ollama import ChatOllama from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langgraph.graph import StateGraph, END except ModuleNotFoundError as exc: @@ -25,20 +24,9 @@ def load_dotenv() -> bool: # type: ignore[override] AIMessage = Any # type: ignore[assignment] BaseMessage = Any # type: ignore[assignment] HumanMessage = Any # type: ignore[assignment] - ChatOllama = None # type: ignore[assignment] StateGraph = None # type: ignore[assignment] END = "__END__" -# Optional: Google Gemini backend -GOOGLE_GENAI_AVAILABLE = False -ChatGoogleGenerativeAI: Any = None -try: - from langchain_google_genai import ChatGoogleGenerativeAI as _ChatGGAI - ChatGoogleGenerativeAI = _ChatGGAI - GOOGLE_GENAI_AVAILABLE = True -except ModuleNotFoundError: - pass - # Load environment variables load_dotenv() @@ -47,56 +35,26 @@ def load_dotenv() -> bool: # type: ignore[override] GATHM_ROOT = PILOT_DIR.parent TOOLS_DIR = GATHM_ROOT / "tools" -def _resolve_backend() -> str: - """Return 'ollama' or 'gemini' from env/config, defaulting to 'ollama'.""" - env_val = os.getenv("GATHM_LLM_BACKEND", "").lower().strip() - if env_val in ("ollama", "gemini"): - return env_val - backend_file = Path.home() / ".gathm" / "llm_backend" - if backend_file.is_file(): - stored = backend_file.read_text().strip().lower() - if stored in ("ollama", "gemini"): - return stored - return "ollama" - -# Model priority: env var > ~/.gathm/model (install.sh) > hardcoded default -def _resolve_model() -> str: - backend = _resolve_backend() - if backend == "gemini": - env_model = os.getenv("GATHM_GEMINI_MODEL") or os.getenv("GEMINI_MODEL") - if env_model: - return env_model - model_file = Path.home() / ".gathm" / "model" - if model_file.is_file(): - stored = model_file.read_text().strip() - if stored: - return stored - return "gemini-2.0-flash-lite" - else: - env_model = os.getenv("GATHM_OLLAMA_MODEL") or os.getenv("OLLAMA_MODEL") - if env_model: - return env_model - model_file = Path.home() / ".gathm" / "model" - if model_file.is_file(): - stored = model_file.read_text().strip() - if stored: - return stored - return "gemma3:12b" - -LLM_BACKEND = _resolve_backend() -OLLAMA_MODEL = _resolve_model() +# --- Unified LLM provider (single source of truth for backend/model config) --- +sys.path.insert(0, str(GATHM_ROOT)) +try: + from lib.llm import LLMConfig, LLMProvider + _llm_config = LLMConfig.from_env() + LLM_BACKEND = _llm_config.backend + OLLAMA_MODEL = _llm_config.model +except Exception: + LLM_BACKEND = os.getenv("GATHM_LLM_BACKEND", "ollama") + OLLAMA_MODEL = os.getenv("GATHM_OLLAMA_MODEL", os.getenv("OLLAMA_MODEL", "gemma3:12b")) + _llm_config = None + PILOT_MAX_HISTORY = int(os.getenv("PILOT_MAX_HISTORY", "12")) def _build_llm(): - """Instantiate the correct LangChain LLM for the configured backend.""" - if LLM_BACKEND == "gemini": - if not GOOGLE_GENAI_AVAILABLE: - raise RuntimeError( - "Google Generative AI SDK not installed. " - "Run: pip install langchain-google-genai" - ) - api_key = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY") - return ChatGoogleGenerativeAI(model=OLLAMA_MODEL, google_api_key=api_key) + """Instantiate the LangChain chat model via the unified LLM provider.""" + if _llm_config is not None: + return LLMProvider(_llm_config).langchain_chat_model() + # Fallback if lib.llm failed to import + from langchain_ollama import ChatOllama # type: ignore[import] return ChatOllama(model=OLLAMA_MODEL) # Colors (kept for non-TUI code paths) @@ -395,11 +353,6 @@ def _require_langchain_runtime() -> None: "Pilot AI runtime dependencies are missing. " "Install pilot/requirements.txt and retry" + detail ) - if LLM_BACKEND == "gemini" and not GOOGLE_GENAI_AVAILABLE: - raise RuntimeError( - "Google Generative AI SDK not installed. " - "Run: pip install langchain-google-genai" - ) # System prompt for models without native tool support def call_model(state: AgentState): diff --git a/tests/test_core.py b/tests/test_core.py new file mode 100644 index 0000000..0692392 --- /dev/null +++ b/tests/test_core.py @@ -0,0 +1,406 @@ +""" +Gathm Enterprise — Core Integration Tests +Tests the Bash library internals (circuit breaker, cache, fallback guard) +and the Python API rate-limiter by running real subprocesses. + +Run: python -m pytest tests/test_core.py -v +""" + +from __future__ import annotations + +import asyncio +import json +import os +import subprocess +import sys +import tempfile +import time +import unittest +from pathlib import Path + +GATHM_ROOT = Path(__file__).resolve().parent.parent + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _bash(script: str, env: dict | None = None, timeout: int = 15) -> tuple[int, str, str]: + """Run a bash snippet, return (returncode, stdout, stderr).""" + full_env = {**os.environ} + if env: + full_env.update(env) + r = subprocess.run(["bash", "-c", script], capture_output=True, text=True, + timeout=timeout, env=full_env) + return r.returncode, r.stdout.strip(), r.stderr.strip() + + +def _lib(*names: str) -> str: + """Return source lines for the named lib/*.bash files.""" + return "\n".join(f'source "{GATHM_ROOT}/lib/{n}.bash"' for n in names) + + +# --------------------------------------------------------------------------- +# TestCircuitBreaker +# --------------------------------------------------------------------------- + +class TestCircuitBreaker(unittest.TestCase): + + def setUp(self): + self._tmp = tempfile.TemporaryDirectory(prefix="gathm-cb-") + self._health = os.path.join(self._tmp.name, "health") + self._logs = os.path.join(self._tmp.name, "logs") + os.makedirs(self._health) + os.makedirs(self._logs) + + def tearDown(self): + self._tmp.cleanup() + + def _run(self, snippet: str) -> tuple[int, str, str]: + preamble = f""" +export GATHM_HEALTH_DIR="{self._health}" +export GATHM_LOG_DIR="{self._logs}" +export GATHM_LOG_LEVEL="ERROR" +export CB_FAILURE_THRESHOLD=3 +export CB_RECOVERY_TIMEOUT=60 +{_lib("logging", "health")} +""" + return _bash(preamble + snippet) + + def test_fresh_tool_is_allowed(self): + rc, out, _ = self._run('cb_allow_request brand_new_tool && echo ALLOWED || echo DENIED') + self.assertEqual(out, "ALLOWED") + + def test_opens_after_threshold_failures(self): + rc, out, _ = self._run(""" +for i in 1 2 3; do cb_record_failure faulty_tool; done +cb_allow_request faulty_tool && echo ALLOWED || echo DENIED +""") + self.assertEqual(out, "DENIED") + + def test_single_success_resets_to_closed(self): + _, out, _ = self._run(""" +for i in 1 2 3; do cb_record_failure t; done +cb_record_success t +cb_allow_request t && echo ALLOWED || echo DENIED +""") + self.assertEqual(out, "ALLOWED") + + def test_failure_count_persists_between_calls(self): + _, out, _ = self._run(""" +cb_record_failure count_tool +cb_record_failure count_tool +grep '^failures=' "$GATHM_HEALTH_DIR/cb_count_tool.state" | cut -d= -f2 +""") + self.assertEqual(out, "2") + + def test_recovery_timeout_yields_half_open(self): + past = int(time.time()) - 120 + _, out, _ = self._run(f""" +cat > "$GATHM_HEALTH_DIR/cb_old_tool.state" <<'EOF' +state=open +failures=3 +last_failure={past} +EOF +cb_get_state old_tool +""") + self.assertEqual(out, "half_open") + + def test_concurrent_writes_do_not_corrupt_state(self): + """Ten parallel failure recordings — state file must remain readable.""" + _, out, _ = self._run(""" +for i in $(seq 1 10); do cb_record_failure parallel_tool & done +wait +grep -c '^state=' "$GATHM_HEALTH_DIR/cb_parallel_tool.state" 2>/dev/null || echo 0 +""") + self.assertEqual(out, "1", "State file should have exactly one 'state=' line") + + def test_below_threshold_stays_closed(self): + _, out, _ = self._run(""" +cb_record_failure almost_tool +cb_record_failure almost_tool +cb_allow_request almost_tool && echo ALLOWED || echo DENIED +""") + self.assertEqual(out, "ALLOWED") + + +# --------------------------------------------------------------------------- +# TestCache +# --------------------------------------------------------------------------- + +class TestCache(unittest.TestCase): + + def setUp(self): + self._tmp = tempfile.TemporaryDirectory(prefix="gathm-cache-") + self._cache = os.path.join(self._tmp.name, "cache") + self._logs = os.path.join(self._tmp.name, "logs") + os.makedirs(self._cache) + os.makedirs(self._logs) + + def tearDown(self): + self._tmp.cleanup() + + def _run(self, snippet: str, extra_env: dict | None = None, + timeout: int = 15) -> tuple[int, str, str]: + env = { + "GATHM_CACHE_DIR": self._cache, + "GATHM_LOG_DIR": self._logs, + "GATHM_LOG_LEVEL": "ERROR", + "GATHM_CACHE_ENABLED": "true", + } + if extra_env: + env.update(extra_env) + preamble = f""" +export GATHM_CACHE_DIR="{self._cache}" +export GATHM_LOG_DIR="{self._logs}" +export GATHM_LOG_LEVEL="ERROR" +export GATHM_CACHE_ENABLED="${{GATHM_CACHE_ENABLED:-true}}" +{_lib("logging", "cache")} +""" + return _bash(preamble + snippet, env=env, timeout=timeout) + + def test_miss_on_empty_cache(self): + rc, _, _ = self._run("cache_get my_tool argA argB") + self.assertEqual(rc, 1) + + def test_hit_after_set(self): + rc, out, _ = self._run(""" +echo "hello world" | cache_set my_tool argA +cache_get my_tool argA +""") + self.assertEqual(rc, 0) + self.assertEqual(out, "hello world") + + def test_miss_after_ttl_expiry(self): + rc, _, _ = self._run(""" +echo "ephemeral" | GATHM_CACHE_TTL=1 cache_set my_tool expiring +sleep 2 +cache_get my_tool expiring +""", timeout=20) + self.assertEqual(rc, 1, "Expired entry should be a cache miss") + + def test_invalidate_removes_entry(self): + rc, _, _ = self._run(""" +echo "gone" | cache_set my_tool key_to_kill +cache_invalidate my_tool key_to_kill +cache_get my_tool key_to_kill +""") + self.assertEqual(rc, 1) + + def test_cache_disabled_always_misses(self): + rc, _, _ = self._run(""" +echo "should_not_cache" | cache_set my_tool arg1 +cache_get my_tool arg1 +""", extra_env={"GATHM_CACHE_ENABLED": "false"}) + self.assertEqual(rc, 1) + + def test_different_args_use_different_keys(self): + _, out, _ = self._run(""" +echo "output_a" | cache_set my_tool args_a +echo "output_b" | cache_set my_tool args_b +a=$(cache_get my_tool args_a) +b=$(cache_get my_tool args_b) +echo "$a:$b" +""") + self.assertEqual(out, "output_a:output_b") + + def test_cleanup_removes_expired_entries(self): + _, out, _ = self._run(""" +echo "dies" | GATHM_CACHE_TTL=1 cache_set my_tool expiry_key +sleep 2 +count=$(cache_cleanup) +echo "$count" +""", timeout=20) + self.assertGreaterEqual(int(out), 1, "cache_cleanup should report at least 1 removed entry") + + def test_stats_counts_active_entries(self): + _, out, _ = self._run(""" +echo "a" | cache_set stat_tool k1 +echo "b" | cache_set stat_tool k2 +cache_stats +""") + stats = json.loads(out) + self.assertGreaterEqual(stats["active"], 2) + self.assertIn("total_entries", stats) + + def test_cache_invalidate_tool_removes_all_for_tool(self): + rc, _, _ = self._run(""" +echo "x" | cache_set target_tool argX +echo "y" | cache_set target_tool argY +echo "z" | cache_set other_tool argZ +cache_invalidate_tool target_tool +# both target_tool entries should be gone +cache_get target_tool argX +""") + self.assertEqual(rc, 1) + + +# --------------------------------------------------------------------------- +# TestFallbackGuard +# --------------------------------------------------------------------------- + +class TestFallbackGuard(unittest.TestCase): + """Tests cycle detection and depth cap added to execute_with_recovery.""" + + def setUp(self): + self._tmp = tempfile.TemporaryDirectory(prefix="gathm-fg-") + self._health = os.path.join(self._tmp.name, "health") + self._logs = os.path.join(self._tmp.name, "logs") + os.makedirs(self._health) + os.makedirs(self._logs) + + def tearDown(self): + self._tmp.cleanup() + + def _run(self, snippet: str) -> tuple[int, str, str]: + preamble = f""" +export GATHM_LOG_DIR="{self._logs}" +export GATHM_HEALTH_DIR="{self._health}" +export GATHM_LOG_LEVEL="ERROR" +export GATHM_MAX_RETRIES=1 + +{_lib("logging", "health")} + +# Stubs — cb is always open so we reach the fallback path immediately +cb_allow_request() {{ return 1; }} +cb_record_success() {{ :; }} +cb_record_failure() {{ :; }} +auto_install_deps() {{ :; }} +retry_with_backoff() {{ shift; return 1; }} + +{_lib("recovery")} +""" + return _bash(preamble + snippet) + + def test_depth_limit_blocks_at_max(self): + _, out, err = self._run(""" +output=$(_GATHM_FALLBACK_DEPTH=2 _GATHM_FALLBACK_CHAIN="a:b" \ + execute_with_recovery tool_c 2>&1) +echo "$output" +""") + combined = (out + err).lower() + self.assertIn("depth", combined, "Should report depth exceeded") + + def test_cycle_detection_blocks_revisit(self): + _, out, err = self._run(""" +output=$(_GATHM_FALLBACK_CHAIN="tool_a:tool_b" \ + execute_with_recovery tool_a 2>&1) +echo "$output" +""") + combined = (out + err).lower() + self.assertIn("cycle", combined, "Should report cycle detected") + + def test_first_call_has_no_guard_overhead(self): + _, out, err = self._run(""" +output=$(execute_with_recovery nonexistent_tool 2>&1) +echo "$output" +""") + combined = (out + err).lower() + self.assertNotIn("cycle", combined) + self.assertNotIn("fallback_depth", combined) + + def test_chain_accumulates_across_hops(self): + _, out, err = self._run(""" +output=$(_GATHM_FALLBACK_DEPTH=1 _GATHM_FALLBACK_CHAIN="tool_a" \ + execute_with_recovery tool_b 2>&1) +# Depth 1 < max(2), no cycle — should fail normally (no cycle/depth message) +echo "${output}" | grep -v "cycle" | grep -v "depth" && echo OK || echo GUARDED +""") + self.assertIn("OK", out + err) + + +# --------------------------------------------------------------------------- +# TestAPIRateLimit +# --------------------------------------------------------------------------- + +try: + from api.server import check_rate_limit, _rate_windows + _FASTAPI_AVAILABLE = True +except Exception: + _FASTAPI_AVAILABLE = False + + +@unittest.skipUnless(_FASTAPI_AVAILABLE, "FastAPI not installed — skipping API rate-limit tests") +class TestAPIRateLimit(unittest.TestCase): + """Tests the sliding-window rate limiter in api/server.py.""" + + def _fresh_key(self) -> str: + return f"test-{time.time()}-{id(self)}" + + def test_allows_requests_under_limit(self): + key = self._fresh_key() + for _ in range(5): + self.assertTrue(check_rate_limit(key, 10)) + + def test_blocks_at_limit(self): + key = self._fresh_key() + for _ in range(5): + check_rate_limit(key, 5) + self.assertFalse(check_rate_limit(key, 5), "6th request should be blocked") + + def test_zero_limit_is_unlimited(self): + key = self._fresh_key() + for _ in range(200): + self.assertTrue(check_rate_limit(key, 0)) + + def test_expired_window_entries_are_evicted(self): + key = self._fresh_key() + _rate_windows[key] = [time.monotonic() - 70.0] * 5 + self.assertTrue(check_rate_limit(key, 5)) + + def test_partial_window_expiry(self): + key = self._fresh_key() + _rate_windows[key] = [time.monotonic() - 70.0] * 5 + for _ in range(5): + self.assertTrue(check_rate_limit(key, 5)) + self.assertFalse(check_rate_limit(key, 5), "Should be blocked at limit") + + def test_independent_keys_do_not_interfere(self): + a = self._fresh_key() + b = self._fresh_key() + for _ in range(5): + check_rate_limit(a, 5) + self.assertTrue(check_rate_limit(b, 5)) + + +# --------------------------------------------------------------------------- +# TestJobStore (in-process, no HTTP server needed) +# --------------------------------------------------------------------------- + +@unittest.skipUnless(_FASTAPI_AVAILABLE, "FastAPI not installed — skipping job store tests") +class TestJobStore(unittest.TestCase): + """Tests Job dataclass and serialisation — no event loop required.""" + + def test_job_default_status_is_pending(self): + from api.server import Job, JobStatus + job = Job(id="abc123", kind="tool", tool="dns", args=["example.com"], timeout=5) + self.assertEqual(job.status, JobStatus.pending) + + def test_job_to_dict_has_required_fields(self): + from api.server import Job, _job_to_dict + job = Job(id="xyz", kind="tool", tool="geo", args=[], timeout=5) + d = _job_to_dict(job) + for f in ("id", "kind", "tool", "args", "status", + "created_at", "exit_code", "output", "line_count"): + self.assertIn(f, d, f"Missing field: {f}") + + def test_job_output_lines_join_to_output(self): + from api.server import Job, _job_to_dict + job = Job(id="j1", kind="tool", tool="dns", args=[], timeout=5) + job.output_lines = ["line one", "line two"] + d = _job_to_dict(job) + self.assertEqual(d["output"], "line one\nline two") + self.assertEqual(d["line_count"], 2) + + def test_job_ids_generated_are_unique(self): + import uuid + ids = {uuid.uuid4().hex for _ in range(50)} + self.assertEqual(len(ids), 50) + + def test_job_statuses_are_string_values(self): + from api.server import JobStatus + for status in JobStatus: + self.assertIsInstance(status.value, str) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/jukebox/tool.yaml b/tools/jukebox/tool.yaml index 9bacef5..0845577 100644 --- a/tools/jukebox/tool.yaml +++ b/tools/jukebox/tool.yaml @@ -9,7 +9,7 @@ dependencies: apis: - name: telehack base_url: telehack.com - health_endpoint: null + health_endpoint: "" auth_required: false input_schema: diff --git a/tools/validate_manifests.py b/tools/validate_manifests.py new file mode 100644 index 0000000..d1052f0 --- /dev/null +++ b/tools/validate_manifests.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python3 +""" +Gathm Enterprise — Tool Manifest Validator +Validates every tools/*/tool.yaml against a strict Pydantic schema. + +Usage: + python3 tools/validate_manifests.py # validate all manifests + python3 tools/validate_manifests.py dns geo # validate specific tools + +Exit code 0 = all valid, 1 = one or more errors. +""" + +from __future__ import annotations + +import sys +from pathlib import Path +from typing import Any, Literal, Optional + +try: + import yaml +except ImportError: + print("ERROR: pyyaml is required. Run: pip install pyyaml", file=sys.stderr) + sys.exit(1) + +try: + from pydantic import BaseModel, Field, field_validator, model_validator + from pydantic import ValidationError + PYDANTIC_V2 = True +except ImportError: + print("ERROR: pydantic>=2.0 is required. Run: pip install pydantic", file=sys.stderr) + sys.exit(1) + + +TOOLS_DIR = Path(__file__).resolve().parent +GATHM_ROOT = TOOLS_DIR.parent + +VALID_CATEGORIES = { + "security", "networking", "data", "media", + "utility", "finance", "science", "productivity", +} + +VALID_TYPES = {"string", "integer", "number", "boolean", "array", "object"} + + +# --------------------------------------------------------------------------- +# Schema models +# --------------------------------------------------------------------------- + +class Argument(BaseModel): + name: str + type: str = "string" + required: bool = False + description: str = "" + + @field_validator("type") + @classmethod + def valid_type(cls, v: str) -> str: + if v not in VALID_TYPES: + raise ValueError(f"type must be one of {sorted(VALID_TYPES)}, got '{v}'") + return v + + +class Flag(BaseModel): + flag: str + description: str = "" + has_argument: bool = False + + +class InputSchema(BaseModel): + arguments: list[Argument] = [] + flags: list[Flag] = [] + + +class JsonField(BaseModel): + name: str + type: str = "string" + + @field_validator("type") + @classmethod + def valid_type(cls, v: str) -> str: + if v not in VALID_TYPES: + raise ValueError(f"type must be one of {sorted(VALID_TYPES)}, got '{v}'") + return v + + +class OutputSchema(BaseModel): + text: str = "" + json_fields: list[JsonField] = [] + + +class ApiDep(BaseModel): + name: str + base_url: str + health_endpoint: str = "" + auth_required: bool = False + + +class Dependencies(BaseModel): + system: list[str] = [] + python: list[str] = [] + + +class ToolManifest(BaseModel): + name: str = Field(..., min_length=1) + version: str = Field(..., pattern=r"^\d+\.\d+\.\d+$") + description: str = Field(..., min_length=5) + category: str | list[str] + + dependencies: Dependencies = Field(default_factory=Dependencies) + apis: list[ApiDep] = [] + input_schema: Optional[InputSchema] = None + output_schema: Optional[OutputSchema] = None + fallback_tool: Optional[str] = None + tags: list[str] = [] + + @field_validator("category") + @classmethod + def valid_category(cls, v: str | list) -> str | list: + cats = [v] if isinstance(v, str) else v + for cat in cats: + if cat not in VALID_CATEGORIES: + raise ValueError( + f"category '{cat}' is not valid. " + f"Must be one of: {sorted(VALID_CATEGORIES)}" + ) + return v + + @model_validator(mode="after") + def no_self_fallback(self) -> "ToolManifest": + if self.fallback_tool and self.fallback_tool == self.name: + raise ValueError(f"fallback_tool cannot point to itself ('{self.name}')") + return self + + +# --------------------------------------------------------------------------- +# Validator runner +# --------------------------------------------------------------------------- + +def validate_manifest(path: Path) -> list[str]: + """Return a list of error strings (empty = valid).""" + errors: list[str] = [] + try: + with open(path) as f: + data = yaml.safe_load(f) + if not isinstance(data, dict): + return [f"YAML root must be a mapping, got {type(data).__name__}"] + ToolManifest.model_validate(data) + except ValidationError as exc: + for e in exc.errors(): + loc = " → ".join(str(x) for x in e["loc"]) + errors.append(f"{loc}: {e['msg']}") + except yaml.YAMLError as exc: + errors.append(f"YAML parse error: {exc}") + except Exception as exc: + errors.append(f"Unexpected error: {exc}") + return errors + + +def run(tool_names: list[str] | None = None) -> int: + """Validate manifests. Returns exit code.""" + if tool_names: + paths = [TOOLS_DIR / name / "tool.yaml" for name in tool_names] + else: + paths = sorted(TOOLS_DIR.glob("*/tool.yaml")) + + total = len(paths) + ok = 0 + failed = 0 + + for path in paths: + rel = path.relative_to(GATHM_ROOT) + if not path.exists(): + print(f"MISSING {rel}") + failed += 1 + continue + errors = validate_manifest(path) + if errors: + print(f"FAIL {rel}") + for e in errors: + print(f" {e}") + failed += 1 + else: + print(f"OK {rel}") + ok += 1 + + print() + print(f"Results: {ok}/{total} passed, {failed} failed") + + if failed > 0: + return 1 + return 0 + + +if __name__ == "__main__": + names = sys.argv[1:] or None + sys.exit(run(names))