From e90909055483413fcfcdb1f4947404b2737b30f9 Mon Sep 17 00:00:00 2001 From: Vexx Date: Sun, 24 May 2026 02:27:15 +0530 Subject: [PATCH] fix(plugins): add argv-level injection validation to build_command MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit User-supplied scan inputs were interpolated into argv lists without validation. With asyncio.create_subprocess_exec (no shell=True), shell metacharacters are safe but argv-level flag injection is real: a value like ports="--script=evil.nse" becomes a standalone argv element that nmap accepts as a flag. Changes: - plugins.py: add _reject_injected_args() — rejects values that begin with '-' (flag injection); for port fields enforces digits/commas/hyphens only via _PORT_SPEC_PATTERN - plugins.py: add _validate_inputs_against_schema() — enforces SELECT allowed-values, INTEGER type, BOOLEAN type, and field-level regex patterns before any interpolation occurs; called in build_command() BEFORE _normalize_inputs() so SELECT checks run on raw user values - port_scanner.py: add _resolve_scan_type() and _resolve_ports() to map shorthand inputs (e.g. "-sV", "top100") to schema-compliant values; removes the old 'speed' field that had no corresponding plugin field - test_command_injection.py: 52 tests covering injection rejection, SELECT/INTEGER/BOOLEAN validation, pattern validation, port normalisation, scan-type normalisation Fixes #201 (command argument injection via unsanitized scan_type/ports/speed). --- backend/secuscan/plugins.py | 88 +++++- backend/secuscan/scanners/port_scanner.py | 127 +++++---- .../backend/unit/test_command_injection.py | 252 ++++++++++++++++++ 3 files changed, 416 insertions(+), 51 deletions(-) create mode 100644 testing/backend/unit/test_command_injection.py diff --git a/backend/secuscan/plugins.py b/backend/secuscan/plugins.py index 4af73a3c..5f93ce94 100644 --- a/backend/secuscan/plugins.py +++ b/backend/secuscan/plugins.py @@ -12,9 +12,12 @@ import hashlib import hmac -from .models import PluginMetadata +from .models import PluginMetadata, PluginFieldType from .config import settings +# Port specifications: digits, commas, and hyphens only (e.g. "22,80,443" or "1-1000") +_PORT_SPEC_PATTERN = re.compile(r"^[\d,\-]+$") + logger = logging.getLogger(__name__) @@ -311,14 +314,91 @@ def _normalize_inputs(self, plugin: PluginMetadata, inputs: Dict[str, Any]) -> D normalized["wordlist"] = self._resolve_wordlist_path(wordlist_value.strip()) return normalized + def _reject_injected_args(self, field_id: str, value: str) -> None: + """Raise ValueError if value looks like a flag injection attempt. + + Port fields are exempt from the leading-dash check but must match the + numeric port-specification grammar. All other string fields must not + begin with a '-' character. + """ + if field_id in ("ports", "port"): + if value and not _PORT_SPEC_PATTERN.match(value): + raise ValueError( + f"Invalid port specification {value!r}: " + "only digits, commas, and hyphens are permitted" + ) + return + if value.lstrip().startswith("-"): + raise ValueError( + f"Field '{field_id}' value must not begin with '-': {value!r}" + ) + + def _validate_inputs_against_schema( + self, plugin: PluginMetadata, inputs: Dict[str, Any] + ) -> None: + """Validate caller-supplied inputs against the plugin's declared field schema. + + Raises ValueError with a descriptive message for the first violation found. + """ + field_map = {f.id: f for f in plugin.fields} + + for field_id, raw_value in inputs.items(): + field = field_map.get(field_id) + if field is None: + continue + + # Skip None / empty values — defaults will be applied later by _with_field_defaults + if raw_value is None or raw_value == "": + continue + + if field.type == PluginFieldType.INTEGER: + try: + int(raw_value) + except (TypeError, ValueError): + raise ValueError( + f"Field '{field_id}' expects an integer; got {raw_value!r}" + ) + continue + + if field.type == PluginFieldType.BOOLEAN: + if isinstance(raw_value, bool): + continue + if isinstance(raw_value, str) and raw_value.lower() in ("true", "false", "1", "0"): + continue + raise ValueError( + f"Field '{field_id}' expects a boolean; got {raw_value!r}" + ) + + if field.type == PluginFieldType.SELECT: + allowed = [opt.get("value") for opt in (field.options or [])] + if raw_value not in allowed: + raise ValueError( + f"Field '{field_id}' value {raw_value!r} is not in allowed " + f"values {allowed}" + ) + continue + + if field.type in (PluginFieldType.STRING, PluginFieldType.TEXT): + value_str = str(raw_value) + + # Pattern validation from field metadata + validation = field.validation or {} + pattern = validation.get("pattern") + if pattern and not re.match(pattern, value_str): + msg = validation.get("message", f"Value does not match pattern {pattern!r}") + raise ValueError(f"Field '{field_id}': {msg}") + + # Reject argv-level flag injection + self._reject_injected_args(field_id, value_str) + def build_command(self, plugin_id: str, inputs: Dict) -> Optional[List[str]]: """ Build command from plugin template and user inputs. - + Args: plugin_id: Plugin identifier inputs: User input values - + Returns: Command as list of arguments """ @@ -326,6 +406,8 @@ def build_command(self, plugin_id: str, inputs: Dict) -> Optional[List[str]]: if not plugin: return None + # Validate before normalisation so SELECT checks run against raw user values + self._validate_inputs_against_schema(plugin, inputs) inputs = self._normalize_inputs(plugin, inputs) command = [] diff --git a/backend/secuscan/scanners/port_scanner.py b/backend/secuscan/scanners/port_scanner.py index b70bbe65..0851c1bc 100644 --- a/backend/secuscan/scanners/port_scanner.py +++ b/backend/secuscan/scanners/port_scanner.py @@ -1,11 +1,9 @@ import asyncio -import json import re -from typing import Dict, Any, List +from typing import Dict, Any, List, Optional, Tuple from .base import BaseScanner from ..plugins import get_plugin_manager -from ..config import settings -from datetime import datetime + class PortScanner(BaseScanner): """ @@ -21,50 +19,82 @@ def name(self) -> str: def category(self) -> str: return "Network Security" - async def run(self, target: str, inputs: Dict[str, Any]) -> Dict[str, Any]: + # ------------------------------------------------------------------ + # Input normalisation helpers + # ------------------------------------------------------------------ + + @staticmethod + def _resolve_scan_type(raw: Any) -> str: + """Map caller-supplied scan_type to the nmap plugin's SELECT value. + + The plugin field 'scan_type' accepts only "S" | "T" | "U". + Callers may pass the raw letter, "-sX", or "sX" forms. """ - Runs Nmap scan and parses output into structured findings. + _VALID = {"S", "T", "U"} + if not raw: + return "T" + value = str(raw).strip().upper() + # Already a bare valid letter + if value in _VALID: + return value + # Strip a leading "-S" or "S" prefix (e.g. "-sT" → "T", "sS" → "S") + stripped = re.sub(r"^-?S", "", value) + letter = stripped[0] if stripped else "" + return letter if letter in _VALID else "T" + + @staticmethod + def _resolve_ports(raw: Any) -> str: + """Map shorthand port specs to a clean numeric range string accepted by the plugin. + + Returns: + Empty string → use plugin default (top-100 via command template) + Numeric range → passed through as-is """ + if not raw or raw in ("", "top100"): + return "" + if raw == "top1000": + return "1-1000" + if raw == "all": + return "1-65535" + # If it looks like a numeric spec (digits, commas, hyphens), pass through + if re.match(r"^[\d,\-]+$", str(raw)): + return str(raw) + return "" + + async def run(self, target: str, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Runs Nmap scan and parses output into structured findings.""" self.update_progress(0.1) - - # Prepare inputs for the Nmap plugin - # Map PortScanner inputs to Nmap plugin fields + plugin_inputs = { "target": target, - "scan_type": inputs.get("scan_type", "-sV"), - "ports": inputs.get("ports", "top100"), - "speed": inputs.get("speed", "T4"), - "safe_mode": inputs.get("safe_mode", True) + "scan_type": self._resolve_scan_type(inputs.get("scan_type", "T")), + "ports": self._resolve_ports(inputs.get("ports", "")), + "service_detection": bool(inputs.get("service_detection", True)), + "os_detection": bool(inputs.get("os_detection", False)), + "safe_mode": bool(inputs.get("safe_mode", True)), } - - # Handle port shortcuts - if plugin_inputs["ports"] == "top100": - plugin_inputs["ports"] = "--top-ports 100" - elif plugin_inputs["ports"] == "top1000": - plugin_inputs["ports"] = "--top-ports 1000" - elif plugin_inputs["ports"] == "all": - plugin_inputs["ports"] = "-p-" plugin_manager = get_plugin_manager() command = plugin_manager.build_command("nmap", plugin_inputs) - + if not command: raise ValueError("Failed to build nmap command") - # Execute self.update_progress(0.2) output, exit_code = await self._execute_command(command) self.update_progress(0.8) - - # Parse + findings = self._parse_nmap_output(output, target) - + self.update_progress(1.0) return { "findings": findings, - "summary": [f"Scanned {target} for open ports.", f"Discovered {len(findings)} open ports."], + "summary": [ + f"Scanned {target} for open ports.", + f"Discovered {len(findings)} open ports.", + ], "open_ports": [f["metadata"]["port"] for f in findings], - "status": "completed" if exit_code == 0 else "failed" + "status": "completed" if exit_code == 0 else "failed", } async def _execute_command(self, command: List[str]) -> tuple: @@ -73,11 +103,11 @@ async def _execute_command(self, command: List[str]) -> tuple: process = await asyncio.create_subprocess_exec( *command, stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.STDOUT + stderr=asyncio.subprocess.STDOUT, ) try: stdout, _ = await process.communicate() - return stdout.decode('utf-8', errors='replace'), process.returncode + return stdout.decode("utf-8", errors="replace"), process.returncode except asyncio.CancelledError: try: process.kill() @@ -88,30 +118,31 @@ async def _execute_command(self, command: List[str]) -> tuple: def _parse_nmap_output(self, output: str, target: str) -> List[Dict[str, Any]]: findings = [] - # Regex for open ports: 80/tcp open http port_pattern = re.compile(r"(\d+)/(tcp|udp)\s+open\s+([\w-]+)\s*(.*)") - + for match in port_pattern.finditer(output): port_str, proto, service, version = match.groups() - + title = f"Open Port: {port_str}/{proto} ({service})" description = f"Port {port_str} is open and running {service} service." if version.strip(): description += f" Version detected: {version.strip()}" - - findings.append({ - "title": title, - "category": "Network Service", - "severity": self.normalize_severity("low"), - "target": target, - "description": description, - "remediation": "Close unnecessary ports and use a firewall to restrict access.", - "metadata": { - "port": port_str, - "protocol": proto, - "service": service, - "version": version.strip() + + findings.append( + { + "title": title, + "category": "Network Service", + "severity": self.normalize_severity("low"), + "target": target, + "description": description, + "remediation": "Close unnecessary ports and use a firewall to restrict access.", + "metadata": { + "port": port_str, + "protocol": proto, + "service": service, + "version": version.strip(), + }, } - }) - + ) + return findings diff --git a/testing/backend/unit/test_command_injection.py b/testing/backend/unit/test_command_injection.py new file mode 100644 index 00000000..be9587bb --- /dev/null +++ b/testing/backend/unit/test_command_injection.py @@ -0,0 +1,252 @@ +""" +Security tests for command argument injection prevention (issue #201). + +Verifies that: +- Flag injection via `ports`, `scan_type`, and other fields is blocked +- SELECT fields reject values outside their declared option list +- INTEGER fields reject non-integer strings +- Pattern-validated STRING fields reject non-matching input +- PortScanner input normalisation produces schema-compliant values +- Valid inputs are accepted unchanged +""" + +import pytest +from unittest.mock import MagicMock +from typing import Any, Dict, List, Optional + +from backend.secuscan.plugins import PluginManager, _PORT_SPEC_PATTERN +from backend.secuscan.models import PluginMetadata, PluginField, PluginFieldType +from backend.secuscan.scanners.port_scanner import PortScanner + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_plugin(**extra_fields) -> PluginMetadata: + """Build a minimal PluginMetadata with caller-supplied field list.""" + base = { + "id": "test-plugin", + "name": "Test Plugin", + "version": "1.0.0", + "description": "test", + "category": "test", + "engine": {"type": "cli", "binary": "echo"}, + "command_template": ["{target}"], + "safety": {"level": "safe"}, + "output": {"format": "text", "parser": "none"}, + "fields": [], + "presets": {}, + } + base.update(extra_fields) + return PluginMetadata(**base) + + +def _make_manager() -> PluginManager: + return PluginManager(plugins_dir="/nonexistent") + + +def _nmap_like_plugin() -> PluginMetadata: + """Minimal replica of the nmap plugin field schema used in validation tests.""" + return _make_plugin( + id="nmap", + fields=[ + PluginField( + id="target", + label="Target", + type=PluginFieldType.STRING, + validation={"pattern": r"^[a-zA-Z0-9.\-]+$", "message": "Invalid target"}, + ), + PluginField( + id="scan_type", + label="Scan Type", + type=PluginFieldType.SELECT, + options=[{"value": "S"}, {"value": "T"}, {"value": "U"}], + ), + PluginField( + id="ports", + label="Ports", + type=PluginFieldType.STRING, + ), + PluginField( + id="timeout", + label="Timeout", + type=PluginFieldType.INTEGER, + ), + PluginField( + id="service_detection", + label="Service detection", + type=PluginFieldType.BOOLEAN, + ), + ], + ) + + +# --------------------------------------------------------------------------- +# _reject_injected_args +# --------------------------------------------------------------------------- + +class TestRejectInjectedArgs: + def setup_method(self): + self.mgr = _make_manager() + + def test_ports_valid_numeric(self): + self.mgr._reject_injected_args("ports", "22,80,443") + + def test_ports_valid_range(self): + self.mgr._reject_injected_args("ports", "1-1000") + + def test_ports_empty_ok(self): + self.mgr._reject_injected_args("ports", "") + + def test_ports_flag_injection_rejected(self): + with pytest.raises(ValueError, match="port specification"): + self.mgr._reject_injected_args("ports", "--script=evil.nse") + + def test_ports_space_injection_rejected(self): + with pytest.raises(ValueError, match="port specification"): + self.mgr._reject_injected_args("ports", "80 --script malware") + + def test_string_leading_dash_rejected(self): + with pytest.raises(ValueError, match="must not begin with '-'"): + self.mgr._reject_injected_args("wordlist", "--dump-header /etc/passwd") + + def test_string_value_ok(self): + self.mgr._reject_injected_args("wordlist", "/usr/share/wordlists/common.txt") + + def test_target_with_valid_hostname_ok(self): + self.mgr._reject_injected_args("target", "example.com") + + +# --------------------------------------------------------------------------- +# _validate_inputs_against_schema +# --------------------------------------------------------------------------- + +class TestSchemaValidation: + def setup_method(self): + self.mgr = _make_manager() + self.plugin = _nmap_like_plugin() + + def _validate(self, inputs: Dict[str, Any]) -> None: + self.mgr._validate_inputs_against_schema(self.plugin, inputs) + + # SELECT field + def test_select_valid_value_accepted(self): + self._validate({"scan_type": "T"}) + + def test_select_invalid_value_rejected(self): + with pytest.raises(ValueError, match="not in allowed values"): + self._validate({"scan_type": "-sV"}) + + def test_select_injection_rejected(self): + with pytest.raises(ValueError, match="not in allowed values"): + self._validate({"scan_type": "T --script malware"}) + + # INTEGER field + def test_integer_valid(self): + self._validate({"timeout": 30}) + self._validate({"timeout": "30"}) + + def test_integer_string_rejected(self): + with pytest.raises(ValueError, match="expects an integer"): + self._validate({"timeout": "thirty"}) + + def test_integer_flag_rejected(self): + with pytest.raises(ValueError, match="expects an integer"): + self._validate({"timeout": "--evil"}) + + # BOOLEAN field + def test_boolean_true_accepted(self): + self._validate({"service_detection": True}) + self._validate({"service_detection": "true"}) + + def test_boolean_invalid_rejected(self): + with pytest.raises(ValueError, match="expects a boolean"): + self._validate({"service_detection": "yes"}) + + # Pattern-validated STRING field (target) + def test_target_valid(self): + self._validate({"target": "example.com"}) + self._validate({"target": "192.168.1.1"}) + + def test_target_invalid_pattern_rejected(self): + with pytest.raises(ValueError, match="Invalid target"): + self._validate({"target": "$(evil)"}) + + # ports STRING field — custom logic + def test_ports_valid(self): + self._validate({"ports": "22,80,443"}) + self._validate({"ports": "1-1000"}) + + def test_ports_flag_injection_rejected(self): + with pytest.raises(ValueError, match="port specification"): + self._validate({"ports": "--script=vuln"}) + + # None/empty values are skipped (defaults handled later) + def test_none_value_skipped(self): + self._validate({"scan_type": None}) + + def test_empty_string_skipped(self): + self._validate({"ports": ""}) + + # Unknown fields are silently ignored + def test_unknown_field_ignored(self): + self._validate({"unknown_field": "--evil"}) + + +# --------------------------------------------------------------------------- +# PortScanner input normalisation +# --------------------------------------------------------------------------- + +class TestPortScannerResolveScanType: + @pytest.mark.parametrize("raw,expected", [ + ("T", "T"), + ("S", "S"), + ("U", "U"), + ("-sT", "T"), + ("-sS", "S"), + ("-sU", "U"), + ("sT", "T"), + ("-sV", "T"), # -sV is not a scan-type letter; defaults to T + ("V", "T"), # V is not a valid scan-type + (None, "T"), + ("", "T"), + ]) + def test_resolve_scan_type(self, raw, expected): + assert PortScanner._resolve_scan_type(raw) == expected + + +class TestPortScannerResolvePorts: + @pytest.mark.parametrize("raw,expected", [ + ("", ""), + (None, ""), + ("top100", ""), + ("top1000", "1-1000"), + ("all", "1-65535"), + ("22,80,443", "22,80,443"), + ("1-1000", "1-1000"), + ("--script=evil.nse", ""), # injection attempt → empty (rejected by schema validation upstream) + ("--top-ports 100", ""), # flag injection → empty + ]) + def test_resolve_ports(self, raw, expected): + assert PortScanner._resolve_ports(raw) == expected + + +# --------------------------------------------------------------------------- +# Port spec pattern +# --------------------------------------------------------------------------- + +class TestPortSpecPattern: + @pytest.mark.parametrize("value,should_match", [ + ("22", True), + ("22,80,443", True), + ("1-1000", True), + ("1-65535", True), + ("", False), + ("--script=evil", False), + ("80 --script", False), + ("22;whoami", False), + ("$(id)", False), + ]) + def test_pattern(self, value, should_match): + assert bool(_PORT_SPEC_PATTERN.match(value)) == should_match