From bc3e9fc146cd9485c5b8a7079158f39a324cd979 Mon Sep 17 00:00:00 2001 From: Hendobox <50964581+Hendobox@users.noreply.github.com> Date: Thu, 28 May 2026 14:17:01 +0100 Subject: [PATCH] feat(wallet_screening): unify TRM/scam tx risk index for analysis Merge malicious contracts and TRM-normalized ETH addresses into a tx-risk index built at init, and use it for malicious interaction detection during tx analysis. Add targeted regression tests to verify TRM counterparty detection and index merging across core and additional sources. Closes #138 Relates to #115 --- skills/finance/wallet_screening/skill.py | 89 +++++++++++++++---- tests/skills/finance/test_wallet_screening.py | 88 ++++++++++++++++++ 2 files changed, 162 insertions(+), 15 deletions(-) diff --git a/skills/finance/wallet_screening/skill.py b/skills/finance/wallet_screening/skill.py index ef64d16..c71380e 100644 --- a/skills/finance/wallet_screening/skill.py +++ b/skills/finance/wallet_screening/skill.py @@ -38,6 +38,9 @@ def __init__(self, config: Optional[Dict[str, Any]] = None): # ETH address -> sanctions records (built once; O(1) lookup per screen) self._sanctions_index: Dict[str, List[Dict]] = {} self._build_sanctions_index() + # ETH address -> tx risk records (core malicious + normalized lists) + self._tx_risk_index: Dict[str, List[Dict]] = {} + self._build_tx_risk_index() @property def manifest(self) -> Dict[str, Any]: @@ -138,9 +141,7 @@ def normalize_eth_address(address: str) -> Optional[str]: """Normalize and validate an Ethereum address (EIP-55 checksum not required).""" if not isinstance(address, str): return None - cleaned = address.strip().translate( - {ord(c): None for c in _ZERO_WIDTH_CHARS} - ) + cleaned = address.strip().translate({ord(c): None for c in _ZERO_WIDTH_CHARS}) if not cleaned.lower().startswith("0x"): return None normalized = "0x" + cleaned[2:].lower() @@ -213,6 +214,53 @@ def _lookup_sanctions_hits(self, address: str) -> List[Dict]: return [] return list(self._sanctions_index.get(normalized, [])) + @staticmethod + def _severity_rank(value: str) -> int: + order = {"critical": 4, "high": 3, "medium": 2, "low": 1} + return order.get(str(value).lower(), 0) + + def _record_to_tx_risk_entry(self, record: Dict) -> Dict[str, Any]: + return { + "contract_name": record.get("name") + or record.get("label") + or record.get("caption") + or "Unknown", + "severity": (record.get("severity") or "high").lower(), + "category": record.get("category") or record.get("reason") or "malicious", + "source_file": record.get("__source_file__", "malicious_scs_2025.json"), + "jurisdictions": record.get("jurisdictions_blocked", []), + } + + def _build_tx_risk_index(self) -> None: + """Index normalized ETH addresses used for tx-level risk screening.""" + index: Dict[str, List[Dict]] = {} + for record in self.malicious_contracts: + if not isinstance(record, dict): + continue + for addr in self._eth_addresses_from_record(record): + index.setdefault(addr, []).append(self._record_to_tx_risk_entry(record)) + + for record in self.additional_datasets: + if not isinstance(record, dict): + continue + source = str(record.get("__source_file__", "")).lower() + if ( + "uniswap_trm" not in source + and "trm" not in source + and "malicious" not in source + ): + continue + for addr in self._eth_addresses_from_record(record): + index.setdefault(addr, []).append(self._record_to_tx_risk_entry(record)) + + self._tx_risk_index = index + + def _lookup_tx_risk_entries(self, address: str) -> List[Dict]: + normalized = self.normalize_eth_address(address) + if not normalized: + return [] + return list(self._tx_risk_index.get(normalized, [])) + def _get_price(self, url: str, currency: str) -> float: try: resp = requests.get(url, timeout=10) @@ -271,8 +319,6 @@ def _analyze_transactions( counterparty_counts = {} malicious_interactions = [] - malicious_map = {c["address"].lower(): c for c in self.malicious_contracts} - for tx in txs: from_addr = tx.get("from", "").lower() to_addr = tx.get("to", "").lower() if tx.get("to") else "" @@ -300,21 +346,34 @@ def _analyze_transactions( # Malicious Check other_party = None - if to_addr and to_addr in malicious_map: - other_party = to_addr - elif from_addr and from_addr in malicious_map: - other_party = from_addr - - if other_party: - contract_info = malicious_map[other_party] + tx_risk_entries: List[Dict] = [] + if to_addr: + tx_risk_entries = self._lookup_tx_risk_entries(to_addr) + if tx_risk_entries: + other_party = to_addr + if not tx_risk_entries and from_addr: + tx_risk_entries = self._lookup_tx_risk_entries(from_addr) + if tx_risk_entries: + other_party = from_addr + + if other_party and tx_risk_entries: + primary = max( + tx_risk_entries, + key=lambda item: self._severity_rank(item.get("severity", "")), + ) + sources = sorted( + {entry.get("source_file", "Unknown") for entry in tx_risk_entries} + ) malicious_interactions.append( { "tx_hash": tx.get("hash"), "other_party": other_party, "direction": "out" if from_addr == wallet_addr else "in", - "contract_name": contract_info.get("name"), - "severity": contract_info.get("severity"), - "jurisdictions": contract_info.get("jurisdictions_blocked", []), + "contract_name": primary.get("contract_name"), + "severity": primary.get("severity"), + "jurisdictions": primary.get("jurisdictions", []), + "source_file": primary.get("source_file"), + "sources": sources, "value_eth": value_eth, } ) diff --git a/tests/skills/finance/test_wallet_screening.py b/tests/skills/finance/test_wallet_screening.py index 0e0084f..1aff92d 100644 --- a/tests/skills/finance/test_wallet_screening.py +++ b/tests/skills/finance/test_wallet_screening.py @@ -162,3 +162,91 @@ def test_sanctions_index_real_ftm_publickey_vector(): assert len(hits) >= 1 assert hits[0]["__source_file__"] == "entities.ftm.json" assert SANCTIONED_ETH in hits[0].get("properties", {}).get("publicKey", []) + + +@patch("skills.finance.wallet_screening.skill.requests.get") +def test_tx_risk_detects_uniswap_trm_counterparty(mock_get): + skill = get_skill() + skill.etherscan_api_key = "dummy_key" + trm_addr = "0x009988Ff77eEaa00051238ee32C48f10a174933E" + skill.malicious_contracts = [] + skill.additional_datasets = [ + { + "address": trm_addr, + "name": "TRM Test Address", + "reason": "Scam (High)", + "severity": "high", + "__source_file__": "normalized_uniswap_trm.json", + } + ] + skill._build_sanctions_index() + skill._build_tx_risk_index() + + mock_eth_balance = MagicMock() + mock_eth_balance.json.return_value = {"status": "1", "result": "0"} + mock_txs = MagicMock() + mock_txs.json.return_value = { + "status": "1", + "result": [ + { + "from": "0xd8da6bf26964af9d7eed9e03e53415d37aa96045", + "to": trm_addr, + "value": "10000000000000000", + "isError": "0", + "gasUsed": "21000", + "gasPrice": "1000000000", + "hash": "0xtesthashtrm", + } + ], + } + mock_price = MagicMock() + mock_price.json.return_value = {"ethereum": {"usd": 2000.0, "eur": 1800.0}} + + def get_side_effect(url, **kwargs): + params = kwargs.get("params") or {} + if params.get("action") == "balance": + return mock_eth_balance + if params.get("action") == "txlist": + return mock_txs + return mock_price + + mock_get.side_effect = get_side_effect + result = skill.execute({"address": "0xd8dA6BF26964aF9D7eEd9e03E53415D37aA96045"}) + + assert result["summary"]["malicious_interaction_count"] == 1 + interaction = result["risk_details"]["malicious_interactions"][0] + assert interaction["other_party"] == trm_addr.lower() + assert interaction["source_file"] == "normalized_uniswap_trm.json" + assert "normalized_uniswap_trm.json" in interaction["sources"] + + +def test_tx_risk_index_merges_core_and_additional_sources(): + skill = get_skill() + core_addr = "0x1111111111111111111111111111111111111111" + trm_addr = "0x2222222222222222222222222222222222222222" + skill.malicious_contracts = [ + { + "address": core_addr, + "name": "Core Mixer", + "severity": "high", + "jurisdictions_blocked": ["US"], + } + ] + skill.additional_datasets = [ + { + "address": trm_addr, + "name": "TRM Scam Address", + "reason": "Scam (Critical)", + "severity": "critical", + "__source_file__": "normalized_uniswap_trm.json", + } + ] + skill._build_tx_risk_index() + + core_entries = skill._lookup_tx_risk_entries(core_addr) + trm_entries = skill._lookup_tx_risk_entries(trm_addr) + + assert len(core_entries) == 1 + assert core_entries[0]["contract_name"] == "Core Mixer" + assert len(trm_entries) == 1 + assert trm_entries[0]["source_file"] == "normalized_uniswap_trm.json"