diff --git a/tests/test_impl_diff.py b/tests/test_impl_diff.py new file mode 100644 index 00000000..8bb0f8e5 --- /dev/null +++ b/tests/test_impl_diff.py @@ -0,0 +1,293 @@ +"""Tests for utils/impl_diff.py.""" + +import unittest +from unittest.mock import patch + +from utils.impl_diff import ( + FunctionSig, + ImplDiff, + StateVarDecl, + _diff_functions, + _extract_function_sigs, + _extract_state_vars, + _is_namespaced_storage, + _normalize_args, + _storage_layout, + diff_implementations, + format_impl_diff, +) + +CONTRACT_OLD = """ +// SPDX-License-Identifier: MIT +contract Vault { + uint256 public minDeposit; + address public owner; + uint256 public constant FEE_BPS = 30; + + function deposit(uint256 _amount) external returns (uint256) { return _amount; } + function withdraw(uint256 _amount) external onlyOwner { _amount; } + function setOwner(address _o) external onlyOwner { owner = _o; } +} +""" + +CONTRACT_NEW = """ +// SPDX-License-Identifier: MIT +contract Vault { + uint256 public minDeposit; + address public owner; + uint256 public maxDeposit; // NEW state var appended at end + uint256 public constant FEE_BPS = 30; + + function deposit(uint256 _amount) external returns (uint256) { return _amount; } + function withdraw(uint256 _amount) external onlyOwner { _amount; } + function setOwner(address _o) external onlyAdmin { owner = _o; } // modifier changed + function setMaxDeposit(uint256 _m) external onlyOwner { maxDeposit = _m; } // new +} +""" + +CONTRACT_REORDERED = """ +contract Bad { + address public owner; // SWAPPED with minDeposit — unsafe + uint256 public minDeposit; +} +""" + +CONTRACT_NAMESPACED = """ +contract NS { + function _getXxxStorage() private pure returns (XxxStorage storage $) { + assembly { $.slot := 0x1234 } + } +} +""" + + +class TestNormalizeArgs(unittest.TestCase): + def test_strips_names(self) -> None: + self.assertEqual(_normalize_args("address _a, uint256 _b"), "address,uint256") + + def test_strips_data_locations(self) -> None: + self.assertEqual(_normalize_args("uint256[] memory arr, bytes calldata data"), "uint256[],bytes") + + def test_empty(self) -> None: + self.assertEqual(_normalize_args(""), "") + self.assertEqual(_normalize_args(" "), "") + + +class TestExtractFunctionSigs(unittest.TestCase): + def test_finds_all(self) -> None: + sigs = _extract_function_sigs(CONTRACT_OLD) + names = [s.name for s in sigs] + self.assertIn("deposit", names) + self.assertIn("withdraw", names) + self.assertIn("setOwner", names) + + def test_captures_visibility_and_modifiers(self) -> None: + sigs = {s.name: s for s in _extract_function_sigs(CONTRACT_OLD)} + self.assertEqual(sigs["setOwner"].visibility, "external") + self.assertIn("onlyOwner", sigs["setOwner"].modifiers) + + +class TestExtractStateVars(unittest.TestCase): + def test_finds_in_order(self) -> None: + vars_ = _extract_state_vars(CONTRACT_OLD) + names = [v.name for v in vars_] + # First two are slot vars in declaration order + self.assertEqual(names[:2], ["minDeposit", "owner"]) + self.assertIn("FEE_BPS", names) + + def test_marks_constant_as_immutable(self) -> None: + vars_ = {v.name: v for v in _extract_state_vars(CONTRACT_OLD)} + self.assertTrue(vars_["FEE_BPS"].immutable) + self.assertFalse(vars_["minDeposit"].immutable) + + def test_default_internal_state_vars_captured(self) -> None: + """Regression: Solidity defaults state-var visibility to internal. + Declarations like `uint256 cap;` were previously skipped because the + regex required an explicit modifier — that produced a false-safe verdict + when an upgrade removed or reordered such vars.""" + src = """ + contract C { + uint256 explicitPublic; // wait, no — explicit visibility test below + uint256 cap; // default internal, NO visibility + address admin; // default internal, NO visibility + mapping(address => uint256) balances; // default internal mapping + } + """ + vars_ = _extract_state_vars(src) + names = [v.name for v in vars_] + self.assertEqual(names, ["explicitPublic", "cap", "admin", "balances"]) + # The default-visibility ones should record visibility as "" + by_name = {v.name: v for v in vars_} + self.assertEqual(by_name["cap"].visibility, "") + self.assertEqual(by_name["admin"].visibility, "") + self.assertEqual(by_name["balances"].visibility, "") + + def test_function_locals_not_captured_after_visibility_fix(self) -> None: + """Even with visibility now optional, locals inside function bodies must + be excluded via the brace-depth check.""" + src = """ + contract C { + uint256 stateVar; + function f() external { + uint256 localUint = 1; + address localAddr; + if (true) { + uint256 deeper = 2; + } + } + } + """ + names = [v.name for v in _extract_state_vars(src)] + self.assertEqual(names, ["stateVar"]) + self.assertNotIn("localUint", names) + self.assertNotIn("localAddr", names) + self.assertNotIn("deeper", names) + + def test_struct_members_not_captured(self) -> None: + """Struct members are at brace depth 2 inside the struct, not state vars.""" + src = """ + contract C { + struct Cfg { uint256 fee; address admin; } + uint256 stateVar; + } + """ + names = [v.name for v in _extract_state_vars(src)] + self.assertNotIn("fee", names) + self.assertNotIn("admin", names) + self.assertIn("stateVar", names) + + def test_removing_default_internal_var_now_detected_as_unsafe(self) -> None: + """End-to-end: an upgrade that removes a default-internal var must be + flagged as unsafe, not silently treated as no-change.""" + old = "contract C { uint256 a; uint256 b; uint256 c; }" + new = "contract C { uint256 a; uint256 c; }" # b removed, c shifts up + old_vars = _extract_state_vars(old) + new_vars = _extract_state_vars(new) + from utils.impl_diff import _storage_layout + + safe, changes, _, _ = _storage_layout(old_vars, new_vars) + self.assertFalse(safe) + self.assertTrue(changes, "expected concrete layout changes") + + +class TestDiffFunctions(unittest.TestCase): + def test_detects_added_and_changed(self) -> None: + old = _extract_function_sigs(CONTRACT_OLD) + new = _extract_function_sigs(CONTRACT_NEW) + added, removed, changed = _diff_functions(old, new) + + added_names = [f.name for f in added] + self.assertIn("setMaxDeposit", added_names) + self.assertEqual(removed, []) + + changed_names = [(o.name, n.name) for o, n in changed] + self.assertIn(("setOwner", "setOwner"), changed_names) + + +class TestStorageLayout(unittest.TestCase): + def test_append_only_is_safe(self) -> None: + old_vars = _extract_state_vars(CONTRACT_OLD) + new_vars = _extract_state_vars(CONTRACT_NEW) + safe, changes, added, removed = _storage_layout(old_vars, new_vars) + self.assertTrue(safe) + self.assertEqual(changes, []) + added_names = [v.name for v in added] + self.assertEqual(added_names, ["maxDeposit"]) + + def test_reorder_is_unsafe(self) -> None: + old_vars = _extract_state_vars(CONTRACT_OLD) + bad_vars = _extract_state_vars(CONTRACT_REORDERED) + safe, changes, _, _ = _storage_layout(old_vars, bad_vars) + self.assertFalse(safe) + self.assertTrue(any("slot 0" in c for c in changes)) + + +class TestNamespacedStorage(unittest.TestCase): + def test_detected(self) -> None: + self.assertTrue(_is_namespaced_storage(CONTRACT_NAMESPACED)) + + def test_plain_contract_is_not(self) -> None: + self.assertFalse(_is_namespaced_storage(CONTRACT_OLD)) + + +class TestDiffImplementations(unittest.TestCase): + @patch("utils.impl_diff._fetch_source") + def test_end_to_end(self, mock_fetch) -> None: + mock_fetch.side_effect = [("Vault", CONTRACT_OLD), ("Vault", CONTRACT_NEW)] + diff = diff_implementations("0xOld", "0xNew", 1) + self.assertIsNotNone(diff) + assert diff is not None + self.assertTrue(diff.storage_layout_safe) + self.assertEqual(len(diff.added_functions), 1) + self.assertEqual(diff.added_functions[0].name, "setMaxDeposit") + self.assertEqual(len(diff.changed_functions), 1) + + @patch("utils.impl_diff._fetch_source", return_value=None) + def test_returns_none_on_unverified(self, mock_fetch) -> None: + self.assertIsNone(diff_implementations("0xOld", "0xNew", 1)) + + +class TestFormatImplDiff(unittest.TestCase): + def test_basic_output(self) -> None: + diff = ImplDiff( + old_addr="0xOld", + new_addr="0xNew", + old_name="Vault", + new_name="Vault", + added_functions=[FunctionSig(name="newFn", args="uint256", visibility="external", modifiers="")], + removed_functions=[], + changed_functions=[], + added_state_vars=[StateVarDecl(name="newVar", type_str="uint256", visibility="public", immutable=False)], + removed_state_vars=[], + layout_changes=[], + storage_layout_safe=True, + namespaced_storage=False, + ) + out = format_impl_diff(diff) + self.assertIn("Old: 0xOld", out) + self.assertIn("New: 0xNew", out) + self.assertIn("Functions added", out) + self.assertIn("newFn(uint256)", out) + self.assertIn("Storage layout safe", out) + + def test_unsafe_layout_warning(self) -> None: + diff = ImplDiff( + old_addr="0xOld", + new_addr="0xNew", + old_name="X", + new_name="X", + added_functions=[], + removed_functions=[], + changed_functions=[], + added_state_vars=[], + removed_state_vars=[], + layout_changes=["slot 0: uint256 a → address b"], + storage_layout_safe=False, + namespaced_storage=False, + ) + out = format_impl_diff(diff) + self.assertIn("NOT upgrade-safe", out) + self.assertIn("slot 0", out) + + def test_namespaced_skips_check(self) -> None: + diff = ImplDiff( + old_addr="0xOld", + new_addr="0xNew", + old_name="", + new_name="", + added_functions=[], + removed_functions=[], + changed_functions=[], + added_state_vars=[], + removed_state_vars=[], + layout_changes=[], + storage_layout_safe=True, + namespaced_storage=True, + ) + out = format_impl_diff(diff) + self.assertIn("EIP-7201", out) + self.assertIn("skipped", out) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_proxy_upgrade.py b/tests/test_proxy_upgrade.py new file mode 100644 index 00000000..9866d242 --- /dev/null +++ b/tests/test_proxy_upgrade.py @@ -0,0 +1,102 @@ +"""Tests for utils/proxy.detect_proxy_upgrade.""" + +import unittest + +from eth_abi import encode +from eth_utils import function_signature_to_4byte_selector +from eth_utils import to_checksum_address as _cs + +from utils.proxy import ProxyUpgrade, detect_proxy_upgrade + + +def encode_call(sig: str, types: list[str], vals: list) -> str: + selector = function_signature_to_4byte_selector(sig).hex() + encoded = encode(types, vals).hex() + return "0x" + selector + encoded + + +PROXY_ADDR = _cs("0x40a2accbd92bca938b02010e17a5b8929b49130d") +NEW_IMPL = _cs("0x2038a35264815ce78bd57787de119dda4f57e216") + + +class TestDetectProxyUpgrade(unittest.TestCase): + def test_upgrade_to(self) -> None: + data = encode_call("upgradeTo(address)", ["address"], [NEW_IMPL]) + result = detect_proxy_upgrade(data, PROXY_ADDR) + self.assertEqual(result, ProxyUpgrade(proxy_address=PROXY_ADDR, new_implementation=NEW_IMPL)) + + def test_upgrade_to_and_call(self) -> None: + data = encode_call("upgradeToAndCall(address,bytes)", ["address", "bytes"], [NEW_IMPL, b""]) + result = detect_proxy_upgrade(data, PROXY_ADDR) + assert result is not None + self.assertEqual(result.new_implementation, NEW_IMPL) + self.assertEqual(result.proxy_address, PROXY_ADDR) + + def test_proxy_admin_upgrade_and_call(self) -> None: + # ProxyAdmin pattern: proxy is arg 0, new impl is arg 1 + data = encode_call( + "upgradeAndCall(address,address,bytes)", + ["address", "address", "bytes"], + [PROXY_ADDR, NEW_IMPL, b""], + ) + # Target is the ProxyAdmin itself; proxy address comes from calldata + admin = _cs("0xecda55c32966b00592ed3922e386063e1bc752c2") + result = detect_proxy_upgrade(data, admin) + assert result is not None + self.assertEqual(result.proxy_address, PROXY_ADDR) + self.assertEqual(result.new_implementation, NEW_IMPL) + + def test_non_upgrade_returns_none(self) -> None: + data = encode_call("transfer(address,uint256)", ["address", "uint256"], [NEW_IMPL, 1]) + self.assertIsNone(detect_proxy_upgrade(data, PROXY_ADDR)) + + def test_empty_calldata(self) -> None: + self.assertIsNone(detect_proxy_upgrade("0x", PROXY_ADDR)) + self.assertIsNone(detect_proxy_upgrade("", PROXY_ADDR)) + + def test_missing_target_for_direct_upgrade(self) -> None: + # When upgrade is called on the proxy itself, target is needed + data = encode_call("upgradeTo(address)", ["address"], [NEW_IMPL]) + self.assertIsNone(detect_proxy_upgrade(data, "")) + + def test_works_offline_for_all_proxy_selectors(self) -> None: + """Regression: detect_proxy_upgrade must not depend on the Sourcify 4byte + lookup for proxy upgrade selectors — those are in KNOWN_SELECTORS so the + decode resolves locally even when the network is unreachable.""" + from unittest.mock import patch + + cases = [ + ( + "upgradeTo(address)", + ["address"], + [NEW_IMPL], + PROXY_ADDR, + ), + ( + "upgradeToAndCall(address,bytes)", + ["address", "bytes"], + [NEW_IMPL, b""], + PROXY_ADDR, + ), + ( + "upgradeAndCall(address,address,bytes)", + ["address", "address", "bytes"], + [PROXY_ADDR, NEW_IMPL, b""], + _cs("0xecda55c32966b00592ed3922e386063e1bc752c2"), + ), + ] + # Patch the 4byte lookup so any call to it would raise — proving we + # never hit the network. + with patch("utils.calldata.decoder.fetch_json") as mock_fetch: + mock_fetch.side_effect = AssertionError("4byte fetch must not be called for known proxy selectors") + for sig, types, vals, tx_target in cases: + with self.subTest(sig=sig): + data = encode_call(sig, types, vals) + result = detect_proxy_upgrade(data, tx_target) + self.assertIsNotNone(result, f"detection failed offline for {sig}") + assert result is not None + self.assertEqual(result.new_implementation, NEW_IMPL) + + +if __name__ == "__main__": + unittest.main() diff --git a/timelock/timelock_alerts.py b/timelock/timelock_alerts.py index ff22632c..4c47c23e 100644 --- a/timelock/timelock_alerts.py +++ b/timelock/timelock_alerts.py @@ -228,9 +228,10 @@ def _build_call_info(event: dict, explorer: str | None, show_index: bool, chain_ # Proxy upgrade detection: show diff link between old and new implementation if len(data_hex) >= 10: - new_impl = detect_proxy_upgrade(data_hex) - if new_impl and chain_id: - old_impl = get_current_implementation(target, chain_id) + upgrade = detect_proxy_upgrade(data_hex, target) + if upgrade and chain_id: + old_impl = get_current_implementation(upgrade.proxy_address, chain_id) + new_impl = upgrade.new_implementation if old_impl: lines.append(f"🔄 Upgrade: `{old_impl}` → `{new_impl}`") diff_url = build_diff_url(old_impl, new_impl, chain_id) diff --git a/utils/calldata/known_selectors.py b/utils/calldata/known_selectors.py index a43d425f..9bdab06e 100644 --- a/utils/calldata/known_selectors.py +++ b/utils/calldata/known_selectors.py @@ -24,6 +24,7 @@ # Proxy / upgrades "0x3659cfe6": "upgradeTo(address)", "0x4f1ef286": "upgradeToAndCall(address,bytes)", + "0x9623609d": "upgradeAndCall(address,address,bytes)", # OpenZeppelin ProxyAdmin # Ownable "0xf2fde38b": "transferOwnership(address)", # Pausable diff --git a/utils/impl_diff.py b/utils/impl_diff.py new file mode 100644 index 00000000..25186ea8 --- /dev/null +++ b/utils/impl_diff.py @@ -0,0 +1,417 @@ +"""Compare two proxy implementations' verified source to surface upgrade diffs. + +When a governance tx upgrades a proxy, the LLM normally sees just the new impl +address and a diff URL it can't follow. This module fetches both impls' source, +extracts the structural surface (function signatures + state variables in +declaration order), and produces a textual diff focused on: + +- Functions added / removed / changed signature +- Storage layout safety (append-only is safe; reorderings or removals are not) + +Skipped in v1: +- Function body changes (would either explode the prompt or require a body hash + signal that's hard to interpret). +- Inherited storage from base contracts (extractor sees the flat source bundle + as fetched from Etherscan, which usually contains inherited contracts, but we + don't follow inheritance ourselves). +- EIP-7201 namespaced storage layouts (flagged and layout check skipped). +""" + +import re +from dataclasses import dataclass +from typing import Iterable + +from utils.logging import get_logger +from utils.source_context import _fetch_source + +logger = get_logger("utils.impl_diff") + +# function () +# Args don't typically contain nested parens in Solidity, so `[^)]*` works. +_FUNCTION_DEF_RE = re.compile( + r"\bfunction\s+(\w+)\s*\(([^)]*)\)([^{;]*)(?:\{|;)", + re.MULTILINE, +) + +# State variable declaration: [visibility] [modifiers] [= value]; +# Visibility is OPTIONAL — Solidity defaults state vars to internal, so plain +# `uint256 cap;` is a valid storage declaration. To avoid matching function-local +# declarations like `uint256 x = 1;`, the caller filters matches by brace depth +# (only depth==1, inside a contract body but outside any function). +_STATE_VAR_RE = re.compile( + r"((?:mapping\s*\([^)]+(?:\([^)]*\)[^)]*)*\))|(?:[A-Za-z_]\w*(?:\[[^\]]*\])?))" # type + r"((?:\s+(?:public|private|internal|external|immutable|constant|override(?:\s*\([^)]*\))?|virtual))*)" # modifiers + r"\s+" + r"([A-Za-z_]\w*)" # name + r"\s*(?:=|;)", +) + +# Solidity keywords that look like types but introduce non-state-var declarations +# at depth 1 (function/struct/enum/etc. headers). Skip these as "types". +_NON_TYPE_KEYWORDS = frozenset( + { + "function", + "modifier", + "constructor", + "receive", + "fallback", + "struct", + "enum", + "event", + "error", + "using", + "contract", + "library", + "interface", + "abstract", + "pragma", + "import", + "type", # `type Foo is uint256;` user-defined value types — not a storage slot + "return", + } +) + +_VISIBILITIES = frozenset({"public", "private", "internal", "external"}) +_FUNCTION_KEYWORDS_TO_SKIP = frozenset( + {"if", "for", "while", "modifier", "function", "constructor", "receive", "fallback"} +) + + +@dataclass(frozen=True) +class FunctionSig: + name: str + args: str # raw arg list e.g. "address _a, uint256 _b" + visibility: str # "external" / "public" / "internal" / "private" or "" + modifiers: str # remaining tokens after visibility (view, payable, onlyOwner, etc.) + + +@dataclass(frozen=True) +class StateVarDecl: + name: str + type_str: str # canonical-ish type, e.g. "uint256", "mapping(address => uint256)" + visibility: str # "public" / etc. + immutable: bool # True if `immutable` or `constant` (NOT a storage slot) + + +@dataclass(frozen=True) +class ImplDiff: + old_addr: str + new_addr: str + old_name: str + new_name: str + added_functions: list[FunctionSig] + removed_functions: list[FunctionSig] + changed_functions: list[tuple[FunctionSig, FunctionSig]] + added_state_vars: list[StateVarDecl] # net additions at the end (append-only) + removed_state_vars: list[StateVarDecl] + layout_changes: list[str] # human-readable list of incompatible changes + storage_layout_safe: bool + namespaced_storage: bool # if true, layout check was skipped + + +def _normalize_args(args: str) -> str: + """Strip param names, collapse whitespace — so `(uint256 a)` and `(uint256 b)` match.""" + parts: list[str] = [] + for raw in args.split(","): + raw = raw.strip() + if not raw: + continue + tokens = raw.split() + # First token is the type. Subsequent: data location keyword(s) + param name. + type_str = tokens[0] + # Skip location markers if present + idx = 1 + while idx < len(tokens) and tokens[idx] in {"memory", "calldata", "storage"}: + idx += 1 + # remaining is param name (may be absent in interface declarations) + parts.append(type_str) + return ",".join(parts) + + +def _extract_function_sigs(source: str) -> list[FunctionSig]: + """Find every `function () ` definition in source order.""" + sigs: list[FunctionSig] = [] + for m in _FUNCTION_DEF_RE.finditer(source): + name = m.group(1) + if name in _FUNCTION_KEYWORDS_TO_SKIP: + continue + args = _normalize_args(m.group(2)) + mods = m.group(3) or "" + tokens = mods.split() + visibility = "" + other: list[str] = [] + for t in tokens: + t_clean = t.rstrip("(") + if t_clean in _VISIBILITIES and not visibility: + visibility = t_clean + else: + other.append(t) + sigs.append( + FunctionSig( + name=name, + args=args, + visibility=visibility, + modifiers=" ".join(other).strip(), + ) + ) + return sigs + + +def _strip_solidity_noise(source: str) -> str: + """Remove comments and string literals so brace counting / regex can't trip on them. + + Replaces stripped content with same-length spaces to preserve byte offsets, + which keeps the brace-depth array indexable against the original source. + """ + out = list(source) + # Block comments /* ... */ + for m in re.finditer(r"/\*[\s\S]*?\*/", source): + for i in range(m.start(), m.end()): + out[i] = " " if source[i] != "\n" else "\n" + # Line comments //... + for m in re.finditer(r"//[^\n]*", "".join(out)): + for i in range(m.start(), m.end()): + out[i] = " " + # String literals "..." and '...' + for pat in (r'"(?:[^"\\\n]|\\.)*"', r"'(?:[^'\\\n]|\\.)*'"): + for m in re.finditer(pat, "".join(out)): + for i in range(m.start(), m.end()): + out[i] = " " + return "".join(out) + + +def _brace_depths(cleaned: str) -> list[int]: + """Return a per-character array of brace nesting depth (post-character). + + Depth at index i is the brace depth *after* processing cleaned[i]. So a + state var declaration matched at start position p has its lexical depth + equal to depths[p - 1] (or 0 if p == 0). + """ + depths = [0] * len(cleaned) + depth = 0 + for i, c in enumerate(cleaned): + if c == "{": + depth += 1 + elif c == "}": + depth -= 1 + depths[i] = depth + return depths + + +def _extract_state_vars(source: str) -> list[StateVarDecl]: + """Find every state-var declaration in source order. + + Captures default-internal vars (no visibility modifier) as well as explicit + ones. Uses brace-depth tracking to exclude function-local declarations. + """ + cleaned = _strip_solidity_noise(source) + depths = _brace_depths(cleaned) + + vars_out: list[StateVarDecl] = [] + seen: set[tuple[str, str]] = set() + for m in _STATE_VAR_RE.finditer(cleaned): + # Determine the brace depth at the START of the match. State vars live + # at depth == 1 (inside a contract/library/interface body, outside any + # function/modifier/constructor body). + start = m.start() + depth_before = depths[start - 1] if start > 0 else 0 + if depth_before != 1: + continue + + type_str = " ".join(m.group(1).split()) + modifier_block = m.group(2) or "" + name = m.group(3) + + # Reject false positives where the regex matched a non-state-var keyword + # as the "type" (e.g., `event Foo(...)` or `function bar(...)` if it + # somehow slipped through). Most are blocked by the `(=|;)` terminator, + # but `using X for Y;` and similar edge cases get filtered here. + if type_str.split()[0] in _NON_TYPE_KEYWORDS: + continue + + # Visibility token if present in the modifier block + visibility = "" + for tok in modifier_block.split(): + if tok in _VISIBILITIES: + visibility = tok + break + + immutable = "immutable" in modifier_block or "constant" in modifier_block + + key = (name, type_str) + if key in seen: + continue + seen.add(key) + + vars_out.append( + StateVarDecl( + name=name, + type_str=type_str, + visibility=visibility, + immutable=immutable, + ) + ) + return vars_out + + +def _is_namespaced_storage(source: str) -> bool: + """Heuristic: EIP-7201 contracts have a `_getXxxStorage()` returning a `storage $`.""" + return bool( + re.search( + r"function\s+_?[gG]et\w*Storage\b[^{]*\breturns\s*\([^)]*\bstorage\b[^)]*\$", + source, + ) + ) + + +def _fkey(f: FunctionSig) -> tuple[str, str]: + """Function identity for diffing: name + arg types (handles overloads).""" + return (f.name, f.args) + + +def _diff_functions( + old_fns: list[FunctionSig], new_fns: list[FunctionSig] +) -> tuple[list[FunctionSig], list[FunctionSig], list[tuple[FunctionSig, FunctionSig]]]: + by_old = {_fkey(f): f for f in old_fns} + by_new = {_fkey(f): f for f in new_fns} + + added = [new for k, new in by_new.items() if k not in by_old] + removed = [old for k, old in by_old.items() if k not in by_new] + changed: list[tuple[FunctionSig, FunctionSig]] = [] + for k in by_old.keys() & by_new.keys(): + old = by_old[k] + new = by_new[k] + if old.visibility != new.visibility or old.modifiers != new.modifiers: + changed.append((old, new)) + return added, removed, changed + + +def _storage_layout( + old_vars: list[StateVarDecl], new_vars: list[StateVarDecl] +) -> tuple[bool, list[str], list[StateVarDecl], list[StateVarDecl]]: + """Return (safe, layout_changes, net_added, net_removed). + + Safe ⇔ the new layout starts with the old layout in the same order (append-only). + Comparison key is (name, type) since rename or type change both shift bytecode-level layout. + """ + # Filter out immutable/constant — they don't occupy a storage slot + old_slots = [v for v in old_vars if not v.immutable] + new_slots = [v for v in new_vars if not v.immutable] + + changes: list[str] = [] + n_common = min(len(old_slots), len(new_slots)) + for i in range(n_common): + o, n = old_slots[i], new_slots[i] + if (o.name, o.type_str) != (n.name, n.type_str): + changes.append(f"slot {i}: {o.type_str} {o.name} → {n.type_str} {n.name}") + + if len(new_slots) < len(old_slots): + for i in range(n_common, len(old_slots)): + v = old_slots[i] + changes.append(f"slot {i}: removed {v.type_str} {v.name}") + + added_at_end = [v for v in new_slots[n_common:]] if len(new_slots) > len(old_slots) else [] + removed_off_end = [v for v in old_slots[n_common:]] if len(old_slots) > len(new_slots) else [] + + safe = not changes + return safe, changes, added_at_end, removed_off_end + + +def diff_implementations(old_addr: str, new_addr: str, chain_id: int) -> ImplDiff | None: + """Fetch both verified impls and produce a structural diff. None on any failure.""" + old = _fetch_source(chain_id, old_addr) + new = _fetch_source(chain_id, new_addr) + if not old or not new: + return None + + old_name, old_src = old + new_name, new_src = new + + old_fns = _extract_function_sigs(old_src) + new_fns = _extract_function_sigs(new_src) + added_fns, removed_fns, changed_fns = _diff_functions(old_fns, new_fns) + + old_vars = _extract_state_vars(old_src) + new_vars = _extract_state_vars(new_src) + namespaced = _is_namespaced_storage(old_src) or _is_namespaced_storage(new_src) + + if namespaced: + layout_safe = True + layout_changes: list[str] = [] + added_at_end: list[StateVarDecl] = [] + removed_off_end: list[StateVarDecl] = [] + else: + layout_safe, layout_changes, added_at_end, removed_off_end = _storage_layout(old_vars, new_vars) + + return ImplDiff( + old_addr=old_addr, + new_addr=new_addr, + old_name=old_name, + new_name=new_name, + added_functions=added_fns, + removed_functions=removed_fns, + changed_functions=changed_fns, + added_state_vars=added_at_end, + removed_state_vars=removed_off_end, + layout_changes=layout_changes, + storage_layout_safe=layout_safe, + namespaced_storage=namespaced, + ) + + +def _fmt_function(f: FunctionSig) -> str: + parts = [f"{f.name}({f.args})"] + if f.visibility: + parts.append(f.visibility) + if f.modifiers: + parts.append(f.modifiers) + return " ".join(parts) + + +def _section(title: str, items: Iterable[str]) -> str: + items = list(items) + if not items: + return "" + return f"{title} ({len(items)}):\n" + "\n".join(f" {x}" for x in items) + + +def format_impl_diff(diff: ImplDiff) -> str: + """Render an ImplDiff into a prompt-ready text block.""" + lines: list[str] = [ + f"Old: {diff.old_addr}" + (f" ({diff.old_name})" if diff.old_name else ""), + f"New: {diff.new_addr}" + (f" ({diff.new_name})" if diff.new_name else ""), + ] + + func_blocks: list[str] = [] + if diff.added_functions: + func_blocks.append(_section("Functions added", (f"+ {_fmt_function(f)}" for f in diff.added_functions))) + if diff.removed_functions: + func_blocks.append(_section("Functions removed", (f"- {_fmt_function(f)}" for f in diff.removed_functions))) + if diff.changed_functions: + func_blocks.append( + _section( + "Functions with changed visibility/modifiers", + (f"~ {_fmt_function(o)} → {_fmt_function(n)}" for o, n in diff.changed_functions), + ) + ) + if func_blocks: + lines.append("") + lines.extend(func_blocks) + + if diff.namespaced_storage: + lines.append("") + lines.append("Storage layout: uses EIP-7201 namespaced storage; positional layout check skipped.") + elif not diff.storage_layout_safe: + lines.append("") + lines.append("⚠ Storage layout NOT upgrade-safe:") + lines.extend(f" {c}" for c in diff.layout_changes) + elif diff.added_state_vars: + lines.append("") + lines.append("Storage layout safe (append-only). New state vars at end:") + for v in diff.added_state_vars: + lines.append(f" + {v.type_str} {v.visibility} {v.name}") + else: + lines.append("") + lines.append("Storage layout: unchanged.") + + return "\n".join(lines) diff --git a/utils/llm/ai_explainer.py b/utils/llm/ai_explainer.py index 0ff09367..570708e0 100644 --- a/utils/llm/ai_explainer.py +++ b/utils/llm/ai_explainer.py @@ -8,6 +8,7 @@ from dataclasses import dataclass from utils.calldata.decoder import DecodedCall, decode_calldata +from utils.impl_diff import diff_implementations, format_impl_diff from utils.llm import get_llm_provider from utils.llm.base import LLMError from utils.logging import get_logger @@ -182,22 +183,30 @@ def _collect_source_contexts( def _get_proxy_upgrade_info(calldata: str, target: str, chain_id: int) -> str: - """Detect proxy upgrade and return context string for the LLM prompt.""" - new_impl = detect_proxy_upgrade(calldata) - if not new_impl: + """Detect proxy upgrade, fetch impl diff, and return context string for the prompt.""" + upgrade = detect_proxy_upgrade(calldata, target) + if not upgrade: return "" - old_impl = get_current_implementation(target, chain_id) - if old_impl: - info = ( - f"This is a PROXY UPGRADE on {target}.\nCurrent implementation: {old_impl}\nNew implementation: {new_impl}" - ) - diff_url = build_diff_url(old_impl, new_impl, chain_id) - if diff_url: - info += f"\nDiff: {diff_url}" - return info + proxy = upgrade.proxy_address + new_impl = upgrade.new_implementation + old_impl = get_current_implementation(proxy, chain_id) + if not old_impl: + return f"This is a PROXY UPGRADE on {proxy}.\nNew implementation: {new_impl}" + + info = f"This is a PROXY UPGRADE on {proxy}.\nCurrent implementation: {old_impl}\nNew implementation: {new_impl}" + diff_url = build_diff_url(old_impl, new_impl, chain_id) + if diff_url: + info += f"\nDiff: {diff_url}" + + try: + impl_diff = diff_implementations(old_impl, new_impl, chain_id) + if impl_diff: + info += "\n\n" + format_impl_diff(impl_diff) + except Exception as e: # noqa: BLE001 - best-effort enrichment + logger.info("Impl diff failed for %s → %s: %s", old_impl, new_impl, e) - return f"This is a PROXY UPGRADE on {target}.\nNew implementation: {new_impl}" + return info def _format_decoded_calls(calls: list[DecodedCall]) -> str: diff --git a/utils/proxy.py b/utils/proxy.py index 4bfd70d2..da250770 100644 --- a/utils/proxy.py +++ b/utils/proxy.py @@ -4,6 +4,8 @@ to compare old vs new implementation source code on Etherscan. """ +from dataclasses import dataclass + from eth_utils import to_checksum_address from utils.calldata.decoder import decode_calldata @@ -16,35 +18,65 @@ # bytes32(uint256(keccak256("eip1967.proxy.implementation")) - 1) EIP1967_IMPL_SLOT = 0x360894A13BA1A3210667C828492DB98DCA3E2076CC3735A920A3CA505D382BBC -# Selectors that indicate a proxy upgrade -_UPGRADE_SELECTORS = frozenset({"0x3659cfe6", "0x4f1ef286"}) +# Selectors that indicate a proxy upgrade. +# - upgradeTo(address) — called on the proxy itself +# - upgradeToAndCall(address,bytes) — called on the proxy itself +# - upgradeAndCall(address,address,bytes) — called on a ProxyAdmin (proxy = arg 0) +_PROXY_DIRECT_SELECTORS = frozenset({"0x3659cfe6", "0x4f1ef286"}) +_PROXY_ADMIN_SELECTOR = "0x9623609d" + + +@dataclass(frozen=True) +class ProxyUpgrade: + """Result of detecting a proxy upgrade in calldata.""" + proxy_address: str # the proxy whose impl is being changed (may differ from tx target) + new_implementation: str -def detect_proxy_upgrade(data_hex: str) -> str | None: - """Check if calldata is a proxy upgrade and return the new implementation address. + +def detect_proxy_upgrade(data_hex: str, target: str = "") -> ProxyUpgrade | None: + """Check if calldata is a proxy upgrade and return proxy + new impl. Supports: - - upgradeTo(address) selector 0x3659cfe6 - - upgradeToAndCall(address,bytes) selector 0x4f1ef286 + - upgradeTo(address) (called on the proxy itself) + - upgradeToAndCall(address,bytes) (called on the proxy itself) + - upgradeAndCall(address,address,bytes) (called on ProxyAdmin; proxy is arg 0) + + Args: + data_hex: calldata hex with 0x prefix + target: the tx's target address — used as the proxy address for the + "called on the proxy itself" variants. For the ProxyAdmin variant + the proxy address comes from the calldata. Returns: - New implementation address (checksummed), or None if not a proxy upgrade. + ProxyUpgrade(proxy_address, new_implementation) or None. """ if not data_hex or len(data_hex) < 10: return None selector = data_hex[:10].lower() - if selector not in _UPGRADE_SELECTORS: - return None - decoded = decode_calldata(data_hex) if not decoded or not decoded.params: return None - # First param is always the new implementation address - type_str, value = decoded.params[0] - if type_str == "address": - return to_checksum_address(value) + if selector in _PROXY_DIRECT_SELECTORS: + type_str, value = decoded.params[0] + if type_str != "address" or not target: + return None + return ProxyUpgrade( + proxy_address=to_checksum_address(target), + new_implementation=to_checksum_address(value), + ) + + if selector == _PROXY_ADMIN_SELECTOR and len(decoded.params) >= 2: + proxy_type, proxy_addr = decoded.params[0] + impl_type, impl_addr = decoded.params[1] + if proxy_type != "address" or impl_type != "address": + return None + return ProxyUpgrade( + proxy_address=to_checksum_address(proxy_addr), + new_implementation=to_checksum_address(impl_addr), + ) return None