Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 223 additions & 0 deletions tests/test_ai_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,5 +473,228 @@ def test_refine_on_makes_two_calls(
self.assertIn("Your Previous Draft", second_call_prompt)


class TestFailedSimulationDropped(unittest.TestCase):
"""Failed Tenderly simulations must not leak into the LLM prompt."""

@patch("utils.llm.ai_explainer.get_source_context", return_value=None)
@patch("utils.llm.ai_explainer.get_contract_label", return_value="")
@patch("utils.llm.ai_explainer.get_llm_provider")
@patch("utils.llm.ai_explainer.simulate_transaction")
@patch("utils.llm.ai_explainer.decode_calldata")
def test_failed_sim_omitted_from_single_prompt(
self,
mock_decode: MagicMock,
mock_simulate: MagicMock,
mock_get_provider: MagicMock,
mock_label: MagicMock,
mock_source: MagicMock,
) -> None:
mock_decode.return_value = DecodedCall(function_name="pause", signature="pause()")
mock_simulate.return_value = SimulationResult(
success=False, gas_used=0, error_message="execution reverted: not authorized"
)
provider = MagicMock()
provider.complete.return_value = "TLDR: pauses. LOW."
provider.model_name = "test"
mock_get_provider.return_value = provider

explain_transaction(target="0xT", calldata="0x8456cb59", chain_id=1)
prompt = provider.complete.call_args[0][0]

self.assertNotIn("--- Simulation Results ---", prompt)
self.assertNotIn("FAILED", prompt)
self.assertNotIn("execution reverted", prompt)

@patch("utils.llm.ai_explainer.get_source_context", return_value=None)
@patch("utils.llm.ai_explainer.get_contract_label", return_value="")
@patch("utils.llm.ai_explainer.get_llm_provider")
@patch("utils.llm.ai_explainer.simulate_transaction")
@patch("utils.llm.ai_explainer.decode_calldata")
def test_failed_sim_omitted_from_batch_prompt(
self,
mock_decode: MagicMock,
mock_simulate: MagicMock,
mock_get_provider: MagicMock,
mock_label: MagicMock,
mock_source: MagicMock,
) -> None:
from utils.llm.ai_explainer import explain_batch_transaction

mock_decode.return_value = DecodedCall(function_name="pause", signature="pause()")
mock_simulate.return_value = SimulationResult(success=False, gas_used=0, error_message="reverted")
provider = MagicMock()
provider.complete.return_value = "TLDR: pauses both. LOW."
provider.model_name = "test"
mock_get_provider.return_value = provider

explain_batch_transaction(
calls=[
{"target": "0xT1", "data": "0x8456cb59", "value": "0"},
{"target": "0xT2", "data": "0x8456cb59", "value": "0"},
],
chain_id=1,
)
prompt = provider.complete.call_args[0][0]
self.assertNotIn("--- Simulation Results ---", prompt)
self.assertNotIn("FAILED", prompt)


class TestAddressLabels(unittest.TestCase):
"""Tests for address-argument annotation in the LLM prompt."""

REGISTRY = "0xF5f2718708F471e43968271956cC01Aaa8C46119"
FARM = "0xac21b22b5aeb11bc32de4ecf59e4538fca48b694"
FARM_CKS = "0xAc21B22B5aEb11bc32De4ecF59E4538fCa48b694"

@patch("utils.llm.ai_explainer.get_source_context", return_value=None)
@patch("utils.llm.ai_explainer.get_llm_provider")
@patch("utils.llm.ai_explainer.simulate_transaction", return_value=None)
@patch("utils.llm.ai_explainer.decode_calldata")
@patch("utils.llm.ai_explainer.get_contract_label")
def test_address_array_arg_is_labeled(
self,
mock_label: MagicMock,
mock_decode: MagicMock,
mock_simulate: MagicMock,
mock_get_provider: MagicMock,
mock_source: MagicMock,
) -> None:
mock_decode.return_value = DecodedCall(
function_name="addFarms",
signature="addFarms(uint256,address[])",
params=[("uint256", 1), ("address[]", (self.FARM,))],
)
mock_label.return_value = "MorphoFarm"
provider = MagicMock()
provider.complete.return_value = "TLDR: adds farm. LOW."
provider.model_name = "test-model"
mock_get_provider.return_value = provider

explain_transaction(target=self.REGISTRY, calldata="0xabcdef10" + "00" * 64, chain_id=1)

prompt = provider.complete.call_args[0][0]
self.assertIn("MorphoFarm", prompt)
self.assertIn(self.FARM_CKS, prompt)
# Address goes on its own line, bulleted, under the type label.
self.assertIn("address[]:", prompt)
self.assertIn(f"- {self.FARM_CKS} (MorphoFarm)", prompt)

@patch("utils.llm.ai_explainer.get_source_context", return_value=None)
@patch("utils.llm.ai_explainer.get_llm_provider")
@patch("utils.llm.ai_explainer.simulate_transaction", return_value=None)
@patch("utils.llm.ai_explainer.decode_calldata")
@patch("utils.llm.ai_explainer.get_contract_label")
def test_scalar_address_arg_is_labeled(
self,
mock_label: MagicMock,
mock_decode: MagicMock,
mock_simulate: MagicMock,
mock_get_provider: MagicMock,
mock_source: MagicMock,
) -> None:
mock_decode.return_value = DecodedCall(
function_name="setOracle",
signature="setOracle(address)",
params=[("address", self.FARM)],
)
mock_label.return_value = "ChainlinkOracle"
provider = MagicMock()
provider.complete.return_value = "TLDR: rewires oracle. MEDIUM."
provider.model_name = "test-model"
mock_get_provider.return_value = provider

explain_transaction(target=self.REGISTRY, calldata="0x7adbf973" + "00" * 32, chain_id=1)

prompt = provider.complete.call_args[0][0]
self.assertIn(f"address: {self.FARM_CKS} (ChainlinkOracle)", prompt)

@patch("utils.llm.ai_explainer.get_source_context", return_value=None)
@patch("utils.llm.ai_explainer.get_llm_provider")
@patch("utils.llm.ai_explainer.simulate_transaction", return_value=None)
@patch("utils.llm.ai_explainer.decode_calldata")
@patch("utils.llm.ai_explainer.get_contract_label")
def test_target_address_is_not_relabeled(
self,
mock_label: MagicMock,
mock_decode: MagicMock,
mock_simulate: MagicMock,
mock_get_provider: MagicMock,
mock_source: MagicMock,
) -> None:
# The target appears as its own argument — should be skipped so we don't
# double up with the Contract Source Context block.
mock_decode.return_value = DecodedCall(
function_name="selfWire",
signature="selfWire(address)",
params=[("address", self.REGISTRY.lower())],
)
provider = MagicMock()
provider.complete.return_value = "TLDR: wires self. LOW."
provider.model_name = "test-model"
mock_get_provider.return_value = provider

explain_transaction(target=self.REGISTRY, calldata="0xdeadbeef" + "00" * 32, chain_id=1)

mock_label.assert_not_called()

@patch("utils.llm.ai_explainer.get_source_context", return_value=None)
@patch("utils.llm.ai_explainer.get_llm_provider")
@patch("utils.llm.ai_explainer.simulate_transaction", return_value=None)
@patch("utils.llm.ai_explainer.decode_calldata")
@patch("utils.llm.ai_explainer.get_contract_label")
def test_unverified_address_left_unannotated(
self,
mock_label: MagicMock,
mock_decode: MagicMock,
mock_simulate: MagicMock,
mock_get_provider: MagicMock,
mock_source: MagicMock,
) -> None:
mock_decode.return_value = DecodedCall(
function_name="setOracle",
signature="setOracle(address)",
params=[("address", self.FARM)],
)
mock_label.return_value = "" # unverified / EOA / no API key
provider = MagicMock()
provider.complete.return_value = "TLDR: rewires. MEDIUM."
provider.model_name = "test-model"
mock_get_provider.return_value = provider

explain_transaction(target=self.REGISTRY, calldata="0x7adbf973" + "00" * 32, chain_id=1)

prompt = provider.complete.call_args[0][0]
# Address shows up, but with no `(Label)` suffix.
self.assertIn(self.FARM_CKS, prompt)
self.assertNotIn(f"{self.FARM_CKS} (", prompt)

@patch("utils.llm.ai_explainer.get_source_context", return_value=None)
@patch("utils.llm.ai_explainer.get_llm_provider")
@patch("utils.llm.ai_explainer.simulate_transaction", return_value=None)
@patch("utils.llm.ai_explainer.decode_calldata")
@patch("utils.llm.ai_explainer.get_contract_label")
def test_zero_address_not_queried(
self,
mock_label: MagicMock,
mock_decode: MagicMock,
mock_simulate: MagicMock,
mock_get_provider: MagicMock,
mock_source: MagicMock,
) -> None:
zero = "0x" + "00" * 20
mock_decode.return_value = DecodedCall(
function_name="setOracle",
signature="setOracle(address)",
params=[("address", zero)],
)
provider = MagicMock()
provider.complete.return_value = "TLDR: unsets oracle. LOW."
provider.model_name = "test-model"
mock_get_provider.return_value = provider

explain_transaction(target=self.REGISTRY, calldata="0x7adbf973" + "00" * 32, chain_id=1)
mock_label.assert_not_called()


if __name__ == "__main__":
unittest.main()
60 changes: 60 additions & 0 deletions tests/test_source_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
_extract_function_snippet,
extract_state_var_snippet,
find_state_var_writes,
get_contract_label,
get_source_context,
reset_cache,
)
Expand Down Expand Up @@ -237,5 +238,64 @@ def test_proxy_follow_skipped_when_impl_equals_target(self, mock_fetch: object,
self.assertEqual(mock_fetch.call_count, 1) # type: ignore[attr-defined]


class TestGetContractLabel(unittest.TestCase):
"""Tests for get_contract_label()."""

def setUp(self) -> None:
reset_cache()

@patch.dict("os.environ", {"ETHERSCAN_TOKEN": "test-key"})
@patch("utils.source_context.fetch_json")
def test_returns_verified_contract_name(self, mock_fetch: object) -> None:
mock_fetch.return_value = { # type: ignore[attr-defined]
"status": "1",
"result": [{"SourceCode": "contract Farm { }", "ContractName": "MorphoFarm"}],
}
label = get_contract_label(1, "0xac21b22b5aeb11bc32de4ecf59e4538fca48b694")
self.assertEqual(label, "MorphoFarm")

@patch.dict("os.environ", {"ETHERSCAN_TOKEN": "test-key"})
@patch("utils.source_context.fetch_json")
def test_unverified_returns_empty(self, mock_fetch: object) -> None:
mock_fetch.return_value = { # type: ignore[attr-defined]
"status": "1",
"result": [{"SourceCode": "", "ContractName": ""}],
}
label = get_contract_label(1, "0xac21b22b5aeb11bc32de4ecf59e4538fca48b694")
self.assertEqual(label, "")

def test_safe_utility_shortcut(self) -> None:
# MultiSendCallOnly — no Etherscan call should be needed.
label = get_contract_label(1, "0x40A2aCCbd92BCA938b02010E17A5b8929b49130D")
self.assertEqual(label, "Safe MultiSendCallOnly")

@patch.dict("os.environ", {"ETHERSCAN_TOKEN": "test-key"})
@patch("utils.proxy.get_current_implementation")
@patch("utils.source_context.fetch_json")
def test_follows_proxy_when_name_is_generic(self, mock_fetch: object, mock_impl: object) -> None:
mock_fetch.side_effect = [ # type: ignore[attr-defined]
{"status": "1", "result": [{"SourceCode": "/* proxy */", "ContractName": "TransparentUpgradeableProxy"}]},
{"status": "1", "result": [{"SourceCode": "/* impl */", "ContractName": "InfinifiBorrowingFarm"}]},
]
mock_impl.return_value = "0x000000000000000000000000000000000000beef" # type: ignore[attr-defined]
label = get_contract_label(1, "0xac21b22b5aeb11bc32de4ecf59e4538fca48b694")
self.assertEqual(label, "InfinifiBorrowingFarm")

@patch.dict("os.environ", {"ETHERSCAN_TOKEN": "test-key"})
@patch("utils.proxy.get_current_implementation", return_value=None)
@patch("utils.source_context.fetch_json")
def test_keeps_specific_name_without_proxy_follow(self, mock_fetch: object, mock_impl: object) -> None:
mock_fetch.return_value = { # type: ignore[attr-defined]
"status": "1",
"result": [{"SourceCode": "/* x */", "ContractName": "FarmRegistry"}],
}
label = get_contract_label(1, "0xac21b22b5aeb11bc32de4ecf59e4538fca48b694")
self.assertEqual(label, "FarmRegistry")
mock_impl.assert_not_called() # type: ignore[attr-defined]

def test_empty_address_returns_empty(self) -> None:
self.assertEqual(get_contract_label(1, ""), "")


if __name__ == "__main__":
unittest.main()
82 changes: 82 additions & 0 deletions tests/test_timelock_alerts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for timelock/timelock_alerts.py — build_alert_message truncation logic."""

import unittest
import unittest.mock
from unittest.mock import patch

from timelock.timelock_alerts import TimelockConfig, build_alert_message
Expand Down Expand Up @@ -112,5 +113,86 @@ def test_message_under_limit_with_ai(self, mock_ai: object, mock_format: object)
self.assertNotIn("...", msg)


class TestMapleProposalUnwrap(unittest.TestCase):
"""Maple ProposalScheduled has no target/data; recover them from the source tx."""

@staticmethod
def _make_schedule_calldata(targets: list[str], datas: list[bytes]) -> str:
from eth_abi import encode
from eth_utils import function_signature_to_4byte_selector

selector = function_signature_to_4byte_selector("scheduleProposals(address[],bytes[])")
body = encode(["address[]", "bytes[]"], [targets, datas])
return "0x" + selector.hex() + body.hex()

@staticmethod
def _wrap_in_safe(inner_hex: str, safe_target: str) -> str:
from eth_abi import encode
from eth_utils import function_signature_to_4byte_selector

selector = function_signature_to_4byte_selector(
"execTransaction(address,uint256,bytes,uint8,uint256,uint256,uint256,address,address,bytes)"
)
zero = "0x" + "00" * 20
body = encode(
["address", "uint256", "bytes", "uint8", "uint256", "uint256", "uint256", "address", "address", "bytes"],
[safe_target, 0, bytes.fromhex(inner_hex[2:]), 0, 0, 0, 0, zero, zero, b""],
)
return "0x" + selector.hex() + body.hex()

@patch("timelock.timelock_alerts.ChainManager")
def test_unwraps_safe_wrapped_schedule_proposals(self, mock_cm: object) -> None:
from timelock.timelock_alerts import _maple_proposal_calls

targets = ["0x" + "aa" * 20, "0x" + "bb" * 20]
datas = [bytes.fromhex("8456cb59"), bytes.fromhex("3f4ba83a")] # pause(), unpause()
inner_hex = self._make_schedule_calldata(targets, datas)
outer = self._wrap_in_safe(inner_hex, "0x2efff88747eb5a3ff00d4d8d0f0800e306c0426b")

mock_client = unittest.mock.MagicMock()
mock_client.eth.get_transaction.return_value = {"input": outer}
mock_cm.get_client.return_value = mock_client # type: ignore[attr-defined]

event = _make_event(timelock_type="Maple", transactionHash="0x" + "ff" * 32)
calls = _maple_proposal_calls(event, chain_id=1)

assert calls is not None
self.assertEqual(len(calls), 2)
self.assertEqual(calls[0]["target"], targets[0])
self.assertEqual(calls[0]["data"], "0x8456cb59")
self.assertEqual(calls[1]["target"], targets[1])
self.assertEqual(calls[1]["data"], "0x3f4ba83a")

@patch("timelock.timelock_alerts.ChainManager")
def test_unwraps_direct_schedule_proposals(self, mock_cm: object) -> None:
from timelock.timelock_alerts import _maple_proposal_calls

targets = ["0x" + "cc" * 20]
datas = [bytes.fromhex("8456cb59")]
inner_hex = self._make_schedule_calldata(targets, datas)

mock_client = unittest.mock.MagicMock()
mock_client.eth.get_transaction.return_value = {"input": inner_hex}
mock_cm.get_client.return_value = mock_client # type: ignore[attr-defined]

event = _make_event(timelock_type="Maple", transactionHash="0x" + "ff" * 32)
calls = _maple_proposal_calls(event, chain_id=1)
assert calls is not None
self.assertEqual(len(calls), 1)
self.assertEqual(calls[0]["data"], "0x8456cb59")

@patch("timelock.timelock_alerts.ChainManager")
def test_returns_none_for_unknown_selector(self, mock_cm: object) -> None:
from timelock.timelock_alerts import _maple_proposal_calls

# proposeRoleUpdates path — we can't synthesize (target, data) pairs from it.
mock_client = unittest.mock.MagicMock()
mock_client.eth.get_transaction.return_value = {"input": "0x2d6e853c" + "00" * 100}
mock_cm.get_client.return_value = mock_client # type: ignore[attr-defined]

event = _make_event(timelock_type="Maple", transactionHash="0x" + "ff" * 32)
self.assertIsNone(_maple_proposal_calls(event, chain_id=1))


if __name__ == "__main__":
unittest.main()
Loading