Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 145 additions & 19 deletions refactron/verification/checks/test_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import shutil
import subprocess
import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Optional
Expand All @@ -20,6 +21,8 @@ class TestSuiteGate(BaseCheck):

def __init__(self, project_root: Optional[Path] = None):
self.project_root = project_root
self._test_file_cache: Optional[Dict[str, List[Path]]] = None
self._all_test_files: Optional[List[Path]] = None

def verify(self, original: str, transformed: str, file_path: Path) -> CheckResult:
start = time.monotonic()
Expand All @@ -29,12 +32,16 @@ def verify(self, original: str, transformed: str, file_path: Path) -> CheckResul
test_files = self._find_relevant_tests(file_path)
if not test_files:
elapsed = int((time.monotonic() - start) * 1000)
details["note"] = "No tests cover this module"
# No matching tests is not a failure, but it is also not the
# strong assurance a passing test run gives — the change is
# simply unverified here, so confidence is reduced rather than
# left high (which would mask a potential false negative).
details["note"] = "No tests found importing this module — change not covered by the gate"
return CheckResult(
check_name=self.name,
passed=True,
blocking_reason="",
confidence=0.9,
confidence=0.6,
duration_ms=elapsed,
details=details,
)
Expand All @@ -54,16 +61,30 @@ def verify(self, original: str, transformed: str, file_path: Path) -> CheckResul
# Delete .pyc cache
self._clear_pycache(file_path)

# Run pytest
cmd = ["python3", "-m", "pytest", "-x", "-q"]
# Run pytest from the project root with the host interpreter so
# the repo's pyproject.toml / pytest.ini / conftest.py are picked
# up and the same venv as the host process is used.
run_cwd = (self.project_root or file_path.parent).resolve()

# Make the project root importable for edge cases where the layout
# relies on PYTHONPATH rather than an installed package.
env = {**os.environ, "PYTHONDONTWRITEBYTECODE": "1"}
existing_pythonpath = env.get("PYTHONPATH", "")
env["PYTHONPATH"] = (
os.pathsep.join([str(run_cwd), existing_pythonpath])
if existing_pythonpath
else str(run_cwd)
)

cmd = [sys.executable, "-m", "pytest", "-x", "-q"]
cmd += [str(f) for f in test_files]
result = subprocess.run(
cmd,
timeout=45,
capture_output=True,
text=True,
env={**os.environ, "PYTHONDONTWRITEBYTECODE": "1"},
cwd=str(file_path.parent),
env=env,
cwd=str(run_cwd),
)

elapsed = int((time.monotonic() - start) * 1000)
Expand Down Expand Up @@ -107,44 +128,149 @@ def verify(self, original: str, transformed: str, file_path: Path) -> CheckResul

def _find_relevant_tests(self, file_path: Path) -> List[Path]:
"""Find test files that import the module at file_path."""
module_name = file_path.stem
search_root = self.project_root or file_path.parent
targets = self._module_targets(file_path)

if self._test_file_cache is None:
self._test_file_cache = {}
self._all_test_files = []

test_dirs = [d for d in [search_root / "tests", search_root / "test"] if d.is_dir()]
search_dirs = test_dirs if test_dirs else [search_root]
excluded_dirs = {".git", ".rag", "__pycache__", "venv", ".venv", "env", "node_modules"}

for root_dir in search_dirs:
for py_file in root_dir.rglob("*.py"):
if any(excluded in py_file.parts for excluded in excluded_dirs):
continue
name = py_file.name
if name.startswith("test_") or name.endswith("_test.py"):
self._all_test_files.append(py_file)

cache_key = str(file_path)
if cache_key in self._test_file_cache:
return self._test_file_cache[cache_key]

target_file = file_path.resolve()
test_files: List[Path] = []
for py_file in search_root.rglob("*.py"):
name = py_file.name
if not (name.startswith("test_") or name.endswith("_test.py")):
continue
for py_file in self._all_test_files: # type: ignore
if py_file == file_path:
continue
try:
source = py_file.read_text(encoding="utf-8")
if self._imports_module(source, module_name):
if self._imports_module(source, targets, target_file, py_file):
test_files.append(py_file)
except Exception:
continue

self._test_file_cache[cache_key] = test_files
return test_files

def _module_targets(self, file_path: Path) -> set:
"""Qualified module names a test could import to exercise file_path.

Returns both the project-root-relative dotted path and the
package-root-relative path (walking up the ``__init__.py`` chain), so
``mypkg/submodule/foo.py`` is matched by ``from mypkg.submodule import
foo`` regardless of where the project is rooted. The bare stem is kept
as a loose fallback for flat (non-package) layouts.
"""
path = Path(file_path)
abs_path = path.resolve()
targets: set = set()
if path.name != "__init__.py":
targets.add(path.stem)

# Project-root-relative qualified name.
root = (self.project_root or path.parent)
try:
rel = abs_path.relative_to(root.resolve())
dotted = self._dotted(rel)
if dotted:
targets.add(dotted)
except ValueError:
pass

# Package-root-relative qualified name (walk up the __init__.py chain).
parts: List[str] = [] if path.name == "__init__.py" else [path.stem]
pkg_dir = abs_path.parent
while (pkg_dir / "__init__.py").is_file():
parts.append(pkg_dir.name)
pkg_dir = pkg_dir.parent
if parts:
targets.add(".".join(reversed(parts)))

return targets

@staticmethod
def _imports_module(source: str, module_name: str) -> bool:
"""Check if source code imports the given module name."""
def _dotted(rel_path: Path) -> str:
"""Convert a relative file path to a dotted module name."""
parts = list(rel_path.with_suffix("").parts)
if parts and parts[-1] == "__init__":
parts = parts[:-1]
return ".".join(parts)

@staticmethod
def _imports_module(
source: str, targets: set, target_file: Path, test_path: Path
) -> bool:
"""Whether ``source`` imports the module under verification.

Absolute imports are matched by qualified name against ``targets``;
package-relative imports (``from . import x``) are resolved on the
filesystem and compared directly to ``target_file``.
"""
try:
tree = ast.parse(source)
except SyntaxError:
return False

imported: set = set()
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
if alias.name == module_name or alias.name.startswith(module_name + "."):
return True
imported.add(alias.name)
elif isinstance(node, ast.ImportFrom):
if node.module and (
node.module == module_name or node.module.startswith(module_name + ".")
):
if node.level and node.level > 0:
if TestSuiteGate._relative_import_hits(node, test_path, target_file):
return True
continue
base = node.module or ""
if base:
imported.add(base)
for alias in node.names:
imported.add(f"{base}.{alias.name}" if base else alias.name)

for name in imported:
for target in targets:
if name == target or name.startswith(target + "."):
return True
return False

@staticmethod
def _relative_import_hits(
node: ast.ImportFrom, test_path: Path, target_file: Path
) -> bool:
"""Resolve a relative import to file paths and test against target."""
base_dir = test_path.resolve().parent
for _ in range(node.level - 1):
base_dir = base_dir.parent
if node.module:
base_dir = base_dir.joinpath(*node.module.split("."))

candidates = [base_dir.with_suffix(".py"), base_dir / "__init__.py"]
for alias in node.names:
candidates.append((base_dir / alias.name).with_suffix(".py"))
candidates.append(base_dir / alias.name / "__init__.py")

for cand in candidates:
try:
if cand.resolve() == target_file:
return True
except OSError:
continue
return False

@staticmethod
def _clear_pycache(file_path: Path) -> None:
"""Remove .pyc files for this module to force re-import."""
Expand Down
126 changes: 126 additions & 0 deletions tests/test_test_gate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Unit tests for TestSuiteGate (Check 3)."""

import os
import sys
from pathlib import Path

import pytest

from refactron.verification.checks import test_gate as test_gate_module
from refactron.verification.checks.test_gate import TestSuiteGate

FIXTURES_DIR = Path(__file__).parent / "fixtures"
Expand Down Expand Up @@ -71,3 +74,126 @@ def test_confidence_is_0_9(self, gate):
original = file_path.read_text(encoding="utf-8")
cr = gate.verify(original, original, file_path)
assert cr.confidence == 0.9


class TestPackageImportMatching:
"""Relevance matching must handle qualified package imports, not just stems."""

def _layout(self, tmp_path: Path, test_import: str):
"""Create mypkg/submodule/foo.py and tests/test_foo.py importing it."""
pkg = tmp_path / "mypkg" / "submodule"
pkg.mkdir(parents=True)
(tmp_path / "mypkg" / "__init__.py").write_text("", encoding="utf-8")
(pkg / "__init__.py").write_text("", encoding="utf-8")
foo = pkg / "foo.py"
foo.write_text("def f():\n return 1\n", encoding="utf-8")
tests = tmp_path / "tests"
tests.mkdir()
test_file = tests / "test_foo.py"
test_file.write_text(
f"{test_import}\n\n\ndef test_it():\n assert True\n", encoding="utf-8"
)
return foo, test_file

def test_qualified_from_import_matches(self, tmp_path):
"""from mypkg.submodule import foo -> matches mypkg/submodule/foo.py."""
foo, test_file = self._layout(tmp_path, "from mypkg.submodule import foo")
gate = TestSuiteGate(project_root=tmp_path)
assert test_file in gate._find_relevant_tests(foo)

def test_dotted_import_matches(self, tmp_path):
"""import mypkg.submodule.foo -> matches the file under verification."""
foo, test_file = self._layout(tmp_path, "import mypkg.submodule.foo")
gate = TestSuiteGate(project_root=tmp_path)
assert test_file in gate._find_relevant_tests(foo)

def test_unrelated_import_does_not_match(self, tmp_path):
foo, _ = self._layout(tmp_path, "import os")
gate = TestSuiteGate(project_root=tmp_path)
assert gate._find_relevant_tests(foo) == []

def test_relative_import_matches(self, tmp_path):
"""from . import foo in a test inside the package is attributed."""
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
foo = pkg / "foo.py"
foo.write_text("def f():\n return 1\n", encoding="utf-8")
test_file = pkg / "test_foo.py"
test_file.write_text(
"from . import foo\n\n\ndef test_it():\n assert foo.f() == 1\n",
encoding="utf-8",
)
gate = TestSuiteGate(project_root=tmp_path)
assert test_file in gate._find_relevant_tests(foo)

def test_no_tests_yields_reduced_confidence(self, tmp_path):
"""An uncovered change is passed but with reduced (not high) confidence."""
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
foo = pkg / "foo.py"
foo.write_text("X = 1\n", encoding="utf-8")
gate = TestSuiteGate(project_root=tmp_path)
cr = gate.verify("X = 1\n", "X = 1\n", foo)
assert cr.passed is True
assert cr.confidence == 0.6
assert "not covered" in cr.details["note"]


class TestPytestInvocation:
"""The pytest subprocess must use the host interpreter and project root."""

def _capture_run(self, monkeypatch):
"""Patch subprocess.run to capture its arguments without running pytest."""
captured = {}

def fake_run(cmd, **kwargs):
captured["cmd"] = cmd
captured["cwd"] = kwargs.get("cwd")
captured["env"] = kwargs.get("env")

class _Completed:
returncode = 0
stdout = ""
stderr = ""

return _Completed()

monkeypatch.setattr(test_gate_module.subprocess, "run", fake_run)
return captured

def test_uses_host_interpreter(self, monkeypatch):
"""pytest must run via sys.executable, not a hard-coded python3."""
captured = self._capture_run(monkeypatch)
gate = TestSuiteGate(project_root=FIXTURES_DIR)
file_path = FIXTURES_DIR / "fixture_test_break.py"
original = file_path.read_text(encoding="utf-8")

gate.verify(original, original, file_path)

assert captured["cmd"][0] == sys.executable
assert captured["cmd"][1:4] == ["-m", "pytest", "-x"]

def test_cwd_is_resolved_project_root(self, monkeypatch):
"""pytest must run from the resolved project root, not file_path.parent."""
captured = self._capture_run(monkeypatch)
gate = TestSuiteGate(project_root=FIXTURES_DIR)
file_path = FIXTURES_DIR / "fixture_test_break.py"
original = file_path.read_text(encoding="utf-8")

gate.verify(original, original, file_path)

assert captured["cwd"] == str(FIXTURES_DIR.resolve())

def test_project_root_added_to_pythonpath(self, monkeypatch):
"""The project root must be present on PYTHONPATH for the subprocess."""
captured = self._capture_run(monkeypatch)
gate = TestSuiteGate(project_root=FIXTURES_DIR)
file_path = FIXTURES_DIR / "fixture_test_break.py"
original = file_path.read_text(encoding="utf-8")

gate.verify(original, original, file_path)

pythonpath = captured["env"]["PYTHONPATH"].split(os.pathsep)
assert str(FIXTURES_DIR.resolve()) in pythonpath
Loading