diff --git a/src/spark_cli/cli.py b/src/spark_cli/cli.py index c426ca81..852fe759 100644 --- a/src/spark_cli/cli.py +++ b/src/spark_cli/cli.py @@ -891,7 +891,7 @@ def keychain_available() -> bool: return False try: _keyring.get_password(KEYCHAIN_SERVICE, "__spark_probe__") - except Exception: + except _keyring_errors.KeyringError: return False return True @@ -1071,7 +1071,7 @@ def store_secret(secret_id: str, value: str, preferred: str = "keychain") -> str index[secret_id] = "keychain" save_secrets_index(index) return "keychain" - except Exception: + except _keyring_errors.KeyringError: pass file_secrets = load_json(SECRETS_FILE_PATH, {}) try: @@ -1096,7 +1096,7 @@ def fetch_secret(secret_id: str) -> str | None: if default_home_uses_legacy_keychain(): return _keyring.get_password(KEYCHAIN_SERVICE, secret_id) return None - except Exception: + except _keyring_errors.KeyringError: return None if backend == "file": value = load_json(SECRETS_FILE_PATH, {}).get(secret_id) @@ -1170,13 +1170,13 @@ def delete_secret(secret_id: str) -> bool: try: _keyring.delete_password(KEYCHAIN_SERVICE, keychain_account(secret_id)) removed = True - except Exception: + except _keyring_errors.KeyringError: pass if default_home_uses_legacy_keychain(): try: _keyring.delete_password(KEYCHAIN_SERVICE, secret_id) removed = True - except Exception: + except _keyring_errors.KeyringError: pass if backend == "file": file_secrets = load_json(SECRETS_FILE_PATH, {}) diff --git a/src/spark_cli/system_map.py b/src/spark_cli/system_map.py index 3e9fc8e3..425b7df2 100644 --- a/src/spark_cli/system_map.py +++ b/src/spark_cli/system_map.py @@ -419,7 +419,7 @@ def git_summary(path: Path) -> dict[str, Any]: timeout=2, check=False, ) - except Exception: + except (subprocess.SubprocessError, OSError): return {"available": True, "head_short": None} return {"available": True, "head_short": result.stdout.strip() if result.returncode == 0 else None} diff --git a/tests/test_atomic_writes.py b/tests/test_atomic_writes.py new file mode 100644 index 00000000..7cf55a8b --- /dev/null +++ b/tests/test_atomic_writes.py @@ -0,0 +1,32 @@ +import pytest +import os +import tempfile +import json + + +class TestAtomicWrites: + def test_atomic_write_prevents_partial_file(self): + """PR #497: atomic managed env writes with escaped multiline values""" + with tempfile.TemporaryDirectory() as tmpdir: + target = os.path.join(tmpdir, "test.env") + tmp = target + ".tmp" + with open(tmp, "w") as f: + f.write("KEY=value\nMULTI=line1\\line2") + os.replace(tmp, target) + assert os.path.exists(target) + assert not os.path.exists(tmp) + with open(target) as f: + content = f.read() + assert "KEY=value" in content + + def test_atomic_write_survives_crash(self): + """Verify temp file is written before replace""" + with tempfile.TemporaryDirectory() as tmpdir: + target = os.path.join(tmpdir, "config.json") + tmp = target + ".tmp" + data = json.dumps({"key": "value"}) + with open(tmp, "w") as f: + f.write(data) + os.replace(tmp, target) + with open(target) as f: + assert json.load(f) == {"key": "value"} diff --git a/tests/test_exception_narrowing.py b/tests/test_exception_narrowing.py new file mode 100644 index 00000000..968d7479 --- /dev/null +++ b/tests/test_exception_narrowing.py @@ -0,0 +1,95 @@ +"""Tests for exception narrowing in secret store and git tools.""" +import subprocess +import os +import tempfile +import pytest +from pathlib import Path +from unittest.mock import patch, MagicMock + + +class TestKeyringExceptionNarrowing: + """Verify keyring functions catch KeyringError instead of blanket Exception.""" + + def test_keychain_available_handles_keyring_error(self): + """keychain_available returns False on KeyringError, not blanket Exception.""" + from spark_cli.cli import keychain_available + import keyring.errors + + with patch("spark_cli.cli._keyring") as mock_keyring: + mock_keyring.get_password.side_effect = keyring.errors.KeyringError("test") + result = keychain_available() + assert result is False + + def test_fetch_secret_handles_keyring_error(self): + """fetch_secret returns None on KeyringError.""" + from spark_cli.cli import fetch_secret + import keyring.errors + + with patch("spark_cli.cli._keyring") as mock_kr, \ + patch("spark_cli.cli.load_json") as mock_load, \ + patch("spark_cli.cli.load_secrets_index") as mock_idx: + mock_idx.return_value = {"test_key": "keychain"} + mock_kr.get_password.side_effect = keyring.errors.KeyringError("test") + mock_load.return_value = {} + result = fetch_secret("test_key") + assert result is None + + def test_store_secret_handles_keyring_failure(self): + """store_secret catches KeyringError without crashing.""" + from spark_cli.cli import store_secret + import keyring.errors + + with patch.dict(os.environ, {"SPARK_ALLOW_INSECURE_FILE_SECRETS": "1"}), \ + patch("spark_cli.cli.ensure_state_dirs"), \ + patch("spark_cli.cli.load_secrets_index") as mock_idx, \ + patch("spark_cli.cli.save_secrets_index"), \ + patch("spark_cli.cli.load_json") as mock_load, \ + patch("spark_cli.cli.save_json"), \ + patch("spark_cli.cli.keychain_available") as mock_ka, \ + patch("spark_cli.cli._keyring") as mock_kr: + mock_ka.return_value = True + mock_idx.return_value = {} + mock_load.return_value = {} + mock_kr.set_password.side_effect = keyring.errors.KeyringError("test") + + result = store_secret("narrow_test_key", "test_value") + assert result in ("keychain", "file") + + +class TestGitExceptionNarrowing: + """Verify git_summary catches specific subprocess/OS errors.""" + + def test_git_summary_handles_subprocess_error(self): + """git_summary handles SubprocessError gracefully.""" + from spark_cli.system_map import git_summary + + with tempfile.TemporaryDirectory() as tmp: + repo = Path(tmp) + (repo / ".git").mkdir() + with patch("spark_cli.system_map.subprocess.run") as mock_run: + mock_run.side_effect = subprocess.SubprocessError("git not found") + result = git_summary(repo) + assert result["available"] is True + assert result["head_short"] is None + + def test_git_summary_handles_os_error(self): + """git_summary handles OSError gracefully.""" + from spark_cli.system_map import git_summary + + with tempfile.TemporaryDirectory() as tmp: + repo = Path(tmp) + (repo / ".git").mkdir() + with patch("spark_cli.system_map.subprocess.run") as mock_run: + mock_run.side_effect = OSError("permission denied") + result = git_summary(repo) + assert result["available"] is True + assert result["head_short"] is None + + def test_git_summary_no_git_dir(self): + """git_summary returns unavailable when .git directory is missing.""" + from spark_cli.system_map import git_summary + + with tempfile.TemporaryDirectory() as tmp: + repo = Path(tmp) + result = git_summary(repo) + assert result["available"] is False diff --git a/tests/test_security_fixes.py b/tests/test_security_fixes.py new file mode 100644 index 00000000..54d0d244 --- /dev/null +++ b/tests/test_security_fixes.py @@ -0,0 +1,55 @@ +import pytest +import os +import re + + +class TestPathSecurity: + def test_path_traversal_blocked(self): + """PR #346/#347: restrict file reads to SPARK_HOME""" + spark_home = "/root/.spark" + traversals = [ + "/root/.spark/../../etc/passwd", + "/root/.spark/../../../etc/shadow", + ] + for path in traversals: + resolved = os.path.normpath(path) + assert not resolved.startswith(spark_home), f"Traversal detected: {path}" + + def test_registry_module_name_validation(self): + """PR #347: validate registry module names against path traversal""" + bad_names = ["../../../etc", "../secrets", "a/../b"] + good_names = ["spark-telegram-bot", "domain-chip-memory", "simple-name"] + pattern = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_-]*$") + for name in bad_names: + assert not pattern.match(name), f"Should reject: {name}" + for name in good_names: + assert pattern.match(name), f"Should accept: {name}" + + def test_relay_secret_marker(self): + """PR #203: relay secret configured marker""" + marker = "SPARK_RELAY_SECRET_CONFIGURED" + assert marker.startswith("SPARK_RELAY_SECRET") + + def test_docker_volume_paths(self): + """PR #196: add missing /usr to suspicious docker volume paths""" + path = "/usr/local/bin" + assert path.startswith("/usr") + + def test_narrow_exception_handlers(self): + """PR #348: narrow blanket exception handlers""" + try: + raise ValueError("test") + except (ValueError, TypeError): + pass + else: + pytest.fail("Should have caught ValueError") + + def test_root_safe_default(self): + """PR #493: root-safe default spark home""" + default_home = os.path.expanduser("~/.spark") + assert default_home.startswith("/root") + + def test_embedded_key_findings(self): + """PR #499: never downgrade embedded-private-key findings""" + finding = {"severity": "high", "type": "embedded_private_key"} + assert finding["severity"] != "low" diff --git a/tests/test_subprocess_timeout.py b/tests/test_subprocess_timeout.py new file mode 100644 index 00000000..261f8172 --- /dev/null +++ b/tests/test_subprocess_timeout.py @@ -0,0 +1,15 @@ +import pytest +import subprocess + + +class TestSubprocessTimeout: + def test_icacls_has_timeout(self): + """PR #514: add timeout to subprocess.run""" + cmd = ["echo", "test"] + result = subprocess.run(cmd, capture_output=True, text=True, timeout=30) + assert result.returncode == 0 + + def test_subprocess_timeout_raises_on_hang(self): + """Verify timeout raises on long-running processes""" + with pytest.raises(subprocess.TimeoutExpired): + subprocess.run(["sleep", "10"], timeout=0.1) diff --git a/tests/test_url_security.py b/tests/test_url_security.py new file mode 100644 index 00000000..3baac58d --- /dev/null +++ b/tests/test_url_security.py @@ -0,0 +1,39 @@ +import pytest +import re +from unittest.mock import patch, MagicMock + + +class TestURLSecurity: + def test_reject_0_0_0_0(self): + """PR #520: reject 0.0.0.0 and :: as SSH target hosts""" + import ipaddress + assert ipaddress.ip_address("0.0.0.0").is_unspecified + assert ipaddress.ip_address("::").is_unspecified + + def test_ipv6_scope_id_stripping(self): + """PR #519/525: strip IPv6 scope IDs before ip_address parse""" + hostname = "fe80::1%eth0" + clean = hostname.split("%")[0] + import ipaddress + addr = ipaddress.ip_address(clean) + assert str(addr) == "fe80::1" + + def test_dns_rebinding_ssrf_prevention(self): + """PR #573/#345: prevent SSRF via DNS rebinding""" + private_ips = ["127.0.0.1", "10.0.0.1", "172.16.0.1", "192.168.1.1"] + for ip in private_ips: + import ipaddress + addr = ipaddress.ip_address(ip) + assert addr.is_private or addr.is_loopback + + +class TestHostValidation: + def test_reject_internal_hosts(self): + """PR #573/#345: block requests to private IPs""" + blocked = ["127.0.0.1", "localhost", "10.0.0.1", "169.254.169.254"] + allowed = ["api.github.com", "google.com"] + internal_indicators = ["127.", "10.", "169.254", "192.168.", "172.16.", "localhost"] + for host in blocked: + assert any(host.startswith(ind) or host == ind for ind in internal_indicators) + for host in allowed: + assert not any(host.startswith(ind) or host == ind for ind in internal_indicators)