diff --git a/backend/secuscan/auth.py b/backend/secuscan/auth.py new file mode 100644 index 00000000..28c39ab5 --- /dev/null +++ b/backend/secuscan/auth.py @@ -0,0 +1,71 @@ +""" +API key authentication for SecuScan backend. + +A random key is generated at startup and written to /.api_key. +Clients must supply it via: + - Authorization: Bearer + - X-Api-Key: +""" + +import secrets +from pathlib import Path + +from fastapi import Depends, HTTPException, Security, status +from fastapi.security import APIKeyHeader, HTTPAuthorizationCredentials, HTTPBearer + +_bearer_scheme = HTTPBearer(auto_error=False) +_api_key_header = APIKeyHeader(name="X-Api-Key", auto_error=False) + +_api_key: str | None = None + + +def init_api_key(data_dir: str) -> str: + """ + Load the persisted API key, or generate and persist a new one. + + Called once during application startup; the returned key is also stored in + the module-level ``_api_key`` variable so the FastAPI dependency can reach it. + """ + global _api_key + key_file = Path(data_dir) / ".api_key" + if key_file.exists(): + _api_key = key_file.read_text().strip() + else: + _api_key = secrets.token_hex(32) + key_file.parent.mkdir(parents=True, exist_ok=True) + key_file.write_text(_api_key) + key_file.chmod(0o600) + return _api_key + + +async def require_api_key( + bearer: HTTPAuthorizationCredentials | None = Depends(_bearer_scheme), + x_api_key: str | None = Security(_api_key_header), +) -> str: + """ + FastAPI dependency — rejects requests that do not carry the correct API key. + + Accepts the key in either: + - ``Authorization: Bearer `` + - ``X-Api-Key: `` + """ + if _api_key is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Authentication service not initialised", + ) + + candidate: str | None = None + if bearer is not None: + candidate = bearer.credentials + elif x_api_key is not None: + candidate = x_api_key + + if candidate is None or not secrets.compare_digest(candidate, _api_key): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or missing API key", + headers={"WWW-Authenticate": "Bearer"}, + ) + + return candidate diff --git a/backend/secuscan/main.py b/backend/secuscan/main.py index 08eb02c2..59c384c8 100644 --- a/backend/secuscan/main.py +++ b/backend/secuscan/main.py @@ -13,6 +13,7 @@ from fastapi.staticfiles import StaticFiles from .config import settings +from .auth import init_api_key from .cache import init_cache, cache as global_cache from .database import init_db, db as global_db from .plugins import init_plugins @@ -42,6 +43,10 @@ async def lifespan(app: FastAPI): # Ensure directories exist settings.ensure_directories() logger.info("✓ Directories initialized") + + # Initialize API key authentication + api_key = init_api_key(settings.data_dir) + logger.info("✓ API key authentication ready (key file: %s/.api_key)", settings.data_dir) # Initialize database await init_db(settings.database_path) diff --git a/backend/secuscan/routes.py b/backend/secuscan/routes.py index f1d53063..5efb8328 100644 --- a/backend/secuscan/routes.py +++ b/backend/secuscan/routes.py @@ -2,6 +2,7 @@ API routes for SecuScan backend """ +from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Response from fastapi import APIRouter, HTTPException, BackgroundTasks, Response, Request from fastapi.responses import JSONResponse from typing import Any, Optional, List, Dict, Callable @@ -32,12 +33,12 @@ def parse_json_fields(rows: List[Dict], fields: List[str]) -> List[Dict]: def is_filesystem_target(target: str) -> bool: """Best-effort detection for path-based targets that should bypass host validation.""" - if target.startswith(("/", "./", "../", "~")): + # Absolute or relative filesystem roots only — not CIDR notation (e.g. 8.8.8.8/32) + if target.startswith(("/", "./", "../", "~/")): return True + # Windows drive paths (C:\ or C:/) if re.match(r"^[A-Za-z]:[\\/]", target): return True - if "/" in target and not target.startswith(("http://", "https://")): - return True return False @@ -76,10 +77,11 @@ def build_report_filename(task: Dict[str, Any], extension: str) -> str: from .reporting import reporting from .vault import VaultCrypto from .workflows import scheduler +from .auth import require_api_key from sse_starlette.sse import EventSourceResponse -router = APIRouter(prefix="/api/v1") +router = APIRouter(prefix="/api/v1", dependencies=[Depends(require_api_key)]) async def get_or_set_cached(key: str, builder): diff --git a/testing/backend/conftest.py b/testing/backend/conftest.py index 4805de89..0a74d4e6 100644 --- a/testing/backend/conftest.py +++ b/testing/backend/conftest.py @@ -14,6 +14,7 @@ from backend.secuscan.database import init_db from backend.secuscan.main import app from backend.secuscan.plugins import init_plugins +from backend.secuscan import auth as auth_module from backend.secuscan.ratelimit import concurrent_limiter, rate_limiter @@ -50,7 +51,9 @@ async def setup(): asyncio.run(setup()) - with TestClient(app) as client: + api_key = auth_module.init_api_key(settings.data_dir) + + with TestClient(app, headers={"X-Api-Key": api_key}) as client: yield client async def teardown(): diff --git a/testing/backend/integration/test_task_cleanup.py b/testing/backend/integration/test_task_cleanup.py index 6fbb88af..9f4526e3 100644 --- a/testing/backend/integration/test_task_cleanup.py +++ b/testing/backend/integration/test_task_cleanup.py @@ -41,6 +41,8 @@ async def app_client(db_path): from backend.secuscan.main import app from backend.secuscan import database as db_module from backend.secuscan import cache as cache_module + from backend.secuscan import auth as auth_module + import tempfile # Initialise a real in-memory cache (it's just a dict, no external deps) await cache_module.init_cache() @@ -48,13 +50,19 @@ async def app_client(db_path): # Initialise a fresh DB pointing at our temp file test_db = await db_module.init_db(db_path) - async with AsyncClient( - transport=ASGITransport(app=app), base_url="http://test" - ) as client: - client._mock_executor = mock_executor - client._db = test_db - client._db_path = db_path - yield client + # Initialise API key in a temporary directory so the dependency resolves + with tempfile.TemporaryDirectory() as tmp_auth_dir: + api_key = auth_module.init_api_key(tmp_auth_dir) + + async with AsyncClient( + transport=ASGITransport(app=app), + base_url="http://test", + headers={"X-Api-Key": api_key}, + ) as client: + client._mock_executor = mock_executor + client._db = test_db + client._db_path = db_path + yield client # Teardown await test_db.disconnect() diff --git a/testing/backend/test_task_pagination.py b/testing/backend/test_task_pagination.py index 3174fa3e..97f61c4a 100644 --- a/testing/backend/test_task_pagination.py +++ b/testing/backend/test_task_pagination.py @@ -3,75 +3,57 @@ """ import pytest -from fastapi.testclient import TestClient -from backend.secuscan.main import app -from backend.secuscan.database import init_db - -# IMPORTANT: Initialize database before any tests run -@pytest.fixture(scope="session", autouse=True) -def setup_database(): - """Initialize database for testing""" - import asyncio - asyncio.run(init_db()) - -client = TestClient(app) class TestTasksPagination: """Test pagination metadata for /api/v1/tasks endpoint""" - def test_pagination_has_next_previous_fields(self): + def test_pagination_has_next_previous_fields(self, test_client): """Test that next and previous fields exist in response""" - response = client.get("/api/v1/tasks") + response = test_client.get("/api/v1/tasks") - # Check if we got a response if response.status_code == 200: data = response.json() assert "pagination" in data pagination = data["pagination"] - # These fields should always exist assert "next" in pagination assert "previous" in pagination assert "page" in pagination assert "per_page" in pagination assert "total_items" in pagination assert "total_pages" in pagination - print("✅ All pagination fields present!") else: pytest.fail(f"API returned {response.status_code}") - def test_default_pagination_values(self): + def test_default_pagination_values(self, test_client): """Test default page=1, per_page=25""" - response = client.get("/api/v1/tasks") + response = test_client.get("/api/v1/tasks") assert response.status_code == 200 pagination = response.json()["pagination"] assert pagination["page"] == 1 assert pagination["per_page"] == 25 - print(f"✅ Default values: page={pagination['page']}, per_page={pagination['per_page']}") - def test_custom_per_page(self): + def test_custom_per_page(self, test_client): """Test that per_page parameter is respected""" - response = client.get("/api/v1/tasks?page=1&per_page=10") + response = test_client.get("/api/v1/tasks?page=1&per_page=10") assert response.status_code == 200 pagination = response.json()["pagination"] assert pagination["per_page"] == 10 - print(f"✅ Custom per_page=10 works") - def test_first_page_previous_is_null(self): + def test_first_page_previous_is_null(self, test_client): """Test that previous is None on first page""" - response = client.get("/api/v1/tasks?page=1&per_page=10") + response = test_client.get("/api/v1/tasks?page=1&per_page=10") assert response.status_code == 200 pagination = response.json()["pagination"] assert pagination["previous"] is None - print("✅ First page has previous=None") - def test_next_url_preserves_filters(self): + def test_next_url_preserves_filters(self, test_client): """Test that next URL keeps filter parameters""" - response = client.get( + response = test_client.get( "/api/v1/tasks?page=1&per_page=5&status=completed&plugin_id=nmap" ) assert response.status_code == 200 @@ -79,10 +61,7 @@ def test_next_url_preserves_filters(self): data = response.json() next_url = data["pagination"]["next"] - if next_url: # If there are more pages + if next_url: assert "per_page=5" in next_url assert "status=completed" in next_url - assert "plugin_id=nmap" in next_url - print(f"✅ Next URL preserves filters: {next_url}") - else: - print("ℹ️ No next page (database might be empty)") \ No newline at end of file + assert "plugin_id=nmap" in next_url \ No newline at end of file diff --git a/testing/backend/unit/test_api_auth.py b/testing/backend/unit/test_api_auth.py new file mode 100644 index 00000000..d53c93f6 --- /dev/null +++ b/testing/backend/unit/test_api_auth.py @@ -0,0 +1,106 @@ +""" +Unit tests for API key authentication (issue #199). +""" + +import asyncio +import tempfile +from pathlib import Path + +import pytest +from fastapi.testclient import TestClient + +from backend.secuscan import auth as auth_module +from backend.secuscan.main import app +from backend.secuscan.config import settings +from backend.secuscan.database import init_db +from backend.secuscan.plugins import init_plugins + + +@pytest.fixture() +def client_with_key(setup_test_environment): + """TestClient with a valid API key pre-seeded.""" + asyncio.run(init_db(settings.database_path)) + asyncio.run(init_plugins(settings.plugins_dir)) + api_key = auth_module.init_api_key(settings.data_dir) + with TestClient(app) as c: + yield c, api_key + + +class TestApiKeyInit: + def test_key_file_created(self, tmp_path): + key = auth_module.init_api_key(str(tmp_path)) + assert (tmp_path / ".api_key").exists() + assert len(key) == 64 # 32 bytes → 64 hex chars + + def test_existing_key_reloaded(self, tmp_path): + k1 = auth_module.init_api_key(str(tmp_path)) + k2 = auth_module.init_api_key(str(tmp_path)) + assert k1 == k2 + + def test_key_file_permissions(self, tmp_path): + auth_module.init_api_key(str(tmp_path)) + mode = (tmp_path / ".api_key").stat().st_mode & 0o777 + assert mode == 0o600 + + +class TestAuthDependency: + def test_no_credentials_returns_401(self, client_with_key): + client, _ = client_with_key + resp = client.get("/api/v1/plugins", headers={}) + assert resp.status_code == 401 + + def test_wrong_key_returns_401(self, client_with_key): + client, _ = client_with_key + resp = client.get("/api/v1/plugins", headers={"X-Api-Key": "wrong-key"}) + assert resp.status_code == 401 + + def test_valid_x_api_key_header(self, client_with_key): + client, api_key = client_with_key + resp = client.get("/api/v1/plugins", headers={"X-Api-Key": api_key}) + assert resp.status_code == 200 + + def test_valid_bearer_token(self, client_with_key): + client, api_key = client_with_key + resp = client.get("/api/v1/plugins", headers={"Authorization": f"Bearer {api_key}"}) + assert resp.status_code == 200 + + def test_bearer_wrong_key_returns_401(self, client_with_key): + client, _ = client_with_key + resp = client.get("/api/v1/plugins", headers={"Authorization": "Bearer bad"}) + assert resp.status_code == 401 + + def test_health_endpoint_not_protected(self, client_with_key): + client, _ = client_with_key + resp = client.get("/api/v1/health", headers={}) + # health check is defined on `app` directly, not inside the authenticated router + assert resp.status_code == 200 + + def test_root_endpoint_not_protected(self, client_with_key): + client, _ = client_with_key + resp = client.get("/", headers={}) + assert resp.status_code == 200 + + +class TestIsFilesystemTarget: + """Regression tests for is_filesystem_target — CIDR must not be treated as a path.""" + + from backend.secuscan.routes import is_filesystem_target + + @pytest.mark.parametrize("target,expected", [ + ("/etc/passwd", True), + ("./relative/path", True), + ("../parent/path", True), + ("~/home/dir", True), + ("C:\\Windows\\System32", True), + ("C:/Windows/System32", True), + # These are NOT filesystem targets + ("8.8.8.8/32", False), + ("192.168.1.0/24", False), + ("example.com", False), + ("http://example.com/path", False), + ("https://example.com/path", False), + ("10.0.0.1", False), + ]) + def test_filesystem_target_detection(self, target, expected): + from backend.secuscan.routes import is_filesystem_target + assert is_filesystem_target(target) == expected