From 82d412718070f01bfb4bd07a78bfe4440cdec008 Mon Sep 17 00:00:00 2001 From: spalen0 Date: Wed, 20 May 2026 15:17:45 +0200 Subject: [PATCH 1/2] refactor(ai-explainer): simplify per review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three internal-only cleanups from a code review pass over the AI/LLM modules. No behavior change; 239/239 tests still pass. 1. utils/impl_diff._strip_solidity_noise — single-pass re.sub. Was: three sequential re.finditer loops, each iterating character-by-character to overwrite a list[chr]. For a 500KB multi-file source bundle that's ~3n character writes plus the cost of joining the list back to a string between passes. Now: one regex with alternation across the four noise patterns plus a re.sub callback that builds a same-length whitespace replacement. Single linear scan, no intermediate list allocations. 2. utils/impl_diff._storage_layout — flatten nesting. The gap-consumption branch reached 5 levels of nesting (if consumed > 0 → if old_gap → if expected_new_gap < 0 / == 0 / > 0 → inner if for new_gap shape). Extracted to two helpers: _check_gap_consumption(consumed, old_gap, new_gap) -> list[str] _check_gap_only_change(old_gap, new_gap) -> list[str] Body is now a flat dispatch on the sign of `consumed`. Same behavior, easier to read and individually testable. 3. Promote cross-module helpers from `_name` to `name` in source_context.py. Both on_chain_state.py and impl_diff.py imported three underscore-prefixed functions from source_context (`_fetch_source`, `_find_state_var_writes`, `_extract_state_var_snippet`). The underscore prefix signals module-private, but they were de facto public APIs because two other production modules depended on them. Renamed without underscore to match their actual usage. Updated three import sites + two test files that patched the old names. The other private helpers (_concat_sources, _extract_function_snippet, _extract_function_body, _build_context) are still truly internal and keep their underscore. Skipped (noted but not worth the churn): - ExplainerOptions dataclass for the 8+ params on explain_transaction / explain_batch_transaction — caller-wide change for an ergonomic-only win. - Unifying explain_transaction vs explain_batch_transaction bodies — the single-call vs loop shape makes full unification ugly. - Risk-verdict enum — strings match the natural LLM-output shape. - Thread-pooled eth_calls — RPC rate-limit fragility, the per-process cache already absorbs most of the load. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/test_impl_diff.py | 4 +- tests/test_on_chain_state.py | 8 +-- tests/test_source_context.py | 14 ++--- utils/impl_diff.py | 100 ++++++++++++++++------------------- utils/on_chain_state.py | 16 +++--- utils/source_context.py | 14 ++--- 6 files changed, 74 insertions(+), 82 deletions(-) diff --git a/tests/test_impl_diff.py b/tests/test_impl_diff.py index b665c1e..07679e4 100644 --- a/tests/test_impl_diff.py +++ b/tests/test_impl_diff.py @@ -257,7 +257,7 @@ def test_plain_contract_is_not(self) -> None: class TestDiffImplementations(unittest.TestCase): - @patch("utils.impl_diff._fetch_source") + @patch("utils.impl_diff.fetch_source") def test_end_to_end(self, mock_fetch) -> None: mock_fetch.side_effect = [("Vault", CONTRACT_OLD), ("Vault", CONTRACT_NEW)] diff = diff_implementations("0xOld", "0xNew", 1) @@ -268,7 +268,7 @@ def test_end_to_end(self, mock_fetch) -> None: self.assertEqual(diff.added_functions[0].name, "setMaxDeposit") self.assertEqual(len(diff.changed_functions), 1) - @patch("utils.impl_diff._fetch_source", return_value=None) + @patch("utils.impl_diff.fetch_source", return_value=None) def test_returns_none_on_unverified(self, mock_fetch) -> None: self.assertIsNone(diff_implementations("0xOld", "0xNew", 1)) diff --git a/tests/test_on_chain_state.py b/tests/test_on_chain_state.py index 8106ce1..b1b850b 100644 --- a/tests/test_on_chain_state.py +++ b/tests/test_on_chain_state.py @@ -123,7 +123,7 @@ def test_skips_array_params(self) -> None: class TestReadBeforeState(unittest.TestCase): @patch("utils.on_chain_state.ChainManager") - @patch("utils.on_chain_state._fetch_source") + @patch("utils.on_chain_state.fetch_source") def test_reads_simple_uint(self, mock_fetch: MagicMock, mock_chain: MagicMock) -> None: source = """ uint256 public maxSlippage; @@ -152,7 +152,7 @@ def test_reads_simple_uint(self, mock_fetch: MagicMock, mock_chain: MagicMock) - self.assertEqual(reads[0].key_args, ()) @patch("utils.on_chain_state.ChainManager") - @patch("utils.on_chain_state._fetch_source") + @patch("utils.on_chain_state.fetch_source") def test_reads_address_keyed_mapping(self, mock_fetch: MagicMock, mock_chain: MagicMock) -> None: source = """ mapping(address => uint256) public coverageCap; @@ -180,12 +180,12 @@ def test_reads_address_keyed_mapping(self, mock_fetch: MagicMock, mock_chain: Ma self.assertEqual(reads[0].value, 5000000000000000) self.assertEqual(reads[0].key_args, (agent,)) - @patch("utils.on_chain_state._fetch_source", return_value=None) + @patch("utils.on_chain_state.fetch_source", return_value=None) def test_no_source_returns_empty(self, mock_fetch: MagicMock) -> None: call = DecodedCall(function_name="setX", signature="setX(uint256)", params=[("uint256", 1)]) self.assertEqual(read_before_state(1, "0xT", call), []) - @patch("utils.on_chain_state._fetch_source") + @patch("utils.on_chain_state.fetch_source") def test_struct_mapping_returns_empty(self, mock_fetch: MagicMock) -> None: source = """ mapping(address => CreditLine) public creditLines; diff --git a/tests/test_source_context.py b/tests/test_source_context.py index 56fedf7..bb60460 100644 --- a/tests/test_source_context.py +++ b/tests/test_source_context.py @@ -8,8 +8,8 @@ _concat_sources, _extract_function_body, _extract_function_snippet, - _extract_state_var_snippet, - _find_state_var_writes, + extract_state_var_snippet, + find_state_var_writes, format_source_context, get_source_context, reset_cache, @@ -77,7 +77,7 @@ def test_handles_nested_braces(self) -> None: class TestFindStateVarWrites(unittest.TestCase): def test_finds_assignment(self) -> None: - writes = _find_state_var_writes(INFINIFI_FARM_SOURCE, "setMaxSlippage") + writes = find_state_var_writes(INFINIFI_FARM_SOURCE, "setMaxSlippage") self.assertIn("maxSlippage", writes) def test_ignores_local_and_keyword_assignments(self) -> None: @@ -88,7 +88,7 @@ def test_ignores_local_and_keyword_assignments(self) -> None: _underscoreVar = 5; } """ - writes = _find_state_var_writes(source, "f") + writes = find_state_var_writes(source, "f") self.assertIn("storedVar", writes) self.assertNotIn("local", writes) # locals can't be distinguished, but this test documents the heuristic limit self.assertNotIn("_underscoreVar", writes) @@ -97,12 +97,12 @@ def test_ignores_local_and_keyword_assignments(self) -> None: class TestExtractStateVarSnippet(unittest.TestCase): def test_extracts_natspec_and_declaration(self) -> None: - snippet = _extract_state_var_snippet(INFINIFI_FARM_SOURCE, "maxSlippage") + snippet = extract_state_var_snippet(INFINIFI_FARM_SOURCE, "maxSlippage") self.assertIn("so actually 1 - slippage", snippet) self.assertIn("uint256 public maxSlippage", snippet) def test_missing_var_returns_empty(self) -> None: - snippet = _extract_state_var_snippet(INFINIFI_FARM_SOURCE, "doesNotExist") + snippet = extract_state_var_snippet(INFINIFI_FARM_SOURCE, "doesNotExist") self.assertEqual(snippet, "") def test_skips_local_vars_without_visibility(self) -> None: @@ -112,7 +112,7 @@ def test_skips_local_vars_without_visibility(self) -> None: } """ # Should not match — no public/private/internal/external modifier - snippet = _extract_state_var_snippet(source, "plainLocal") + snippet = extract_state_var_snippet(source, "plainLocal") self.assertEqual(snippet, "") diff --git a/utils/impl_diff.py b/utils/impl_diff.py index 1268d34..0c16575 100644 --- a/utils/impl_diff.py +++ b/utils/impl_diff.py @@ -22,7 +22,7 @@ from typing import Iterable from utils.logging import get_logger -from utils.source_context import _fetch_source +from utils.source_context import fetch_source logger = get_logger("utils.impl_diff") @@ -157,27 +157,18 @@ def _extract_function_sigs(source: str) -> list[FunctionSig]: return sigs +_SOLIDITY_NOISE_RE = re.compile( + r'/\*[\s\S]*?\*/|//[^\n]*|"(?:[^"\\\n]|\\.)*"|\'(?:[^\'\\\n]|\\.)*\'', +) + + def _strip_solidity_noise(source: str) -> str: - """Remove comments and string literals so brace counting / regex can't trip on them. + """Replace comments and string literals with same-length whitespace. - Replaces stripped content with same-length spaces to preserve byte offsets, - which keeps the brace-depth array indexable against the original source. + Preserving byte offsets keeps the brace-depth array indexable against the + original source. Newlines are preserved so line numbers stay stable. """ - out = list(source) - # Block comments /* ... */ - for m in re.finditer(r"/\*[\s\S]*?\*/", source): - for i in range(m.start(), m.end()): - out[i] = " " if source[i] != "\n" else "\n" - # Line comments //... - for m in re.finditer(r"//[^\n]*", "".join(out)): - for i in range(m.start(), m.end()): - out[i] = " " - # String literals "..." and '...' - for pat in (r'"(?:[^"\\\n]|\\.)*"', r"'(?:[^'\\\n]|\\.)*'"): - for m in re.finditer(pat, "".join(out)): - for i in range(m.start(), m.end()): - out[i] = " " - return "".join(out) + return _SOLIDITY_NOISE_RE.sub(lambda m: "".join("\n" if c == "\n" else " " for c in m.group(0)), source) def _brace_depths(cleaned: str) -> list[int]: @@ -344,54 +335,55 @@ def _storage_layout( changes.append(f"slot {i}: {o.type_str} {o.name} → {n.type_str} {n.name}") consumed = len(new_core) - len(old_core) - added_at_end: list[StateVarDecl] = [] - removed_off_end: list[StateVarDecl] = [] + added_at_end = list(new_core[len(old_core) :]) if consumed > 0 else [] + removed_off_end = list(old_core[len(new_core) :]) if consumed < 0 else [] if consumed > 0: - added_at_end = list(new_core[len(old_core) :]) - # New vars appended after the old core. - if old_gap is not None: - # Old contract reserved a gap; the new vars must come out of it. - expected_new_gap = old_gap - consumed - if expected_new_gap < 0: - changes.append( - f"consumed {consumed} new slot(s) but old gap was only {old_gap}; layout overflows reserved space" - ) - elif expected_new_gap == 0: - if new_gap is not None: - changes.append(f"old gap of {old_gap} fully consumed but new contract still has gap of {new_gap}") - else: # expected_new_gap > 0 - if new_gap is None: - changes.append( - f"old gap of {old_gap} not preserved (expected new gap of {expected_new_gap}, got none)" - ) - elif new_gap != expected_new_gap: - changes.append( - f"gap mismatch: consumed {consumed} slot(s); expected new gap of {expected_new_gap}, got {new_gap}" - ) - # If old had no gap, appending at the end is still safe (no shift). + changes.extend(_check_gap_consumption(consumed, old_gap, new_gap)) elif consumed < 0: - # New contract is shorter than old: old vars were removed off the end. - removed_off_end = list(old_core[len(new_core) :]) for i, v in enumerate(removed_off_end, start=len(new_core)): changes.append(f"slot {i}: removed {v.type_str} {v.name}") else: - # Same length cores. If gaps don't agree, flag. - if old_gap is not None and new_gap is None: - changes.append(f"old gap of size {old_gap} removed in new layout") - elif old_gap is None and new_gap is not None: - changes.append(f"new layout introduces gap of size {new_gap} not present in old") - elif old_gap != new_gap: - changes.append(f"gap size changed from {old_gap} to {new_gap} without slot consumption") + changes.extend(_check_gap_only_change(old_gap, new_gap)) safe = not changes return safe, changes, added_at_end, removed_off_end +def _check_gap_consumption(consumed: int, old_gap: int | None, new_gap: int | None) -> list[str]: + """Validate that `consumed` new vars correspond to a matching gap shrink. + + If the old contract had no gap, appending at the end is still safe (no shift). + """ + if old_gap is None: + return [] + expected_new_gap = old_gap - consumed + if expected_new_gap < 0: + return [f"consumed {consumed} new slot(s) but old gap was only {old_gap}; layout overflows reserved space"] + if expected_new_gap == 0 and new_gap is not None: + return [f"old gap of {old_gap} fully consumed but new contract still has gap of {new_gap}"] + if expected_new_gap > 0 and new_gap is None: + return [f"old gap of {old_gap} not preserved (expected new gap of {expected_new_gap}, got none)"] + if expected_new_gap > 0 and new_gap != expected_new_gap: + return [f"gap mismatch: consumed {consumed} slot(s); expected new gap of {expected_new_gap}, got {new_gap}"] + return [] + + +def _check_gap_only_change(old_gap: int | None, new_gap: int | None) -> list[str]: + """Flag gap presence/size disagreements when the non-gap layout is unchanged.""" + if old_gap is not None and new_gap is None: + return [f"old gap of size {old_gap} removed in new layout"] + if old_gap is None and new_gap is not None: + return [f"new layout introduces gap of size {new_gap} not present in old"] + if old_gap != new_gap: + return [f"gap size changed from {old_gap} to {new_gap} without slot consumption"] + return [] + + def diff_implementations(old_addr: str, new_addr: str, chain_id: int) -> ImplDiff | None: """Fetch both verified impls and produce a structural diff. None on any failure.""" - old = _fetch_source(chain_id, old_addr) - new = _fetch_source(chain_id, new_addr) + old = fetch_source(chain_id, old_addr) + new = fetch_source(chain_id, new_addr) if not old or not new: return None diff --git a/utils/on_chain_state.py b/utils/on_chain_state.py index a8c98e4..3d0872f 100644 --- a/utils/on_chain_state.py +++ b/utils/on_chain_state.py @@ -23,9 +23,9 @@ from utils.chains import Chain from utils.logging import get_logger from utils.source_context import ( - _extract_state_var_snippet, - _fetch_source, - _find_state_var_writes, + extract_state_var_snippet, + fetch_source, + find_state_var_writes, ) from utils.web3_wrapper import ChainManager @@ -193,8 +193,8 @@ def _guess_getter_from_setter(decoded_call: DecodedCall) -> tuple[str, list[str] def _resolve_source_for_function(chain_id: int, target: str, function_name: str) -> str | None: """Return the source where `function_name` is defined, following the proxy if needed.""" - fetched = _fetch_source(chain_id, target) - if fetched and _find_state_var_writes(fetched[1], function_name): + fetched = fetch_source(chain_id, target) + if fetched and find_state_var_writes(fetched[1], function_name): return fetched[1] from utils.proxy import get_current_implementation @@ -203,7 +203,7 @@ def _resolve_source_for_function(chain_id: int, target: str, function_name: str) if not impl or impl.lower() == target.lower(): return fetched[1] if fetched else None - fetched_impl = _fetch_source(chain_id, impl) + fetched_impl = fetch_source(chain_id, impl) return fetched_impl[1] if fetched_impl else (fetched[1] if fetched else None) @@ -226,13 +226,13 @@ def read_before_state( if not source: return [] - var_names = _find_state_var_writes(source, decoded_call.function_name) + var_names = find_state_var_writes(source, decoded_call.function_name) if not var_names: return [] reads: list[StateRead] = [] for var_name in var_names: - snippet = _extract_state_var_snippet(source, var_name) + snippet = extract_state_var_snippet(source, var_name) parsed = _parse_var_declaration(snippet, var_name) if snippet else None key_values: list[Any] = [] diff --git a/utils/source_context.py b/utils/source_context.py index 6dd3a0d..9b08438 100644 --- a/utils/source_context.py +++ b/utils/source_context.py @@ -52,7 +52,7 @@ class SourceContext: state_var_snippets: list[str] # natspec + declaration for each mutated state var -def _fetch_source(chain_id: int, address: str) -> tuple[str, str] | None: +def fetch_source(chain_id: int, address: str) -> tuple[str, str] | None: """Fetch (contract_name, concatenated_source) for a verified contract. Returns None if the API key is missing, the contract is unverified, or the @@ -123,7 +123,7 @@ def _extract_function_snippet(source: str, function_name: str) -> str: return f"{natspec.rstrip()}\n{match.group(2).strip()}".strip() -def _find_state_var_writes(source: str, function_name: str) -> list[str]: +def find_state_var_writes(source: str, function_name: str) -> list[str]: """State variable names assigned inside the function body, deduped, in order.""" body = _extract_function_body(source, function_name) if not body: @@ -161,7 +161,7 @@ def _extract_function_body(source: str, function_name: str) -> str: return source[start : i - 1] -def _extract_state_var_snippet(source: str, var_name: str) -> str: +def extract_state_var_snippet(source: str, var_name: str) -> str: """Find a state variable declaration with any preceding natspec. Requires a visibility modifier so local declarations inside function bodies @@ -190,11 +190,11 @@ def _build_context(contract_name: str, source: str, function_name: str) -> Sourc if not func_snippet: return None - var_names = _find_state_var_writes(source, function_name) + var_names = find_state_var_writes(source, function_name) var_snippets: list[str] = [] total = len(func_snippet) for name in var_names: - snippet = _extract_state_var_snippet(source, name) + snippet = extract_state_var_snippet(source, name) if not snippet: continue if total + len(snippet) > MAX_SNIPPET_CHARS: @@ -216,7 +216,7 @@ def get_source_context(chain_id: int, address: str, function_name: str) -> Sourc EIP-1967 proxy slot (if any) and retry against the implementation source. Best-effort: returns None on any failure (unverified, missing key, no match). """ - fetched = _fetch_source(chain_id, address) + fetched = fetch_source(chain_id, address) if fetched: ctx = _build_context(fetched[0], fetched[1], function_name) if ctx: @@ -229,7 +229,7 @@ def get_source_context(chain_id: int, address: str, function_name: str) -> Sourc if not impl or impl.lower() == address.lower(): return None - fetched_impl = _fetch_source(chain_id, impl) + fetched_impl = fetch_source(chain_id, impl) if not fetched_impl: return None return _build_context(fetched_impl[0], fetched_impl[1], function_name) From 1523b0839d9aaaf9f74d16ee1647498fbcb82b9a Mon Sep 17 00:00:00 2001 From: spalen0 Date: Wed, 20 May 2026 15:36:23 +0200 Subject: [PATCH 2/2] test: drop low-signal substring/format tests from AI explainer suite MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Removes 16 tests (-212 lines) that mostly assert "did X string appear in the formatter output." These overlap heavily with the integration paths and break on any wording change without flagging real bugs. Dropped: - TestFormatDecodedCalls (3) — covered by every _build_prompt test - TestFormatSimulationContext (5) — same pattern, format X → assertIn X - TestSystemPromptBrevity (1) — asserts specific SYSTEM_PROMPT strings; fragile against prompt iteration - TestFormatImplDiff (3) — formatter substring checks - TestFormatSourceContext (2) — same - 2 duplicate refine cases: test_pass_with_trailing_whitespace_still_keeps_draft (subsumed into test_pass_keeps_draft_unchanged with a whitespace-padded PASS) test_revision_with_empty_summary_falls_back (identical setup to test_empty_response_falls_back_to_draft) Kept (all regression guards for real past bugs): - TestParseExplanation — each tests a distinct LLM-output shape - TestStorageLayout gap-consumption tests — caught the OZ false-unsafe issue - TestRemovingDefaultInternalVarNowDetectedAsUnsafe — auditor regression - TestProxyUpgrade.test_works_offline_for_all_proxy_selectors — perf regression - TestProxyUpgrade.test_non_upgrade_short_circuits_before_decode — perf regression - Proxy-follow tests in test_source_context.py and test_on_chain_state.py - Refine flag wiring + skip_simulation tests 223/239 tests still pass. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/test_ai_explainer.py | 129 ++--------------------------------- tests/test_impl_diff.py | 66 ------------------ tests/test_source_context.py | 25 ------- 3 files changed, 4 insertions(+), 216 deletions(-) diff --git a/tests/test_ai_explainer.py b/tests/test_ai_explainer.py index cb52dd3..226cb60 100644 --- a/tests/test_ai_explainer.py +++ b/tests/test_ai_explainer.py @@ -7,8 +7,6 @@ from utils.llm.ai_explainer import ( Explanation, _build_prompt, - _format_decoded_calls, - _format_simulation_context, _parse_explanation, _refine_explanation, explain_transaction, @@ -16,102 +14,7 @@ ) from utils.llm.base import LLMError from utils.source_context import SourceContext -from utils.tenderly.simulation import AssetChange, SimulationResult, StateChange - - -class TestFormatDecodedCalls(unittest.TestCase): - """Tests for _format_decoded_calls.""" - - def test_single_call_no_params(self) -> None: - calls = [DecodedCall(function_name="pause", signature="pause()")] - result = _format_decoded_calls(calls) - self.assertIn("Call 1: pause()", result) - - def test_single_call_with_params(self) -> None: - calls = [ - DecodedCall( - function_name="grantRole", - signature="grantRole(bytes32,address)", - params=[("bytes32", b"\x00" * 32), ("address", "0xABC")], - ) - ] - result = _format_decoded_calls(calls) - self.assertIn("grantRole(bytes32,address)", result) - self.assertIn("bytes32:", result) - self.assertIn("address:", result) - - def test_multiple_calls(self) -> None: - calls = [ - DecodedCall(function_name="pause", signature="pause()"), - DecodedCall(function_name="unpause", signature="unpause()"), - ] - result = _format_decoded_calls(calls) - self.assertIn("Call 1: pause()", result) - self.assertIn("Call 2: unpause()", result) - - -class TestFormatSimulationContext(unittest.TestCase): - """Tests for _format_simulation_context.""" - - def test_successful_simulation(self) -> None: - sim = SimulationResult(success=True, gas_used=50000) - result = _format_simulation_context(sim) - self.assertIn("SUCCESS", result) - self.assertIn("50,000", result) - - def test_failed_simulation(self) -> None: - sim = SimulationResult(success=False, gas_used=21000, error_message="execution reverted") - result = _format_simulation_context(sim) - self.assertIn("FAILED", result) - self.assertIn("execution reverted", result) - - def test_with_asset_changes(self) -> None: - sim = SimulationResult( - success=True, - gas_used=100000, - asset_changes=[ - AssetChange( - token_address="0xToken", - token_name="USDC", - token_symbol="USDC", - from_address="0xA", - to_address="0xB", - amount="1000", - raw_amount="1000000000", - decimals=6, - ) - ], - ) - result = _format_simulation_context(sim) - self.assertIn("Token transfers:", result) - self.assertIn("USDC", result) - - def test_with_state_changes(self) -> None: - sim = SimulationResult( - success=True, - gas_used=100000, - state_changes=[ - StateChange( - contract_address="0xContract", - key="0x01", - original="0x00", - dirty="0x01", - ) - ], - ) - result = _format_simulation_context(sim) - self.assertIn("State changes", result) - self.assertIn("0xContract", result) - - def test_with_logs(self) -> None: - sim = SimulationResult( - success=True, - gas_used=100000, - logs=[{"name": "Transfer", "inputs": [{"soltype": {"name": "to"}, "value": "0xB"}]}], - ) - result = _format_simulation_context(sim) - self.assertIn("Events emitted", result) - self.assertIn("Transfer", result) +from utils.tenderly.simulation import SimulationResult class TestBuildPrompt(unittest.TestCase): @@ -226,17 +129,6 @@ def test_mixed_signatures_no_section(self) -> None: self.assertNotIn("--- Shared Across Batch ---", result) -class TestSystemPromptBrevity(unittest.TestCase): - """Verify the system prompt enforces brevity rules.""" - - def test_includes_word_cap_and_no_preamble_rules(self) -> None: - calls = [DecodedCall(function_name="pause", signature="pause()")] - result = _build_prompt(target="0xT", value=0, decoded_calls=calls, simulation=None) - self.assertIn("≤25 words", result) - self.assertIn('"This transaction"', result) - self.assertIn("risk tag in caps", result) - - class TestSkipSimulation(unittest.TestCase): """Tests for skip_simulation flag.""" @@ -501,18 +393,12 @@ class TestRefineExplanation(unittest.TestCase): """Tests for _refine_explanation.""" def test_pass_keeps_draft_unchanged(self) -> None: + # Trailing whitespace around "PASS" must also count as PASS. draft = Explanation(summary="Lowers fee 30→25 bps. LOW.", detail="bla") provider = MagicMock() - provider.complete.return_value = "PASS" - result = _refine_explanation("orig prompt", draft, provider) - self.assertIs(result, draft) - provider.complete.assert_called_once() - - def test_pass_with_trailing_whitespace_still_keeps_draft(self) -> None: - draft = Explanation(summary="x", detail="y") - provider = MagicMock() provider.complete.return_value = " PASS \n" - self.assertIs(_refine_explanation("p", draft, provider), draft) + self.assertIs(_refine_explanation("orig prompt", draft, provider), draft) + provider.complete.assert_called_once() def test_revision_replaces_draft(self) -> None: draft = Explanation(summary="This transaction does X. LOW.", detail="bla") @@ -534,13 +420,6 @@ def test_empty_response_falls_back_to_draft(self) -> None: provider.complete.return_value = "" self.assertIs(_refine_explanation("p", draft, provider), draft) - def test_revision_with_empty_summary_falls_back(self) -> None: - draft = Explanation(summary="x", detail="y") - provider = MagicMock() - # Reply parses to empty summary (no TLDR marker, no body) - provider.complete.return_value = "" - self.assertIs(_refine_explanation("p", draft, provider), draft) - class TestRefineFlagInExplainTransaction(unittest.TestCase): """Tests that the refine flag triggers a second LLM call.""" diff --git a/tests/test_impl_diff.py b/tests/test_impl_diff.py index 07679e4..7c0d1cc 100644 --- a/tests/test_impl_diff.py +++ b/tests/test_impl_diff.py @@ -4,9 +4,6 @@ from unittest.mock import patch from utils.impl_diff import ( - FunctionSig, - ImplDiff, - StateVarDecl, _diff_functions, _extract_function_sigs, _extract_state_vars, @@ -14,7 +11,6 @@ _normalize_args, _storage_layout, diff_implementations, - format_impl_diff, ) CONTRACT_OLD = """ @@ -273,67 +269,5 @@ def test_returns_none_on_unverified(self, mock_fetch) -> None: self.assertIsNone(diff_implementations("0xOld", "0xNew", 1)) -class TestFormatImplDiff(unittest.TestCase): - def test_basic_output(self) -> None: - diff = ImplDiff( - old_addr="0xOld", - new_addr="0xNew", - old_name="Vault", - new_name="Vault", - added_functions=[FunctionSig(name="newFn", args="uint256", visibility="external", modifiers="")], - removed_functions=[], - changed_functions=[], - added_state_vars=[StateVarDecl(name="newVar", type_str="uint256", visibility="public", immutable=False)], - removed_state_vars=[], - layout_changes=[], - storage_layout_safe=True, - namespaced_storage=False, - ) - out = format_impl_diff(diff) - self.assertIn("Old: 0xOld", out) - self.assertIn("New: 0xNew", out) - self.assertIn("Functions added", out) - self.assertIn("newFn(uint256)", out) - self.assertIn("Storage layout safe", out) - - def test_unsafe_layout_warning(self) -> None: - diff = ImplDiff( - old_addr="0xOld", - new_addr="0xNew", - old_name="X", - new_name="X", - added_functions=[], - removed_functions=[], - changed_functions=[], - added_state_vars=[], - removed_state_vars=[], - layout_changes=["slot 0: uint256 a → address b"], - storage_layout_safe=False, - namespaced_storage=False, - ) - out = format_impl_diff(diff) - self.assertIn("NOT upgrade-safe", out) - self.assertIn("slot 0", out) - - def test_namespaced_skips_check(self) -> None: - diff = ImplDiff( - old_addr="0xOld", - new_addr="0xNew", - old_name="", - new_name="", - added_functions=[], - removed_functions=[], - changed_functions=[], - added_state_vars=[], - removed_state_vars=[], - layout_changes=[], - storage_layout_safe=True, - namespaced_storage=True, - ) - out = format_impl_diff(diff) - self.assertIn("EIP-7201", out) - self.assertIn("skipped", out) - - if __name__ == "__main__": unittest.main() diff --git a/tests/test_source_context.py b/tests/test_source_context.py index bb60460..763702f 100644 --- a/tests/test_source_context.py +++ b/tests/test_source_context.py @@ -4,13 +4,11 @@ from unittest.mock import patch from utils.source_context import ( - SourceContext, _concat_sources, _extract_function_body, _extract_function_snippet, extract_state_var_snippet, find_state_var_writes, - format_source_context, get_source_context, reset_cache, ) @@ -239,28 +237,5 @@ def test_proxy_follow_skipped_when_impl_equals_target(self, mock_fetch: object, self.assertEqual(mock_fetch.call_count, 1) # type: ignore[attr-defined] -class TestFormatSourceContext(unittest.TestCase): - def test_includes_contract_name_and_snippets(self) -> None: - ctx = SourceContext( - contract_name="Farm", - function_snippet="/// @notice foo\nfunction setMaxSlippage(uint256) external;", - state_var_snippets=["/// @notice slippage\nuint256 public maxSlippage;"], - ) - result = format_source_context(ctx) - self.assertIn("Contract: Farm", result) - self.assertIn("setMaxSlippage", result) - self.assertIn("Relevant state variables:", result) - self.assertIn("maxSlippage", result) - - def test_omits_state_var_section_when_empty(self) -> None: - ctx = SourceContext( - contract_name="Foo", - function_snippet="function pause() external;", - state_var_snippets=[], - ) - result = format_source_context(ctx) - self.assertNotIn("Relevant state variables", result) - - if __name__ == "__main__": unittest.main()