Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions tests/test_impl_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,52 @@ def test_reorder_is_unsafe(self) -> None:
self.assertFalse(safe)
self.assertTrue(any("slot 0" in c for c in changes))

def test_consuming_one_gap_slot_is_safe(self) -> None:
"""Canonical OZ pattern: append a new state var by shrinking the trailing gap."""
old = _extract_state_vars("contract C { uint256 a; uint256[50] __gap; }")
new = _extract_state_vars("contract C { uint256 a; uint256 b; uint256[49] __gap; }")
safe, changes, added, _ = _storage_layout(old, new)
self.assertTrue(safe, f"expected safe gap consumption, got changes={changes}")
self.assertEqual([v.name for v in added], ["b"])

def test_consuming_multiple_gap_slots_is_safe(self) -> None:
old = _extract_state_vars("contract C { uint256 a; uint256[50] __gap; }")
new = _extract_state_vars("contract C { uint256 a; uint256 b; address c; uint48 d; uint256[47] __gap; }")
safe, changes, _, _ = _storage_layout(old, new)
self.assertTrue(safe, f"expected safe multi-slot consumption, got changes={changes}")

def test_gap_size_underflow_is_unsafe(self) -> None:
"""If the new contract consumes MORE slots than the old gap reserved."""
old = _extract_state_vars("contract C { uint256 a; uint256[2] __gap; }")
new = _extract_state_vars("contract C { uint256 a; uint256 b; uint256 c; uint256 d; }")
safe, changes, _, _ = _storage_layout(old, new)
self.assertFalse(safe)
self.assertTrue(any("overflow" in c for c in changes))

def test_gap_not_shrunk_correctly_is_unsafe(self) -> None:
"""Consumed 1 slot but gap kept its original size — slots[2..] now shifted."""
old = _extract_state_vars("contract C { uint256 a; uint256[50] __gap; }")
new = _extract_state_vars("contract C { uint256 a; uint256 b; uint256[50] __gap; }")
safe, changes, _, _ = _storage_layout(old, new)
self.assertFalse(safe)
self.assertTrue(any("gap mismatch" in c for c in changes))

def test_fully_consumed_gap_removed_is_safe(self) -> None:
"""Old had a 1-slot gap; new fills it and removes the gap entirely."""
old = _extract_state_vars("contract C { uint256 a; uint256[1] __gap; }")
new = _extract_state_vars("contract C { uint256 a; uint256 b; }")
safe, changes, _, _ = _storage_layout(old, new)
self.assertTrue(safe, f"expected safe full consumption, got changes={changes}")

def test_gap_removed_without_consumption_is_unsafe(self) -> None:
"""Removing a gap without filling it changes the slot count and is unsafe
if any inheriting contract assumed it would still be there."""
old = _extract_state_vars("contract C { uint256 a; uint256[5] __gap; }")
new = _extract_state_vars("contract C { uint256 a; }")
safe, changes, _, _ = _storage_layout(old, new)
self.assertFalse(safe)
self.assertTrue(any("gap" in c for c in changes))


class TestNamespacedStorage(unittest.TestCase):
def test_detected(self) -> None:
Expand Down
14 changes: 14 additions & 0 deletions tests/test_proxy_upgrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,20 @@ def test_works_offline_for_all_proxy_selectors(self) -> None:
assert result is not None
self.assertEqual(result.new_implementation, NEW_IMPL)

def test_non_upgrade_short_circuits_before_decode(self) -> None:
"""Perf regression guard: a non-upgrade selector must NOT trigger a
Sourcify lookup. Without the early-return guard, every alert call
could wait on a 30s timeout for unknown selectors."""
from unittest.mock import patch

# Random non-upgrade selector + arbitrary bytes — looks like unknown data
data = "0xdeadbeef" + "00" * 32
with patch("utils.calldata.decoder.fetch_json") as mock_fetch:
mock_fetch.side_effect = AssertionError("Sourcify lookup triggered on non-upgrade selector")
result = detect_proxy_upgrade(data, PROXY_ADDR)
self.assertIsNone(result)
mock_fetch.assert_not_called()


if __name__ == "__main__":
unittest.main()
5 changes: 5 additions & 0 deletions timelock/timelock_alerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,11 @@ def _build_call_info(event: dict, explorer: str | None, show_index: bool, chain_
if len(data_hex) >= 10:
upgrade = detect_proxy_upgrade(data_hex, target)
if upgrade and chain_id:
# For ProxyAdmin-routed upgrades, `target` is the ProxyAdmin contract;
# the proxy being upgraded is inside the calldata. Surface it explicitly
# so recipients know which contract is changing.
if upgrade.proxy_address.lower() != target.lower():
lines.append(f"🅿️ Proxy: `{upgrade.proxy_address}`")
old_impl = get_current_implementation(upgrade.proxy_address, chain_id)
new_impl = upgrade.new_implementation
if old_impl:
Expand Down
91 changes: 81 additions & 10 deletions utils/impl_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,32 +286,103 @@ def _diff_functions(
return added, removed, changed


# An OZ-style trailing storage gap: `uintN[K] __gap;` (or `_gap`, `gap`).
# Reserved for future upgrades; consuming part of it is the canonical safe
# pattern, so we detach the trailing gap before comparing layouts.
_STORAGE_GAP_TYPE_RE = re.compile(r"^u?int\d*\s*\[\s*(\d+)\s*\]$")


def _gap_size(v: StateVarDecl) -> int | None:
"""If `v` looks like an OZ trailing storage gap, return its size. Else None."""
if not v.name.lower().endswith("gap"):
return None
m = _STORAGE_GAP_TYPE_RE.match(v.type_str.replace(" ", ""))
return int(m.group(1)) if m else None


def _detach_trailing_gap(slots: list[StateVarDecl]) -> tuple[list[StateVarDecl], int | None]:
"""Strip the trailing storage gap (if any) and return (slots_before_gap, gap_size)."""
if not slots:
return slots, None
size = _gap_size(slots[-1])
if size is None:
return slots, None
return slots[:-1], size


def _storage_layout(
old_vars: list[StateVarDecl], new_vars: list[StateVarDecl]
) -> tuple[bool, list[str], list[StateVarDecl], list[StateVarDecl]]:
"""Return (safe, layout_changes, net_added, net_removed).

Safe ⇔ the new layout starts with the old layout in the same order (append-only).
Comparison key is (name, type) since rename or type change both shift bytecode-level layout.
Safe upgrade patterns:
1. Append-only: new layout begins with the old layout (in the same order).
2. OZ storage-gap consumption: trailing `uintN[K] __gap` shrinks by exactly
the number of new vars inserted before it. Old contracts often reserve
a gap so future upgrades can claim slots without shifting parent
storage. Mis-handling this would produce false "unsafe" warnings for
most real OpenZeppelin upgradeable contracts.

Comparison key is (name, type) since rename or type change both shift the
bytecode-level storage layout.
"""
# Filter out immutable/constant — they don't occupy a storage slot
old_slots = [v for v in old_vars if not v.immutable]
new_slots = [v for v in new_vars if not v.immutable]

# Detach trailing gaps so we can analyze gap consumption separately.
old_core, old_gap = _detach_trailing_gap(old_slots)
new_core, new_gap = _detach_trailing_gap(new_slots)

changes: list[str] = []
n_common = min(len(old_slots), len(new_slots))

# Compare positions both contracts share. Any mismatch here is unsafe.
n_common = min(len(old_core), len(new_core))
for i in range(n_common):
o, n = old_slots[i], new_slots[i]
o, n = old_core[i], new_core[i]
if (o.name, o.type_str) != (n.name, n.type_str):
changes.append(f"slot {i}: {o.type_str} {o.name} → {n.type_str} {n.name}")

if len(new_slots) < len(old_slots):
for i in range(n_common, len(old_slots)):
v = old_slots[i]
consumed = len(new_core) - len(old_core)
added_at_end: list[StateVarDecl] = []
removed_off_end: list[StateVarDecl] = []

if consumed > 0:
added_at_end = list(new_core[len(old_core) :])
# New vars appended after the old core.
if old_gap is not None:
# Old contract reserved a gap; the new vars must come out of it.
expected_new_gap = old_gap - consumed
if expected_new_gap < 0:
changes.append(
f"consumed {consumed} new slot(s) but old gap was only {old_gap}; layout overflows reserved space"
)
elif expected_new_gap == 0:
if new_gap is not None:
changes.append(f"old gap of {old_gap} fully consumed but new contract still has gap of {new_gap}")
else: # expected_new_gap > 0
if new_gap is None:
changes.append(
f"old gap of {old_gap} not preserved (expected new gap of {expected_new_gap}, got none)"
)
elif new_gap != expected_new_gap:
changes.append(
f"gap mismatch: consumed {consumed} slot(s); expected new gap of {expected_new_gap}, got {new_gap}"
)
# If old had no gap, appending at the end is still safe (no shift).
elif consumed < 0:
# New contract is shorter than old: old vars were removed off the end.
removed_off_end = list(old_core[len(new_core) :])
for i, v in enumerate(removed_off_end, start=len(new_core)):
changes.append(f"slot {i}: removed {v.type_str} {v.name}")

added_at_end = [v for v in new_slots[n_common:]] if len(new_slots) > len(old_slots) else []
removed_off_end = [v for v in old_slots[n_common:]] if len(old_slots) > len(new_slots) else []
else:
# Same length cores. If gaps don't agree, flag.
if old_gap is not None and new_gap is None:
changes.append(f"old gap of size {old_gap} removed in new layout")
elif old_gap is None and new_gap is not None:
changes.append(f"new layout introduces gap of size {new_gap} not present in old")
elif old_gap != new_gap:
changes.append(f"gap size changed from {old_gap} to {new_gap} without slot consumption")

safe = not changes
return safe, changes, added_at_end, removed_off_end
Expand Down
5 changes: 5 additions & 0 deletions utils/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ def detect_proxy_upgrade(data_hex: str, target: str = "") -> ProxyUpgrade | None
return None

selector = data_hex[:10].lower()
# Short-circuit before decoding: the vast majority of calls are not proxy
# upgrades, and a `decode_calldata` miss can wait on a 30s Sourcify lookup.
if selector not in _PROXY_DIRECT_SELECTORS and selector != _PROXY_ADMIN_SELECTOR:
return None

decoded = decode_calldata(data_hex)
if not decoded or not decoded.params:
return None
Expand Down