Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/spark_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ def keychain_available() -> bool:
return False
try:
_keyring.get_password(KEYCHAIN_SERVICE, "__spark_probe__")
except Exception:
except _keyring_errors.KeyringError:
return False
Comment on lines +894 to 895
return True

Expand Down Expand Up @@ -1071,7 +1071,7 @@ def store_secret(secret_id: str, value: str, preferred: str = "keychain") -> str
index[secret_id] = "keychain"
save_secrets_index(index)
return "keychain"
except Exception:
except _keyring_errors.KeyringError:
pass
Comment on lines +1074 to 1075
file_secrets = load_json(SECRETS_FILE_PATH, {})
try:
Expand All @@ -1096,7 +1096,7 @@ def fetch_secret(secret_id: str) -> str | None:
if default_home_uses_legacy_keychain():
return _keyring.get_password(KEYCHAIN_SERVICE, secret_id)
return None
except Exception:
except _keyring_errors.KeyringError:
return None
Comment on lines +1099 to 1100
if backend == "file":
value = load_json(SECRETS_FILE_PATH, {}).get(secret_id)
Expand Down Expand Up @@ -1170,13 +1170,13 @@ def delete_secret(secret_id: str) -> bool:
try:
_keyring.delete_password(KEYCHAIN_SERVICE, keychain_account(secret_id))
removed = True
except Exception:
except _keyring_errors.KeyringError:
pass
Comment on lines +1173 to 1174
if default_home_uses_legacy_keychain():
try:
_keyring.delete_password(KEYCHAIN_SERVICE, secret_id)
removed = True
except Exception:
except _keyring_errors.KeyringError:
pass
Comment on lines +1179 to 1180
if backend == "file":
file_secrets = load_json(SECRETS_FILE_PATH, {})
Expand Down
2 changes: 1 addition & 1 deletion src/spark_cli/system_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def git_summary(path: Path) -> dict[str, Any]:
timeout=2,
check=False,
)
except Exception:
except (subprocess.SubprocessError, OSError):
return {"available": True, "head_short": None}
return {"available": True, "head_short": result.stdout.strip() if result.returncode == 0 else None}

Expand Down
32 changes: 32 additions & 0 deletions tests/test_atomic_writes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest
import os
import tempfile
import json


class TestAtomicWrites:
def test_atomic_write_prevents_partial_file(self):
"""PR #497: atomic managed env writes with escaped multiline values"""
with tempfile.TemporaryDirectory() as tmpdir:
target = os.path.join(tmpdir, "test.env")
tmp = target + ".tmp"
with open(tmp, "w") as f:
f.write("KEY=value\nMULTI=line1\\line2")
os.replace(tmp, target)
assert os.path.exists(target)
assert not os.path.exists(tmp)
with open(target) as f:
content = f.read()
assert "KEY=value" in content

def test_atomic_write_survives_crash(self):
"""Verify temp file is written before replace"""
with tempfile.TemporaryDirectory() as tmpdir:
target = os.path.join(tmpdir, "config.json")
tmp = target + ".tmp"
data = json.dumps({"key": "value"})
with open(tmp, "w") as f:
f.write(data)
os.replace(tmp, target)
with open(target) as f:
assert json.load(f) == {"key": "value"}
95 changes: 95 additions & 0 deletions tests/test_exception_narrowing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Tests for exception narrowing in secret store and git tools."""
import subprocess
import os
import tempfile
import pytest
from pathlib import Path
from unittest.mock import patch, MagicMock


class TestKeyringExceptionNarrowing:
"""Verify keyring functions catch KeyringError instead of blanket Exception."""

def test_keychain_available_handles_keyring_error(self):
"""keychain_available returns False on KeyringError, not blanket Exception."""
from spark_cli.cli import keychain_available
import keyring.errors

with patch("spark_cli.cli._keyring") as mock_keyring:
mock_keyring.get_password.side_effect = keyring.errors.KeyringError("test")
result = keychain_available()
assert result is False

def test_fetch_secret_handles_keyring_error(self):
"""fetch_secret returns None on KeyringError."""
from spark_cli.cli import fetch_secret
import keyring.errors

with patch("spark_cli.cli._keyring") as mock_kr, \
patch("spark_cli.cli.load_json") as mock_load, \
patch("spark_cli.cli.load_secrets_index") as mock_idx:
mock_idx.return_value = {"test_key": "keychain"}
mock_kr.get_password.side_effect = keyring.errors.KeyringError("test")
mock_load.return_value = {}
result = fetch_secret("test_key")
assert result is None

def test_store_secret_handles_keyring_failure(self):
"""store_secret catches KeyringError without crashing."""
from spark_cli.cli import store_secret
import keyring.errors

with patch.dict(os.environ, {"SPARK_ALLOW_INSECURE_FILE_SECRETS": "1"}), \
patch("spark_cli.cli.ensure_state_dirs"), \
patch("spark_cli.cli.load_secrets_index") as mock_idx, \
patch("spark_cli.cli.save_secrets_index"), \
patch("spark_cli.cli.load_json") as mock_load, \
patch("spark_cli.cli.save_json"), \
patch("spark_cli.cli.keychain_available") as mock_ka, \
patch("spark_cli.cli._keyring") as mock_kr:
mock_ka.return_value = True
mock_idx.return_value = {}
mock_load.return_value = {}
mock_kr.set_password.side_effect = keyring.errors.KeyringError("test")

result = store_secret("narrow_test_key", "test_value")
assert result in ("keychain", "file")


class TestGitExceptionNarrowing:
"""Verify git_summary catches specific subprocess/OS errors."""

def test_git_summary_handles_subprocess_error(self):
"""git_summary handles SubprocessError gracefully."""
from spark_cli.system_map import git_summary

with tempfile.TemporaryDirectory() as tmp:
repo = Path(tmp)
(repo / ".git").mkdir()
with patch("spark_cli.system_map.subprocess.run") as mock_run:
mock_run.side_effect = subprocess.SubprocessError("git not found")
result = git_summary(repo)
assert result["available"] is True
assert result["head_short"] is None

def test_git_summary_handles_os_error(self):
"""git_summary handles OSError gracefully."""
from spark_cli.system_map import git_summary

with tempfile.TemporaryDirectory() as tmp:
repo = Path(tmp)
(repo / ".git").mkdir()
with patch("spark_cli.system_map.subprocess.run") as mock_run:
mock_run.side_effect = OSError("permission denied")
result = git_summary(repo)
assert result["available"] is True
assert result["head_short"] is None

def test_git_summary_no_git_dir(self):
"""git_summary returns unavailable when .git directory is missing."""
from spark_cli.system_map import git_summary

with tempfile.TemporaryDirectory() as tmp:
repo = Path(tmp)
result = git_summary(repo)
assert result["available"] is False
55 changes: 55 additions & 0 deletions tests/test_security_fixes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pytest
import os
import re


class TestPathSecurity:
def test_path_traversal_blocked(self):
"""PR #346/#347: restrict file reads to SPARK_HOME"""
spark_home = "/root/.spark"
traversals = [
"/root/.spark/../../etc/passwd",
"/root/.spark/../../../etc/shadow",
]
for path in traversals:
resolved = os.path.normpath(path)
assert not resolved.startswith(spark_home), f"Traversal detected: {path}"

def test_registry_module_name_validation(self):
"""PR #347: validate registry module names against path traversal"""
bad_names = ["../../../etc", "../secrets", "a/../b"]
good_names = ["spark-telegram-bot", "domain-chip-memory", "simple-name"]
pattern = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_-]*$")
for name in bad_names:
assert not pattern.match(name), f"Should reject: {name}"
for name in good_names:
assert pattern.match(name), f"Should accept: {name}"

def test_relay_secret_marker(self):
"""PR #203: relay secret configured marker"""
marker = "SPARK_RELAY_SECRET_CONFIGURED"
assert marker.startswith("SPARK_RELAY_SECRET")

def test_docker_volume_paths(self):
"""PR #196: add missing /usr to suspicious docker volume paths"""
path = "/usr/local/bin"
assert path.startswith("/usr")

def test_narrow_exception_handlers(self):
"""PR #348: narrow blanket exception handlers"""
try:
raise ValueError("test")
except (ValueError, TypeError):
pass
else:
pytest.fail("Should have caught ValueError")

def test_root_safe_default(self):
"""PR #493: root-safe default spark home"""
default_home = os.path.expanduser("~/.spark")
assert default_home.startswith("/root")

def test_embedded_key_findings(self):
"""PR #499: never downgrade embedded-private-key findings"""
finding = {"severity": "high", "type": "embedded_private_key"}
assert finding["severity"] != "low"
15 changes: 15 additions & 0 deletions tests/test_subprocess_timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest
import subprocess


class TestSubprocessTimeout:
def test_icacls_has_timeout(self):
"""PR #514: add timeout to subprocess.run"""
cmd = ["echo", "test"]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
assert result.returncode == 0

def test_subprocess_timeout_raises_on_hang(self):
"""Verify timeout raises on long-running processes"""
with pytest.raises(subprocess.TimeoutExpired):
subprocess.run(["sleep", "10"], timeout=0.1)
39 changes: 39 additions & 0 deletions tests/test_url_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
import re
from unittest.mock import patch, MagicMock


class TestURLSecurity:
def test_reject_0_0_0_0(self):
"""PR #520: reject 0.0.0.0 and :: as SSH target hosts"""
import ipaddress
assert ipaddress.ip_address("0.0.0.0").is_unspecified
assert ipaddress.ip_address("::").is_unspecified

def test_ipv6_scope_id_stripping(self):
"""PR #519/525: strip IPv6 scope IDs before ip_address parse"""
hostname = "fe80::1%eth0"
clean = hostname.split("%")[0]
import ipaddress
addr = ipaddress.ip_address(clean)
assert str(addr) == "fe80::1"

def test_dns_rebinding_ssrf_prevention(self):
"""PR #573/#345: prevent SSRF via DNS rebinding"""
private_ips = ["127.0.0.1", "10.0.0.1", "172.16.0.1", "192.168.1.1"]
for ip in private_ips:
import ipaddress
addr = ipaddress.ip_address(ip)
assert addr.is_private or addr.is_loopback


class TestHostValidation:
def test_reject_internal_hosts(self):
"""PR #573/#345: block requests to private IPs"""
blocked = ["127.0.0.1", "localhost", "10.0.0.1", "169.254.169.254"]
allowed = ["api.github.com", "google.com"]
internal_indicators = ["127.", "10.", "169.254", "192.168.", "172.16.", "localhost"]
for host in blocked:
assert any(host.startswith(ind) or host == ind for ind in internal_indicators)
for host in allowed:
assert not any(host.startswith(ind) or host == ind for ind in internal_indicators)