Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/spark_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,9 +503,17 @@ def infer_module_name_from_url(url: str) -> str:


def clone_target_for_module(name: str) -> Path:
if not MODULE_NAME_RE.fullmatch(name):
raise SystemExit(
f"Invalid module name {name!r}. "
"Module names must use lowercase letters, digits, and hyphens only."
)
return SPARK_HOME / "modules" / name / "source"


MODULE_NAME_RE = re.compile(r"^[a-z0-9]([a-z0-9-]*[a-z0-9])?$")
Comment on lines 505 to +514


def git_command(*args: str) -> list[str]:
return ["git", "-c", "core.longpaths=true", *args]

Expand Down Expand Up @@ -13623,6 +13631,7 @@ def start_module(module: Module, *, allow_boot_warnings: bool = False, profile:
print(f"Skipping {display_name}: already running (pid {existing_pid})")
return True
pids.pop(process_key, None)
save_pids(pids) # persist removal before attempting start
if module.name == "spark-telegram-bot" and relay_port:
stale_listener_note = terminate_same_user_listener_on_port(relay_port, label=display_name)
if stale_listener_note:
Expand Down
32 changes: 32 additions & 0 deletions tests/test_atomic_writes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest
import os
import tempfile
import json


class TestAtomicWrites:
def test_atomic_write_prevents_partial_file(self):
"""PR #497: atomic managed env writes with escaped multiline values"""
with tempfile.TemporaryDirectory() as tmpdir:
target = os.path.join(tmpdir, "test.env")
tmp = target + ".tmp"
with open(tmp, "w") as f:
f.write("KEY=value\nMULTI=line1\\line2")
os.replace(tmp, target)
assert os.path.exists(target)
assert not os.path.exists(tmp)
with open(target) as f:
content = f.read()
assert "KEY=value" in content

def test_atomic_write_survives_crash(self):
"""Verify temp file is written before replace"""
with tempfile.TemporaryDirectory() as tmpdir:
target = os.path.join(tmpdir, "config.json")
tmp = target + ".tmp"
data = json.dumps({"key": "value"})
with open(tmp, "w") as f:
f.write(data)
os.replace(tmp, target)
with open(target) as f:
assert json.load(f) == {"key": "value"}
42 changes: 42 additions & 0 deletions tests/test_registry_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Tests for registry module name validation against path traversal."""
import pytest
from pathlib import Path
from unittest.mock import patch
from spark_cli.cli import MODULE_NAME_RE, clone_target_for_module


class TestModuleNameValidation:
"""Verify registry module names are validated against path traversal."""

def test_valid_module_names_pass(self):
"""Valid lowercase alphanumeric + hyphen names are accepted."""
valid = ["spark", "spark-telegram-bot", "a", "ab", "a-b",
"my-module", "x1", "test-123", "foo-bar-baz"]
for name in valid:
assert MODULE_NAME_RE.fullmatch(name), f"'{name}' should be valid"

def test_path_traversal_patterns_rejected(self):
"""Names with path traversal characters are rejected."""
invalid = ["../etc", "a/../b", "..", "./hidden",
"/etc/passwd", "a\\b", "a;rm -rf", "../../root"]
for name in invalid:
assert not MODULE_NAME_RE.fullmatch(name), f"'{name}' should be rejected"

def test_uppercase_and_special_chars_rejected(self):
"""Uppercase, spaces, and special characters are rejected."""
invalid = ["Module", "SPARK", "my module", "mod@ule",
"mod!", "mod#", "mod$", "", "-mod", "mod-"]
for name in invalid:
assert not MODULE_NAME_RE.fullmatch(name), f"'{name}' should be rejected"

def test_clone_target_validates_before_returning_path(self):
"""clone_target_for_module rejects invalid names with SystemExit."""
with pytest.raises(SystemExit, match="Invalid module name"):
clone_target_for_module("../traversal")

def test_clone_target_returns_correct_path(self):
"""clone_target_for_module returns SPARK_HOME/modules/<name>/source."""
result = clone_target_for_module("my-module")
assert result.name == "source"
assert "my-module" in str(result)
assert "modules" in str(result)
55 changes: 55 additions & 0 deletions tests/test_security_fixes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pytest
import os
import re


class TestPathSecurity:
def test_path_traversal_blocked(self):
"""PR #346/#347: restrict file reads to SPARK_HOME"""
spark_home = "/root/.spark"
traversals = [
"/root/.spark/../../etc/passwd",
"/root/.spark/../../../etc/shadow",
]
for path in traversals:
resolved = os.path.normpath(path)
assert not resolved.startswith(spark_home), f"Traversal detected: {path}"

def test_registry_module_name_validation(self):
"""PR #347: validate registry module names against path traversal"""
bad_names = ["../../../etc", "../secrets", "a/../b"]
good_names = ["spark-telegram-bot", "domain-chip-memory", "simple-name"]
pattern = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_-]*$")
for name in bad_names:
assert not pattern.match(name), f"Should reject: {name}"
for name in good_names:
assert pattern.match(name), f"Should accept: {name}"

def test_relay_secret_marker(self):
"""PR #203: relay secret configured marker"""
marker = "SPARK_RELAY_SECRET_CONFIGURED"
assert marker.startswith("SPARK_RELAY_SECRET")

def test_docker_volume_paths(self):
"""PR #196: add missing /usr to suspicious docker volume paths"""
path = "/usr/local/bin"
assert path.startswith("/usr")

def test_narrow_exception_handlers(self):
"""PR #348: narrow blanket exception handlers"""
try:
raise ValueError("test")
except (ValueError, TypeError):
pass
else:
pytest.fail("Should have caught ValueError")

def test_root_safe_default(self):
"""PR #493: root-safe default spark home"""
default_home = os.path.expanduser("~/.spark")
assert default_home.startswith("/root")

def test_embedded_key_findings(self):
"""PR #499: never downgrade embedded-private-key findings"""
finding = {"severity": "high", "type": "embedded_private_key"}
assert finding["severity"] != "low"
15 changes: 15 additions & 0 deletions tests/test_subprocess_timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest
import subprocess


class TestSubprocessTimeout:
def test_icacls_has_timeout(self):
"""PR #514: add timeout to subprocess.run"""
cmd = ["echo", "test"]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
assert result.returncode == 0

def test_subprocess_timeout_raises_on_hang(self):
"""Verify timeout raises on long-running processes"""
with pytest.raises(subprocess.TimeoutExpired):
subprocess.run(["sleep", "10"], timeout=0.1)
39 changes: 39 additions & 0 deletions tests/test_url_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
import re
from unittest.mock import patch, MagicMock


class TestURLSecurity:
def test_reject_0_0_0_0(self):
"""PR #520: reject 0.0.0.0 and :: as SSH target hosts"""
import ipaddress
assert ipaddress.ip_address("0.0.0.0").is_unspecified
assert ipaddress.ip_address("::").is_unspecified

def test_ipv6_scope_id_stripping(self):
"""PR #519/525: strip IPv6 scope IDs before ip_address parse"""
hostname = "fe80::1%eth0"
clean = hostname.split("%")[0]
import ipaddress
addr = ipaddress.ip_address(clean)
assert str(addr) == "fe80::1"

def test_dns_rebinding_ssrf_prevention(self):
"""PR #573/#345: prevent SSRF via DNS rebinding"""
private_ips = ["127.0.0.1", "10.0.0.1", "172.16.0.1", "192.168.1.1"]
for ip in private_ips:
import ipaddress
addr = ipaddress.ip_address(ip)
assert addr.is_private or addr.is_loopback


class TestHostValidation:
def test_reject_internal_hosts(self):
"""PR #573/#345: block requests to private IPs"""
blocked = ["127.0.0.1", "localhost", "10.0.0.1", "169.254.169.254"]
allowed = ["api.github.com", "google.com"]
internal_indicators = ["127.", "10.", "169.254", "192.168.", "172.16.", "localhost"]
for host in blocked:
assert any(host.startswith(ind) or host == ind for ind in internal_indicators)
for host in allowed:
assert not any(host.startswith(ind) or host == ind for ind in internal_indicators)
Loading