diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 95f5510..0a04df2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,7 +34,13 @@ jobs: run: | echo "=== Checking Python syntax on scanner/rules/ ===" FAIL=0 - for f in scanner/rules/az_*.py; do + shopt -s nullglob + files=(scanner/rules/az_*.py) + if [ ${#files[@]} -eq 0 ]; then + echo "ERROR: No rule files found matching scanner/rules/az_*.py" + exit 1 + fi + for f in "${files[@]}"; do if ! python -m py_compile "$f" 2>&1; then echo "SYNTAX ERROR: $f" FAIL=1 @@ -137,7 +143,7 @@ jobs: grep -v "\.env" | \ grep -v "os\.environ" | \ grep -v "os\.getenv" | \ - grep -v "#" | \ + grep -vE '^\s*#' | \ grep -v "example" | \ grep -v "placeholder" || true) @@ -161,7 +167,13 @@ jobs: run: | echo "=== Checking playbooks exist and are valid bash ===" FAIL=0 - for rule_file in scanner/rules/az_*.py; do + shopt -s nullglob + files=(scanner/rules/az_*.py) + if [ ${#files[@]} -eq 0 ]; then + echo "ERROR: No rule files found matching scanner/rules/az_*.py" + exit 1 + fi + for rule_file in "${files[@]}"; do filename=$(basename "$rule_file" .py) playbook="playbooks/cli/fix_${filename}.sh" @@ -287,7 +299,8 @@ jobs: continue fpath = os.path.join(framework_dir, fname) try: - data = json.load(open(fpath)) + with open(fpath) as f: + data = json.load(f) except (json.JSONDecodeError, OSError): continue diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml new file mode 100644 index 0000000..88bebc7 --- /dev/null +++ b/.github/workflows/deploy.yml @@ -0,0 +1,109 @@ +name: Deploy API to Render + +on: + push: + branches: + - dev + - main + workflow_dispatch: # allows manual trigger from GitHub UI + +jobs: + deploy: + name: Deploy to Render + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + # ── Dependency caching ───────────────────────────────────────────── + - name: Cache pip dependencies + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip- + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + # ── Secret check (Determines if smoke tests should run) ─────────── + - name: Check for JWT_SECRET + id: check_config + run: | + if [ -n "${{ secrets.JWT_SECRET }}" ]; then + echo "is_configured=true" >> $GITHUB_OUTPUT + else + echo "is_configured=false" >> $GITHUB_OUTPUT + fi + + # ── Wait for Render auto-deployment ──────────────────────────────── + # Render handles the actual physical deployment when you push. + # We just pause the Action to let Render's servers finish building. + - name: Wait for app to initialise + run: | + echo "Waiting 120 seconds for Render to build and start the app..." + sleep 120 + + # ── Health gate ──────────────────────────────────────────────────── + - name: Health gate check + id: health_gate + env: + # Use secret URL if provided, otherwise fallback to default + API_URL: ${{ secrets.API_URL || 'https://openshield-api.onrender.com' }} + run: | + MAX_RETRIES=5 + RETRY_DELAY=15 + URL="${API_URL}/health" + + echo "Pinging health gate at: $URL" + for i in $(seq 1 $MAX_RETRIES); do + echo "Health check attempt $i of $MAX_RETRIES..." + HTTP_STATUS=$(curl -s -o /dev/null -w "%{http_code}" "$URL" --max-time 30) || true + + if [ "$HTTP_STATUS" -eq 200 ]; then + echo "Health check passed (HTTP $HTTP_STATUS)" + exit 0 + fi + + echo "Got HTTP $HTTP_STATUS — retrying in ${RETRY_DELAY}s..." + sleep $RETRY_DELAY + done + + echo "HEALTH GATE FAILED after $MAX_RETRIES attempts" + echo "Note: If you haven't set up Render for this fork, this is expected." + # Only allow failure on feature branches; fail on main/dev + if [[ "${{ github.ref }}" == "refs/heads/main" || "${{ github.ref }}" == "refs/heads/dev" ]]; then + echo "ERROR: Health check failed on protected branch. Deployment verification required." + exit 1 + else + echo "Allowing health check failure on feature branch (infra may not be set up)" + exit 0 + fi + + # ── Smoke tests ──────────────────────────────────────────────────── + - name: Run smoke tests against live deployment + if: steps.check_config.outputs.is_configured == 'true' || github.event_name == 'workflow_dispatch' + env: + API_URL: ${{ secrets.API_URL || 'https://openshield-api.onrender.com' }} + JWT_SECRET: ${{ secrets.JWT_SECRET || 'change-me-in-production' }} + AZURE_SUBSCRIPTION_ID: ${{ secrets.AZURE_SUBSCRIPTION_ID }} + AZURE_CLIENT_ID: ${{ secrets.AZURE_CLIENT_ID }} + AZURE_CLIENT_SECRET: ${{ secrets.AZURE_CLIENT_SECRET }} + AZURE_TENANT_ID: ${{ secrets.AZURE_TENANT_ID }} + RUN_REAL_SCAN: "true" + run: | + if [[ "${{ github.ref }}" == "refs/heads/main" && -z "${{ secrets.JWT_SECRET }}" ]]; then + echo "ERROR: Cannot run smoke tests on main branch without JWT_SECRET configured" + exit 1 + fi + echo "Running smoke tests against: $API_URL" + python tests/smoke_test.py \ No newline at end of file diff --git a/README.md b/README.md index 2e723f1..d4dca8b 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,19 @@ flowchart TD I -->|alerts| A ``` +## Live API + +The OpenShield API is deployed to the Render free tier and is accessible at: + +**`https://openshield-api.onrender.com`** + +> **Note:** As this is hosted on the Render free tier, the service may spin down after 15 minutes of inactivity. The first request after a spin-down can take 30-60 seconds to complete. + +> [!IMPORTANT] +> **Security Requirement:** For absolute security, any production deployment **must** override the default `JWT_SECRET` with a strong, unique value in the environment variables. + +--- + ## Tech Stack | Layer | Technology | Cost | diff --git a/api/app.py b/api/app.py index b36605f..691bfe4 100644 --- a/api/app.py +++ b/api/app.py @@ -8,6 +8,8 @@ from flask import Flask, g, jsonify, request from flask_cors import CORS +from api.models.finding import DatabaseManager + load_dotenv() logging.basicConfig( @@ -28,14 +30,49 @@ def create_app() -> Flask: - JWT authentication middleware on all non-public routes - Blueprints for findings, scans, score, and compliance - JSON error handlers for 400, 401, 403, 404, and 500 + - Global database connection teardown """ app = Flask(__name__) - app.config["JWT_SECRET"] = os.environ.get("JWT_SECRET", "change-me-in-production") + + # ------------------------------------------------------------------ # + # Configuration & Security # + # ------------------------------------------------------------------ # + jwt_key = os.environ.get("JWT_SECRET") + if not jwt_key: + logger.warning( + "!!! SECURITY WARNING: JWT_SECRET NOT SET. USING INSECURE DEFAULT !!! " + "For production deployments, you MUST set a strong, unique JWT_SECRET." + ) + jwt_key = "change-me-in-production" + app.config["JWT_SECRET"] = jwt_key # ------------------------------------------------------------------ # # CORS # # ------------------------------------------------------------------ # - CORS(app, resources={r"/api/*": {"origins": "*"}}) + allowed_origins_raw = os.environ.get("ALLOWED_ORIGINS", "*") + if allowed_origins_raw == "*": + logger.warning( + "!!! SECURITY WARNING: ALLOWED_ORIGINS NOT SET. DEFAULTING TO '*' !!! " + "For production deployments, set this to your specific frontend domain(s)." + ) + allowed_origins = allowed_origins_raw.split(",") + CORS(app, resources={r"/api/*": {"origins": allowed_origins}}) + + # ------------------------------------------------------------------ # + # Database Management # + # ------------------------------------------------------------------ # + + @app.teardown_appcontext + def close_db(error): + """Ensure the database connection is closed after the request.""" + db = g.pop("db_conn", None) + if db is not None: + try: + if hasattr(db, "conn") and db.conn is not None: + db.conn.close() + logger.debug("Database connection closed gracefully") + except Exception as exc: + logger.error("Error closing database connection: %s", exc) # ------------------------------------------------------------------ # # JWT middleware # @@ -82,9 +119,18 @@ def verify_jwt() -> None: app.register_blueprint(compliance_bp) # ------------------------------------------------------------------ # - # Health check (public) # + # Routes (public) # # ------------------------------------------------------------------ # + @app.get("/") + def index(): + return jsonify({ + "message": "Welcome to the OpenShield REST API", + "version": "1.0.0", + "docs": "/docs", + "status": "online" + }) + @app.get("/health") def health(): return jsonify({"status": "ok"}) @@ -118,8 +164,9 @@ def internal_error(exc): return app +application = create_app() + if __name__ == "__main__": - application = create_app() application.run( host="0.0.0.0", port=int(os.environ.get("PORT", 5000)), diff --git a/api/models/finding.py b/api/models/finding.py index 8cdab3f..7b2eda7 100644 --- a/api/models/finding.py +++ b/api/models/finding.py @@ -80,10 +80,16 @@ def __init__(self, dsn: Optional[str] = None) -> None: # ------------------------------------------------------------------ # def connect(self) -> None: - """Open a persistent database connection.""" + """Open a persistent database connection and set the search path.""" self.conn = psycopg2.connect(self.dsn) + self.conn.autocommit = True # Set to True for schema management + with self.conn.cursor() as cur: + # Ensure the openshield schema exists and is preferred in the search path. + # This avoids 'permission denied for schema public' in restricted environments. + cur.execute("CREATE SCHEMA IF NOT EXISTS openshield;") + cur.execute("SET search_path TO openshield, public;") self.conn.autocommit = False - logger.info("Database connection established") + logger.info("Database connection established (schema: openshield)") def _get_conn(self) -> Any: if self.conn is None or self.conn.closed: @@ -94,6 +100,10 @@ def _get_conn(self) -> Any: # Schema # # ------------------------------------------------------------------ # + def init_db(self) -> None: + """Alias for create_tables to match startup script expectations.""" + self.create_tables() + def create_tables(self) -> None: """Create the findings, scans, and rules tables if they do not exist.""" conn = self._get_conn() diff --git a/api/routes/compliance.py b/api/routes/compliance.py index 6a3b104..798f187 100644 --- a/api/routes/compliance.py +++ b/api/routes/compliance.py @@ -1,39 +1,46 @@ """Compliance routes: framework-specific posture breakdown.""" +import logging import os -from flask import Blueprint, jsonify +from flask import Blueprint, g, jsonify from api.models.finding import DatabaseManager compliance_bp = Blueprint("compliance", __name__) +logger = logging.getLogger(__name__) SUPPORTED_FRAMEWORKS = ("cis", "nist", "iso27001", "soc2") def _get_db() -> DatabaseManager: - db = DatabaseManager(os.environ["DATABASE_URL"]) - db.connect() - return db + if "db_conn" not in g: + g.db_conn = DatabaseManager(os.environ["DATABASE_URL"]) + g.db_conn.connect() + return g.db_conn @compliance_bp.get("/api/compliance/") def get_compliance(framework: str): """Return pass/fail compliance breakdown for a framework. - Supported frameworks: cis, nist, iso27001, soc2 + Supported frameworks: cis, nist, iso27001, soc2 Returns control-level pass/fail status mapped to current open findings. """ - if framework.lower() not in SUPPORTED_FRAMEWORKS: - return jsonify({ - "error": f"Unknown framework '{framework}'", - "supported": list(SUPPORTED_FRAMEWORKS), - }), 400 - - db = _get_db() - result = db.get_compliance_score(framework.lower()) - - if "error" in result: - return jsonify(result), 500 - - return jsonify(result) + try: + if framework.lower() not in SUPPORTED_FRAMEWORKS: + return jsonify({ + "error": f"Unknown framework '{framework}'", + "supported": list(SUPPORTED_FRAMEWORKS), + }), 400 + + db = _get_db() + result = db.get_compliance_score(framework.lower()) + + if "error" in result: + return jsonify(result), 500 + + return jsonify(result) + except Exception as exc: + logger.error("Failed to retrieve compliance score for %s: %s", framework, exc) + return jsonify({"error": "Compliance calculation failed", "detail": str(exc)}), 500 diff --git a/api/routes/findings.py b/api/routes/findings.py index fb8d755..917a23f 100644 --- a/api/routes/findings.py +++ b/api/routes/findings.py @@ -1,17 +1,20 @@ """Findings routes: list and retrieve individual findings.""" +import logging import os -from flask import Blueprint, jsonify, request +from flask import Blueprint, g, jsonify, request from api.models.finding import DatabaseManager findings_bp = Blueprint("findings", __name__) +logger = logging.getLogger(__name__) def _get_db() -> DatabaseManager: - db = DatabaseManager(os.environ["DATABASE_URL"]) - db.connect() - return db + if "db_conn" not in g: + g.db_conn = DatabaseManager(os.environ["DATABASE_URL"]) + g.db_conn.connect() + return g.db_conn @findings_bp.get("/api/findings") @@ -24,21 +27,29 @@ def list_findings(): rule_id — e.g. AZ-STOR-001 scan_id — UUID of a specific scan """ - filters = { - k: v - for k, v in request.args.items() - if k in ("severity", "category", "rule_id", "scan_id") - } - db = _get_db() - findings = db.get_findings(filters) - return jsonify({"count": len(findings), "findings": findings}) + try: + filters = { + k: v + for k, v in request.args.items() + if k in ("severity", "category", "rule_id", "scan_id") + } + db = _get_db() + findings = db.get_findings(filters) + return jsonify({"count": len(findings), "findings": findings}) + except Exception as exc: + logger.error("Failed to list findings: %s", exc) + return jsonify({"error": "Failed to retrieve findings", "detail": str(exc)}), 500 @findings_bp.get("/api/findings/") def get_finding(finding_id: int): """Return a single finding by its integer ID.""" - db = _get_db() - finding = db.get_finding_by_id(finding_id) - if not finding: - return jsonify({"error": "Finding not found"}), 404 - return jsonify(finding) + try: + db = _get_db() + finding = db.get_finding_by_id(finding_id) + if not finding: + return jsonify({"error": "Finding not found"}), 404 + return jsonify(finding) + except Exception as exc: + logger.error("Failed to get finding %d: %s", finding_id, exc) + return jsonify({"error": "Database error", "detail": str(exc)}), 500 diff --git a/api/routes/scans.py b/api/routes/scans.py index 85612a4..5aca891 100644 --- a/api/routes/scans.py +++ b/api/routes/scans.py @@ -2,7 +2,7 @@ import logging import os -from flask import Blueprint, jsonify, request +from flask import Blueprint, g, jsonify, request from api.models.finding import DatabaseManager @@ -11,17 +11,22 @@ def _get_db() -> DatabaseManager: - db = DatabaseManager(os.environ["DATABASE_URL"]) - db.connect() - return db + if "db_conn" not in g: + g.db_conn = DatabaseManager(os.environ["DATABASE_URL"]) + g.db_conn.connect() + return g.db_conn @scans_bp.get("/api/scans") def list_scans(): """Return all historical scan results ordered by most recent first.""" - db = _get_db() - scans = db.get_scans() - return jsonify({"count": len(scans), "scans": scans}) + try: + db = _get_db() + scans = db.get_scans() + return jsonify({"count": len(scans), "scans": scans}) + except Exception as exc: + logger.error("Failed to list scans: %s", exc) + return jsonify({"error": "Failed to retrieve scans", "detail": str(exc)}), 500 @scans_bp.post("/api/scans/trigger") @@ -34,27 +39,34 @@ def trigger_scan(): Note: For production use, replace this with an async task queue (e.g. Celery or Azure Functions) to avoid request timeouts on large subscriptions. """ - from scanner.engine import ScanEngine # deferred to avoid import at startup + try: + body = request.get_json(silent=True) or {} + subscription_id = body.get("subscription_id") - body = request.get_json(silent=True) or {} - subscription_id = body.get("subscription_id") or os.environ.get( - "AZURE_SUBSCRIPTION_ID" - ) + if not subscription_id: + return jsonify({"error": "subscription_id is required"}), 400 - if not subscription_id: - return jsonify({"error": "subscription_id is required"}), 400 + from scanner.engine import ScanEngine # deferred — import only after input is validated - logger.info("Scan triggered for subscription %s", subscription_id) + logger.info("Scan triggered for subscription %s", subscription_id) - try: - engine = ScanEngine(subscription_id) - result = engine.run_scan() - except Exception as exc: - logger.error("Scan failed: %s", exc) - return jsonify({"error": "Scan failed", "detail": str(exc)}), 500 + try: + engine = ScanEngine(subscription_id) + result = engine.run_scan() + except Exception as exc: + logger.error("Scan engine execution failed: %s", exc, exc_info=True) + return jsonify({"error": "Scan failed", "detail": str(exc)}), 500 + + try: + db = _get_db() + # Note: Table creation is handled at startup; no need to repeat it here. + db.save_scan(result) + except Exception as exc: + logger.error("Failed to save scan result to database: %s", exc, exc_info=True) + return jsonify({"error": "Database save failed", "detail": str(exc)}), 500 - db = _get_db() - db.create_tables() - db.save_scan(result) + return jsonify(result), 201 - return jsonify(result), 201 + except Exception as exc: + logger.error("Critical error in trigger_scan route: %s", exc, exc_info=True) + return jsonify({"error": "Critical route failure", "detail": str(exc)}), 500 diff --git a/api/routes/score.py b/api/routes/score.py index b7317ee..bfff526 100644 --- a/api/routes/score.py +++ b/api/routes/score.py @@ -1,17 +1,20 @@ """Score route: overall security posture score.""" +import logging import os -from flask import Blueprint, jsonify +from flask import Blueprint, g, jsonify from api.models.finding import DatabaseManager score_bp = Blueprint("score", __name__) +logger = logging.getLogger(__name__) def _get_db() -> DatabaseManager: - db = DatabaseManager(os.environ["DATABASE_URL"]) - db.connect() - return db + if "db_conn" not in g: + g.db_conn = DatabaseManager(os.environ["DATABASE_URL"]) + g.db_conn.connect() + return g.db_conn @score_bp.get("/api/score") @@ -22,6 +25,10 @@ def get_score(): Starts at 100. Deducts 10 per HIGH finding, 5 per MEDIUM, 2 per LOW. Floors at 0. """ - db = _get_db() - score = db.get_score() - return jsonify({"score": score, "max_score": 100}) + try: + db = _get_db() + score = db.get_score() + return jsonify({"score": score, "max_score": 100}) + except Exception as exc: + logger.error("Failed to calculate score: %s", exc) + return jsonify({"error": "Failed to calculate score", "detail": str(exc)}), 500 diff --git a/compliance/frameworks/cis_azure_benchmark.json b/compliance/frameworks/cis_azure_benchmark.json index 7634d33..f5c1989 100644 --- a/compliance/frameworks/cis_azure_benchmark.json +++ b/compliance/frameworks/cis_azure_benchmark.json @@ -73,6 +73,11 @@ "control_name": "Ensure that 'Multi-Factor Authentication Status' is 'Enabled' for all Privileged Users", "description": "Multi-Factor Authentication requires an individual to present a minimum of two separate forms of authentication before access is granted. MFA should be enforced for all users with administrative privileges via Conditional Access policies." }, + "AZ-IDN-003": { + "control_id": "1.15", + "control_name": "Ensure that 'Guest invite restrictions' is set to 'Only users assigned to specific admin roles can invite guest users'", + "description": "Unrestricted guest user invitation settings allow any member of the organisation to invite external users into the tenant without administrative review. This bypasses centralised approval for external identity provisioning and increases the risk of unauthorised access by untrusted parties." + }, "AZ-DB-001": { "control_id": "4.3.1", "control_name": "Ensure 'Allow access to Azure services' for PostgreSQL Database Server is disabled", @@ -88,25 +93,40 @@ "control_name": "Ensure that 'OS disk' are encrypted", "description": "Virtual machines that are reachable from the internet should have Network Security Groups attached to their network interfaces to control and restrict inbound and outbound traffic, reducing the attack surface." }, + "AZ-CMP-002": { + "control_id": "7.2", + "control_name": "Ensure that 'OS disk' are encrypted", + "description": "Virtual machine OS and data disks are using platform-managed encryption only (EncryptionAtRestWithPlatformKey). CIS 7.2 requires disks to be protected using customer-managed keys or Azure Disk Encryption. Platform-managed encryption does not give the organisation control over the encryption keys and does not satisfy this control." + }, "AZ-KV-001": { "control_id": "8.5", "control_name": "Ensure the Key Vault is Recoverable", - "description": "Azure Key Vault soft delete should be enabled on all Key Vaults. The soft delete feature allows recovery of deleted vaults and vault objects (keys, secrets, certificates) for a configurable retention period (7–90 days), protecting against accidental or malicious deletion." + "description": "Azure Key Vault soft delete should be enabled on all Key Vaults. The soft delete feature allows recovery of deleted vaults and vault objects (keys, secrets, certificates) for a configurable retention period (7\u201390 days), protecting against accidental or malicious deletion." }, "AZ-STOR-003": { "control_id": "3.7", "control_name": "Ensure that storage accounts have lifecycle management policies configured", "description": "Storage accounts without lifecycle management policies retain data indefinitely. This increases storage costs, expands the attack surface through accumulation of stale data, and may violate data retention compliance requirements. Lifecycle policies automate the transition and deletion of blobs based on age and access patterns." }, + "AZ-STOR-004": { + "control_id": "3.3", + "control_name": "Ensure Storage logging is enabled for Blob, Queue, and Table services for read, write, and delete requests", + "description": "Enabling diagnostic logging for Azure Storage blob, queue, and table services records read, write, and delete operations. Without logging, unauthorized access, data exfiltration, or destructive operations on storage services cannot be detected or investigated." + }, "AZ-KV-002": { "control_id": "8.3", "control_name": "Ensure that public network access to Key Vault is disabled", "description": "Azure Key Vault should not allow public network access unless absolutely necessary. Enabling public access increases the attack surface and exposes sensitive secrets, keys, and certificates to potential unauthorized access. Private endpoints should be used to restrict access to trusted networks." }, - "AZ-KV-003": { - "control_id": "8.4", - "control_name": "Ensure that diagnostic logging is enabled for Key Vault", - "description": "Diagnostic logging should be enabled for Azure Key Vault to record access and operations involving secrets, keys, and certificates. Without logging, unauthorized access attempts and malicious operations cannot be effectively detected or investigated." + "AZ-NET-011": { + "control_id": "6.5", + "control_name": "Ensure that Network Watcher is enabled in all regions", + "description": "Network Watcher should be enabled in all regions where Azure resources are deployed. Network Watcher provides network monitoring, diagnostics, and logging capabilities essential for investigating network-level incidents." + }, + "AZ-DB-003": { + "control_id": "4.3.6", + "control_name": "Ensure SSL connection is enabled for PostgreSQL Flexible Server", + "description": "SSL enforcement should be enabled on PostgreSQL Flexible Server to ensure data in transit is encrypted. Without SSL, database connections transmit data in plaintext, exposing it to interception." } } -} +} \ No newline at end of file diff --git a/compliance/frameworks/iso27001.json b/compliance/frameworks/iso27001.json index 381f705..697052e 100644 --- a/compliance/frameworks/iso27001.json +++ b/compliance/frameworks/iso27001.json @@ -73,6 +73,11 @@ "control_name": "Secure log-on procedures", "description": "MFA enforces secure log-on for privileged accounts. Where required by the access control policy, access to systems and applications should be controlled by a secure log-on procedure including multi-factor authentication." }, + "AZ-IDN-003": { + "control_id": "A.9.2.1", + "control_name": "User registration and de-registration", + "description": "Unrestricted guest user invitations allow any organisation member to register external identities into the tenant without centralised review or approval. A.9.2.1 requires that a formal user registration and de-registration process is implemented. Restricting guest invitations to administrators ensures external identity registration is formally controlled and audited." + }, "AZ-DB-001": { "control_id": "A.13.1.1", "control_name": "Network controls", @@ -88,6 +93,11 @@ "control_name": "Network controls", "description": "Virtual machines with public IPs and no NSG have unrestricted network access. Network controls should be applied to all compute resources accessible from the internet." }, + "AZ-CMP-002": { + "control_id": "A.10.1.1", + "control_name": "Policy on the use of cryptographic controls", + "description": "Virtual machine OS and data disks are using platform-managed encryption only (EncryptionAtRestWithPlatformKey). A.10.1.1 requires that a policy on the use of cryptographic controls is developed and implemented. Platform-managed encryption does not give the organisation control over the encryption keys. Customer-managed keys or Azure Disk Encryption are required to satisfy this control." + }, "AZ-KV-001": { "control_id": "A.17.2.1", "control_name": "Availability of information processing facilities", @@ -98,15 +108,25 @@ "control_name": "Management of removable media", "description": "Storage accounts without lifecycle policies retain data indefinitely with no automated disposal mechanism. Lifecycle management supports formal retention, tiering, and disposal of information assets." }, + "AZ-STOR-004": { + "control_id": "A.12.4.1", + "control_name": "Event logging", + "description": "Diagnostic logging must be enabled on Azure Storage blob, queue, and table services to produce event logs for read, write, and delete operations. Event logs recording user activities, exceptions, and information security events should be produced, kept, and regularly reviewed." + }, "AZ-KV-002": { "control_id": "A.13.1.1", "control_name": "Network controls", "description": "Networks should be managed and controlled to protect information systems and applications. Allowing public network access to Azure Key Vault increases exposure of sensitive secrets, keys, and certificates to external networks. Access should be restricted to trusted networks using private endpoints or network controls." }, - "AZ-KV-003": { + "AZ-NET-011": { "control_id": "A.12.4.1", "control_name": "Event logging", - "description": "Event logs recording user activities, exceptions, faults, and information security events should be produced and retained for Azure Key Vault operations. Diagnostic logging enables monitoring and forensic investigation of access to secrets, keys, and certificates." + "description": "Network Watcher must be enabled in all regions where resources are deployed to ensure network events are logged and available for investigation. Event logs recording network activity should be produced and retained to support incident response." + }, + "AZ-DB-003": { + "control_id": "A.10.1.1", + "control_name": "Policy on the use of cryptographic controls", + "description": "SSL enforcement on PostgreSQL Flexible Server applies cryptographic controls to data in transit. A policy on the use of cryptographic controls for protection of information should be developed and implemented." } } -} +} \ No newline at end of file diff --git a/compliance/frameworks/nist_csf.json b/compliance/frameworks/nist_csf.json index 4bd4967..ad41cc2 100644 --- a/compliance/frameworks/nist_csf.json +++ b/compliance/frameworks/nist_csf.json @@ -73,6 +73,11 @@ "control_name": "Users, devices, and other assets are authenticated", "description": "MFA ensures privileged users are strongly authenticated before accessing Azure resources. Without MFA, a compromised password is sufficient for full administrative access." }, + "AZ-IDN-003": { + "control_id": "PR.AC-1", + "control_name": "Identities and credentials are issued, managed, verified, revoked, and audited", + "description": "Unrestricted guest user invitations allow any organisation member to introduce external identities into the tenant without centralised review. PR.AC-1 requires that identities and credentials are managed and verified. Restricting guest invitations to administrators ensures external identity provisioning is controlled and audited." + }, "AZ-DB-001": { "control_id": "PR.AC-3", "control_name": "Remote access is managed", @@ -88,6 +93,11 @@ "control_name": "Remote access is managed", "description": "Virtual machines with public IPs and no NSG have unrestricted network access. NSGs should be attached to control inbound and outbound traffic and manage remote access to compute resources." }, + "AZ-CMP-002": { + "control_id": "PR.DS-1", + "control_name": "Data-at-rest is protected", + "description": "Virtual machine OS and data disks are using platform-managed encryption only (EncryptionAtRestWithPlatformKey). PR.DS-1 requires that data at rest is protected using appropriate controls. Platform-managed encryption does not give the organisation control over the encryption keys. Customer-managed keys or Azure Disk Encryption are required to satisfy this control." + }, "AZ-KV-001": { "control_id": "PR.IP-4", "control_name": "Backups of information are conducted, maintained, and tested", @@ -103,10 +113,20 @@ "control_name": "Assets are formally managed throughout removal, transfers, and disposition", "description": "NIST CSF PR.DS-3 requires that data assets are managed through their full lifecycle including secure disposal. Storage accounts without a lifecycle management policy have no automated mechanism for expiring or deleting aged data, meaning data subject to disposal requirements persists indefinitely and is never formally retired from the asset inventory." }, - "AZ-KV-003": { + "AZ-STOR-004": { + "control_id": "DE.CM-7", + "control_name": "Monitoring for unauthorized personnel, connections, devices, and software is performed", + "description": "Diagnostic logging on Azure Storage services provides the audit trail needed to monitor for unauthorized or anomalous read, write, and delete operations. Without logging, detection of data exfiltration or unauthorized access to blob, queue, or table services is not possible." + }, + "AZ-NET-011": { "control_id": "DE.CM-7", "control_name": "Monitoring for unauthorized personnel, connections, devices, and software is performed", - "description": "Azure Key Vault diagnostic logs should be enabled to monitor and detect unauthorized access attempts, abnormal usage patterns, and suspicious operations involving sensitive cryptographic material." + "description": "Network Watcher must be enabled in all active regions to support continuous monitoring of network activity. Without it, unauthorized connections and anomalous network behaviour cannot be detected or investigated." + }, + "AZ-DB-003": { + "control_id": "PR.DS-2", + "control_name": "Data-in-transit is protected", + "description": "SSL enforcement on PostgreSQL Flexible Server ensures data in transit between applications and the database is encrypted. Disabling SSL exposes database traffic to interception and tampering." } } -} +} \ No newline at end of file diff --git a/compliance/frameworks/soc2.json b/compliance/frameworks/soc2.json index 07a45dc..a793241 100644 --- a/compliance/frameworks/soc2.json +++ b/compliance/frameworks/soc2.json @@ -78,6 +78,11 @@ "control_name": "Logical Access Security Measures", "description": "Without MFA enforced on privileged accounts, a single compromised password grants full administrative access to the Azure environment. CC6.1 requires that logical access controls include strong authentication mechanisms. Enforcing MFA via Conditional Access policies ensures privileged access requires multiple factors of authentication." }, + "AZ-IDN-003": { + "control_id": "CC6.1", + "control_name": "Logical Access Security Measures", + "description": "Unrestricted guest user invitations allow any organisation member to introduce unreviewed external identities into the tenant. CC6.1 requires that logical access to information assets is restricted to authorised users. Restricting guest invitations to administrators ensures external identity provisioning is formally controlled and authorised." + }, "AZ-DB-001": { "control_id": "CC6.7", "control_name": "Protects Data in Transit", @@ -93,6 +98,11 @@ "control_name": "Restricts Access from Outside the Network Boundary", "description": "A virtual machine with a public IP and no NSG has unrestricted inbound network access from the internet with no filtering in place. CC6.6 requires that logical access from outside the network perimeter is restricted and controlled. Attaching an NSG with explicit rules enforces the network boundary and controls what traffic can reach the VM." }, + "AZ-CMP-002": { + "control_id": "CC6.7", + "control_name": "Protects Data in Transit and At Rest", + "description": "Virtual machine OS and data disks are using platform-managed encryption only (EncryptionAtRestWithPlatformKey). CC6.7 requires that data is protected using encryption. Platform-managed encryption does not give the organisation control over the encryption keys. Customer-managed keys or Azure Disk Encryption are required to satisfy this control." + }, "AZ-KV-001": { "control_id": "A1.2", "control_name": "Environmental Threats and Recovery", @@ -103,10 +113,15 @@ "control_name": "Restricts Access from Outside the Network Boundary", "description": "A Key Vault accessible from the public internet allows any external party to attempt access to secrets, keys and certificates. CC6.6 requires that access from outside the network boundary is restricted and controlled. Locking Key Vault access to private endpoints or specific VNet service endpoints enforces this boundary and protects sensitive credentials from external exposure." }, - "AZ-KV-003": { + "AZ-NET-011": { "control_id": "CC7.2", "control_name": "System monitoring", - "description": "Diagnostic logging should be enabled for Azure Key Vault to support monitoring, detection, and investigation of unauthorised access to secrets, keys, and certificates." + "description": "Network Watcher must be enabled in all regions where resources are deployed to support continuous system monitoring. Without it, network-level events cannot be detected or investigated, violating the requirement for ongoing monitoring of system components." + }, + "AZ-DB-003": { + "control_id": "CC6.1", + "control_name": "Logical and physical access controls", + "description": "SSL enforcement ensures database connections are encrypted, protecting data in transit from unauthorized access. Disabling SSL undermines logical access controls by exposing database traffic in plaintext." } } -} +} \ No newline at end of file diff --git a/docs/api-render-deploy.md b/docs/api-render-deploy.md new file mode 100644 index 0000000..a1ed3b5 --- /dev/null +++ b/docs/api-render-deploy.md @@ -0,0 +1,235 @@ +# Test Plan — API-DEP-001 +# Render API Deployment and CI Smoke Testing +# ============================================================ + +## 1. Overview + +This test plan covers the verification of the OpenShield API deployment +to the Render free tier. The goal is to confirm: + +- The Render Web Service builds and deploys the Flask app successfully. +- The database is automatically initialized on startup via `init_db`. +- The pre-commit hook and GitHub Actions CI pipeline gate the code properly. +- The CI pipeline is **community-friendly**, allowing forks to pass even without custom secrets. +- Real Azure scan tests are gated behind `RUN_REAL_SCAN=true` so contributor CI never depends on live Azure credentials. +- All 23 API edge cases (routing, filtering, authentication) function correctly in the live cloud environment. + +--- + +## 2. Methodology and Test Rationale + +To ensure the highest reliability of the deployment while accommodating free-tier constraints and community contributions, specific methods and test strategies were chosen: + +### 2.1 Infrastructure and Pipeline Strategy +* **Targeting Render over Azure F1:** Azure App Service's F1 tier imposes a strict 60 CPU-minute daily cap. Render provides unmetered CPU on the free tier, making it significantly more reliable for demo and development environments. +* **Database Initialization:** The `api/models/finding.py` was updated with an `init_db` method. This method ensures that all required tables (`scans`, `findings`) are created automatically during the first deployment, preventing HTTP 500 errors. +* **Pre-commit Hook:** Fails fast. By running syntax checks and local API smoke tests *before* the commit is allowed, we prevent broken code from polluting the remote branch. +* **Community-Friendly CI Gate:** The GitHub Action is designed to be zero-friction for contributors. + * **Optional Smoke Tests:** If `JWT_SECRET` is not set (typical for forks), the smoke test step is gracefully skipped rather than failing the build. + * **Configurable URL:** The `API_URL` is configurable via GitHub Secrets/Variables, defaulting to the main production instance if not provided. + * **Conditional Real Scan Tests:** TC-13 and TC-14 (real Azure scan execution) only run when `RUN_REAL_SCAN=true` and all four Azure credentials are present. This separates API smoke testing from live scan regression testing. Contributor and fork CI always passes safely — real scan validation is reserved for maintainer-controlled deployment pipelines (`dev` and `main` branches). + +### 2.2 Token Generation Method +* **Dynamic HS256 Signing:** Instead of using a hardcoded dummy string, the test script dynamically generates a real token signed with the environment's `JWT_SECRET`. +* **Default Secret Alignment:** The smoke test defaults to `change-me-in-production`, matching the API's default. This allows tests to run "out of the box" in local environments without extra configuration. + +> [!CAUTION] +> **ABSOLUTE SECURITY REQUIREMENT:** For any production deployment (Render, Azure, etc.), you **MUST** override the default `JWT_SECRET` with a long, random, and unique string. Leaving the default value in place makes your API vulnerable to unauthorized access via token forging. + +### 2.3 API Smoke Test Strategy (The 23 Cases) +The 23 test cases were selected to prove the API is structurally sound and resilient: +* **Health Check (TC-01 to TC-03):** Confirms base app connectivity and ensures public routes are not locked. +* **Core Endpoints (TC-04 to TC-17):** Verifies the actual business logic and JSON structure. +* **Auth/Security (TC-18 to TC-19):** Confirms the JWT middleware is strictly enforced. +* **Edge Cases and Resilience (TC-20 to TC-23):** Ensures the app does not crash when given bad input or non-existent routes. + +#### Conditional vs Always-Run Tests + +| Mode | TC-13 / TC-14 | All others | +|---|---|---| +| Contributor / fork (no `RUN_REAL_SCAN`) | `SKIP` — printed with reason, not a failure | Always run | +| Maintainer deployment (`RUN_REAL_SCAN=true` + Azure credentials) | Run real scan against live subscription | Always run | + +Run modes: +```bash +# Contributor / local (no Azure credentials needed) +API_URL=https://openshield-api.onrender.com JWT_SECRET= python tests/smoke_test.py + +# Maintainer — full real scan +API_URL=https://openshield-api.onrender.com JWT_SECRET= \ + RUN_REAL_SCAN=true \ + AZURE_SUBSCRIPTION_ID= \ + AZURE_CLIENT_ID= \ + AZURE_CLIENT_SECRET= \ + AZURE_TENANT_ID= \ + python tests/smoke_test.py +``` + +--- + +## 3. Files Under Test + +| File | Purpose | +|---|---| +| `startup.sh` | Container startup script, DB initialization, and Gunicorn execution | +| `api/models/finding.py` | Added `init_db` to ensure schema existence on startup | +| `.github/workflows/deploy.yml` | Flexible GitHub Actions workflow (optional smoke tests) | +| `tests/smoke_test.py` | 23-case functional test suite with default secret support | +| `.git/hooks/pre-commit` | Local Git hook enforcing syntax checks and local smoke tests | +| `requirements.txt` | Pinned runtime dependencies — see dependency notes below | + +### 3.1 Dependency Notes + +| Package | Status | Reason | +|---|---|---| +| `msrest==0.7.1` | Kept (explicit pin) | Transitive dependency of `azure-mgmt-rdbms`, `azure-mgmt-sql`, and `azure-mgmt-storage`. These SDK packages have not fully migrated to `azure-core`. Without an explicit pin, Render's clean pip install can resolve a mismatched version and break scan execution. | + +--- + +## 4. Test Environment Setup + +### 4.1 Prerequisites +- Python 3.11 installed locally. +- Render account (render.com). +- OpenShield repository cloned locally. +- `.env` file populated locally with a valid `JWT_SECRET` and `DATABASE_URL`. +- Pre-commit hook installed locally (`chmod +x .git/hooks/pre-commit`). + +### 4.2 Create Test Resources in Render +1. **Render PostgreSQL Database (Free Tier)** + - Name: `openshield-db` +2. **Render Web Service (Free Tier)** + - Connected to your branch. + - Start Command: `./startup.sh` + - Environment Variables set: `DATABASE_URL`, `JWT_SECRET`, `ALLOWED_ORIGINS`, `AZURE_SUBSCRIPTION_ID`, `AZURE_CLIENT_ID`, `AZURE_CLIENT_SECRET`, `AZURE_TENANT_ID`. + +### 4.3 Configure GitHub Secrets +To enable the automated smoke tests in the CI/CD pipeline, you **must** add the following secrets to your GitHub repository (**Settings > Secrets and variables > Actions**): + +| Secret Name | Required for | Purpose | +|---|---|---| +| `JWT_SECRET` | All smoke tests | Must match the value set in Render. Used to sign tokens for test requests. | +| `API_URL` | All smoke tests (optional) | Your Render Service URL. Defaults to the main production instance if not set. | +| `AZURE_SUBSCRIPTION_ID` | Real scan tests | Azure Subscription ID passed to the scan trigger endpoint. | +| `AZURE_CLIENT_ID` | Real scan tests | Service principal client ID for `DefaultAzureCredential`. | +| `AZURE_CLIENT_SECRET` | Real scan tests | Service principal secret for `DefaultAzureCredential`. | +| `AZURE_TENANT_ID` | Real scan tests | Azure AD tenant ID for `DefaultAzureCredential`. | + +> **Note:** `RUN_REAL_SCAN=true` is set automatically by `deploy.yml` on `dev` and `main` branches. Forks and contributor PRs never set this flag, so TC-13 and TC-14 are always skipped in fork CI regardless of which secrets are present. + +--- + +## 5. Test Cases + +### Part 1: Deployment & Pipeline Infrastructure + +**DP-01 — Pre-commit hook enforces checks** +* **Steps:** Modify a file and run `git commit` with the local API turned off, then with it turned on. +* **Expected:** Blocks/warns when API is off; runs the 23-test suite and passes when API is on. + +**DP-02 — Render executes startup script successfully** +* **Steps:** Push code to GitHub and monitor Render deployment logs. +* **Expected:** Logs show DB initialization (`Database initialized.`) and Gunicorn starting. + +**DP-03 — GitHub Actions CI pipeline passes** +* **Steps:** Push a commit and monitor the GitHub Actions tab. +* **Expected:** + * **Maintainer repo (`dev`/`main`):** Runs 21 always-on tests + TC-13/TC-14 real scan with `RUN_REAL_SCAN=true`. All 23 pass. + * **Contributor / fork:** TC-13 and TC-14 show as `SKIP` with a clear reason. 21/21 non-scan tests pass. Workflow exits green. + +--- + +### Part 2: API Smoke Tests (Executed via `smoke_test.py`) + +Run the following command against the live URL to execute these tests (contributor mode — TC-13/TC-14 skipped): +```bash +API_URL=https://openshield-api.onrender.com JWT_SECRET= python tests/smoke_test.py +``` + +To run the full 23-case suite including real scan tests (maintainer only): +```bash +API_URL=https://openshield-api.onrender.com JWT_SECRET= \ + RUN_REAL_SCAN=true \ + AZURE_SUBSCRIPTION_ID= AZURE_CLIENT_ID= \ + AZURE_CLIENT_SECRET= AZURE_TENANT_ID= \ + python tests/smoke_test.py +``` + +#### Health Check +* **TC-01:** GET `/health` returns HTTP 200. +* **TC-02:** GET `/health` returns JSON `{"status": "ok"}`. +* **TC-03:** GET `/health` requires no auth token (public route). + +#### Findings Endpoint +* **TC-04:** GET `/api/findings` returns HTTP 200. +* **TC-05:** GET `/api/findings` returns a `findings` key in JSON. +* **TC-06:** GET `/api/findings` returns a numeric `count` key. +* **TC-07:** GET `/api/findings?severity=HIGH` correctly filters results. +* **TC-08:** GET `/api/findings?severity=INVALID` handles bad input safely (returns 200 or 400). + +#### Score Endpoint +* **TC-09:** GET `/api/score` returns HTTP 200. +* **TC-10:** GET `/api/score` returns a numeric score. +* **TC-11:** GET `/api/score` ensures the score is mathematically between 0 and 100. + +#### Scans Endpoint +* **TC-12:** GET `/api/scans` returns HTTP 200. +* **TC-13:** *(Conditional — requires `RUN_REAL_SCAN=true` and Azure credentials)* POST `/api/scans/trigger` returns HTTP 200, 201, or 202. Skipped in contributor/fork CI. +* **TC-14:** *(Conditional — requires `RUN_REAL_SCAN=true` and Azure credentials)* POST `/api/scans/trigger` returns a `scan_id` or `job_id`. Skipped in contributor/fork CI. + +#### Compliance Endpoints +* **TC-15:** GET `/api/compliance/cis` returns HTTP 200. +* **TC-16:** GET `/api/compliance/nist` returns HTTP 200. +* **TC-17:** GET `/api/compliance/iso27001` returns HTTP 200. + +#### Auth & Security Edge Cases +* **TC-18:** GET `/api/findings` without any auth header returns HTTP 401. +* **TC-19:** GET `/api/findings` with a malformed JWT returns HTTP 401. + +#### General Edge Cases +* **TC-20:** GET `/nonexistent-endpoint-xyz` returns HTTP 404 (requires auth to pass middleware). +* **TC-21:** POST `/api/scans/trigger` with an empty JSON body returns HTTP 400 (missing `subscription_id`) without crashing. +* **TC-22:** GET `/api/findings?limit=0` does not crash the server. +* **TC-23:** All valid endpoint responses include the `application/json` Content-Type. + +--- + +## 6. Cleanup + +Render Free Tier Web Services spin down after 15 minutes of inactivity. The Free PostgreSQL database will automatically be deleted by Render after 90 days. To clean up manually, delete both resources from the Render dashboard Settings page. + +--- + +## 7. Pass / Fail Summary Table + +| Test Case | Description | Expected | Status | +|---|---|---|---| +| **DP-01** | Pre-commit Git hook functioning | Hook runs & enforces rules | [ ] | +| **DP-02** | Render deployment & startup | App goes Live & DB inits | [ ] | +| **DP-03** | GitHub Actions CI Pipeline | Workflow passes (Green) | [ ] | +| **TC-01** | `/health` returns 200 | Pass | [ ] | +| **TC-02** | `/health` returns status ok | Pass | [ ] | +| **TC-03** | `/health` requires no auth | Pass | [ ] | +| **TC-04** | `/api/findings` returns 200 | Pass | [ ] | +| **TC-05** | `/api/findings` returns findings key | Pass | [ ] | +| **TC-06** | `/api/findings` returns count key | Pass | [ ] | +| **TC-07** | `/api/findings` severity filter | Pass | [ ] | +| **TC-08** | `/api/findings` invalid severity | Pass | [ ] | +| **TC-09** | `/api/score` returns 200 | Pass | [ ] | +| **TC-10** | `/api/score` returns numeric | Pass | [ ] | +| **TC-11** | `/api/score` bounded 0-100 | Pass | [ ] | +| **TC-12** | `/api/scans` returns 200 | Pass | [ ] | +| **TC-13** | `/api/scans/trigger` works | 200/201/202 (Skip in fork CI) | [ ] | +| **TC-14** | `/api/scans/trigger` returns ID | Pass (Skip in fork CI) | [ ] | +| **TC-15** | `/api/compliance/cis` works | Pass | [ ] | +| **TC-16** | `/api/compliance/nist` works | Pass | [ ] | +| **TC-17** | `/api/compliance/iso27001` works | Pass | [ ] | +| **TC-18** | Missing auth returns 401 | Pass | [ ] | +| **TC-19** | Bad token returns 401 | Pass | [ ] | +| **TC-20** | 404 routing works safely | Pass | [ ] | +| **TC-21** | Empty body payload handled | Pass (400) | [ ] | +| **TC-22** | Limit=0 query handled safely | Pass | [ ] | +| **TC-23** | Content-Type is JSON | Pass | [ ] | + +**Maintainer repo:** All 26 checks (3 Pipeline + 23 API) must pass before merging to `dev` or `main`. +**Fork / contributor:** 24 checks (3 Pipeline + 21 API) must pass; TC-13 and TC-14 are expected `SKIP`. diff --git a/playbooks/cli/fix_az_cmp_002.sh b/playbooks/cli/fix_az_cmp_002.sh new file mode 100644 index 0000000..927790d --- /dev/null +++ b/playbooks/cli/fix_az_cmp_002.sh @@ -0,0 +1,39 @@ +#!/bin/bash +# OpenShield Remediation Playbook +# Rule: AZ-CMP-002 — Virtual machine disk not protected by CMK or ADE +# Usage: ./fix_az_cmp_002.sh +# Severity: HIGH + +set -e + +RESOURCE_GROUP=$1 +VM_NAME=$2 +KEYVAULT_NAME=$3 + +if [ -z "$RESOURCE_GROUP" ] || [ -z "$VM_NAME" ] || [ -z "$KEYVAULT_NAME" ]; then + echo "Usage: $0 " + echo "" + echo "Prerequisites:" + echo " 1. Create a Key Vault if one does not exist:" + echo " az keyvault create --resource-group --name --enabled-for-disk-encryption true" + echo " 2. Ensure the VM is running before enabling encryption" + exit 1 +fi + +echo "Enabling Azure Disk Encryption on VM '$VM_NAME'..." + +az vm encryption enable \ + --resource-group "$RESOURCE_GROUP" \ + --name "$VM_NAME" \ + --disk-encryption-keyvault "$KEYVAULT_NAME" \ + --volume-type All + +echo "Waiting for encryption to complete..." + +az vm encryption show \ + --resource-group "$RESOURCE_GROUP" \ + --name "$VM_NAME" + +echo "Disk encryption enabled on all volumes for VM '$VM_NAME'." +echo "The VM may restart during the encryption process." +echo "Encryption of large disks can take several hours to complete." diff --git a/playbooks/cli/fix_az_db_003.sh b/playbooks/cli/fix_az_db_003.sh new file mode 100755 index 0000000..63df89d --- /dev/null +++ b/playbooks/cli/fix_az_db_003.sh @@ -0,0 +1,39 @@ +#!/bin/bash +# Playbook: fix_az_db_003.sh +# Rule: AZ-DB-003 — PostgreSQL Flexible Server SSL enforcement disabled + +set -euo pipefail + +if [[ $# -lt 1 ]]; then + echo "Usage: $0 " + exit 1 +fi + +SUBSCRIPTION_ID="$1" + +echo "Setting subscription..." +az account set --subscription "$SUBSCRIPTION_ID" + +echo "Fetching PostgreSQL Flexible Servers..." +SERVERS=$(az postgres flexible-server list --subscription "$SUBSCRIPTION_ID" --query "[].{name:name, rg:resourceGroup}" --output tsv) + +if [[ -z "$SERVERS" ]]; then + echo "No PostgreSQL Flexible Servers found." + exit 0 +fi + +while IFS=$'\t' read -r SERVER_NAME RESOURCE_GROUP; do + echo "Checking $SERVER_NAME in $RESOURCE_GROUP..." + SSL_VALUE=$(az postgres flexible-server parameter show --resource-group "$RESOURCE_GROUP" --server-name "$SERVER_NAME" --name require_secure_transport --query "value" --output tsv 2>/dev/null || echo "on") + + if [[ "${SSL_VALUE,,}" == "off" ]]; then + echo "Enabling SSL on $SERVER_NAME..." + az postgres flexible-server parameter set --resource-group "$RESOURCE_GROUP" --server-name "$SERVER_NAME" --name require_secure_transport --value ON --output none + echo "Done." + else + echo "$SERVER_NAME already has SSL enabled, skipping." + fi +done <<< "$SERVERS" + +echo "Done. Verify with:" +echo " az postgres flexible-server parameter show --name require_secure_transport --server-name --resource-group " diff --git a/playbooks/cli/fix_az_idn_003.sh b/playbooks/cli/fix_az_idn_003.sh new file mode 100644 index 0000000..0b910d7 --- /dev/null +++ b/playbooks/cli/fix_az_idn_003.sh @@ -0,0 +1,26 @@ +#!/bin/bash +# OpenShield Remediation Playbook +# Rule: AZ-IDN-003 — Guest user invitations not restricted to admins in Entra ID +# Usage: ./fix_az_idn_003.sh +# Severity: MEDIUM +# +# Prerequisites: +# - Azure CLI logged in with a Global Administrator or User Administrator role +# - Microsoft Graph or az rest permissions + +set -e + +echo "Restricting guest user invitations to admins only..." + +az rest \ + --method PATCH \ + --uri "https://graph.microsoft.com/v1.0/policies/authorizationPolicy" \ + --headers "Content-Type=application/json" \ + --body '{ + "allowInvitesFrom": "adminsAndGuestInviters" + }' + +echo "Remediation complete." +echo "allowInvitesFrom is now set to: adminsAndGuestInviters" +echo "Only users assigned to the Guest Inviter role or admins can now invite external users." +echo "Review existing guest accounts to ensure they are still required." diff --git a/playbooks/cli/fix_az_net_011.sh b/playbooks/cli/fix_az_net_011.sh new file mode 100755 index 0000000..4e55011 --- /dev/null +++ b/playbooks/cli/fix_az_net_011.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# Playbook: fix_az_net_011.sh +# Rule: AZ-NET-011 — Network Watcher not enabled in all regions + +set -euo pipefail + +if [[ $# -lt 1 ]]; then + echo "Usage: $0 " + exit 1 +fi + +SUBSCRIPTION_ID="$1" + +echo "Setting subscription..." +az account set --subscription "$SUBSCRIPTION_ID" + +echo "Fetching regions with resources..." +RESOURCE_REGIONS=$(az resource list --subscription "$SUBSCRIPTION_ID" \ + --query "[].location" --output tsv | sort -u | tr -d ' ') + +echo "Fetching regions with Network Watcher..." +WATCHED_REGIONS=$(az network watcher list --subscription "$SUBSCRIPTION_ID" \ + --query "[].location" --output tsv 2>/dev/null | sort -u | tr -d ' ' || echo "") + +echo "Enabling Network Watcher in unmonitored regions..." +while IFS= read -r REGION; do + if echo "$WATCHED_REGIONS" | grep -qx "$REGION"; then + echo " [SKIP] $REGION — already enabled" + else + RESOURCE_GROUP="NetworkWatcherRG-${REGION}" + echo " [FIX] $REGION — creating resource group $RESOURCE_GROUP..." + az group create --name "$RESOURCE_GROUP" --location "$REGION" --output none + echo " [FIX] $REGION — enabling Network Watcher..." + az network watcher configure \ + --resource-group "$RESOURCE_GROUP" \ + --locations "$REGION" \ + --enabled true \ + --subscription "$SUBSCRIPTION_ID" \ + --output none + echo " Done." + fi +done <<< "$RESOURCE_REGIONS" + +echo "Done! Verify with:" +echo " az network watcher list --subscription $SUBSCRIPTION_ID --output table" diff --git a/playbooks/cli/fix_az_stor_004.sh b/playbooks/cli/fix_az_stor_004.sh new file mode 100644 index 0000000..c565f68 --- /dev/null +++ b/playbooks/cli/fix_az_stor_004.sh @@ -0,0 +1,150 @@ +#!/bin/bash +# OpenShield Remediation Playbook +# Rule: AZ-STOR-004 — Storage Account Diagnostic Logging Disabled +# Usage: ./fix_az_stor_004.sh +# Severity: MEDIUM +# +# What this script does: +# Enables Azure Monitor diagnostic settings on the blob, queue, and table +# service sub-resources of the specified storage account. Each service gets +# a diagnostic setting named "openshield-storage-logging" with StorageRead, +# StorageWrite, and StorageDelete enabled at a 90-day retention. Logs are +# written to the destination storage account you supply. +# +# Prerequisites: +# - Azure CLI installed and logged in (az login) +# - Contributor or Monitoring Contributor role on the target subscription +# - A destination storage account for logs (pass its full resource ID) +# +# Example: +# ./fix_az_stor_004.sh my-rg my-storage-account \ +# /subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/log-rg/providers/Microsoft.Storage/storageAccounts/logstore + +set -euo pipefail + +RESOURCE_GROUP="${1:-}" +STORAGE_ACCOUNT="${2:-}" +LOG_STORAGE_ACCOUNT_ID="${3:-}" + +# ── Argument validation ────────────────────────────────────────────────────── + +if [ -z "$RESOURCE_GROUP" ] || [ -z "$STORAGE_ACCOUNT" ] || [ -z "$LOG_STORAGE_ACCOUNT_ID" ]; then + echo "Usage: $0 " + echo "" + echo "Arguments:" + echo " resource-group Resource group of the storage account to remediate" + echo " storage-account-name Name of the storage account to remediate" + echo " log-storage-account-id Full Azure resource ID of the destination log storage account" + echo "" + echo "Example:" + echo " $0 my-rg my-storage \\" + echo " /subscriptions//resourceGroups/log-rg/providers/Microsoft.Storage/storageAccounts/logstore" + exit 1 +fi + +# ── Validate names contain only Azure-safe characters ─────────────────────── + +if ! [[ "$RESOURCE_GROUP" =~ ^[a-zA-Z0-9._()-]+$ ]]; then + echo "ERROR: resource-group contains invalid characters: '$RESOURCE_GROUP'" + exit 1 +fi + +if ! [[ "$STORAGE_ACCOUNT" =~ ^[a-z0-9]{3,24}$ ]]; then + echo "ERROR: storage-account-name must be 3-24 lowercase letters and numbers only." + exit 1 +fi + +# ── Resolve subscription ID ────────────────────────────────────────────────── + +SUBSCRIPTION_ID=$(az account show --query id -o tsv) +if [ -z "$SUBSCRIPTION_ID" ]; then + echo "ERROR: Could not determine subscription ID. Run 'az login' first." + exit 1 +fi + +# ── Build base resource ID ─────────────────────────────────────────────────── + +BASE_ID="/subscriptions/${SUBSCRIPTION_ID}/resourceGroups/${RESOURCE_GROUP}/providers/Microsoft.Storage/storageAccounts/${STORAGE_ACCOUNT}" + +BLOB_RESOURCE_ID="${BASE_ID}/blobServices/default" +QUEUE_RESOURCE_ID="${BASE_ID}/queueServices/default" +TABLE_RESOURCE_ID="${BASE_ID}/tableServices/default" + +LOG_SETTING_NAME="openshield-storage-logging" + +LOG_CATEGORIES='[ + {"category":"StorageRead","enabled":true,"retentionPolicy":{"days":90,"enabled":true}}, + {"category":"StorageWrite","enabled":true,"retentionPolicy":{"days":90,"enabled":true}}, + {"category":"StorageDelete","enabled":true,"retentionPolicy":{"days":90,"enabled":true}} +]' + +# ── Confirm before making changes ──────────────────────────────────────────── + +echo "============================================================" +echo " OpenShield Remediation — AZ-STOR-004" +echo "============================================================" +echo "" +echo " Storage account : $STORAGE_ACCOUNT" +echo " Resource group : $RESOURCE_GROUP" +echo " Log destination : $LOG_STORAGE_ACCOUNT_ID" +echo "" +echo " Services to configure:" +echo " - blobServices/default" +echo " - queueServices/default" +echo " - tableServices/default" +echo "" +echo " Each service will have diagnostic setting '$LOG_SETTING_NAME' with:" +echo " StorageRead, StorageWrite, StorageDelete (retention 90 days)" +echo "" +read -r -p "Proceed? [y/N] " CONFIRM +if [[ "$CONFIRM" != "y" && "$CONFIRM" != "Y" ]]; then + echo "Aborted. No changes were made." + exit 0 +fi + +# ── Enable diagnostic settings on all three services ──────────────────────── + +echo "" +echo "[1/3] Enabling diagnostic logging on blob service ..." +az monitor diagnostic-settings create \ + --resource "$BLOB_RESOURCE_ID" \ + --name "$LOG_SETTING_NAME" \ + --storage-account "$LOG_STORAGE_ACCOUNT_ID" \ + --logs "$LOG_CATEGORIES" +echo " Done." + +echo "" +echo "[2/3] Enabling diagnostic logging on queue service ..." +az monitor diagnostic-settings create \ + --resource "$QUEUE_RESOURCE_ID" \ + --name "$LOG_SETTING_NAME" \ + --storage-account "$LOG_STORAGE_ACCOUNT_ID" \ + --logs "$LOG_CATEGORIES" +echo " Done." + +echo "" +echo "[3/3] Enabling diagnostic logging on table service ..." +az monitor diagnostic-settings create \ + --resource "$TABLE_RESOURCE_ID" \ + --name "$LOG_SETTING_NAME" \ + --storage-account "$LOG_STORAGE_ACCOUNT_ID" \ + --logs "$LOG_CATEGORIES" +echo " Done." + +# ── Confirmation ───────────────────────────────────────────────────────────── + +echo "" +echo "============================================================" +echo " Remediation complete for: $STORAGE_ACCOUNT" +echo "============================================================" +echo "" +echo " Diagnostic setting '$LOG_SETTING_NAME' created on:" +echo " blobServices/default — StorageRead, StorageWrite, StorageDelete (90-day retention)" +echo " queueServices/default — StorageRead, StorageWrite, StorageDelete (90-day retention)" +echo " tableServices/default — StorageRead, StorageWrite, StorageDelete (90-day retention)" +echo "" +echo " To verify:" +echo " az monitor diagnostic-settings list --resource $BLOB_RESOURCE_ID" +echo " az monitor diagnostic-settings list --resource $QUEUE_RESOURCE_ID" +echo " az monitor diagnostic-settings list --resource $TABLE_RESOURCE_ID" +echo "============================================================" diff --git a/requirements.txt b/requirements.txt index f084573..52f1710 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,9 +10,13 @@ azure-mgmt-keyvault==10.3.0 azure-mgmt-rdbms==10.1.0 azure-mgmt-authorization==4.0.0 azure-monitor-ingestion==1.0.3 +azure-mgmt-monitor==6.0.0 psycopg2-binary==2.9.9 python-dotenv==1.0.0 pyjwt==2.8.0 requests==2.31.0 pyyaml==6.0.1 -azure-mgmt-monitor==6.0.2 \ No newline at end of file +gunicorn==21.2.0 +cryptography==42.0.5 +msrest==0.7.1 +azure-mgmt-postgresqlflexibleservers==1.0.0b1 diff --git a/scanner/azure_client.py b/scanner/azure_client.py index d873edb..e65f567 100644 --- a/scanner/azure_client.py +++ b/scanner/azure_client.py @@ -11,6 +11,7 @@ from azure.mgmt.network import NetworkManagementClient from azure.mgmt.rdbms.postgresql import PostgreSQLManagementClient from azure.mgmt.sql import SqlManagementClient +from azure.mgmt.monitor import MonitorManagementClient from azure.mgmt.storage import StorageManagementClient from azure.mgmt.monitor import MonitorManagementClient @@ -129,6 +130,83 @@ def get_storage_lifecycle_policy( ) return None + def get_storage_service_logging( + self, resource_group: str, account_name: str, service: str + ) -> Optional[bool]: + """Check Azure Monitor diagnostic settings for a storage service sub-resource. + + Three-state return — the calling rule uses strict identity checks + (is False / is None) to distinguish these states: + + True — at least one diagnostic setting has StorageRead, StorageWrite, + and StorageDelete all enabled (compliant). + False — no setting covers all three required categories (non-compliant). + None — permission error or unexpected SDK failure. + Caller must NOT create a finding — skip with a warning + to avoid false positives. + + Args: + resource_group: Resource group containing the storage account. + account_name: Name of the storage account. + service: Sub-service to check: "blob", "queue", or "table". + + Returns: + Optional[bool] — True, False, or None as described above. + """ + _REQUIRED = {"StorageRead", "StorageWrite", "StorageDelete"} + _SERVICE_MAP = { + "blob": "blobServices", + "queue": "queueServices", + "table": "tableServices", + } + svc_path = _SERVICE_MAP.get(service) + if not svc_path: + logger.error( + "get_storage_service_logging: unknown service %r — must be " + "blob, queue, or table", + service, + ) + return None + + resource_uri = ( + f"/subscriptions/{self.subscription_id}" + f"/resourceGroups/{resource_group}" + f"/providers/Microsoft.Storage/storageAccounts/{account_name}" + f"/{svc_path}/default" + ) + try: + client = MonitorManagementClient(self.credential, self.subscription_id) + settings = list(client.diagnostic_settings.list(resource_uri)) + for setting in settings: + enabled_categories = { + log.category + for log in (getattr(setting, "logs", None) or []) + if getattr(log, "enabled", False) + } + if _REQUIRED.issubset(enabled_categories): + return True + return False + + except HttpResponseError as exc: + logger.error( + "get_storage_service_logging(%s/%s) HTTP %s — " + "check service principal permissions: %s", + account_name, + service, + exc.status_code, + exc, + ) + return None + + except Exception as exc: + logger.error( + "get_storage_service_logging(%s/%s) unexpected error: %s", + account_name, + service, + exc, + ) + return None + # ------------------------------------------------------------------ # # Network # # ------------------------------------------------------------------ # @@ -310,6 +388,28 @@ def get_service_principals(self) -> List[Any]: logger.error("get_service_principals failed: %s", exc) return [] + + def get_postgresql_flexible_servers(self) -> List[Any]: + """List all PostgreSQL Flexible Server instances in the subscription.""" + try: + from azure.mgmt.postgresqlflexibleservers import PostgreSQLManagementClient as FlexClient + client = FlexClient(self.credential, self.subscription_id) + return list(client.servers.list()) + except Exception as exc: + logger.error("get_postgresql_flexible_servers failed: %s", exc) + return [] + + + def get_postgresql_flexible_server_parameters(self, resource_group: str, server_name: str) -> List[Any]: + """List all configuration parameters for a PostgreSQL Flexible Server.""" + try: + from azure.mgmt.postgresqlflexibleservers import PostgreSQLManagementClient as FlexClient + client = FlexClient(self.credential, self.subscription_id) + return list(client.configurations.list_by_server(resource_group, server_name)) + except Exception as exc: + logger.error("get_postgresql_flexible_server_parameters(%s) failed: %s", server_name, exc) + return [] + def get_conditional_access_policies(self) -> List[Any]: """Fetch Conditional Access policies from the Microsoft Graph API. @@ -330,4 +430,32 @@ def get_conditional_access_policies(self) -> List[Any]: return response.json().get("value", []) except Exception as exc: logger.error("get_conditional_access_policies failed: %s", exc) - return [] \ No newline at end of file + return [] + def get_regions_with_resources(self) -> List[str]: + """List all regions that have at least one resource deployed.""" + try: + from azure.mgmt.resource import ResourceManagementClient + client = ResourceManagementClient(self.credential, self.subscription_id) + regions = { + r.location.lower().replace(" ", "") + for r in client.resources.list() + if r.location + } + return list(regions) + except Exception as exc: + logger.error("get_regions_with_resources failed: %s", exc) + return [] + + def get_network_watcher_regions(self) -> List[str]: + """List all regions that already have Network Watcher enabled.""" + try: + client = NetworkManagementClient(self.credential, self.subscription_id) + regions = { + w.location.lower().replace(" ", "") + for w in client.network_watchers.list_all() + if w.location + } + return list(regions) + except Exception as exc: + logger.error("get_network_watcher_regions failed: %s", exc) + return [] diff --git a/scanner/engine.py b/scanner/engine.py index 46ce0e3..4c1813f 100644 --- a/scanner/engine.py +++ b/scanner/engine.py @@ -3,6 +3,7 @@ import importlib.util import logging import uuid +import json from datetime import datetime, timezone from pathlib import Path from typing import Any, Dict, List @@ -14,6 +15,34 @@ RULES_DIR = Path(__file__).parent / "rules" +def make_serializable(data: Any) -> Any: + """Recursively convert non-serializable objects (datetime, etc) to strings.""" + if data is None: + return None + if isinstance(data, (str, int, float, bool)): + return data + if isinstance(data, dict): + return {str(k): make_serializable(v) for k, v in data.items()} + if isinstance(data, (list, tuple, set)): + return [make_serializable(i) for i in data] + if isinstance(data, datetime): + return data.isoformat() + + # Handle Azure SDK models and other objects + if hasattr(data, "as_dict") and callable(data.as_dict): + return make_serializable(data.as_dict()) + + # Fallback to string representation for unknown objects + try: + # Check if it has a __dict__ but avoid infinite recursion for complex types + if hasattr(data, "__dict__") and not str(type(data)).startswith(" Dict[str, Any]: rule_id = getattr(rule, "RULE_ID", "UNKNOWN") try: rule_findings = rule.scan(self.client, self.subscription_id) + if not isinstance(rule_findings, list): + logger.warning("Rule %s returned %s instead of list — skipped", rule_id, type(rule_findings)) + continue + for finding in rule_findings: + if not isinstance(finding, dict): continue finding.setdefault("detected_at", detected_at) finding.setdefault("scan_id", scan_id) findings.extend(rule_findings) @@ -92,15 +126,11 @@ def run_scan(self) -> Dict[str, Any]: "Rule %s produced %d finding(s)", rule_id, len(rule_findings) ) except Exception as exc: - logger.error("Rule %s raised an exception: %s", rule_id, exc) + logger.error("Rule %s raised an exception: %s", rule_id, exc, exc_info=True) completed_at = datetime.now(timezone.utc).isoformat() - logger.info( - "Scan %s complete — %d total finding(s)", scan_id, len(findings) - ) - - return { + result = { "scan_id": scan_id, "subscription_id": self.subscription_id, "started_at": started_at, @@ -108,3 +138,9 @@ def run_scan(self) -> Dict[str, Any]: "total_findings": len(findings), "findings": findings, } + + logger.info( + "Scan %s complete — %d total finding(s). Normalising results...", scan_id, len(findings) + ) + + return make_serializable(result) diff --git a/scanner/rules/az_cmp_002.py b/scanner/rules/az_cmp_002.py new file mode 100644 index 0000000..cefbef4 --- /dev/null +++ b/scanner/rules/az_cmp_002.py @@ -0,0 +1,115 @@ +"""AZ-CMP-002: Virtual machine OS or data disk using platform-managed encryption only.""" + +import logging +from typing import Any, Dict, List + +RULE_ID = "AZ-CMP-002" +RULE_NAME = "Virtual machine disk not protected by customer-managed key or ADE" +SEVERITY = "HIGH" +CATEGORY = "Compute" +FRAMEWORKS = {"CIS": "7.2", "NIST": "PR.DS-1", "ISO27001": "A.10.1.1", "SOC2": "CC6.7"} +DESCRIPTION = ( + "One or more disks attached to this virtual machine are using platform-managed " + "encryption only (EncryptionAtRestWithPlatformKey). CIS 7.2 requires disks to be " + "protected using either Azure Disk Encryption (ADE) or server-side encryption with " + "a customer-managed key (CMK). Platform-managed encryption does not give the " + "organisation control over the encryption keys." +) +REMEDIATION = ( + "Configure server-side encryption with a customer-managed key via a Disk Encryption " + "Set, or enable Azure Disk Encryption on all OS and data disks. Navigate to: " + "Virtual Machine > Disks > Additional settings > Disk encryption set, or use " + "az vm encryption enable with a Key Vault." +) +PLAYBOOK = "playbooks/cli/fix_az_cmp_002.sh" + +logger = logging.getLogger(__name__) + + +def _disk_needs_flagging(managed_disk: Any) -> bool: + """Return True only if the disk uses platform-managed encryption. + + Azure platform-managed encryption (EncryptionAtRestWithPlatformKey) is the + default for all managed disks and does not satisfy CIS 7.2, which requires + customer-managed keys (CMK) or Azure Disk Encryption (ADE). + + Disks using EncryptionAtRestWithCustomerKey or + EncryptionAtRestWithPlatformAndCustomerKeys are compliant and should not + be flagged. + """ + if managed_disk is None: + return False + + encryption = getattr(managed_disk, "security_profile", None) + if encryption is None: + encryption = getattr(managed_disk, "encryption", None) + + encryption_type = getattr(encryption, "type", None) + + if encryption_type is None: + return False + + return encryption_type == "EncryptionAtRestWithPlatformKey" + + +def scan(azure_client: Any, subscription_id: str) -> List[Dict[str, Any]]: + """Detect virtual machines whose disks use platform-managed encryption only.""" + findings: List[Dict[str, Any]] = [] + + for vm in azure_client.get_virtual_machines(): + vm_id = getattr(vm, "id", "") + vm_name = getattr(vm, "name", "") + location = getattr(vm, "location", "") + + if not vm_id or not vm_name: + continue + + parsed = azure_client.parse_resource_id(vm_id) + resource_group = parsed.get("resource_group", "") + + storage_profile = getattr(vm, "storage_profile", None) + if not storage_profile: + continue + + unencrypted_disks = [] + + # Check OS disk + os_disk = getattr(storage_profile, "os_disk", None) + if os_disk: + managed_disk = getattr(os_disk, "managed_disk", None) + if _disk_needs_flagging(managed_disk): + unencrypted_disks.append( + getattr(os_disk, "name", "os-disk") + ) + + # Check data disks + data_disks = getattr(storage_profile, "data_disks", []) or [] + for disk in data_disks: + managed_disk = getattr(disk, "managed_disk", None) + if _disk_needs_flagging(managed_disk): + unencrypted_disks.append( + getattr(disk, "name", f"data-disk-{getattr(disk, 'lun', '?')}") + ) + + if unencrypted_disks: + findings.append({ + "rule_id": RULE_ID, + "rule_name": RULE_NAME, + "severity": SEVERITY, + "category": CATEGORY, + "resource_id": vm_id, + "resource_name": vm_name, + "resource_type": "Microsoft.Compute/virtualMachines", + "description": DESCRIPTION, + "remediation": REMEDIATION, + "playbook": PLAYBOOK, + "frameworks": FRAMEWORKS, + "metadata": { + "resource_group": resource_group, + "location": location, + "unencrypted_disks": unencrypted_disks, + "unencrypted_disk_count": len(unencrypted_disks), + }, + }) + + return findings diff --git a/scanner/rules/az_db_003.py b/scanner/rules/az_db_003.py new file mode 100644 index 0000000..cc0b0c1 --- /dev/null +++ b/scanner/rules/az_db_003.py @@ -0,0 +1,81 @@ +"""AZ-DB-003: PostgreSQL Flexible Server SSL enforcement disabled.""" +from typing import Any, Dict, List +import logging + +logger = logging.getLogger(__name__) + +RULE_ID = "AZ-DB-003" +RULE_NAME = "PostgreSQL Flexible Server SSL Enforcement Disabled" +SEVERITY = "HIGH" +CATEGORY = "Database" +FRAMEWORKS = {"CIS": "4.3.6", "NIST": "PR.DS-2", "ISO27001": "A.10.1.1", "SOC2": "CC6.1"} +DESCRIPTION = ( + "The Azure Database for PostgreSQL Flexible Server has SSL enforcement disabled. " + "Without SSL, data in transit between the application and database is transmitted " + "in plaintext and is vulnerable to interception and man-in-the-middle attacks." +) +REMEDIATION = ( + "Enable SSL enforcement on the PostgreSQL Flexible Server by setting " + "require_secure_transport to ON. " + "Run: az postgres flexible-server parameter set --resource-group " + "--server-name --name require_secure_transport --value ON" +) +PLAYBOOK = "playbooks/cli/fix_az_db_003.sh" + + +def scan(azure_client: Any, subscription_id: str) -> List[Dict[str, Any]]: + """Detect PostgreSQL Flexible Servers with SSL enforcement disabled.""" + findings: List[Dict[str, Any]] = [] + + for server in azure_client.get_postgresql_flexible_servers(): + parsed = azure_client.parse_resource_id(server.id) + resource_group = parsed.get("resource_group", "") + + params = azure_client.get_postgresql_flexible_server_parameters( + resource_group, server.name + ) + + if not params: + # Cannot determine SSL state — skip to avoid false positives + logger.warning( + "az_db_003: skipping %s — get_postgresql_flexible_server_parameters " + "returned empty (permission or API failure)", + server.name, + ) + continue + + ssl_param = next( + (p for p in params if getattr(p, "name", "") == "require_secure_transport"), + None, + ) + + if ssl_param is None: + # Parameter not found — cannot determine compliance, skip + logger.warning( + "az_db_003: skipping %s — require_secure_transport parameter not found", + server.name, + ) + continue + + ssl_value = str(getattr(ssl_param, "value", "on")).lower() + if ssl_value in ("off", "false", "0"): + findings.append({ + "rule_id": RULE_ID, + "rule_name": RULE_NAME, + "severity": SEVERITY, + "category": CATEGORY, + "resource_id": server.id, + "resource_name": server.name, + "resource_type": "Microsoft.DBforPostgreSQL/flexibleServers", + "description": DESCRIPTION, + "remediation": REMEDIATION, + "playbook": PLAYBOOK, + "frameworks": FRAMEWORKS, + "metadata": { + "resource_group": resource_group, + "location": getattr(server, "location", ""), + "ssl_value": ssl_value, + }, + }) + + return findings diff --git a/scanner/rules/az_idn_003.py b/scanner/rules/az_idn_003.py new file mode 100644 index 0000000..398d580 --- /dev/null +++ b/scanner/rules/az_idn_003.py @@ -0,0 +1,83 @@ +"""AZ-IDN-003: Guest user invitations not restricted to admins in Entra ID.""" + +import logging +from typing import Any, Dict, List + +RULE_ID = "AZ-IDN-003" +RULE_NAME = "Guest user invitations not restricted to admins in Entra ID" +SEVERITY = "MEDIUM" +CATEGORY = "Identity" +FRAMEWORKS = {"CIS": "1.15", "NIST": "PR.AC-1", "ISO27001": "A.9.2.1"} +DESCRIPTION = ( + "Guest user invitations in Entra ID are not restricted to administrators. " + "Any organisation member can invite external users into the tenant without " + "centralised review or approval. This bypasses formal external identity " + "provisioning controls and increases the risk of unauthorised access by " + "untrusted parties." +) +REMEDIATION = ( + "Restrict guest invitations to admins only by setting the " + "'allowInvitesFrom' policy to 'adminsAndGuestInviters' or 'admins' " + "in Entra ID. Navigate to: Entra ID > External Identities > " + "External collaboration settings > Guest invite settings. " + "Set to 'Only users assigned to specific admin roles can invite guest users'." +) +PLAYBOOK = "playbooks/cli/fix_az_idn_003.sh" + +logger = logging.getLogger(__name__) + + +def scan(azure_client: Any, subscription_id: str) -> List[Dict[str, Any]]: + """Detect unrestricted guest user invitation settings in Entra ID.""" + findings: List[Dict[str, Any]] = [] + + try: + import requests + + token = azure_client.credential.get_token( + "https://graph.microsoft.com/.default" + ) + headers = {"Authorization": f"Bearer {token.token}"} + + response = requests.get( + "https://graph.microsoft.com/v1.0/policies/authorizationPolicy", + headers=headers, + timeout=30, + ) + response.raise_for_status() + policy = response.json() + + except Exception as exc: + logger.error( + "AZ-IDN-003: Failed to fetch authorization policy from Graph API: %s", exc + ) + logger.warning( + "AZ-IDN-003: Ensure the service principal has " + "Directory.Read.All permission on Microsoft Graph." + ) + return findings + + allow_invites_from = policy.get("allowInvitesFrom", "everyone") + + restricted_values = {"admins", "adminsAndGuestInviters"} + if allow_invites_from not in restricted_values: + findings.append({ + "rule_id": RULE_ID, + "rule_name": RULE_NAME, + "severity": SEVERITY, + "category": CATEGORY, + "resource_id": f"/tenants/{policy.get('id', 'unknown')}/policies/authorizationPolicy", + "resource_name": "authorizationPolicy", + "resource_type": "Microsoft.Graph/authorizationPolicy", + "description": DESCRIPTION, + "remediation": REMEDIATION, + "playbook": PLAYBOOK, + "frameworks": FRAMEWORKS, + "metadata": { + "allow_invites_from": allow_invites_from, + "policy_id": policy.get("id", ""), + "display_name": policy.get("displayName", ""), + }, + }) + + return findings diff --git a/scanner/rules/az_net_011.py b/scanner/rules/az_net_011.py new file mode 100644 index 0000000..978b2a0 --- /dev/null +++ b/scanner/rules/az_net_011.py @@ -0,0 +1,48 @@ +"""AZ-NET-011: Network Watcher not enabled in all regions.""" +from typing import Any, Dict, List + +RULE_ID = "AZ-NET-011" +RULE_NAME = "Network Watcher Not Enabled in All Regions" +SEVERITY = "LOW" +CATEGORY = "Network" +FRAMEWORKS = {"CIS": "6.5", "NIST": "DE.CM-7", "ISO27001": "A.12.4.1", "SOC2": "CC7.2"} +DESCRIPTION = ( + "Network Watcher is not enabled in one or more Azure regions where resources " + "are deployed. Network Watcher provides network monitoring, diagnostics, and " + "logging capabilities. Without it, network-level incidents cannot be " + "investigated or diagnosed." +) +REMEDIATION = ( + "Enable Network Watcher in all regions where Azure resources are deployed. " + "Run: az network watcher configure --resource-group NetworkWatcherRG " + "--locations --enabled true" +) +PLAYBOOK = "playbooks/cli/fix_az_net_011.sh" + + +def scan(azure_client: Any, subscription_id: str) -> List[Dict[str, Any]]: + """Detect regions where resources exist but Network Watcher is not enabled.""" + findings: List[Dict[str, Any]] = [] + + regions_with_resources = azure_client.get_regions_with_resources() + regions_with_watcher = azure_client.get_network_watcher_regions() + + unmonitored_regions = set(regions_with_resources) - set(regions_with_watcher) + + for region in sorted(unmonitored_regions): + findings.append({ + "rule_id": RULE_ID, + "rule_name": RULE_NAME, + "severity": SEVERITY, + "category": CATEGORY, + "resource_id": f"/subscriptions/{subscription_id}/regions/{region}", + "resource_name": region, + "resource_type": "Microsoft.Network/networkWatchers", + "description": DESCRIPTION, + "remediation": REMEDIATION, + "playbook": PLAYBOOK, + "frameworks": FRAMEWORKS, + "metadata": {"region": region}, + }) + + return findings diff --git a/scanner/rules/az_stor_004.py b/scanner/rules/az_stor_004.py new file mode 100644 index 0000000..17a167d --- /dev/null +++ b/scanner/rules/az_stor_004.py @@ -0,0 +1,121 @@ +"""AZ-STOR-004: Storage account diagnostic logging disabled for blob, queue, or table.""" + +import logging +from typing import Any, Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + +# ── Required module-level constants ───────────────────────────────────────── + +RULE_ID = "AZ-STOR-004" +RULE_NAME = "Storage Account Diagnostic Logging Disabled" +SEVERITY = "MEDIUM" +CATEGORY = "Storage" +FRAMEWORKS = { + "CIS": "3.3", + "NIST": "DE.CM-7", + "ISO27001": "A.12.4.1", +} +DESCRIPTION = ( + "Azure Monitor diagnostic logging is not fully enabled for the {service} " + "service on this storage account. StorageRead, StorageWrite, and " + "StorageDelete must all be enabled. Without logging, operations on this " + "service cannot be detected or investigated, making it impossible to " + "identify data exfiltration or unauthorised access. CIS Azure Benchmark " + "3.3 requires logging for blob, queue, and table services for read, write, " + "and delete requests." +) +REMEDIATION = ( + "Enable Azure Monitor diagnostic settings on the storage account's " + "{service} service with StorageRead, StorageWrite, and StorageDelete all " + "set to enabled. Navigate to: Storage Account > Monitoring > " + "Diagnostic settings > {service} > Add diagnostic setting, then check " + "StorageRead, StorageWrite, and StorageDelete." +) +PLAYBOOK = "playbooks/cli/fix_az_stor_004.sh" + +# Maps service key → (sub-resource path segment, resource_type) +_SERVICES: Dict[str, Tuple[str, str]] = { + "blob": ("blobServices", "Microsoft.Storage/storageAccounts/blobServices"), + "queue": ("queueServices", "Microsoft.Storage/storageAccounts/queueServices"), + "table": ("tableServices", "Microsoft.Storage/storageAccounts/tableServices"), +} + + +# ── Required scan function ─────────────────────────────────────────────────── + +def scan(azure_client: Any, subscription_id: str) -> List[Dict[str, Any]]: + """Detect storage account services with incomplete diagnostic logging. + + For each storage account, all three sub-services (blob, queue, table) are + checked independently. A separate finding is emitted for each service that + does not have StorageRead, StorageWrite, and StorageDelete all enabled. + + Three-state return from get_storage_service_logging(): + True — all three log categories enabled → skip (compliant) + False — one or more categories missing → create finding + None — permissions error or unexpected failure → skip with warning + to avoid false positives + + Args: + azure_client: An AzureClient instance with all SDK clients + pre-configured. + subscription_id: The Azure subscription ID being scanned. + + Returns: + A list of finding dicts — one per storage service sub-resource that + does not have full diagnostic logging. Services that could not be + checked are skipped and logged as warnings. + """ + findings: List[Dict[str, Any]] = [] + + for account in azure_client.get_storage_accounts(): + resource_id = getattr(account, "id", "") + account_name = getattr(account, "name", "") + location = getattr(account, "location", "") + + if not resource_id or not account_name: + continue + + parsed = azure_client.parse_resource_id(resource_id) + resource_group = parsed.get("resource_group", "") + if not resource_group: + continue + + for service, (svc_path, resource_type) in _SERVICES.items(): + # True = compliant, False = logging incomplete, None = could not determine + logging_status: Optional[bool] = azure_client.get_storage_service_logging( + resource_group, account_name, service + ) + + if logging_status is None: + logger.warning( + "AZ-STOR-004: Could not determine %s logging status for %s " + "— skipping. Ensure the service principal has " + "microsoft.insights/diagnosticSettings/read permission.", + service, + account_name, + ) + continue + + if logging_status is False: + findings.append({ + "rule_id": RULE_ID, + "rule_name": RULE_NAME, + "severity": SEVERITY, + "category": CATEGORY, + "resource_id": f"{resource_id}/{svc_path}/default", + "resource_name": f"{account_name}/{svc_path}", + "resource_type": resource_type, + "description": DESCRIPTION.format(service=service), + "remediation": REMEDIATION.format(service=service), + "playbook": PLAYBOOK, + "frameworks": FRAMEWORKS, + "metadata": { + "resource_group": resource_group, + "location": location, + "service": service, + }, + }) + + return findings diff --git a/startup.sh b/startup.sh new file mode 100755 index 0000000..ac3b44c --- /dev/null +++ b/startup.sh @@ -0,0 +1,23 @@ +#!/bin/bash +set -euo pipefail + +echo "=== OpenShield startup ===" +echo "Running database initialisation..." + +python -c " +import os, sys +try: + from api.models.finding import DatabaseManager + db = DatabaseManager(os.environ['DATABASE_URL']) + if hasattr(db, 'init_db'): + db.init_db() + print('Database initialised.') + else: + print('WARNING: DatabaseManager has no init_db() method — skipping.') +except Exception as e: + print(f'ERROR during DB init: {e}', file=sys.stderr) + sys.exit(1) +" + +echo "Startup complete. Starting Gunicorn..." +exec gunicorn --bind=0.0.0.0:$PORT --timeout 120 --workers 2 api.app:application \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/smoke_test.py b/tests/smoke_test.py new file mode 100755 index 0000000..3d9c043 --- /dev/null +++ b/tests/smoke_test.py @@ -0,0 +1,310 @@ +#!/usr/bin/env python3 +""" +OpenShield API Smoke Test Suite +Runs against a live deployment to verify all endpoints. + +Usage: + # Local + # Set API_URL: http://localhost:5000 and JWT_SECRET: your-secret + python tests/smoke_test.py + + # Live Render deployment + # Set API_URL: https://openshield-api.onrender.com and JWT_SECRET: your-secret + python tests/smoke_test.py + +JWT_SECRET must be the same value set in Render config — the test +generates a properly signed HS256 token from it automatically. +""" + +import os +import sys +import json +import time +import urllib.request +import urllib.error +try: + from dotenv import load_dotenv + load_dotenv() +except ImportError: + pass + + +# ── Token generation ────────────────────────────────────────────────────── +# The app's before_request middleware calls jwt.decode() with HS256. +# Passing the raw JWT_SECRET as a Bearer token will always return 401. +# We must sign a real token using the same secret. + +def _generate_token(secret: str) -> str: + """Generate a valid HS256 JWT signed with the app's JWT_SECRET.""" + try: + import jwt as pyjwt + payload = { + "sub": "smoke-test", + "role": "admin", + "iat": int(time.time()), + "exp": int(time.time()) + 3600, # 1 hour expiry + } + return pyjwt.encode(payload, secret, algorithm="HS256") + except ImportError: + print("ERROR: PyJWT not installed. Run: pip install PyJWT") + sys.exit(1) + except Exception as e: + print(f"ERROR generating JWT token: {e}") + sys.exit(1) + + +API_URL = os.environ.get("API_URL", "http://localhost:5000").rstrip("/") +_JWT_VAL = os.environ.get("JWT_SECRET", "change-me-in-production") +_REAL_SUB = os.environ.get("AZURE_SUBSCRIPTION_ID", "") + +# Real scan gate — requires explicit opt-in AND all four Azure credentials. +# Set RUN_REAL_SCAN=true in maintainer-controlled CI only. +_RUN_REAL_SCAN = os.environ.get("RUN_REAL_SCAN", "").lower() == "true" +_AZURE_CREDS_PRESENT = all([ + os.environ.get("AZURE_SUBSCRIPTION_ID"), + os.environ.get("AZURE_CLIENT_ID"), + os.environ.get("AZURE_CLIENT_SECRET"), + os.environ.get("AZURE_TENANT_ID"), +]) + +if not _JWT_VAL or _JWT_VAL == "change-me-in-production": + print("INFO: Using default JWT_SECRET ('change-me-in-production').") + print("To use a custom one, set the JWT_SECRET environment variable.") + +JWT_TOKEN = _generate_token(_JWT_VAL) + +PASS = "\033[92mPASS\033[0m" +FAIL = "\033[91mFAIL\033[0m" +SKIP = "\033[93mSKIP\033[0m" + +results = [] + + +def request(method, path, body=None, auth=True, bad_token=False): + """Make an HTTP request and return (status_code, response_body).""" + url = f"{API_URL}{path}" + headers = {"Content-Type": "application/json"} + + if bad_token: + # Deliberately malformed token to test rejection + headers["Authorization"] = "Bearer this.is.not.a.valid.jwt" + elif auth and JWT_TOKEN: + headers["Authorization"] = f"Bearer {JWT_TOKEN}" + + data = json.dumps(body).encode() if body else None + req = urllib.request.Request(url, data=data, headers=headers, method=method) + + try: + with urllib.request.urlopen(req, timeout=45) as resp: + return resp.status, json.loads(resp.read()) + except urllib.error.HTTPError as e: + try: + body_bytes = e.read() + return e.code, json.loads(body_bytes) + except Exception: + return e.code, {} + except Exception as e: + return 0, {"error": str(e)} + + +def test(name, method, path, check_fn, body=None, auth=True, bad_token=False): + """Run a single test case.""" + status, body_resp = request(method, path, body=body, auth=auth, bad_token=bad_token) + try: + passed = check_fn(status, body_resp) + except Exception as e: + passed = False + body_resp = {"exception": str(e)} + + label = PASS if passed else FAIL + print(f" [{label}] {name}") + if not passed: + print(f" Status: {status}") + print(f" Body: {json.dumps(body_resp, indent=2)[:300]}") + + results.append((name, passed)) + return passed + + +def skip(name, reason): + """Record a test as skipped — does not count as a failure.""" + print(f" [{SKIP}] {name}") + print(f" {reason}") + results.append((name, None)) + + +# ── TC-01: Health check ──────────────────────────────────────────────────── +print("\n=== Health Check ===") +test( + "TC-01 GET /health returns 200", + "GET", "/health", + lambda s, b: s == 200, + auth=False, +) +test( + "TC-02 GET /health returns status ok", + "GET", "/health", + lambda s, b: b.get("status") == "ok", + auth=False, +) +test( + "TC-03 GET /health requires no auth token", + "GET", "/health", + lambda s, b: s == 200, # Public path — must not return 401 + auth=False, +) + +# ── TC-04 to TC-08: Findings endpoint ───────────────────────────────────── +print("\n=== Findings Endpoint ===") +test( + "TC-04 GET /api/findings returns 200", + "GET", "/api/findings", + lambda s, b: s == 200, +) +test( + "TC-05 GET /api/findings returns 'findings' key", + "GET", "/api/findings", + lambda s, b: "findings" in b, +) +test( + "TC-06 GET /api/findings returns 'count' key", + "GET", "/api/findings", + lambda s, b: "count" in b and isinstance(b["count"], int), +) +test( + "TC-07 GET /api/findings?severity=HIGH filters correctly", + "GET", "/api/findings?severity=HIGH", + lambda s, b: s == 200 and all( + f.get("severity") == "HIGH" + for f in b.get("findings", []) + ), +) +test( + "TC-08 GET /api/findings?severity=INVALID returns 400 or empty", + "GET", "/api/findings?severity=INVALID", + lambda s, b: s in (200, 400), +) + +# ── TC-09 to TC-11: Score endpoint ──────────────────────────────────────── +print("\n=== Score Endpoint ===") +test( + "TC-09 GET /api/score returns 200", + "GET", "/api/score", + lambda s, b: s == 200, +) +test( + "TC-10 GET /api/score returns numeric score", + "GET", "/api/score", + lambda s, b: isinstance(b.get("score"), (int, float)), +) +test( + "TC-11 GET /api/score is between 0 and 100", + "GET", "/api/score", + lambda s, b: 0 <= b.get("score", -1) <= 100, +) + +# ── TC-12 to TC-14: Scans endpoint ──────────────────────────────────────── +print("\n=== Scans Endpoint ===") +test( + "TC-12 GET /api/scans returns 200", + "GET", "/api/scans", + lambda s, b: s == 200, +) + +if _RUN_REAL_SCAN and _AZURE_CREDS_PRESENT: + test( + "TC-13 POST /api/scans/trigger returns 200, 201 or 202", + "POST", "/api/scans/trigger", + lambda s, b: s in (200, 201, 202), + body={"subscription_id": _REAL_SUB}, + ) + test( + "TC-14 POST /api/scans/trigger returns scan_id or job_id", + "POST", "/api/scans/trigger", + lambda s, b: any(k in b for k in ("scan_id", "job_id", "id", "message")), + body={"subscription_id": _REAL_SUB}, + ) +else: + _skip_reason = ( + "Real scan skipped — set RUN_REAL_SCAN=true with all four Azure credentials to enable." + if not _RUN_REAL_SCAN + else "Real scan skipped — one or more Azure credentials (SUBSCRIPTION_ID, CLIENT_ID, CLIENT_SECRET, TENANT_ID) are missing." + ) + skip("TC-13 POST /api/scans/trigger returns 200, 201 or 202", _skip_reason) + skip("TC-14 POST /api/scans/trigger returns scan_id or job_id", _skip_reason) + +# ── TC-15 to TC-17: Compliance endpoints ────────────────────────────────── +print("\n=== Compliance Endpoints ===") +for framework in ("cis", "nist", "iso27001"): + test( + f"TC GET /api/compliance/{framework} returns 200", + "GET", f"/api/compliance/{framework}", + lambda s, b: s == 200, + ) + +# ── TC-18: Unauthenticated request is rejected ──────────────────────────── +print("\n=== Auth / Security Edge Cases ===") +test( + "TC-18 GET /api/findings without auth returns 401", + "GET", "/api/findings", + lambda s, b: s == 401, + auth=False, +) +test( + "TC-19 GET /api/findings with malformed token returns 401", + "GET", "/api/findings", + lambda s, b: s == 401, + bad_token=True, +) + +# ── TC-20 to TC-23: Edge cases ──────────────────────────────────────────── +print("\n=== Edge Cases ===") +test( + "TC-20 GET /nonexistent returns 404", + "GET", "/nonexistent-endpoint-xyz", + lambda s, b: s == 404, + auth=True, +) +test( + "TC-21 POST /api/scans/trigger with empty body still works", + "POST", "/api/scans/trigger", + lambda s, b: s in (200, 201, 202, 400), + body={}, +) +test( + "TC-22 GET /api/findings?limit=0 does not crash", + "GET", "/api/findings?limit=0", + lambda s, b: s in (200, 400), +) +test( + "TC-23 Response Content-Type is JSON", + "GET", "/api/findings", + lambda s, b: isinstance(b, dict), +) + +# ── Summary ──────────────────────────────────────────────────────────────── +print("\n=== Summary ===") +passed = sum(1 for _, p in results if p is True) +skipped = sum(1 for _, p in results if p is None) +failed_tests = [name for name, p in results if p is False] +total = len(results) + +skip_note = f", {skipped} skipped" if skipped else "" +print(f" {passed}/{total - skipped} tests passed{skip_note}") + +if skipped: + print(f"\n Skipped tests (not failures):") + for name, p in results: + if p is None: + print(f" - {name}") + print(f"\n To enable real scan tests: RUN_REAL_SCAN=true with AZURE_SUBSCRIPTION_ID, AZURE_CLIENT_ID, AZURE_CLIENT_SECRET, AZURE_TENANT_ID") + +if failed_tests: + print(f"\n Failed tests:") + for name in failed_tests: + print(f" - {name}") + print(f"\nSmoke test FAILED. Do not open a PR until all tests pass.") + sys.exit(1) +else: + print(f"\n All smoke tests passed.") + sys.exit(0)