diff --git a/refactron/autofix/engine.py b/refactron/autofix/engine.py index 0de5af8..b46fb30 100644 --- a/refactron/autofix/engine.py +++ b/refactron/autofix/engine.py @@ -12,6 +12,26 @@ from refactron.autofix.models import FixResult, FixRiskLevel from refactron.core.models import CodeIssue +# Directory markers used to locate the project/VCS root when no explicit +# root is supplied to fix_file(). +_PROJECT_ROOT_MARKERS = (".git", ".hg", ".svn", "pyproject.toml", "setup.py", "setup.cfg") + + +def _discover_project_root(file_path: Path) -> Path: + """Walk up from ``file_path`` to find the project/VCS root. + + Returns the nearest ancestor directory containing a VCS directory + (``.git``/``.hg``/``.svn``) or a project marker + (``pyproject.toml``/``setup.py``/``setup.cfg``). Falls back to the + file's own directory when no marker is found. + """ + start = file_path.parent if file_path.suffix else file_path + for directory in (start, *start.parents): + for marker in _PROJECT_ROOT_MARKERS: + if (directory / marker).exists(): + return directory + return start + class AutoFixEngine: """ @@ -146,6 +166,7 @@ def fix_file( issues: List[CodeIssue], dry_run: bool = True, verify: bool = False, + project_root: Optional[Path] = None, ) -> Tuple[str, Optional[str]]: """ Apply all fixable issues to a file. @@ -163,6 +184,13 @@ def fix_file( issues: List of CodeIssue objects to attempt to fix. dry_run: When True, no bytes are written to disk. verify: When True, run VerificationEngine before writing. + project_root: Root directory used by verification (e.g. for test + discovery). When None, the root is discovered by walking up + from ``file_path`` to the nearest VCS/project marker, falling + back to the file's own directory. Callers that already know + the project root (e.g. RefactronPipeline) should pass it + explicitly so monorepo/nested layouts verify against the + correct root. Returns: Tuple of (fixed_code, diff). diff is None/empty when no changes @@ -189,7 +217,8 @@ def fix_file( from refactron.verification import VerificationEngine logger = logging.getLogger(__name__) - ve = VerificationEngine(project_root=file_path.parent) + root = project_root or _discover_project_root(file_path) + ve = VerificationEngine(project_root=root) vr = ve.verify(code, current_code, file_path) if not vr.safe_to_apply: logger.warning("Verification blocked %s: %s", file_path, vr.blocking_reason) diff --git a/refactron/verification/checks/test_gate.py b/refactron/verification/checks/test_gate.py index 3f8327f..3bf824b 100644 --- a/refactron/verification/checks/test_gate.py +++ b/refactron/verification/checks/test_gate.py @@ -20,6 +20,8 @@ class TestSuiteGate(BaseCheck): def __init__(self, project_root: Optional[Path] = None): self.project_root = project_root + self._test_file_cache: Optional[Dict[str, List[Path]]] = None + self._all_test_files: Optional[List[Path]] = None def verify(self, original: str, transformed: str, file_path: Path) -> CheckResult: start = time.monotonic() @@ -110,11 +112,27 @@ def _find_relevant_tests(self, file_path: Path) -> List[Path]: module_name = file_path.stem search_root = self.project_root or file_path.parent + if self._test_file_cache is None: + self._test_file_cache = {} + self._all_test_files = [] + + test_dirs = [d for d in [search_root / "tests", search_root / "test"] if d.is_dir()] + search_dirs = test_dirs if test_dirs else [search_root] + excluded_dirs = {".git", ".rag", "__pycache__", "venv", ".venv", "env", "node_modules"} + + for root_dir in search_dirs: + for py_file in root_dir.rglob("*.py"): + if any(excluded in py_file.parts for excluded in excluded_dirs): + continue + name = py_file.name + if name.startswith("test_") or name.endswith("_test.py"): + self._all_test_files.append(py_file) + + if module_name in self._test_file_cache: + return self._test_file_cache[module_name] + test_files: List[Path] = [] - for py_file in search_root.rglob("*.py"): - name = py_file.name - if not (name.startswith("test_") or name.endswith("_test.py")): - continue + for py_file in self._all_test_files: # type: ignore if py_file == file_path: continue try: @@ -123,6 +141,8 @@ def _find_relevant_tests(self, file_path: Path) -> List[Path]: test_files.append(py_file) except Exception: continue + + self._test_file_cache[module_name] = test_files return test_files @staticmethod diff --git a/tests/test_fix_file_project_root.py b/tests/test_fix_file_project_root.py new file mode 100644 index 0000000..9070518 --- /dev/null +++ b/tests/test_fix_file_project_root.py @@ -0,0 +1,136 @@ +"""Tests for project_root threading into verification from AutoFixEngine.fix_file(). + +Covers: +- _discover_project_root() walks up to the nearest VCS/project marker +- _discover_project_root() falls back to the file's directory when no marker exists +- fix_file(verify=True) passes an explicit project_root to VerificationEngine +- fix_file(verify=True) discovers the root when none is supplied, + instead of always using file_path.parent +""" + +from pathlib import Path + +from refactron.autofix import engine as engine_module +from refactron.autofix.engine import AutoFixEngine, _discover_project_root +from refactron.core.models import CodeIssue, IssueCategory, IssueLevel + + +def _trailing_ws_issue(file_path: Path) -> CodeIssue: + return CodeIssue( + rule_id="remove_trailing_whitespace", + message="Trailing whitespace detected", + file_path=file_path, + line_number=1, + category=IssueCategory.STYLE, + level=IssueLevel.WARNING, + ) + + +# ─── _discover_project_root ────────────────────────────────────────────────── + + +def test_discover_project_root_finds_vcs_root(tmp_path): + """A nested file resolves to the ancestor that holds the .git directory.""" + (tmp_path / ".git").mkdir() + nested = tmp_path / "services" / "api" / "src" + nested.mkdir(parents=True) + file_path = nested / "module.py" + file_path.write_text("x = 1\n", encoding="utf-8") + + assert _discover_project_root(file_path) == tmp_path + + +def test_discover_project_root_finds_pyproject(tmp_path): + """pyproject.toml is treated as a project-root marker.""" + (tmp_path / "pyproject.toml").write_text("[project]\n", encoding="utf-8") + nested = tmp_path / "pkg" + nested.mkdir() + file_path = nested / "module.py" + file_path.write_text("x = 1\n", encoding="utf-8") + + assert _discover_project_root(file_path) == tmp_path + + +def test_discover_project_root_falls_back_to_file_dir(tmp_path): + """With no marker anywhere, the root falls back to the file's own directory.""" + nested = tmp_path / "loose" + nested.mkdir() + file_path = nested / "module.py" + file_path.write_text("x = 1\n", encoding="utf-8") + + assert _discover_project_root(file_path) == nested + + +# ─── fix_file → VerificationEngine project_root ────────────────────────────── + + +class _RecordingVerificationEngine: + """Stand-in that records the project_root it was constructed with.""" + + last_project_root = None + + def __init__(self, project_root=None, checks=None): + type(self).last_project_root = project_root + + def verify(self, original, transformed, file_path): + # Always allow the transform through so fix_file proceeds normally. + class _Result: + safe_to_apply = True + blocking_reason = None + + return _Result() + + +def _patch_verification_engine(monkeypatch): + """Route `from refactron.verification import VerificationEngine` to the recorder.""" + import refactron.verification as verification_pkg + + _RecordingVerificationEngine.last_project_root = None + monkeypatch.setattr(verification_pkg, "VerificationEngine", _RecordingVerificationEngine) + + +def test_fix_file_passes_explicit_project_root(tmp_path, monkeypatch): + """An explicit project_root must reach VerificationEngine unchanged.""" + _patch_verification_engine(monkeypatch) + + nested = tmp_path / "a" / "b" + nested.mkdir(parents=True) + file_path = nested / "module.py" + file_path.write_text("x = 1 \n", encoding="utf-8") # trailing whitespace + + AutoFixEngine().fix_file( + file_path, + [_trailing_ws_issue(file_path)], + dry_run=True, + verify=True, + project_root=tmp_path, + ) + + assert _RecordingVerificationEngine.last_project_root == tmp_path + + +def test_fix_file_discovers_root_when_none_given(tmp_path, monkeypatch): + """Without an explicit root, fix_file discovers the VCS root, not file_path.parent.""" + _patch_verification_engine(monkeypatch) + + (tmp_path / ".git").mkdir() + nested = tmp_path / "deep" / "nested" + nested.mkdir(parents=True) + file_path = nested / "module.py" + file_path.write_text("x = 1 \n", encoding="utf-8") # trailing whitespace + + AutoFixEngine().fix_file( + file_path, + [_trailing_ws_issue(file_path)], + dry_run=True, + verify=True, + ) + + # The discovered root must be the repo root, not the file's parent directory. + assert _RecordingVerificationEngine.last_project_root == tmp_path + assert _RecordingVerificationEngine.last_project_root != file_path.parent + + +def test_engine_module_exposes_discover_helper(): + """_discover_project_root is importable from the engine module.""" + assert hasattr(engine_module, "_discover_project_root")