Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 43 additions & 8 deletions refactron/verification/checks/test_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import shutil
import subprocess
import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Optional
Expand All @@ -20,6 +21,8 @@ class TestSuiteGate(BaseCheck):

def __init__(self, project_root: Optional[Path] = None):
self.project_root = project_root
self._test_file_cache: Optional[Dict[str, List[Path]]] = None
self._all_test_files: Optional[List[Path]] = None

def verify(self, original: str, transformed: str, file_path: Path) -> CheckResult:
start = time.monotonic()
Expand Down Expand Up @@ -54,16 +57,30 @@ def verify(self, original: str, transformed: str, file_path: Path) -> CheckResul
# Delete .pyc cache
self._clear_pycache(file_path)

# Run pytest
cmd = ["python3", "-m", "pytest", "-x", "-q"]
# Run pytest from the project root with the host interpreter so
# the repo's pyproject.toml / pytest.ini / conftest.py are picked
# up and the same venv as the host process is used.
run_cwd = (self.project_root or file_path.parent).resolve()

# Make the project root importable for edge cases where the layout
# relies on PYTHONPATH rather than an installed package.
env = {**os.environ, "PYTHONDONTWRITEBYTECODE": "1"}
existing_pythonpath = env.get("PYTHONPATH", "")
env["PYTHONPATH"] = (
os.pathsep.join([str(run_cwd), existing_pythonpath])
if existing_pythonpath
else str(run_cwd)
)

cmd = [sys.executable, "-m", "pytest", "-x", "-q"]
cmd += [str(f) for f in test_files]
result = subprocess.run(
cmd,
timeout=45,
capture_output=True,
text=True,
env={**os.environ, "PYTHONDONTWRITEBYTECODE": "1"},
cwd=str(file_path.parent),
env=env,
cwd=str(run_cwd),
)

elapsed = int((time.monotonic() - start) * 1000)
Expand Down Expand Up @@ -110,11 +127,27 @@ def _find_relevant_tests(self, file_path: Path) -> List[Path]:
module_name = file_path.stem
search_root = self.project_root or file_path.parent

if self._test_file_cache is None:
self._test_file_cache = {}
self._all_test_files = []

test_dirs = [d for d in [search_root / "tests", search_root / "test"] if d.is_dir()]
search_dirs = test_dirs if test_dirs else [search_root]
excluded_dirs = {".git", ".rag", "__pycache__", "venv", ".venv", "env", "node_modules"}

for root_dir in search_dirs:
for py_file in root_dir.rglob("*.py"):
if any(excluded in py_file.parts for excluded in excluded_dirs):
continue
name = py_file.name
if name.startswith("test_") or name.endswith("_test.py"):
self._all_test_files.append(py_file)

if module_name in self._test_file_cache:
return self._test_file_cache[module_name]

test_files: List[Path] = []
for py_file in search_root.rglob("*.py"):
name = py_file.name
if not (name.startswith("test_") or name.endswith("_test.py")):
continue
for py_file in self._all_test_files: # type: ignore
if py_file == file_path:
continue
try:
Expand All @@ -123,6 +156,8 @@ def _find_relevant_tests(self, file_path: Path) -> List[Path]:
test_files.append(py_file)
except Exception:
continue

self._test_file_cache[module_name] = test_files
return test_files

@staticmethod
Expand Down
61 changes: 61 additions & 0 deletions tests/test_test_gate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Unit tests for TestSuiteGate (Check 3)."""

import os
import sys
from pathlib import Path

import pytest

from refactron.verification.checks import test_gate as test_gate_module
from refactron.verification.checks.test_gate import TestSuiteGate

FIXTURES_DIR = Path(__file__).parent / "fixtures"
Expand Down Expand Up @@ -71,3 +74,61 @@ def test_confidence_is_0_9(self, gate):
original = file_path.read_text(encoding="utf-8")
cr = gate.verify(original, original, file_path)
assert cr.confidence == 0.9


class TestPytestInvocation:
"""The pytest subprocess must use the host interpreter and project root."""

def _capture_run(self, monkeypatch):
"""Patch subprocess.run to capture its arguments without running pytest."""
captured = {}

def fake_run(cmd, **kwargs):
captured["cmd"] = cmd
captured["cwd"] = kwargs.get("cwd")
captured["env"] = kwargs.get("env")

class _Completed:
returncode = 0
stdout = ""
stderr = ""

return _Completed()

monkeypatch.setattr(test_gate_module.subprocess, "run", fake_run)
return captured

def test_uses_host_interpreter(self, monkeypatch):
"""pytest must run via sys.executable, not a hard-coded python3."""
captured = self._capture_run(monkeypatch)
gate = TestSuiteGate(project_root=FIXTURES_DIR)
file_path = FIXTURES_DIR / "fixture_test_break.py"
original = file_path.read_text(encoding="utf-8")

gate.verify(original, original, file_path)

assert captured["cmd"][0] == sys.executable
assert captured["cmd"][1:4] == ["-m", "pytest", "-x"]

def test_cwd_is_resolved_project_root(self, monkeypatch):
"""pytest must run from the resolved project root, not file_path.parent."""
captured = self._capture_run(monkeypatch)
gate = TestSuiteGate(project_root=FIXTURES_DIR)
file_path = FIXTURES_DIR / "fixture_test_break.py"
original = file_path.read_text(encoding="utf-8")

gate.verify(original, original, file_path)

assert captured["cwd"] == str(FIXTURES_DIR.resolve())

def test_project_root_added_to_pythonpath(self, monkeypatch):
"""The project root must be present on PYTHONPATH for the subprocess."""
captured = self._capture_run(monkeypatch)
gate = TestSuiteGate(project_root=FIXTURES_DIR)
file_path = FIXTURES_DIR / "fixture_test_break.py"
original = file_path.read_text(encoding="utf-8")

gate.verify(original, original, file_path)

pythonpath = captured["env"]["PYTHONPATH"].split(os.pathsep)
assert str(FIXTURES_DIR.resolve()) in pythonpath
Loading