diff --git a/tests/test_ai_explainer.py b/tests/test_ai_explainer.py index 226cb60..6681190 100644 --- a/tests/test_ai_explainer.py +++ b/tests/test_ai_explainer.py @@ -473,5 +473,228 @@ def test_refine_on_makes_two_calls( self.assertIn("Your Previous Draft", second_call_prompt) +class TestFailedSimulationDropped(unittest.TestCase): + """Failed Tenderly simulations must not leak into the LLM prompt.""" + + @patch("utils.llm.ai_explainer.get_source_context", return_value=None) + @patch("utils.llm.ai_explainer.get_contract_label", return_value="") + @patch("utils.llm.ai_explainer.get_llm_provider") + @patch("utils.llm.ai_explainer.simulate_transaction") + @patch("utils.llm.ai_explainer.decode_calldata") + def test_failed_sim_omitted_from_single_prompt( + self, + mock_decode: MagicMock, + mock_simulate: MagicMock, + mock_get_provider: MagicMock, + mock_label: MagicMock, + mock_source: MagicMock, + ) -> None: + mock_decode.return_value = DecodedCall(function_name="pause", signature="pause()") + mock_simulate.return_value = SimulationResult( + success=False, gas_used=0, error_message="execution reverted: not authorized" + ) + provider = MagicMock() + provider.complete.return_value = "TLDR: pauses. LOW." + provider.model_name = "test" + mock_get_provider.return_value = provider + + explain_transaction(target="0xT", calldata="0x8456cb59", chain_id=1) + prompt = provider.complete.call_args[0][0] + + self.assertNotIn("--- Simulation Results ---", prompt) + self.assertNotIn("FAILED", prompt) + self.assertNotIn("execution reverted", prompt) + + @patch("utils.llm.ai_explainer.get_source_context", return_value=None) + @patch("utils.llm.ai_explainer.get_contract_label", return_value="") + @patch("utils.llm.ai_explainer.get_llm_provider") + @patch("utils.llm.ai_explainer.simulate_transaction") + @patch("utils.llm.ai_explainer.decode_calldata") + def test_failed_sim_omitted_from_batch_prompt( + self, + mock_decode: MagicMock, + mock_simulate: MagicMock, + mock_get_provider: MagicMock, + mock_label: MagicMock, + mock_source: MagicMock, + ) -> None: + from utils.llm.ai_explainer import explain_batch_transaction + + mock_decode.return_value = DecodedCall(function_name="pause", signature="pause()") + mock_simulate.return_value = SimulationResult(success=False, gas_used=0, error_message="reverted") + provider = MagicMock() + provider.complete.return_value = "TLDR: pauses both. LOW." + provider.model_name = "test" + mock_get_provider.return_value = provider + + explain_batch_transaction( + calls=[ + {"target": "0xT1", "data": "0x8456cb59", "value": "0"}, + {"target": "0xT2", "data": "0x8456cb59", "value": "0"}, + ], + chain_id=1, + ) + prompt = provider.complete.call_args[0][0] + self.assertNotIn("--- Simulation Results ---", prompt) + self.assertNotIn("FAILED", prompt) + + +class TestAddressLabels(unittest.TestCase): + """Tests for address-argument annotation in the LLM prompt.""" + + REGISTRY = "0xF5f2718708F471e43968271956cC01Aaa8C46119" + FARM = "0xac21b22b5aeb11bc32de4ecf59e4538fca48b694" + FARM_CKS = "0xAc21B22B5aEb11bc32De4ecF59E4538fCa48b694" + + @patch("utils.llm.ai_explainer.get_source_context", return_value=None) + @patch("utils.llm.ai_explainer.get_llm_provider") + @patch("utils.llm.ai_explainer.simulate_transaction", return_value=None) + @patch("utils.llm.ai_explainer.decode_calldata") + @patch("utils.llm.ai_explainer.get_contract_label") + def test_address_array_arg_is_labeled( + self, + mock_label: MagicMock, + mock_decode: MagicMock, + mock_simulate: MagicMock, + mock_get_provider: MagicMock, + mock_source: MagicMock, + ) -> None: + mock_decode.return_value = DecodedCall( + function_name="addFarms", + signature="addFarms(uint256,address[])", + params=[("uint256", 1), ("address[]", (self.FARM,))], + ) + mock_label.return_value = "MorphoFarm" + provider = MagicMock() + provider.complete.return_value = "TLDR: adds farm. LOW." + provider.model_name = "test-model" + mock_get_provider.return_value = provider + + explain_transaction(target=self.REGISTRY, calldata="0xabcdef10" + "00" * 64, chain_id=1) + + prompt = provider.complete.call_args[0][0] + self.assertIn("MorphoFarm", prompt) + self.assertIn(self.FARM_CKS, prompt) + # Address goes on its own line, bulleted, under the type label. + self.assertIn("address[]:", prompt) + self.assertIn(f"- {self.FARM_CKS} (MorphoFarm)", prompt) + + @patch("utils.llm.ai_explainer.get_source_context", return_value=None) + @patch("utils.llm.ai_explainer.get_llm_provider") + @patch("utils.llm.ai_explainer.simulate_transaction", return_value=None) + @patch("utils.llm.ai_explainer.decode_calldata") + @patch("utils.llm.ai_explainer.get_contract_label") + def test_scalar_address_arg_is_labeled( + self, + mock_label: MagicMock, + mock_decode: MagicMock, + mock_simulate: MagicMock, + mock_get_provider: MagicMock, + mock_source: MagicMock, + ) -> None: + mock_decode.return_value = DecodedCall( + function_name="setOracle", + signature="setOracle(address)", + params=[("address", self.FARM)], + ) + mock_label.return_value = "ChainlinkOracle" + provider = MagicMock() + provider.complete.return_value = "TLDR: rewires oracle. MEDIUM." + provider.model_name = "test-model" + mock_get_provider.return_value = provider + + explain_transaction(target=self.REGISTRY, calldata="0x7adbf973" + "00" * 32, chain_id=1) + + prompt = provider.complete.call_args[0][0] + self.assertIn(f"address: {self.FARM_CKS} (ChainlinkOracle)", prompt) + + @patch("utils.llm.ai_explainer.get_source_context", return_value=None) + @patch("utils.llm.ai_explainer.get_llm_provider") + @patch("utils.llm.ai_explainer.simulate_transaction", return_value=None) + @patch("utils.llm.ai_explainer.decode_calldata") + @patch("utils.llm.ai_explainer.get_contract_label") + def test_target_address_is_not_relabeled( + self, + mock_label: MagicMock, + mock_decode: MagicMock, + mock_simulate: MagicMock, + mock_get_provider: MagicMock, + mock_source: MagicMock, + ) -> None: + # The target appears as its own argument — should be skipped so we don't + # double up with the Contract Source Context block. + mock_decode.return_value = DecodedCall( + function_name="selfWire", + signature="selfWire(address)", + params=[("address", self.REGISTRY.lower())], + ) + provider = MagicMock() + provider.complete.return_value = "TLDR: wires self. LOW." + provider.model_name = "test-model" + mock_get_provider.return_value = provider + + explain_transaction(target=self.REGISTRY, calldata="0xdeadbeef" + "00" * 32, chain_id=1) + + mock_label.assert_not_called() + + @patch("utils.llm.ai_explainer.get_source_context", return_value=None) + @patch("utils.llm.ai_explainer.get_llm_provider") + @patch("utils.llm.ai_explainer.simulate_transaction", return_value=None) + @patch("utils.llm.ai_explainer.decode_calldata") + @patch("utils.llm.ai_explainer.get_contract_label") + def test_unverified_address_left_unannotated( + self, + mock_label: MagicMock, + mock_decode: MagicMock, + mock_simulate: MagicMock, + mock_get_provider: MagicMock, + mock_source: MagicMock, + ) -> None: + mock_decode.return_value = DecodedCall( + function_name="setOracle", + signature="setOracle(address)", + params=[("address", self.FARM)], + ) + mock_label.return_value = "" # unverified / EOA / no API key + provider = MagicMock() + provider.complete.return_value = "TLDR: rewires. MEDIUM." + provider.model_name = "test-model" + mock_get_provider.return_value = provider + + explain_transaction(target=self.REGISTRY, calldata="0x7adbf973" + "00" * 32, chain_id=1) + + prompt = provider.complete.call_args[0][0] + # Address shows up, but with no `(Label)` suffix. + self.assertIn(self.FARM_CKS, prompt) + self.assertNotIn(f"{self.FARM_CKS} (", prompt) + + @patch("utils.llm.ai_explainer.get_source_context", return_value=None) + @patch("utils.llm.ai_explainer.get_llm_provider") + @patch("utils.llm.ai_explainer.simulate_transaction", return_value=None) + @patch("utils.llm.ai_explainer.decode_calldata") + @patch("utils.llm.ai_explainer.get_contract_label") + def test_zero_address_not_queried( + self, + mock_label: MagicMock, + mock_decode: MagicMock, + mock_simulate: MagicMock, + mock_get_provider: MagicMock, + mock_source: MagicMock, + ) -> None: + zero = "0x" + "00" * 20 + mock_decode.return_value = DecodedCall( + function_name="setOracle", + signature="setOracle(address)", + params=[("address", zero)], + ) + provider = MagicMock() + provider.complete.return_value = "TLDR: unsets oracle. LOW." + provider.model_name = "test-model" + mock_get_provider.return_value = provider + + explain_transaction(target=self.REGISTRY, calldata="0x7adbf973" + "00" * 32, chain_id=1) + mock_label.assert_not_called() + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_source_context.py b/tests/test_source_context.py index 763702f..af41cce 100644 --- a/tests/test_source_context.py +++ b/tests/test_source_context.py @@ -9,6 +9,7 @@ _extract_function_snippet, extract_state_var_snippet, find_state_var_writes, + get_contract_label, get_source_context, reset_cache, ) @@ -237,5 +238,64 @@ def test_proxy_follow_skipped_when_impl_equals_target(self, mock_fetch: object, self.assertEqual(mock_fetch.call_count, 1) # type: ignore[attr-defined] +class TestGetContractLabel(unittest.TestCase): + """Tests for get_contract_label().""" + + def setUp(self) -> None: + reset_cache() + + @patch.dict("os.environ", {"ETHERSCAN_TOKEN": "test-key"}) + @patch("utils.source_context.fetch_json") + def test_returns_verified_contract_name(self, mock_fetch: object) -> None: + mock_fetch.return_value = { # type: ignore[attr-defined] + "status": "1", + "result": [{"SourceCode": "contract Farm { }", "ContractName": "MorphoFarm"}], + } + label = get_contract_label(1, "0xac21b22b5aeb11bc32de4ecf59e4538fca48b694") + self.assertEqual(label, "MorphoFarm") + + @patch.dict("os.environ", {"ETHERSCAN_TOKEN": "test-key"}) + @patch("utils.source_context.fetch_json") + def test_unverified_returns_empty(self, mock_fetch: object) -> None: + mock_fetch.return_value = { # type: ignore[attr-defined] + "status": "1", + "result": [{"SourceCode": "", "ContractName": ""}], + } + label = get_contract_label(1, "0xac21b22b5aeb11bc32de4ecf59e4538fca48b694") + self.assertEqual(label, "") + + def test_safe_utility_shortcut(self) -> None: + # MultiSendCallOnly — no Etherscan call should be needed. + label = get_contract_label(1, "0x40A2aCCbd92BCA938b02010E17A5b8929b49130D") + self.assertEqual(label, "Safe MultiSendCallOnly") + + @patch.dict("os.environ", {"ETHERSCAN_TOKEN": "test-key"}) + @patch("utils.proxy.get_current_implementation") + @patch("utils.source_context.fetch_json") + def test_follows_proxy_when_name_is_generic(self, mock_fetch: object, mock_impl: object) -> None: + mock_fetch.side_effect = [ # type: ignore[attr-defined] + {"status": "1", "result": [{"SourceCode": "/* proxy */", "ContractName": "TransparentUpgradeableProxy"}]}, + {"status": "1", "result": [{"SourceCode": "/* impl */", "ContractName": "InfinifiBorrowingFarm"}]}, + ] + mock_impl.return_value = "0x000000000000000000000000000000000000beef" # type: ignore[attr-defined] + label = get_contract_label(1, "0xac21b22b5aeb11bc32de4ecf59e4538fca48b694") + self.assertEqual(label, "InfinifiBorrowingFarm") + + @patch.dict("os.environ", {"ETHERSCAN_TOKEN": "test-key"}) + @patch("utils.proxy.get_current_implementation", return_value=None) + @patch("utils.source_context.fetch_json") + def test_keeps_specific_name_without_proxy_follow(self, mock_fetch: object, mock_impl: object) -> None: + mock_fetch.return_value = { # type: ignore[attr-defined] + "status": "1", + "result": [{"SourceCode": "/* x */", "ContractName": "FarmRegistry"}], + } + label = get_contract_label(1, "0xac21b22b5aeb11bc32de4ecf59e4538fca48b694") + self.assertEqual(label, "FarmRegistry") + mock_impl.assert_not_called() # type: ignore[attr-defined] + + def test_empty_address_returns_empty(self) -> None: + self.assertEqual(get_contract_label(1, ""), "") + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_timelock_alerts.py b/tests/test_timelock_alerts.py index 80905dc..2882858 100644 --- a/tests/test_timelock_alerts.py +++ b/tests/test_timelock_alerts.py @@ -1,6 +1,7 @@ """Tests for timelock/timelock_alerts.py — build_alert_message truncation logic.""" import unittest +import unittest.mock from unittest.mock import patch from timelock.timelock_alerts import TimelockConfig, build_alert_message @@ -112,5 +113,86 @@ def test_message_under_limit_with_ai(self, mock_ai: object, mock_format: object) self.assertNotIn("...", msg) +class TestMapleProposalUnwrap(unittest.TestCase): + """Maple ProposalScheduled has no target/data; recover them from the source tx.""" + + @staticmethod + def _make_schedule_calldata(targets: list[str], datas: list[bytes]) -> str: + from eth_abi import encode + from eth_utils import function_signature_to_4byte_selector + + selector = function_signature_to_4byte_selector("scheduleProposals(address[],bytes[])") + body = encode(["address[]", "bytes[]"], [targets, datas]) + return "0x" + selector.hex() + body.hex() + + @staticmethod + def _wrap_in_safe(inner_hex: str, safe_target: str) -> str: + from eth_abi import encode + from eth_utils import function_signature_to_4byte_selector + + selector = function_signature_to_4byte_selector( + "execTransaction(address,uint256,bytes,uint8,uint256,uint256,uint256,address,address,bytes)" + ) + zero = "0x" + "00" * 20 + body = encode( + ["address", "uint256", "bytes", "uint8", "uint256", "uint256", "uint256", "address", "address", "bytes"], + [safe_target, 0, bytes.fromhex(inner_hex[2:]), 0, 0, 0, 0, zero, zero, b""], + ) + return "0x" + selector.hex() + body.hex() + + @patch("timelock.timelock_alerts.ChainManager") + def test_unwraps_safe_wrapped_schedule_proposals(self, mock_cm: object) -> None: + from timelock.timelock_alerts import _maple_proposal_calls + + targets = ["0x" + "aa" * 20, "0x" + "bb" * 20] + datas = [bytes.fromhex("8456cb59"), bytes.fromhex("3f4ba83a")] # pause(), unpause() + inner_hex = self._make_schedule_calldata(targets, datas) + outer = self._wrap_in_safe(inner_hex, "0x2efff88747eb5a3ff00d4d8d0f0800e306c0426b") + + mock_client = unittest.mock.MagicMock() + mock_client.eth.get_transaction.return_value = {"input": outer} + mock_cm.get_client.return_value = mock_client # type: ignore[attr-defined] + + event = _make_event(timelock_type="Maple", transactionHash="0x" + "ff" * 32) + calls = _maple_proposal_calls(event, chain_id=1) + + assert calls is not None + self.assertEqual(len(calls), 2) + self.assertEqual(calls[0]["target"], targets[0]) + self.assertEqual(calls[0]["data"], "0x8456cb59") + self.assertEqual(calls[1]["target"], targets[1]) + self.assertEqual(calls[1]["data"], "0x3f4ba83a") + + @patch("timelock.timelock_alerts.ChainManager") + def test_unwraps_direct_schedule_proposals(self, mock_cm: object) -> None: + from timelock.timelock_alerts import _maple_proposal_calls + + targets = ["0x" + "cc" * 20] + datas = [bytes.fromhex("8456cb59")] + inner_hex = self._make_schedule_calldata(targets, datas) + + mock_client = unittest.mock.MagicMock() + mock_client.eth.get_transaction.return_value = {"input": inner_hex} + mock_cm.get_client.return_value = mock_client # type: ignore[attr-defined] + + event = _make_event(timelock_type="Maple", transactionHash="0x" + "ff" * 32) + calls = _maple_proposal_calls(event, chain_id=1) + assert calls is not None + self.assertEqual(len(calls), 1) + self.assertEqual(calls[0]["data"], "0x8456cb59") + + @patch("timelock.timelock_alerts.ChainManager") + def test_returns_none_for_unknown_selector(self, mock_cm: object) -> None: + from timelock.timelock_alerts import _maple_proposal_calls + + # proposeRoleUpdates path — we can't synthesize (target, data) pairs from it. + mock_client = unittest.mock.MagicMock() + mock_client.eth.get_transaction.return_value = {"input": "0x2d6e853c" + "00" * 100} + mock_cm.get_client.return_value = mock_client # type: ignore[attr-defined] + + event = _make_event(timelock_type="Maple", transactionHash="0x" + "ff" * 32) + self.assertIsNone(_maple_proposal_calls(event, chain_id=1)) + + if __name__ == "__main__": unittest.main() diff --git a/timelock/timelock_alerts.py b/timelock/timelock_alerts.py index 42c7e83..a71ac42 100644 --- a/timelock/timelock_alerts.py +++ b/timelock/timelock_alerts.py @@ -14,12 +14,14 @@ from eth_utils import to_checksum_address from utils.cache import cache_filename, get_last_value_for_key_from_file, write_last_value_to_file -from utils.calldata.decoder import format_call_lines +from utils.calldata.decoder import decode_calldata, format_call_lines from utils.chains import EXPLORER_URLS, Chain from utils.llm.ai_explainer import explain_batch_transaction, explain_transaction, format_explanation_line from utils.logging import get_logger from utils.proxy import build_diff_url, detect_proxy_upgrade, get_current_implementation +from utils.safe_tx import unwrap_safe_exec_transaction from utils.telegram import MAX_MESSAGE_LENGTH, send_telegram_message +from utils.web3_wrapper import ChainManager load_dotenv() @@ -253,9 +255,78 @@ def _build_call_info(event: dict, explorer: str | None, show_index: bool, chain_ return lines +def _maple_proposal_calls(event: dict, chain_id: int) -> list[dict[str, str]] | None: + """Recover the inner (target, data) pairs from a Maple ProposalScheduled event. + + The GovernorTimelock only stores a hash of the proposal calls on-chain, so the + ProposalScheduled event itself has no target/data. The actual payload lives in + the transaction that emitted the event — typically a Safe execTransaction wrapping + a scheduleProposals(address[], bytes[]) call into the GovernorTimelock. + + Returns None if the tx can't be fetched or doesn't match the expected shape. + """ + tx_hash = event.get("transactionHash") + if not tx_hash: + return None + + try: + chain = Chain.from_chain_id(chain_id) + client = ChainManager.get_client(chain) + tx = client.eth.get_transaction(tx_hash) + except Exception as e: # noqa: BLE001 + _logger.info("Failed to fetch Maple proposal tx %s: %s", tx_hash, e) + return None + + raw_input = tx.get("input") + input_hex = raw_input.hex() if isinstance(raw_input, bytes) else str(raw_input or "") + if input_hex and not input_hex.startswith("0x"): + input_hex = "0x" + input_hex + + # Unwrap one layer of Safe execTransaction if present; otherwise decode directly. + inner = unwrap_safe_exec_transaction(input_hex) + inner_data = inner.data if inner else input_hex + if not inner_data or len(inner_data) < 10: + return None + + decoded = decode_calldata(inner_data) + if not decoded or len(decoded.params) < 2: + return None + + # Expect scheduleProposals(address[] targets, bytes[] data). Some Maple proposal + # paths (proposeRoleUpdates etc.) don't carry concrete (target, data) tuples we + # can hand to the explainer — bail out cleanly for those. + targets_type, targets = decoded.params[0] + data_type, datas = decoded.params[1] + if targets_type != "address[]" or data_type != "bytes[]" or len(targets) != len(datas): + return None + + def _to_hex(d: object) -> str: + if isinstance(d, bytes): + return "0x" + d.hex() + s = str(d) + return s if s.startswith("0x") else "0x" + s + + return [{"target": str(t), "data": _to_hex(d), "value": "0"} for t, d in zip(targets, datas)] + + def _get_ai_explanation(events: list[dict], timelock_info: TimelockConfig, chain_id: int) -> str | None: """Generate AI explanation for timelock events. Returns None on any failure.""" try: + # Maple's ProposalScheduled event only carries an opaque proposalId — the + # targets/data must be recovered from the originating transaction. + if events and events[0].get("timelockType") == "Maple": + calls = _maple_proposal_calls(events[0], chain_id) + if not calls: + return None + return explain_batch_transaction( + calls=calls, + chain_id=chain_id, + protocol=timelock_info.protocol, + label=timelock_info.label, + from_address=timelock_info.address, + refine=True, + ) + calls_with_data = [e for e in events if e.get("target") and e.get("data") and len(e.get("data", "")) >= 10] if not calls_with_data: return None diff --git a/utils/calldata/known_selectors.py b/utils/calldata/known_selectors.py index 9bdab06..b5186e7 100644 --- a/utils/calldata/known_selectors.py +++ b/utils/calldata/known_selectors.py @@ -38,6 +38,11 @@ "0xddf0b009": "queue(uint256)", "0xfe0d94c1": "execute(uint256)", "0xbb913f41": "_setImplementation(address)", + # Maple GovernorTimelock + "0xd9ab9270": "scheduleProposals(address[],bytes[])", + "0x2d6e853c": "proposeRoleUpdates(bytes32[],address[],bool[])", + # Safe (Gnosis) + "0x6a761202": "execTransaction(address,uint256,bytes,uint8,uint256,uint256,uint256,address,address,bytes)", # Compound Comptroller "0xa76b3fda": "_supportMarket(address)", "0x55ee1fe1": "_setPriceOracle(address)", diff --git a/utils/llm/README.md b/utils/llm/README.md index 86148a5..ec4a57e 100644 --- a/utils/llm/README.md +++ b/utils/llm/README.md @@ -121,7 +121,7 @@ The structural diff is injected into the prompt under `--- Implementation Diff - The prompt is built from all available context. The system prompt enforces brevity: -- TLDR ≤25 words, starts with a verb, no "This transaction…" preamble +- Starts with a verb, no "This transaction…" preamble - Trailing risk tag in caps (LOW / MEDIUM / HIGH / CRITICAL) - Refuses to assume parameter units from function name alone - Trusts source-context natspec over prior assumptions @@ -182,7 +182,7 @@ The full prompt is logged at INFO level for debugging. ### 7. Optional Refine Pass -When `refine=True` is passed to `explain_transaction` / `explain_batch_transaction`, a second LLM call critiques the draft against a checklist (verb-leading TLDR, ≤25 words, supported units, risk-magnitude consistency) and revises only if it finds concrete issues. Hard rules forbid introducing new unit assumptions, removing hedges, escalating LOW out of caution, or style-only churn. Falls back to the draft on `PASS`, on any `LLMError`, or on an empty revision. +When `refine=True` is passed to `explain_transaction` / `explain_batch_transaction`, a second LLM call critiques the draft against a checklist (verb-leading TLDR, supported units, risk-magnitude consistency) and revises only if it finds concrete issues. Hard rules forbid introducing new unit assumptions, removing hedges, escalating LOW out of caution, or style-only churn. Falls back to the draft on `PASS`, on any `LLMError`, or on an empty revision. Cost: ~2× LLM calls per alert when enabled. Default is **off**. diff --git a/utils/llm/ai_explainer.py b/utils/llm/ai_explainer.py index 570708e..c6873a0 100644 --- a/utils/llm/ai_explainer.py +++ b/utils/llm/ai_explainer.py @@ -7,6 +7,8 @@ from dataclasses import dataclass +from eth_utils import to_checksum_address + from utils.calldata.decoder import DecodedCall, decode_calldata from utils.impl_diff import diff_implementations, format_impl_diff from utils.llm import get_llm_provider @@ -15,7 +17,7 @@ from utils.on_chain_state import StateRead, format_state_reads, read_before_state from utils.paste import upload_to_paste from utils.proxy import build_diff_url, detect_proxy_upgrade, get_current_implementation -from utils.source_context import SourceContext, format_source_context, get_source_context +from utils.source_context import SourceContext, format_source_context, get_contract_label, get_source_context from utils.telegram import escape_markdown from utils.tenderly.simulation import SimulationResult, simulate_transaction @@ -23,12 +25,14 @@ SYSTEM_PROMPT = """You are a DeFi risk analyst writing alerts for a monitoring team. Output two sections. -TLDR: ≤25 words. Start with a verb describing the effect. Do NOT open with -"This transaction", "The proposal", or similar — the reader already knows -what kind of tx this is. End with a risk tag in caps: LOW / MEDIUM / HIGH / CRITICAL. +TLDR: 2-4 short sentences. Cover [what changed] · [magnitude or impact] · [risk tag]. +Start with a verb describing the effect. Do NOT open with "This transaction", "The proposal", +or similar — the reader already knows what kind of tx this is. +End with a risk tag in caps: LOW / MEDIUM / HIGH / CRITICAL. Good example: "Lowers swap fee 30→25 bps on USDC/USDT pool. Marginal LP revenue cut. LOW." -Bad example: "This governance transaction adjusts the swap fee parameter on the USDC/USDT pool from 30 basis points to 25 basis points, which slightly reduces revenue for liquidity providers. Risk is LOW." +Bad (too terse, drops impact): "Adds farm. LOW." +Bad (preamble + run-on): "This governance transaction adjusts the swap fee parameter on the USDC/USDT pool from 30 basis points to 25 basis points, which slightly reduces revenue for liquidity providers. Risk is LOW." DETAIL: thorough analysis covering: - What each call does and why @@ -61,7 +65,8 @@ 1. Does the TLDR start with a verb (NOT "This transaction" / "The proposal" / "The transaction" / "This governance")? -2. Is the TLDR ≤25 words? +2. Is the TLDR 2-4 short sentences? (Single-sentence TLDRs + that omit the impact/magnitude beat are too terse — flag for revision.) 3. Does the TLDR end with a risk tag in CAPS (LOW / MEDIUM / HIGH / CRITICAL)? 4. Are all numeric magnitudes/units in the draft supported by either the Contract Source Context section or the Current State section above? Or @@ -209,13 +214,98 @@ def _get_proxy_upgrade_info(calldata: str, target: str, chain_id: int) -> str: return info -def _format_decoded_calls(calls: list[DecodedCall]) -> str: - """Format decoded calls into a readable string for the LLM prompt.""" +def _checksum_or_none(addr: str) -> str | None: + """Return checksummed address or None if `addr` isn't a parseable hex address.""" + if not isinstance(addr, str) or not addr.startswith("0x"): + return None + try: + return to_checksum_address(addr) + except ValueError: + return None + + +def _annotate_address(addr: str, labels: dict[str, str]) -> str: + """Render a single address with an optional `(ContractName)` suffix.""" + checksum = _checksum_or_none(addr) + if checksum is None: + return str(addr) + label = labels.get(checksum) + return f"{checksum} ({label})" if label else checksum + + +def _extract_address_args(decoded: DecodedCall) -> list[str]: + """All address-typed argument values (scalars and arrays) for one decoded call.""" + out: list[str] = [] + for type_str, value in decoded.params: + if type_str == "address" and isinstance(value, str): + out.append(value) + elif type_str.startswith("address[") and isinstance(value, (list, tuple)): + out.extend(v for v in value if isinstance(v, str)) + return out + + +def _collect_address_labels( + targets_and_calls: list[tuple[str, DecodedCall]], + chain_id: int, +) -> dict[str, str]: + """Look up `{checksum_address: contract_name}` for every address-typed argument. + + Skips each call's own target — its name is already surfaced via Contract + Source Context. Best-effort: any lookup failure is silently dropped. + """ + target_lower = {tgt.lower() for tgt, _ in targets_and_calls if tgt} + seen: set[str] = set() + labels: dict[str, str] = {} + + for _, decoded in targets_and_calls: + for raw in _extract_address_args(decoded): + addr_lower = raw.lower() + if addr_lower in seen: + continue + seen.add(addr_lower) + if addr_lower in target_lower: + continue + # Skip malformed / zero addresses — no point asking Etherscan. + if len(addr_lower) != 42 or int(addr_lower, 16) == 0: + continue + checksum = _checksum_or_none(raw) + if checksum is None: + continue + try: + label = get_contract_label(chain_id, checksum) + except Exception as e: # noqa: BLE001 - best-effort enrichment + logger.info("Contract label fetch failed for %s: %s", checksum, e) + continue + if label: + labels[checksum] = label + return labels + + +def _format_decoded_calls( + calls: list[DecodedCall], + address_labels: dict[str, str] | None = None, +) -> str: + """Format decoded calls into a readable string for the LLM prompt. + + When ``address_labels`` is provided, address arguments (including elements + of ``address[]``) are annotated with their contract name so the LLM can + refer to "MorphoFarm" instead of "0xac21...". + """ + labels = address_labels or {} parts: list[str] = [] for i, call in enumerate(calls): lines = [f"Call {i + 1}: {call.signature}"] for type_str, value in call.params: - lines.append(f" {type_str}: {value}") + if type_str == "address": + lines.append(f" {type_str}: {_annotate_address(value, labels)}") + elif type_str.startswith("address[") and isinstance(value, (list, tuple)): + if not value: + lines.append(f" {type_str}: []") + else: + lines.append(f" {type_str}:") + lines.extend(f" - {_annotate_address(v, labels)}" for v in value) + else: + lines.append(f" {type_str}: {value}") parts.append("\n".join(lines)) return "\n\n".join(parts) @@ -266,6 +356,7 @@ def _build_prompt( source_contexts: list[SourceContext] | None = None, context_note: str = "", state_reads: list[tuple[str, list[StateRead]]] | None = None, + address_labels: dict[str, str] | None = None, ) -> str: """Build the full prompt for the LLM.""" parts: list[str] = [SYSTEM_PROMPT, ""] @@ -281,7 +372,7 @@ def _build_prompt( if context_note: parts.append(f"\n--- Execution Context ---\n{context_note}") - parts.append(f"\n--- Decoded Calldata ---\n{_format_decoded_calls(decoded_calls)}") + parts.append(f"\n--- Decoded Calldata ---\n{_format_decoded_calls(decoded_calls, address_labels)}") constants_note = _format_batch_param_constants(decoded_calls) if constants_note: @@ -439,6 +530,7 @@ def explain_transaction( proxy_upgrade_info = _get_proxy_upgrade_info(calldata, target, chain_id) source_contexts = _collect_source_contexts([(target, decoded)], chain_id) state_reads = _collect_state_reads([(target, decoded)], chain_id) + address_labels = _collect_address_labels([(target, decoded)], chain_id) simulation: SimulationResult | None = None if not skip_simulation: @@ -451,6 +543,13 @@ def explain_transaction( ) if simulation: logger.info("Simulation completed: success=%s gas=%s", simulation.success, simulation.gas_used) + if not simulation.success: + # Tenderly often misreports legitimate governance calls as reverting + # (wrong msg.sender, missing storage overrides). Including a failed + # sim in the prompt biases the LLM toward "this tx will revert" + # and inflates risk — drop it so the LLM works from calldata only. + logger.warning("Simulation reported failure (%s); omitting from prompt", simulation.error_message) + simulation = None else: logger.info("Simulation unavailable, proceeding with decoded calldata only") @@ -465,6 +564,7 @@ def explain_transaction( source_contexts=source_contexts, context_note=context_note, state_reads=state_reads, + address_labels=address_labels, ) logger.info("Full AI context for %s:\n%s", target, prompt) @@ -542,7 +642,12 @@ def explain_batch_transaction( if not decoded_calls: return None - simulation = next((s for s in simulations if s is not None), None) + # Prefer a successful sim; if every inner call failed in Tenderly, drop the + # sim section entirely rather than feeding the LLM "FAILED" + a misleading + # revert reason from a sim that probably just couldn't model the real call. + simulation = next((s for s in simulations if s is not None and s.success), None) + if simulation is None and any(s is not None and not s.success for s in simulations): + logger.warning("All batch simulations reported failure; omitting from prompt") upgrade_parts: list[str] = [] for call in calls: @@ -553,6 +658,7 @@ def explain_batch_transaction( source_contexts = _collect_source_contexts(decoded_with_target, chain_id) state_reads = _collect_state_reads(decoded_with_target, chain_id) + address_labels = _collect_address_labels(decoded_with_target, chain_id) targets = ", ".join(c.get("target", "?") for c in calls) total_value = sum(int(c.get("value", "0")) for c in calls) @@ -568,6 +674,7 @@ def explain_batch_transaction( source_contexts=source_contexts, context_note=context_note, state_reads=state_reads, + address_labels=address_labels, ) logger.info("Full AI context for batch (%s calls):\n%s", len(calls), prompt) diff --git a/utils/safe_tx.py b/utils/safe_tx.py new file mode 100644 index 0000000..937d403 --- /dev/null +++ b/utils/safe_tx.py @@ -0,0 +1,51 @@ +"""Helpers for inspecting Gnosis Safe transactions. + +The Safe contract wraps every batch via ``execTransaction(to, value, data, ...)``, +so when a Safe schedules a Maple timelock proposal (or any other on-chain +action) the actual call payload lives one layer deep, inside the ``data`` +parameter of that outer call. +""" + +from dataclasses import dataclass + +from utils.calldata.decoder import decode_calldata + + +@dataclass(frozen=True) +class InnerCall: + """The inner call wrapped by a Safe ``execTransaction``.""" + + target: str + value: int + data: str # hex with 0x prefix + operation: int # 0 = CALL, 1 = DELEGATECALL + + +_EXEC_TRANSACTION_SELECTOR = "0x6a761202" + + +def unwrap_safe_exec_transaction(input_hex: str) -> InnerCall | None: + """Decode a Safe ``execTransaction`` call into its inner (target, value, data, operation). + + Returns None if ``input_hex`` is not an ``execTransaction`` payload. + """ + if not input_hex or len(input_hex) < 10: + return None + if input_hex[:10].lower() != _EXEC_TRANSACTION_SELECTOR: + return None + + decoded = decode_calldata(input_hex) + if not decoded or len(decoded.params) < 4: + return None + + target = decoded.params[0][1] + value = int(decoded.params[1][1]) + data = decoded.params[2][1] + operation = int(decoded.params[3][1]) + + if isinstance(data, bytes): + data = "0x" + data.hex() + elif isinstance(data, str) and not data.startswith("0x"): + data = "0x" + data + + return InnerCall(target=target, value=value, data=data, operation=operation) diff --git a/utils/source_context.py b/utils/source_context.py index 9b08438..f36b64f 100644 --- a/utils/source_context.py +++ b/utils/source_context.py @@ -42,6 +42,21 @@ _CONTROL_KEYWORDS = frozenset({"if", "for", "while", "require", "revert", "return", "emit", "assembly", "unchecked"}) +# Proxy contract names that are not informative on their own — when the target is +# named one of these, prefer the implementation contract's name. +_GENERIC_PROXY_NAMES = frozenset( + { + "TransparentUpgradeableProxy", + "ERC1967Proxy", + "BeaconProxy", + "EIP173Proxy", + "Proxy", + "UpgradeableProxy", + "InitializableImmutableAdminUpgradeabilityProxy", + "InitializableAdminUpgradeabilityProxy", + } +) + @dataclass(frozen=True) class SourceContext: @@ -235,6 +250,48 @@ def get_source_context(chain_id: int, address: str, function_name: str) -> Sourc return _build_context(fetched_impl[0], fetched_impl[1], function_name) +def get_contract_label(chain_id: int, address: str) -> str: + """Best-effort human label for a contract address. + + Resolution order: + 1. Safe utility registry (no API call) — covers MultiSendCallOnly etc. + 2. Etherscan ContractName for the address. + 3. If that name is a generic proxy wrapper, follow EIP-1967 to the impl + and use the impl's contract name instead. + + Returns "" for EOAs, unverified contracts, missing API key, or any failure. + """ + if not address: + return "" + + # Lazy import to keep utils/source_context.py free of safe/ dependencies at + # module load time (and to avoid a cycle if safe ever imports from here). + from safe.multisend import safe_utility_label + + cheap = safe_utility_label(address) + if cheap: + return cheap + + fetched = fetch_source(chain_id, address) + if not fetched: + return "" + + name = fetched[0] + if name and name not in _GENERIC_PROXY_NAMES: + return name + + from utils.proxy import get_current_implementation + + impl = get_current_implementation(address, chain_id) + if not impl or impl.lower() == address.lower(): + return name + + impl_fetched = fetch_source(chain_id, impl) + if impl_fetched and impl_fetched[0]: + return impl_fetched[0] + return name + + def format_source_context(ctx: SourceContext) -> str: """Format a SourceContext into a prompt-ready string.""" lines: list[str] = []