Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 4 additions & 125 deletions tests/test_ai_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,111 +7,14 @@
from utils.llm.ai_explainer import (
Explanation,
_build_prompt,
_format_decoded_calls,
_format_simulation_context,
_parse_explanation,
_refine_explanation,
explain_transaction,
format_explanation_line,
)
from utils.llm.base import LLMError
from utils.source_context import SourceContext
from utils.tenderly.simulation import AssetChange, SimulationResult, StateChange


class TestFormatDecodedCalls(unittest.TestCase):
"""Tests for _format_decoded_calls."""

def test_single_call_no_params(self) -> None:
calls = [DecodedCall(function_name="pause", signature="pause()")]
result = _format_decoded_calls(calls)
self.assertIn("Call 1: pause()", result)

def test_single_call_with_params(self) -> None:
calls = [
DecodedCall(
function_name="grantRole",
signature="grantRole(bytes32,address)",
params=[("bytes32", b"\x00" * 32), ("address", "0xABC")],
)
]
result = _format_decoded_calls(calls)
self.assertIn("grantRole(bytes32,address)", result)
self.assertIn("bytes32:", result)
self.assertIn("address:", result)

def test_multiple_calls(self) -> None:
calls = [
DecodedCall(function_name="pause", signature="pause()"),
DecodedCall(function_name="unpause", signature="unpause()"),
]
result = _format_decoded_calls(calls)
self.assertIn("Call 1: pause()", result)
self.assertIn("Call 2: unpause()", result)


class TestFormatSimulationContext(unittest.TestCase):
"""Tests for _format_simulation_context."""

def test_successful_simulation(self) -> None:
sim = SimulationResult(success=True, gas_used=50000)
result = _format_simulation_context(sim)
self.assertIn("SUCCESS", result)
self.assertIn("50,000", result)

def test_failed_simulation(self) -> None:
sim = SimulationResult(success=False, gas_used=21000, error_message="execution reverted")
result = _format_simulation_context(sim)
self.assertIn("FAILED", result)
self.assertIn("execution reverted", result)

def test_with_asset_changes(self) -> None:
sim = SimulationResult(
success=True,
gas_used=100000,
asset_changes=[
AssetChange(
token_address="0xToken",
token_name="USDC",
token_symbol="USDC",
from_address="0xA",
to_address="0xB",
amount="1000",
raw_amount="1000000000",
decimals=6,
)
],
)
result = _format_simulation_context(sim)
self.assertIn("Token transfers:", result)
self.assertIn("USDC", result)

def test_with_state_changes(self) -> None:
sim = SimulationResult(
success=True,
gas_used=100000,
state_changes=[
StateChange(
contract_address="0xContract",
key="0x01",
original="0x00",
dirty="0x01",
)
],
)
result = _format_simulation_context(sim)
self.assertIn("State changes", result)
self.assertIn("0xContract", result)

def test_with_logs(self) -> None:
sim = SimulationResult(
success=True,
gas_used=100000,
logs=[{"name": "Transfer", "inputs": [{"soltype": {"name": "to"}, "value": "0xB"}]}],
)
result = _format_simulation_context(sim)
self.assertIn("Events emitted", result)
self.assertIn("Transfer", result)
from utils.tenderly.simulation import SimulationResult


class TestBuildPrompt(unittest.TestCase):
Expand Down Expand Up @@ -226,17 +129,6 @@ def test_mixed_signatures_no_section(self) -> None:
self.assertNotIn("--- Shared Across Batch ---", result)


class TestSystemPromptBrevity(unittest.TestCase):
"""Verify the system prompt enforces brevity rules."""

def test_includes_word_cap_and_no_preamble_rules(self) -> None:
calls = [DecodedCall(function_name="pause", signature="pause()")]
result = _build_prompt(target="0xT", value=0, decoded_calls=calls, simulation=None)
self.assertIn("≤25 words", result)
self.assertIn('"This transaction"', result)
self.assertIn("risk tag in caps", result)


class TestSkipSimulation(unittest.TestCase):
"""Tests for skip_simulation flag."""

Expand Down Expand Up @@ -501,18 +393,12 @@ class TestRefineExplanation(unittest.TestCase):
"""Tests for _refine_explanation."""

def test_pass_keeps_draft_unchanged(self) -> None:
# Trailing whitespace around "PASS" must also count as PASS.
draft = Explanation(summary="Lowers fee 30→25 bps. LOW.", detail="bla")
provider = MagicMock()
provider.complete.return_value = "PASS"
result = _refine_explanation("orig prompt", draft, provider)
self.assertIs(result, draft)
provider.complete.assert_called_once()

def test_pass_with_trailing_whitespace_still_keeps_draft(self) -> None:
draft = Explanation(summary="x", detail="y")
provider = MagicMock()
provider.complete.return_value = " PASS \n"
self.assertIs(_refine_explanation("p", draft, provider), draft)
self.assertIs(_refine_explanation("orig prompt", draft, provider), draft)
provider.complete.assert_called_once()

def test_revision_replaces_draft(self) -> None:
draft = Explanation(summary="This transaction does X. LOW.", detail="bla")
Expand All @@ -534,13 +420,6 @@ def test_empty_response_falls_back_to_draft(self) -> None:
provider.complete.return_value = ""
self.assertIs(_refine_explanation("p", draft, provider), draft)

def test_revision_with_empty_summary_falls_back(self) -> None:
draft = Explanation(summary="x", detail="y")
provider = MagicMock()
# Reply parses to empty summary (no TLDR marker, no body)
provider.complete.return_value = ""
self.assertIs(_refine_explanation("p", draft, provider), draft)


class TestRefineFlagInExplainTransaction(unittest.TestCase):
"""Tests that the refine flag triggers a second LLM call."""
Expand Down
70 changes: 2 additions & 68 deletions tests/test_impl_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,13 @@
from unittest.mock import patch

from utils.impl_diff import (
FunctionSig,
ImplDiff,
StateVarDecl,
_diff_functions,
_extract_function_sigs,
_extract_state_vars,
_is_namespaced_storage,
_normalize_args,
_storage_layout,
diff_implementations,
format_impl_diff,
)

CONTRACT_OLD = """
Expand Down Expand Up @@ -257,7 +253,7 @@ def test_plain_contract_is_not(self) -> None:


class TestDiffImplementations(unittest.TestCase):
@patch("utils.impl_diff._fetch_source")
@patch("utils.impl_diff.fetch_source")
def test_end_to_end(self, mock_fetch) -> None:
mock_fetch.side_effect = [("Vault", CONTRACT_OLD), ("Vault", CONTRACT_NEW)]
diff = diff_implementations("0xOld", "0xNew", 1)
Expand All @@ -268,72 +264,10 @@ def test_end_to_end(self, mock_fetch) -> None:
self.assertEqual(diff.added_functions[0].name, "setMaxDeposit")
self.assertEqual(len(diff.changed_functions), 1)

@patch("utils.impl_diff._fetch_source", return_value=None)
@patch("utils.impl_diff.fetch_source", return_value=None)
def test_returns_none_on_unverified(self, mock_fetch) -> None:
self.assertIsNone(diff_implementations("0xOld", "0xNew", 1))


class TestFormatImplDiff(unittest.TestCase):
def test_basic_output(self) -> None:
diff = ImplDiff(
old_addr="0xOld",
new_addr="0xNew",
old_name="Vault",
new_name="Vault",
added_functions=[FunctionSig(name="newFn", args="uint256", visibility="external", modifiers="")],
removed_functions=[],
changed_functions=[],
added_state_vars=[StateVarDecl(name="newVar", type_str="uint256", visibility="public", immutable=False)],
removed_state_vars=[],
layout_changes=[],
storage_layout_safe=True,
namespaced_storage=False,
)
out = format_impl_diff(diff)
self.assertIn("Old: 0xOld", out)
self.assertIn("New: 0xNew", out)
self.assertIn("Functions added", out)
self.assertIn("newFn(uint256)", out)
self.assertIn("Storage layout safe", out)

def test_unsafe_layout_warning(self) -> None:
diff = ImplDiff(
old_addr="0xOld",
new_addr="0xNew",
old_name="X",
new_name="X",
added_functions=[],
removed_functions=[],
changed_functions=[],
added_state_vars=[],
removed_state_vars=[],
layout_changes=["slot 0: uint256 a → address b"],
storage_layout_safe=False,
namespaced_storage=False,
)
out = format_impl_diff(diff)
self.assertIn("NOT upgrade-safe", out)
self.assertIn("slot 0", out)

def test_namespaced_skips_check(self) -> None:
diff = ImplDiff(
old_addr="0xOld",
new_addr="0xNew",
old_name="",
new_name="",
added_functions=[],
removed_functions=[],
changed_functions=[],
added_state_vars=[],
removed_state_vars=[],
layout_changes=[],
storage_layout_safe=True,
namespaced_storage=True,
)
out = format_impl_diff(diff)
self.assertIn("EIP-7201", out)
self.assertIn("skipped", out)


if __name__ == "__main__":
unittest.main()
8 changes: 4 additions & 4 deletions tests/test_on_chain_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def test_skips_array_params(self) -> None:

class TestReadBeforeState(unittest.TestCase):
@patch("utils.on_chain_state.ChainManager")
@patch("utils.on_chain_state._fetch_source")
@patch("utils.on_chain_state.fetch_source")
def test_reads_simple_uint(self, mock_fetch: MagicMock, mock_chain: MagicMock) -> None:
source = """
uint256 public maxSlippage;
Expand Down Expand Up @@ -152,7 +152,7 @@ def test_reads_simple_uint(self, mock_fetch: MagicMock, mock_chain: MagicMock) -
self.assertEqual(reads[0].key_args, ())

@patch("utils.on_chain_state.ChainManager")
@patch("utils.on_chain_state._fetch_source")
@patch("utils.on_chain_state.fetch_source")
def test_reads_address_keyed_mapping(self, mock_fetch: MagicMock, mock_chain: MagicMock) -> None:
source = """
mapping(address => uint256) public coverageCap;
Expand Down Expand Up @@ -180,12 +180,12 @@ def test_reads_address_keyed_mapping(self, mock_fetch: MagicMock, mock_chain: Ma
self.assertEqual(reads[0].value, 5000000000000000)
self.assertEqual(reads[0].key_args, (agent,))

@patch("utils.on_chain_state._fetch_source", return_value=None)
@patch("utils.on_chain_state.fetch_source", return_value=None)
def test_no_source_returns_empty(self, mock_fetch: MagicMock) -> None:
call = DecodedCall(function_name="setX", signature="setX(uint256)", params=[("uint256", 1)])
self.assertEqual(read_before_state(1, "0xT", call), [])

@patch("utils.on_chain_state._fetch_source")
@patch("utils.on_chain_state.fetch_source")
def test_struct_mapping_returns_empty(self, mock_fetch: MagicMock) -> None:
source = """
mapping(address => CreditLine) public creditLines;
Expand Down
39 changes: 7 additions & 32 deletions tests/test_source_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
from unittest.mock import patch

from utils.source_context import (
SourceContext,
_concat_sources,
_extract_function_body,
_extract_function_snippet,
_extract_state_var_snippet,
_find_state_var_writes,
format_source_context,
extract_state_var_snippet,
find_state_var_writes,
get_source_context,
reset_cache,
)
Expand Down Expand Up @@ -77,7 +75,7 @@ def test_handles_nested_braces(self) -> None:

class TestFindStateVarWrites(unittest.TestCase):
def test_finds_assignment(self) -> None:
writes = _find_state_var_writes(INFINIFI_FARM_SOURCE, "setMaxSlippage")
writes = find_state_var_writes(INFINIFI_FARM_SOURCE, "setMaxSlippage")
self.assertIn("maxSlippage", writes)

def test_ignores_local_and_keyword_assignments(self) -> None:
Expand All @@ -88,7 +86,7 @@ def test_ignores_local_and_keyword_assignments(self) -> None:
_underscoreVar = 5;
}
"""
writes = _find_state_var_writes(source, "f")
writes = find_state_var_writes(source, "f")
self.assertIn("storedVar", writes)
self.assertNotIn("local", writes) # locals can't be distinguished, but this test documents the heuristic limit
self.assertNotIn("_underscoreVar", writes)
Expand All @@ -97,12 +95,12 @@ def test_ignores_local_and_keyword_assignments(self) -> None:

class TestExtractStateVarSnippet(unittest.TestCase):
def test_extracts_natspec_and_declaration(self) -> None:
snippet = _extract_state_var_snippet(INFINIFI_FARM_SOURCE, "maxSlippage")
snippet = extract_state_var_snippet(INFINIFI_FARM_SOURCE, "maxSlippage")
self.assertIn("so actually 1 - slippage", snippet)
self.assertIn("uint256 public maxSlippage", snippet)

def test_missing_var_returns_empty(self) -> None:
snippet = _extract_state_var_snippet(INFINIFI_FARM_SOURCE, "doesNotExist")
snippet = extract_state_var_snippet(INFINIFI_FARM_SOURCE, "doesNotExist")
self.assertEqual(snippet, "")

def test_skips_local_vars_without_visibility(self) -> None:
Expand All @@ -112,7 +110,7 @@ def test_skips_local_vars_without_visibility(self) -> None:
}
"""
# Should not match — no public/private/internal/external modifier
snippet = _extract_state_var_snippet(source, "plainLocal")
snippet = extract_state_var_snippet(source, "plainLocal")
self.assertEqual(snippet, "")


Expand Down Expand Up @@ -239,28 +237,5 @@ def test_proxy_follow_skipped_when_impl_equals_target(self, mock_fetch: object,
self.assertEqual(mock_fetch.call_count, 1) # type: ignore[attr-defined]


class TestFormatSourceContext(unittest.TestCase):
def test_includes_contract_name_and_snippets(self) -> None:
ctx = SourceContext(
contract_name="Farm",
function_snippet="/// @notice foo\nfunction setMaxSlippage(uint256) external;",
state_var_snippets=["/// @notice slippage\nuint256 public maxSlippage;"],
)
result = format_source_context(ctx)
self.assertIn("Contract: Farm", result)
self.assertIn("setMaxSlippage", result)
self.assertIn("Relevant state variables:", result)
self.assertIn("maxSlippage", result)

def test_omits_state_var_section_when_empty(self) -> None:
ctx = SourceContext(
contract_name="Foo",
function_snippet="function pause() external;",
state_var_snippets=[],
)
result = format_source_context(ctx)
self.assertNotIn("Relevant state variables", result)


if __name__ == "__main__":
unittest.main()
Loading