From 08ce5786179f35f28baf0589477c5e2f9551cb20 Mon Sep 17 00:00:00 2001 From: kylexqian Date: Thu, 5 Mar 2026 01:22:50 -0800 Subject: [PATCH 01/13] feat: resolve TEE endpoint and TLS cert from on-chain registry Instead of blindly trusting the TLS certificate presented by the TEE server (TOFU), the SDK now queries the on-chain TEERegistry contract to discover active LLM proxy endpoints and their verified certificates. Key changes: - Add TEERegistry.abi and tee_registry.py to query the registry contract - Replace TOFU cert fetch in llm.py with registry-verified DER cert - Client.init queries the registry by default; og_llm_server_url still works as an explicit override (falls back to system CA verification) - Add DEFAULT_TEE_REGISTRY_ADDRESS and DEFAULT_TEE_REGISTRY_RPC_URL to defaults.py (OG EVM chain at http://13.59.43.94:8545) - Surface tee_id, tee_endpoint, tee_payment_address on every TextGenerationOutput and on the final StreamChunk so callers can audit which enclave served their request - Print TEE node info (endpoint, TEE ID, payment address) in all three CLI print helpers (completion, chat, streaming) Co-Authored-By: Claude Sonnet 4.6 --- src/opengradient/abi/TEERegistry.abi | 48 ++++++++ src/opengradient/cli.py | 32 ++++- src/opengradient/client/client.py | 65 +++++++++- src/opengradient/client/llm.py | 97 ++++++--------- src/opengradient/client/tee_registry.py | 152 ++++++++++++++++++++++++ src/opengradient/defaults.py | 7 +- src/opengradient/types.py | 15 +++ 7 files changed, 344 insertions(+), 72 deletions(-) create mode 100644 src/opengradient/abi/TEERegistry.abi create mode 100644 src/opengradient/client/tee_registry.py diff --git a/src/opengradient/abi/TEERegistry.abi b/src/opengradient/abi/TEERegistry.abi new file mode 100644 index 0000000..9dcf4a5 --- /dev/null +++ b/src/opengradient/abi/TEERegistry.abi @@ -0,0 +1,48 @@ +[ + { + "inputs": [], + "name": "getActiveTEEs", + "outputs": [{"internalType": "bytes32[]", "name": "", "type": "bytes32[]"}], + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [{"internalType": "uint8", "name": "teeType", "type": "uint8"}], + "name": "getTEEsByType", + "outputs": [{"internalType": "bytes32[]", "name": "", "type": "bytes32[]"}], + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [{"internalType": "bytes32", "name": "teeId", "type": "bytes32"}], + "name": "getTEE", + "outputs": [ + { + "components": [ + {"internalType": "address", "name": "owner", "type": "address"}, + {"internalType": "address", "name": "paymentAddress", "type": "address"}, + {"internalType": "string", "name": "endpoint", "type": "string"}, + {"internalType": "bytes", "name": "publicKey", "type": "bytes"}, + {"internalType": "bytes", "name": "tlsCertificate", "type": "bytes"}, + {"internalType": "bytes32", "name": "pcrHash", "type": "bytes32"}, + {"internalType": "uint8", "name": "teeType", "type": "uint8"}, + {"internalType": "bool", "name": "active", "type": "bool"}, + {"internalType": "uint256", "name": "registeredAt", "type": "uint256"}, + {"internalType": "uint256", "name": "lastUpdatedAt", "type": "uint256"} + ], + "internalType": "struct TEERegistry.TEEInfo", + "name": "", + "type": "tuple" + } + ], + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [{"internalType": "bytes32", "name": "teeId", "type": "bytes32"}], + "name": "isActive", + "outputs": [{"internalType": "bool", "name": "", "type": "bool"}], + "stateMutability": "view", + "type": "function" + } +] diff --git a/src/opengradient/cli.py b/src/opengradient/cli.py index 979f676..bbec05e 100644 --- a/src/opengradient/cli.py +++ b/src/opengradient/cli.py @@ -413,13 +413,29 @@ def completion( x402_settlement_mode=x402SettlementModes[x402_settlement_mode], ) - print_llm_completion_result(model_cid, completion_output.transaction_hash, completion_output.completion_output, is_vanilla=False) + print_llm_completion_result(model_cid, completion_output.transaction_hash, completion_output.completion_output, is_vanilla=False, result=completion_output) except Exception as e: click.echo(f"Error running LLM completion: {str(e)}") -def print_llm_completion_result(model_cid, tx_hash, llm_output, is_vanilla=True): +def _print_tee_info(tee_id, tee_endpoint, tee_payment_address): + """Print TEE node info if available.""" + if not any([tee_id, tee_endpoint, tee_payment_address]): + return + click.secho("TEE Node:", fg="magenta", bold=True) + if tee_endpoint: + click.echo(" Endpoint: ", nl=False) + click.secho(tee_endpoint, fg="magenta") + if tee_id: + click.echo(" TEE ID: ", nl=False) + click.secho(tee_id, fg="magenta") + if tee_payment_address: + click.echo(" Payment address: ", nl=False) + click.secho(tee_payment_address, fg="magenta") + + +def print_llm_completion_result(model_cid, tx_hash, llm_output, is_vanilla=True, result=None): click.secho("✅ LLM completion Successful", fg="green", bold=True) click.echo("──────────────────────────────────────") click.echo("Model: ", nl=False) @@ -435,6 +451,9 @@ def print_llm_completion_result(model_cid, tx_hash, llm_output, is_vanilla=True) click.echo("Source: ", nl=False) click.secho("OpenGradient TEE", fg="cyan", bold=True) + if result is not None: + _print_tee_info(result.tee_id, result.tee_endpoint, result.tee_payment_address) + click.echo("──────────────────────────────────────") click.secho("LLM Output:", fg="yellow", bold=True) click.echo() @@ -578,13 +597,13 @@ def chat( if stream: print_streaming_chat_result(model_cid, result, is_tee=True) else: - print_llm_chat_result(model_cid, result.transaction_hash, result.finish_reason, result.chat_output, is_vanilla=False) + print_llm_chat_result(model_cid, result.transaction_hash, result.finish_reason, result.chat_output, is_vanilla=False, result=result) except Exception as e: click.echo(f"Error running LLM chat inference: {str(e)}") -def print_llm_chat_result(model_cid, tx_hash, finish_reason, chat_output, is_vanilla=True): +def print_llm_chat_result(model_cid, tx_hash, finish_reason, chat_output, is_vanilla=True, result=None): click.secho("✅ LLM Chat Successful", fg="green", bold=True) click.echo("──────────────────────────────────────") click.echo("Model: ", nl=False) @@ -600,6 +619,9 @@ def print_llm_chat_result(model_cid, tx_hash, finish_reason, chat_output, is_van click.echo("Source: ", nl=False) click.secho("OpenGradient TEE", fg="cyan", bold=True) + if result is not None: + _print_tee_info(result.tee_id, result.tee_endpoint, result.tee_payment_address) + click.echo("──────────────────────────────────────") click.secho("Finish Reason: ", fg="yellow", bold=True) click.echo() @@ -673,6 +695,8 @@ def print_streaming_chat_result(model_cid, stream, is_tee=True): click.echo("Finish reason: ", nl=False) click.secho(chunk.choices[0].finish_reason, fg="green") + _print_tee_info(chunk.tee_id, chunk.tee_endpoint, chunk.tee_payment_address) + click.echo("──────────────────────────────────────") click.echo(f"Chunks received: {chunk_count}") click.echo(f"Content length: {len(''.join(content_parts))} characters") diff --git a/src/opengradient/client/client.py b/src/opengradient/client/client.py index 8bb35e6..1af1dae 100644 --- a/src/opengradient/client/client.py +++ b/src/opengradient/client/client.py @@ -1,5 +1,6 @@ """Main Client class that unifies all OpenGradient service namespaces.""" +import logging from typing import Optional from web3 import Web3 @@ -7,15 +8,18 @@ from ..defaults import ( DEFAULT_API_URL, DEFAULT_INFERENCE_CONTRACT_ADDRESS, - DEFAULT_OPENGRADIENT_LLM_SERVER_URL, - DEFAULT_OPENGRADIENT_LLM_STREAMING_SERVER_URL, DEFAULT_RPC_URL, + DEFAULT_TEE_REGISTRY_ADDRESS, + DEFAULT_TEE_REGISTRY_RPC_URL, ) from .alpha import Alpha from .llm import LLM from .model_hub import ModelHub +from .tee_registry import TEERegistry from .twins import Twins +logger = logging.getLogger(__name__) + class Client: """ @@ -62,8 +66,10 @@ def __init__( rpc_url: str = DEFAULT_RPC_URL, api_url: str = DEFAULT_API_URL, contract_address: str = DEFAULT_INFERENCE_CONTRACT_ADDRESS, - og_llm_server_url: Optional[str] = DEFAULT_OPENGRADIENT_LLM_SERVER_URL, - og_llm_streaming_server_url: Optional[str] = DEFAULT_OPENGRADIENT_LLM_STREAMING_SERVER_URL, + og_llm_server_url: Optional[str] = None, + og_llm_streaming_server_url: Optional[str] = None, + tee_registry_address: str = DEFAULT_TEE_REGISTRY_ADDRESS, + tee_registry_rpc_url: str = DEFAULT_TEE_REGISTRY_RPC_URL, ): """ Initialize the OpenGradient client. @@ -74,6 +80,11 @@ def __init__( You can supply a separate ``alpha_private_key`` so each chain uses its own funded wallet. When omitted, ``private_key`` is used for both. + By default the LLM server endpoint and its TLS certificate are fetched from + the on-chain TEE Registry, which stores certificates that were verified during + enclave attestation. You can override the endpoint by passing + ``og_llm_server_url`` explicitly (the system CA bundle is used for that URL). + Args: private_key: Private key whose wallet holds **Base Sepolia OPG tokens** for x402 LLM payments. @@ -86,8 +97,15 @@ def __init__( rpc_url: RPC URL for the OpenGradient Alpha Testnet. api_url: API URL for the OpenGradient API. contract_address: Inference contract address. - og_llm_server_url: OpenGradient LLM server URL. - og_llm_streaming_server_url: OpenGradient LLM streaming server URL. + og_llm_server_url: Override the LLM server URL instead of using the + registry-discovered endpoint. When set, the TLS certificate is + validated against the system CA bundle rather than the registry. + og_llm_streaming_server_url: Override the LLM streaming server URL. + Defaults to ``og_llm_server_url`` when that is provided. + tee_registry_address: Address of the TEERegistry contract used to + discover active LLM proxy endpoints and their verified TLS certs. + tee_registry_rpc_url: RPC endpoint for the chain that hosts the + TEERegistry contract. """ blockchain = Web3(Web3.HTTPProvider(rpc_url)) wallet_account = blockchain.eth.account.from_key(private_key) @@ -102,6 +120,38 @@ def __init__( if email is not None: hub_user = ModelHub._login_to_hub(email, password) + # Resolve LLM server URL and TLS certificate. + # If the caller provided explicit URLs, use those with standard CA verification. + # Otherwise, discover the endpoint and registry-verified cert from the TEE Registry. + llm_tls_cert_der: Optional[bytes] = None + tee = None + if og_llm_server_url is None: + try: + registry = TEERegistry( + rpc_url=tee_registry_rpc_url, + registry_address=tee_registry_address, + ) + tee = registry.get_llm_tee() + if tee is not None: + og_llm_server_url = tee.endpoint + og_llm_streaming_server_url = og_llm_streaming_server_url or tee.endpoint + llm_tls_cert_der = tee.tls_cert_der + logger.info("Using TEE endpoint from registry: %s (teeId=%s)", tee.endpoint, tee.tee_id) + else: + raise ValueError( + "No active LLM proxy TEE found in the registry. " + "Pass og_llm_server_url explicitly to override." + ) + except ValueError: + raise + except Exception as e: + raise RuntimeError( + f"Failed to fetch LLM TEE endpoint from registry ({tee_registry_address}): {e}. " + "Pass og_llm_server_url explicitly to override." + ) from e + else: + og_llm_streaming_server_url = og_llm_streaming_server_url or og_llm_server_url + # Create namespaces self.model_hub = ModelHub(hub_user=hub_user) self.wallet_address = wallet_account.address @@ -110,6 +160,9 @@ def __init__( wallet_account=wallet_account, og_llm_server_url=og_llm_server_url, og_llm_streaming_server_url=og_llm_streaming_server_url, + tls_cert_der=llm_tls_cert_der, + tee_id=tee.tee_id if tee is not None else None, + tee_payment_address=tee.payment_address if tee is not None else None, ) self.alpha = Alpha( diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index f95acd3..2258f7c 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -2,13 +2,10 @@ import asyncio import json +import ssl import threading from queue import Queue from typing import AsyncGenerator, Dict, List, Optional, Union -import ssl -import socket -import tempfile -from urllib.parse import urlparse import httpx from eth_account.account import LocalAccount @@ -21,6 +18,7 @@ from ..types import TEE_LLM, StreamChunk, TextGenerationOutput, TextGenerationStream, x402SettlementMode from .exceptions import OpenGradientError from .opg_token import Permit2ApprovalResult, ensure_opg_approval +from .tee_registry import build_ssl_context_from_der X402_PROCESSING_HASH_HEADER = "x-processing-hash" X402_PLACEHOLDER_API_KEY = "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" @@ -40,53 +38,6 @@ ) -def _fetch_tls_cert_as_ssl_context(server_url: str) -> Optional[ssl.SSLContext]: - """ - Connect to a server, retrieve its TLS certificate (TOFU), - and return an ssl.SSLContext that trusts ONLY that certificate. - - Hostname verification is disabled because the TEE server's cert - is typically issued for a hostname but we may connect via IP address. - The pinned certificate itself provides the trust anchor. - - Returns None if the server is not HTTPS or unreachable. - """ - parsed = urlparse(server_url) - if parsed.scheme != "https": - return None - - hostname = parsed.hostname - port = parsed.port or 443 - - # Connect without verification to retrieve the server's certificate - fetch_ctx = ssl.create_default_context() - fetch_ctx.check_hostname = False - fetch_ctx.verify_mode = ssl.CERT_NONE - - try: - with socket.create_connection((hostname, port), timeout=10) as sock: - with fetch_ctx.wrap_socket(sock, server_hostname=hostname) as ssock: - der_cert = ssock.getpeercert(binary_form=True) - pem_cert = ssl.DER_cert_to_PEM_cert(der_cert) - except Exception: - return None - - # Write PEM to a temp file so we can load it into the SSLContext - cert_file = tempfile.NamedTemporaryFile( - prefix="og_tee_tls_", suffix=".pem", delete=False, mode="w" - ) - cert_file.write(pem_cert) - cert_file.flush() - cert_file.close() - - # Build an SSLContext that trusts ONLY this cert, with hostname check disabled - ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - ctx.load_verify_locations(cert_file.name) - ctx.check_hostname = False # Cert is for a hostname, but we connect via IP - ctx.verify_mode = ssl.CERT_REQUIRED # Still verify the cert itself - return ctx - - class LLM: """ LLM inference namespace. @@ -110,17 +61,34 @@ class LLM: result = client.llm.completion(model=TEE_LLM.CLAUDE_3_5_HAIKU, prompt="Hello") """ - def __init__(self, wallet_account: LocalAccount, og_llm_server_url: str, og_llm_streaming_server_url: str): + def __init__( + self, + wallet_account: LocalAccount, + og_llm_server_url: str, + og_llm_streaming_server_url: str, + tls_cert_der: Optional[bytes] = None, + tee_id: Optional[str] = None, + tee_payment_address: Optional[str] = None, + ): self._wallet_account = wallet_account self._og_llm_server_url = og_llm_server_url self._og_llm_streaming_server_url = og_llm_streaming_server_url - self._tls_verify: Union[ssl.SSLContext, bool] = ( - _fetch_tls_cert_as_ssl_context(self._og_llm_server_url) or True - ) - self._streaming_tls_verify: Union[ssl.SSLContext, bool] = ( - _fetch_tls_cert_as_ssl_context(self._og_llm_streaming_server_url) or True - ) + # TEE metadata surfaced on every response so callers can verify/audit which + # enclave served the request. + self._tee_id = tee_id + self._tee_endpoint = og_llm_server_url + self._tee_payment_address = tee_payment_address + + if tls_cert_der: + # Use the registry-verified certificate as the sole trust anchor. + ssl_ctx = build_ssl_context_from_der(tls_cert_der) + self._tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx + self._streaming_tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx + else: + # No cert from registry — fall back to default system CA verification. + self._tls_verify = True + self._streaming_tls_verify = True signer = EthAccountSignerv2(self._wallet_account) self._x402_client = x402Clientv2() @@ -283,6 +251,9 @@ async def make_request_v2(): completion_output=result.get("completion"), tee_signature=result.get("tee_signature"), tee_timestamp=result.get("tee_timestamp"), + tee_id=self._tee_id, + tee_endpoint=self._tee_endpoint, + tee_payment_address=self._tee_payment_address, ) except Exception as e: @@ -422,6 +393,9 @@ async def make_request_v2(): chat_output=message, tee_signature=result.get("tee_signature"), tee_timestamp=result.get("tee_timestamp"), + tee_id=self._tee_id, + tee_endpoint=self._tee_endpoint, + tee_payment_address=self._tee_payment_address, ) except Exception as e: @@ -560,7 +534,12 @@ async def _parse_sse_response(response) -> AsyncGenerator[StreamChunk, None]: try: data = json.loads(data_str) - yield StreamChunk.from_sse_data(data) + chunk = StreamChunk.from_sse_data(data) + if chunk.is_final: + chunk.tee_id = self._tee_id + chunk.tee_endpoint = self._tee_endpoint + chunk.tee_payment_address = self._tee_payment_address + yield chunk except json.JSONDecodeError: continue diff --git a/src/opengradient/client/tee_registry.py b/src/opengradient/client/tee_registry.py new file mode 100644 index 0000000..83e92b0 --- /dev/null +++ b/src/opengradient/client/tee_registry.py @@ -0,0 +1,152 @@ +"""TEE Registry client for fetching verified TEE endpoints and TLS certificates.""" + +import logging +import ssl +import tempfile +from dataclasses import dataclass +from typing import List, Optional + +from web3 import Web3 + +from ._utils import get_abi + +logger = logging.getLogger(__name__) + +# TEE types as defined in the registry contract +TEE_TYPE_LLM_PROXY = 0 +TEE_TYPE_VALIDATOR = 1 + + +@dataclass +class TEEEndpoint: + """A verified TEE with its endpoint URL and TLS certificate from the registry.""" + + tee_id: str + endpoint: str + tls_cert_der: bytes + payment_address: str + + +class TEERegistry: + """ + Queries the on-chain TEE Registry contract to retrieve verified TEE endpoints + and their TLS certificates. + + Instead of blindly trusting the TLS certificate presented by a TEE server + (TOFU), this class fetches the certificate that was submitted and verified + during TEE registration. Any certificate that does not match the one stored + in the registry should be rejected. + + Args: + rpc_url: RPC endpoint for the chain where the registry is deployed. + registry_address: Address of the deployed TEERegistry contract. + """ + + def __init__(self, rpc_url: str, registry_address: str): + self._web3 = Web3(Web3.HTTPProvider(rpc_url)) + abi = get_abi("TEERegistry.abi") + self._contract = self._web3.eth.contract( + address=Web3.to_checksum_address(registry_address), + abi=abi, + ) + + def get_active_tees_by_type(self, tee_type: int) -> List[TEEEndpoint]: + """ + Return all active TEEs of the given type with their endpoints and TLS certs. + + Args: + tee_type: Integer TEE type (0=LLMProxy, 1=Validator). + + Returns: + List of TEEEndpoint objects for active TEEs of that type. + """ + type_label = {TEE_TYPE_LLM_PROXY: "LLMProxy", TEE_TYPE_VALIDATOR: "Validator"}.get(tee_type, str(tee_type)) + + try: + tee_ids: List[bytes] = self._contract.functions.getTEEsByType(tee_type).call() + except Exception as e: + logger.warning("Failed to fetch TEE IDs from registry (type=%s): %s", type_label, e) + return [] + + logger.debug("Registry returned %d TEE ID(s) for type=%s", len(tee_ids), type_label) + + endpoints: List[TEEEndpoint] = [] + for tee_id in tee_ids: + tee_id_hex = tee_id.hex() + try: + info = self._contract.functions.getTEE(tee_id).call() + # TEEInfo tuple order: owner, paymentAddress, endpoint, publicKey, + # tlsCertificate, pcrHash, teeType, active, + # registeredAt, lastUpdatedAt + owner, payment_address, endpoint, _pub_key, tls_cert_der, _pcr_hash, _tee_type, active, _reg_at, _upd_at = info + if not active: + logger.debug(" teeId=%s status=inactive endpoint=%s (skipped)", tee_id_hex, endpoint) + continue + if not endpoint or not tls_cert_der: + logger.warning(" teeId=%s missing endpoint or TLS cert (skipped)", tee_id_hex) + continue + logger.info( + " teeId=%s endpoint=%s paymentAddress=%s certBytes=%d", + tee_id_hex, + endpoint, + payment_address, + len(tls_cert_der), + ) + endpoints.append( + TEEEndpoint( + tee_id=tee_id_hex, + endpoint=endpoint, + tls_cert_der=bytes(tls_cert_der), + payment_address=payment_address, + ) + ) + except Exception as e: + logger.warning("Failed to fetch TEE info for teeId=%s: %s", tee_id_hex, e) + + logger.info("Discovered %d active %s TEE(s) from registry", len(endpoints), type_label) + return endpoints + + def get_llm_tee(self) -> Optional[TEEEndpoint]: + """ + Return the first active LLM proxy TEE from the registry. + + Returns: + TEEEndpoint for an active LLM proxy TEE, or None if none are available. + """ + logger.debug("Querying TEE registry for active LLM proxy TEEs...") + tees = self.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) + if tees: + logger.info("Selected LLM TEE: endpoint=%s teeId=%s", tees[0].endpoint, tees[0].tee_id) + else: + logger.warning("No active LLM proxy TEEs found in registry") + return tees[0] if tees else None + + +def build_ssl_context_from_der(der_cert: bytes) -> ssl.SSLContext: + """ + Build an ssl.SSLContext that trusts *only* the given DER-encoded certificate. + + Hostname verification is disabled because TEE servers are typically addressed + by IP while the cert may be issued for a different hostname. The pinned + certificate itself is the trust anchor — only that cert is accepted. + + Args: + der_cert: DER-encoded X.509 certificate bytes as stored in the registry. + + Returns: + ssl.SSLContext configured to accept only the pinned certificate. + """ + pem = ssl.DER_cert_to_PEM_cert(der_cert) + + cert_file = tempfile.NamedTemporaryFile( + prefix="og_tee_tls_", suffix=".pem", delete=False, mode="w" + ) + cert_file.write(pem) + cert_file.flush() + cert_file.close() + + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.load_verify_locations(cert_file.name) + ctx.check_hostname = False # TEE cert may be issued for a hostname; we connect via IP + ctx.verify_mode = ssl.CERT_REQUIRED + return ctx diff --git a/src/opengradient/defaults.py b/src/opengradient/defaults.py index c053225..1df851e 100644 --- a/src/opengradient/defaults.py +++ b/src/opengradient/defaults.py @@ -6,6 +6,7 @@ DEFAULT_INFERENCE_CONTRACT_ADDRESS = "0x8383C9bD7462F12Eb996DD02F78234C0421A6FaE" DEFAULT_SCHEDULER_ADDRESS = "0x7179724De4e7FF9271FA40C0337c7f90C0508eF6" DEFAULT_BLOCKCHAIN_EXPLORER = "https://explorer.opengradient.ai/tx/" -# TODO (Kyle): Add a process to fetch these IPs from the TEE registry -DEFAULT_OPENGRADIENT_LLM_SERVER_URL = "https://3.15.214.21:443" -DEFAULT_OPENGRADIENT_LLM_STREAMING_SERVER_URL = "https://3.15.214.21:443" +# TEE Registry contract on the OG EVM chain — used to discover LLM proxy endpoints +# and fetch their registry-verified TLS certificates instead of blindly trusting TOFU. +DEFAULT_TEE_REGISTRY_ADDRESS = "0x3d641a2791533b4a0000345ea8d509d01e1ec301" +DEFAULT_TEE_REGISTRY_RPC_URL = "http://13.59.43.94:8545" diff --git a/src/opengradient/types.py b/src/opengradient/types.py index 3a13aac..98b1d47 100644 --- a/src/opengradient/types.py +++ b/src/opengradient/types.py @@ -241,6 +241,9 @@ class StreamChunk: is_final: Whether this is the final chunk (before [DONE]) tee_signature: RSA-PSS signature over the response, present on the final chunk tee_timestamp: ISO timestamp from the TEE at signing time, present on the final chunk + tee_id: On-chain TEE registry ID of the enclave that served this request (final chunk only) + tee_endpoint: Endpoint URL of the TEE that served this request (final chunk only) + tee_payment_address: Payment address registered for the TEE (final chunk only) """ choices: List[StreamChoice] @@ -249,6 +252,9 @@ class StreamChunk: is_final: bool = False tee_signature: Optional[str] = None tee_timestamp: Optional[str] = None + tee_id: Optional[str] = None + tee_endpoint: Optional[str] = None + tee_payment_address: Optional[str] = None @classmethod def from_sse_data(cls, data: Dict) -> "StreamChunk": @@ -396,6 +402,15 @@ class TextGenerationOutput: tee_timestamp: Optional[str] = None """ISO timestamp from the TEE at signing time.""" + tee_id: Optional[str] = None + """On-chain TEE registry ID (keccak256 of the enclave's public key) of the TEE that served this request.""" + + tee_endpoint: Optional[str] = None + """Endpoint URL of the TEE that served this request, as registered on-chain.""" + + tee_payment_address: Optional[str] = None + """Payment address registered for the TEE that served this request.""" + @dataclass class AbiFunction: From bfedcae15aff2e21975d05d771765db20d1d7daa Mon Sep 17 00:00:00 2001 From: kylexqian Date: Fri, 6 Mar 2026 01:48:02 -0800 Subject: [PATCH 02/13] fix: fall back to non-streaming endpoint when tool calls are requested with stream=True MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The TEE streaming endpoint returns an empty delta ("delta": {}) and no tool call content in SSE events when tools is supplied — the server-side streaming path simply does not emit tool call data. Introduce _tee_llm_chat_tools_as_stream which transparently calls the non-streaming /v1/chat/completions endpoint and wraps the complete TextGenerationOutput as a single final StreamChunk (with tool_calls populated in delta). chat() now routes stream=True + tools to this method, preserving the streaming iterator interface for callers and the CLI while returning correct results. Also removes the temporary [SSE RAW] debug print added during diagnosis, and fixes from_sse_data to accept "message" as a fallback for "delta" when the proxy sends a non-streaming format in SSE events. Co-Authored-By: Claude Sonnet 4.6 --- src/opengradient/client/llm.py | 72 +++++++++++++++++++++++++++++++++- src/opengradient/types.py | 4 +- 2 files changed, 74 insertions(+), 2 deletions(-) diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index 2258f7c..9a5db0e 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -2,11 +2,14 @@ import asyncio import json +import logging import ssl import threading from queue import Queue from typing import AsyncGenerator, Dict, List, Optional, Union +logger = logging.getLogger(__name__) + import httpx from eth_account.account import LocalAccount from x402v2 import x402Client as x402Clientv2 @@ -15,7 +18,7 @@ from x402v2.mechanisms.evm.exact.register import register_exact_evm_client as register_exact_evm_clientv2 from x402v2.mechanisms.evm.upto.register import register_upto_evm_client as register_upto_evm_clientv2 -from ..types import TEE_LLM, StreamChunk, TextGenerationOutput, TextGenerationStream, x402SettlementMode +from ..types import TEE_LLM, StreamChunk, StreamChoice, StreamDelta, TextGenerationOutput, TextGenerationStream, x402SettlementMode from .exceptions import OpenGradientError from .opg_token import Permit2ApprovalResult, ensure_opg_approval from .tee_registry import build_ssl_context_from_der @@ -305,6 +308,20 @@ def chat( OpenGradientError: If the inference fails. """ if stream: + if tools: + # The TEE streaming endpoint omits tool call content from SSE events. + # Fall back transparently to the non-streaming endpoint and emit a + # single final StreamChunk so callers get the complete tool call data. + return self._tee_llm_chat_tools_as_stream( + model=model.split("/")[1], + messages=messages, + max_tokens=max_tokens, + stop_sequence=stop_sequence, + temperature=temperature, + tools=tools, + tool_choice=tool_choice, + x402_settlement_mode=x402_settlement_mode, + ) # Use threading bridge for true sync streaming return self._tee_llm_chat_stream_sync( model=model.split("/")[1], @@ -408,6 +425,59 @@ async def make_request_v2(): except Exception as e: raise OpenGradientError(f"TEE LLM chat failed: {str(e)}") + def _tee_llm_chat_tools_as_stream( + self, + model: str, + messages: List[Dict], + max_tokens: int = 100, + stop_sequence: Optional[List[str]] = None, + temperature: float = 0.0, + tools: Optional[List[Dict]] = None, + tool_choice: Optional[str] = None, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH, + ): + """ + Transparent non-streaming fallback for tool-call requests with stream=True. + + The TEE streaming endpoint returns an empty delta when tools are present — + tool call content is not emitted as SSE events. This method calls the + non-streaming endpoint instead and emits a single final StreamChunk that + carries the complete tool call response, preserving the streaming interface + for callers (including the CLI). + """ + result = self._tee_llm_chat( + model=model, + messages=messages, + max_tokens=max_tokens, + stop_sequence=stop_sequence, + temperature=temperature, + tools=tools, + tool_choice=tool_choice, + x402_settlement_mode=x402_settlement_mode, + ) + + chat_output = result.chat_output or {} + delta = StreamDelta( + role=chat_output.get("role"), + content=chat_output.get("content"), + tool_calls=chat_output.get("tool_calls"), + ) + choice = StreamChoice( + delta=delta, + index=0, + finish_reason=result.finish_reason, + ) + yield StreamChunk( + choices=[choice], + model=model, + is_final=True, + tee_signature=result.tee_signature, + tee_timestamp=result.tee_timestamp, + tee_id=result.tee_id, + tee_endpoint=result.tee_endpoint, + tee_payment_address=result.tee_payment_address, + ) + def _tee_llm_chat_stream_sync( self, model: str, diff --git a/src/opengradient/types.py b/src/opengradient/types.py index 98b1d47..1a703c4 100644 --- a/src/opengradient/types.py +++ b/src/opengradient/types.py @@ -269,7 +269,9 @@ def from_sse_data(cls, data: Dict) -> "StreamChunk": """ choices = [] for choice_data in data.get("choices", []): - delta_data = choice_data.get("delta", {}) + # The TEE proxy sometimes sends SSE events using the non-streaming "message" + # key instead of the standard streaming "delta" key. Fall back gracefully. + delta_data = choice_data.get("delta") or choice_data.get("message") or {} delta = StreamDelta(content=delta_data.get("content"), role=delta_data.get("role"), tool_calls=delta_data.get("tool_calls")) choice = StreamChoice(delta=delta, index=choice_data.get("index", 0), finish_reason=choice_data.get("finish_reason")) choices.append(choice) From 7057ddb9f190aae8a2955bcc23caca4878dd2f3f Mon Sep 17 00:00:00 2001 From: kylexqian Date: Fri, 6 Mar 2026 01:50:57 -0800 Subject: [PATCH 03/13] fix(tests): update client_test fixtures to mock TEERegistry and fix TEE_LLM references Two issues caused CI failures: 1. The mock_web3 fixture didn't patch TEERegistry, so Client.__init__ tried to instantiate a real TEERegistry (with a live Web3 connection) even in unit tests. The mock_abi_files fallback returned {} for TEERegistry.abi but web3.eth.contract() requires a list, causing ValueError. Fix: patch src.opengradient.client.client.TEERegistry inside mock_web3 and return a fake TEEEndpoint with a stub endpoint/tee_id/payment_address. 2. Three LLM tests referenced TEE_LLM.GPT_4O which no longer exists in the enum. Updated to TEE_LLM.GPT_5. Co-Authored-By: Claude Sonnet 4.6 --- tests/client_test.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/tests/client_test.py b/tests/client_test.py index f17283b..721f382 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -21,7 +21,10 @@ @pytest.fixture def mock_web3(): """Create a mock Web3 instance.""" - with patch("src.opengradient.client.client.Web3") as mock: + with ( + patch("src.opengradient.client.client.Web3") as mock, + patch("src.opengradient.client.client.TEERegistry") as mock_tee_registry, + ): mock_instance = MagicMock() mock.return_value = mock_instance mock.HTTPProvider.return_value = MagicMock() @@ -31,6 +34,14 @@ def mock_web3(): mock_instance.eth.gas_price = 1000000000 mock_instance.eth.contract.return_value = MagicMock() + # Return a fake active TEE endpoint so Client.__init__ doesn't need a live registry + mock_tee = MagicMock() + mock_tee.endpoint = "https://test.tee.server" + mock_tee.tls_cert_der = None + mock_tee.tee_id = "test-tee-id" + mock_tee.payment_address = "0xTestPaymentAddress" + mock_tee_registry.return_value.get_llm_tee.return_value = mock_tee + yield mock_instance @@ -194,7 +205,7 @@ def test_llm_completion_success(self, client): ) result = client.llm.completion( - model=TEE_LLM.GPT_4O, + model=TEE_LLM.GPT_5, prompt="Hello", max_tokens=100, ) @@ -215,7 +226,7 @@ def test_llm_chat_success_non_streaming(self, client): ) result = client.llm.chat( - model=TEE_LLM.GPT_4O, + model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hello"}], stream=False, ) @@ -233,7 +244,7 @@ def test_llm_chat_streaming(self, client): mock_stream.return_value = iter(mock_chunks) result = client.llm.chat( - model=TEE_LLM.GPT_4O, + model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hello"}], stream=True, ) From 7cf92d10a20893efbcb20a03b70f22deb6717a32 Mon Sep 17 00:00:00 2001 From: kylexqian Date: Fri, 6 Mar 2026 02:02:19 -0800 Subject: [PATCH 04/13] fix(cli): guard against empty choices in streaming loop and format tool calls consistently Two bugs caused tool call results to be invisible in the CLI: 1. print_streaming_chat_result accessed chunk.choices[0] unconditionally. Usage-only SSE chunks carry an empty choices list, which caused an IndexError that was silently swallowed by the outer except block, truncating output before tool calls or finish_reason were printed. Fixed by guarding all choices[0] accesses with `if chunk.choices:`. 2. print_llm_chat_result (non-streaming) printed tool_calls as a raw Python dict repr. Updated to use the same formatted output as the streaming path: "Tool Calls: / Function: ... / Arguments: ...". Co-Authored-By: Claude Sonnet 4.6 --- src/opengradient/cli.py | 57 ++++++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/src/opengradient/cli.py b/src/opengradient/cli.py index bbec05e..650a86e 100644 --- a/src/opengradient/cli.py +++ b/src/opengradient/cli.py @@ -630,16 +630,24 @@ def print_llm_chat_result(model_cid, tx_hash, finish_reason, chat_output, is_van click.secho("Chat Output:", fg="yellow", bold=True) click.echo() for key, value in chat_output.items(): - if value is not None and value not in ("", "[]", []): + if value is None or value in ("", "[]", []): + continue + if key == "tool_calls": + # Format tool calls the same way as the streaming path + click.secho("Tool Calls:", fg="yellow", bold=True) + for tool_call in value: + fn = tool_call.get("function", {}) + click.echo(f" Function: {fn.get('name', '')}") + click.echo(f" Arguments: {fn.get('arguments', '')}") + elif key == "content" and isinstance(value, list): # Normalize list-of-blocks content (e.g. Gemini 3 thought signatures) - if key == "content" and isinstance(value, list): - text = " ".join( - block.get("text", "") for block in value - if isinstance(block, dict) and block.get("type") == "text" - ).strip() - click.echo(f"{key}: {text}") - else: - click.echo(f"{key}: {value}") + text = " ".join( + block.get("text", "") for block in value + if isinstance(block, dict) and block.get("type") == "text" + ).strip() + click.echo(f"{key}: {text}") + else: + click.echo(f"{key}: {value}") click.echo() @@ -663,20 +671,21 @@ def print_streaming_chat_result(model_cid, stream, is_tee=True): for chunk in stream: chunk_count += 1 - if chunk.choices[0].delta.content: - content = chunk.choices[0].delta.content - sys.stdout.write(content) - sys.stdout.flush() - content_parts.append(content) - - # Handle tool calls - if chunk.choices[0].delta.tool_calls: - sys.stdout.write("\n") - sys.stdout.flush() - click.secho("Tool Calls:", fg="yellow", bold=True) - for tool_call in chunk.choices[0].delta.tool_calls: - click.echo(f" Function: {tool_call['function']['name']}") - click.echo(f" Arguments: {tool_call['function']['arguments']}") + if chunk.choices: + if chunk.choices[0].delta.content: + content = chunk.choices[0].delta.content + sys.stdout.write(content) + sys.stdout.flush() + content_parts.append(content) + + # Handle tool calls + if chunk.choices[0].delta.tool_calls: + sys.stdout.write("\n") + sys.stdout.flush() + click.secho("Tool Calls:", fg="yellow", bold=True) + for tool_call in chunk.choices[0].delta.tool_calls: + click.echo(f" Function: {tool_call['function']['name']}") + click.echo(f" Arguments: {tool_call['function']['arguments']}") # Print final info when stream completes if chunk.is_final: @@ -691,7 +700,7 @@ def print_streaming_chat_result(model_cid, stream, is_tee=True): click.echo(f" Total tokens: {chunk.usage.total_tokens}") click.echo() - if chunk.choices[0].finish_reason: + if chunk.choices and chunk.choices[0].finish_reason: click.echo("Finish reason: ", nl=False) click.secho(chunk.choices[0].finish_reason, fg="green") From 304e066fa44e05c18f11ce84b0cd09fe83dda0b7 Mon Sep 17 00:00:00 2001 From: kylexqian Date: Fri, 6 Mar 2026 02:03:54 -0800 Subject: [PATCH 05/13] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 262637b..fd0428d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "opengradient" -version = "0.7.5" +version = "0.7.6" description = "Python SDK for OpenGradient decentralized model management & inference services" authors = [{name = "OpenGradient", email = "adam@vannalabs.ai"}] readme = "README.md" From 8b9cc269cb5bb14975d00f07cf556ceb3ff8152f Mon Sep 17 00:00:00 2001 From: kylexqian Date: Fri, 6 Mar 2026 02:06:40 -0800 Subject: [PATCH 06/13] fix(makefile): add system prompt to chat-tool targets to reliably trigger tool calls With tool_choice="auto", models like GPT-5 require a system prompt instructing them to use tools, otherwise they respond with finish_reason "stop" and empty content. Updated chat-tool and chat-stream-tool to include a system message matching the pattern that works in the SDK. Co-Authored-By: Claude Sonnet 4.6 --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index acc46ee..0bda45d 100644 --- a/Makefile +++ b/Makefile @@ -87,14 +87,14 @@ chat-stream: chat-tool: python -m opengradient.cli chat \ --model $(MODEL) \ - --messages '[{"role":"user","content":"What is the weather in Tokyo?"}]' \ + --messages '[{"role":"system","content":"You are a helpful assistant. Use the available tools when needed."},{"role":"user","content":"What is the weather in Tokyo? Give me the temperature in celsius."}]' \ --tools '[{"type":"function","function":{"name":"get_weather","description":"Get weather for a location","parameters":{"type":"object","properties":{"location":{"type":"string"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}]' \ --max-tokens 100 chat-stream-tool: python -m opengradient.cli chat \ --model $(MODEL) \ - --messages '[{"role":"user","content":"What is the weather in Tokyo?"}]' \ + --messages '[{"role":"system","content":"You are a helpful assistant. Use the available tools when needed."},{"role":"user","content":"What is the weather in Tokyo? Give me the temperature in celsius."}]' \ --tools '[{"type":"function","function":{"name":"get_weather","description":"Get weather for a location","parameters":{"type":"object","properties":{"location":{"type":"string"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}]' \ --max-tokens 100 \ --stream From 56a1ce4560297150630e9c41e2b2558578cb183b Mon Sep 17 00:00:00 2001 From: kylexqian Date: Fri, 6 Mar 2026 02:15:46 -0800 Subject: [PATCH 07/13] fix(makefile): use proven Dallas/Texas payload to reliably trigger GPT-5 tool calls The previous Tokyo/get_weather(location) payload still failed with finish_reason: stop on GPT-5. Switch chat-tool and chat-stream-tool to use get_current_weather(city, state, unit) with Dallas, Texas which consistently triggers tool_calls finish_reason. Co-Authored-By: Claude Sonnet 4.6 --- Makefile | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index 0bda45d..cdb51c6 100644 --- a/Makefile +++ b/Makefile @@ -87,16 +87,16 @@ chat-stream: chat-tool: python -m opengradient.cli chat \ --model $(MODEL) \ - --messages '[{"role":"system","content":"You are a helpful assistant. Use the available tools when needed."},{"role":"user","content":"What is the weather in Tokyo? Give me the temperature in celsius."}]' \ - --tools '[{"type":"function","function":{"name":"get_weather","description":"Get weather for a location","parameters":{"type":"object","properties":{"location":{"type":"string"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}]' \ - --max-tokens 100 + --messages '[{"role":"system","content":"You are a helpful assistant. Use tools when needed."},{"role":"user","content":"What'\''s the weather like in Dallas, Texas? Give me the temperature in fahrenheit."}]' \ + --tools '[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"city":{"type":"string"},"state":{"type":"string"},"unit":{"type":"string","enum":["fahrenheit","celsius"]}},"required":["city","state","unit"]}}}]' \ + --max-tokens 200 chat-stream-tool: python -m opengradient.cli chat \ --model $(MODEL) \ - --messages '[{"role":"system","content":"You are a helpful assistant. Use the available tools when needed."},{"role":"user","content":"What is the weather in Tokyo? Give me the temperature in celsius."}]' \ - --tools '[{"type":"function","function":{"name":"get_weather","description":"Get weather for a location","parameters":{"type":"object","properties":{"location":{"type":"string"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}]' \ - --max-tokens 100 \ + --messages '[{"role":"system","content":"You are a helpful assistant. Use tools when needed."},{"role":"user","content":"What'\''s the weather like in Dallas, Texas? Give me the temperature in fahrenheit."}]' \ + --tools '[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"city":{"type":"string"},"state":{"type":"string"},"unit":{"type":"string","enum":["fahrenheit","celsius"]}},"required":["city","state","unit"]}}}]' \ + --max-tokens 200 \ --stream .PHONY: install build publish check docs test utils_test client_test langchain_adapter_test opg_token_test integrationtest examples \ From cc6919e52eccc61ef73a81e018f21f796ec5f141 Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Tue, 10 Mar 2026 18:00:20 -0400 Subject: [PATCH 08/13] in memory cert + registry test --- Makefile | 5 +- src/opengradient/client/tee_registry.py | 10 +- tests/langchain_adapter_test.py | 12 +- tests/tee_registry_test.py | 250 ++++++++++++++++++++++++ 4 files changed, 261 insertions(+), 16 deletions(-) create mode 100644 tests/tee_registry_test.py diff --git a/Makefile b/Makefile index cdb51c6..d4fabf8 100644 --- a/Makefile +++ b/Makefile @@ -31,7 +31,7 @@ docs: # Testing # ============================================================================ -test: utils_test client_test langchain_adapter_test opg_token_test +test: utils_test client_test langchain_adapter_test opg_token_test tee_registry_test utils_test: pytest tests/utils_test.py -v @@ -45,6 +45,9 @@ langchain_adapter_test: opg_token_test: pytest tests/opg_token_test.py -v +tee_registry_test: + pytest tests/tee_registry_test.py -v + integrationtest: python integrationtest/agent/test_agent.py python integrationtest/workflow_models/test_workflow_models.py diff --git a/src/opengradient/client/tee_registry.py b/src/opengradient/client/tee_registry.py index 83e92b0..e61b0a4 100644 --- a/src/opengradient/client/tee_registry.py +++ b/src/opengradient/client/tee_registry.py @@ -2,7 +2,6 @@ import logging import ssl -import tempfile from dataclasses import dataclass from typing import List, Optional @@ -138,15 +137,8 @@ def build_ssl_context_from_der(der_cert: bytes) -> ssl.SSLContext: """ pem = ssl.DER_cert_to_PEM_cert(der_cert) - cert_file = tempfile.NamedTemporaryFile( - prefix="og_tee_tls_", suffix=".pem", delete=False, mode="w" - ) - cert_file.write(pem) - cert_file.flush() - cert_file.close() - ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - ctx.load_verify_locations(cert_file.name) + ctx.load_verify_locations(cadata=pem) ctx.check_hostname = False # TEE cert may be issued for a hostname; we connect via IP ctx.verify_mode = ssl.CERT_REQUIRED return ctx diff --git a/tests/langchain_adapter_test.py b/tests/langchain_adapter_test.py index ab290a8..1671c7f 100644 --- a/tests/langchain_adapter_test.py +++ b/tests/langchain_adapter_test.py @@ -26,34 +26,34 @@ def mock_client(): @pytest.fixture def model(mock_client): """Create an OpenGradientChatModel with a mocked client.""" - return OpenGradientChatModel(private_key="0x" + "a" * 64, model_cid=TEE_LLM.GPT_4O) + return OpenGradientChatModel(private_key="0x" + "a" * 64, model_cid=TEE_LLM.GPT_5) class TestOpenGradientChatModel: def test_initialization(self, model): """Test model initializes with correct fields.""" - assert model.model_cid == TEE_LLM.GPT_4O + assert model.model_cid == TEE_LLM.GPT_5 assert model.max_tokens == 300 assert model.x402_settlement_mode == x402SettlementMode.SETTLE_BATCH assert model._llm_type == "opengradient" def test_initialization_custom_max_tokens(self, mock_client): """Test model initializes with custom max_tokens.""" - model = OpenGradientChatModel(private_key="0x" + "a" * 64, model_cid=TEE_LLM.CLAUDE_3_5_HAIKU, max_tokens=1000) + model = OpenGradientChatModel(private_key="0x" + "a" * 64, model_cid=TEE_LLM.CLAUDE_HAIKU_4_5, max_tokens=1000) assert model.max_tokens == 1000 def test_initialization_custom_settlement_mode(self, mock_client): """Test model initializes with custom settlement mode.""" model = OpenGradientChatModel( private_key="0x" + "a" * 64, - model_cid=TEE_LLM.GPT_4O, + model_cid=TEE_LLM.GPT_5, x402_settlement_mode=x402SettlementMode.SETTLE, ) assert model.x402_settlement_mode == x402SettlementMode.SETTLE def test_identifying_params(self, model): """Test _identifying_params returns model name.""" - assert model._identifying_params == {"model_name": TEE_LLM.GPT_4O} + assert model._identifying_params == {"model_name": TEE_LLM.GPT_5} class TestGenerate: @@ -210,7 +210,7 @@ def test_passes_correct_params_to_client(self, model, mock_client): model._generate([HumanMessage(content="Hi")], stop=["END"]) mock_client.llm.chat.assert_called_once_with( - model=TEE_LLM.GPT_4O, + model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}], stop_sequence=["END"], max_tokens=300, diff --git a/tests/tee_registry_test.py b/tests/tee_registry_test.py new file mode 100644 index 0000000..72469de --- /dev/null +++ b/tests/tee_registry_test.py @@ -0,0 +1,250 @@ +import os +import ssl +import sys +from unittest.mock import MagicMock, patch + +import pytest + +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) + +from src.opengradient.client.tee_registry import ( + TEE_TYPE_LLM_PROXY, + TEE_TYPE_VALIDATOR, + TEEEndpoint, + TEERegistry, + build_ssl_context_from_der, +) + + +# --- Helpers --- + + +def _make_tee_info( + endpoint="https://tee.example.com", + payment_address="0xPayment", + tls_cert_der=b"\x01\x02\x03", + active=True, +): + """Build a tuple matching the TEEInfo struct order from the contract.""" + return ( + "0xOwner", # owner + payment_address, # paymentAddress + endpoint, # endpoint + b"pubkey", # publicKey + tls_cert_der, # tlsCertificate + b"pcrhash", # pcrHash + 0, # teeType + active, # active + 1000, # registeredAt + 2000, # lastUpdatedAt + ) + + +def _make_self_signed_der() -> bytes: + """Generate a minimal self-signed DER certificate for testing.""" + from cryptography import x509 + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.x509.oid import NameOID + import datetime + + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + subject = issuer = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "test")]) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now(datetime.UTC)) + .not_valid_after(datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=1)) + .sign(key, hashes.SHA256()) + ) + return cert.public_bytes(serialization.Encoding.DER) + + +# --- Fixtures --- + + +@pytest.fixture +def mock_contract(): + """Create a TEERegistry with a mocked Web3 contract.""" + with ( + patch("src.opengradient.client.tee_registry.Web3") as mock_web3_cls, + patch("src.opengradient.client.tee_registry.get_abi") as mock_get_abi, + ): + mock_get_abi.return_value = [] + mock_web3 = MagicMock() + mock_web3_cls.return_value = mock_web3 + mock_web3_cls.HTTPProvider.return_value = MagicMock() + mock_web3_cls.to_checksum_address.side_effect = lambda x: x + + contract = MagicMock() + mock_web3.eth.contract.return_value = contract + + registry = TEERegistry(rpc_url="http://localhost:8545", registry_address="0xRegistry") + yield registry, contract + + +# --- TEERegistry Tests --- + + +class TestGetActiveTeesByType: + def test_returns_active_tees(self, mock_contract): + registry, contract = mock_contract + + tee_id = b"\xaa" * 32 + contract.functions.getTEEsByType.return_value.call.return_value = [tee_id] + contract.functions.getTEE.return_value.call.return_value = _make_tee_info() + + result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) + + assert len(result) == 1 + assert result[0].tee_id == tee_id.hex() + assert result[0].endpoint == "https://tee.example.com" + assert result[0].payment_address == "0xPayment" + assert result[0].tls_cert_der == b"\x01\x02\x03" + + def test_skips_inactive_tees(self, mock_contract): + registry, contract = mock_contract + + tee_id = b"\xbb" * 32 + contract.functions.getTEEsByType.return_value.call.return_value = [tee_id] + contract.functions.getTEE.return_value.call.return_value = _make_tee_info(active=False) + + result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) + assert len(result) == 0 + + def test_skips_tee_with_empty_endpoint(self, mock_contract): + registry, contract = mock_contract + + tee_id = b"\xcc" * 32 + contract.functions.getTEEsByType.return_value.call.return_value = [tee_id] + contract.functions.getTEE.return_value.call.return_value = _make_tee_info(endpoint="") + + result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) + assert len(result) == 0 + + def test_skips_tee_with_empty_cert(self, mock_contract): + registry, contract = mock_contract + + tee_id = b"\xdd" * 32 + contract.functions.getTEEsByType.return_value.call.return_value = [tee_id] + contract.functions.getTEE.return_value.call.return_value = _make_tee_info(tls_cert_der=b"") + + result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) + assert len(result) == 0 + + def test_returns_empty_on_rpc_failure(self, mock_contract): + registry, contract = mock_contract + + contract.functions.getTEEsByType.return_value.call.side_effect = Exception("RPC error") + + result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) + assert result == [] + + def test_skips_individual_tee_on_lookup_failure(self, mock_contract): + registry, contract = mock_contract + + good_id = b"\xaa" * 32 + bad_id = b"\xbb" * 32 + contract.functions.getTEEsByType.return_value.call.return_value = [bad_id, good_id] + + def get_tee_side_effect(tee_id): + mock = MagicMock() + if tee_id == bad_id: + mock.call.side_effect = Exception("lookup failed") + else: + mock.call.return_value = _make_tee_info() + return mock + + contract.functions.getTEE.side_effect = get_tee_side_effect + + result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) + assert len(result) == 1 + assert result[0].tee_id == good_id.hex() + + def test_multiple_active_tees(self, mock_contract): + registry, contract = mock_contract + + ids = [b"\x01" * 32, b"\x02" * 32, b"\x03" * 32] + contract.functions.getTEEsByType.return_value.call.return_value = ids + + def get_tee_side_effect(tee_id): + mock = MagicMock() + mock.call.return_value = _make_tee_info( + endpoint=f"https://tee-{tee_id.hex()[:4]}.example.com" + ) + return mock + + contract.functions.getTEE.side_effect = get_tee_side_effect + + result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) + assert len(result) == 3 + + def test_validator_type_label(self, mock_contract): + """Ensure validator type queries work the same way.""" + registry, contract = mock_contract + + contract.functions.getTEEsByType.return_value.call.return_value = [] + + result = registry.get_active_tees_by_type(TEE_TYPE_VALIDATOR) + assert result == [] + contract.functions.getTEEsByType.assert_called_once_with(TEE_TYPE_VALIDATOR) + + +class TestGetLlmTee: + def test_returns_first_active_tee(self, mock_contract): + registry, contract = mock_contract + + ids = [b"\x01" * 32, b"\x02" * 32] + contract.functions.getTEEsByType.return_value.call.return_value = ids + contract.functions.getTEE.return_value.call.return_value = _make_tee_info() + + result = registry.get_llm_tee() + + assert result is not None + assert result.tee_id == ids[0].hex() + + def test_returns_none_when_no_tees(self, mock_contract): + registry, contract = mock_contract + + contract.functions.getTEEsByType.return_value.call.return_value = [] + + result = registry.get_llm_tee() + assert result is None + + def test_queries_llm_proxy_type(self, mock_contract): + registry, contract = mock_contract + + contract.functions.getTEEsByType.return_value.call.return_value = [] + registry.get_llm_tee() + + contract.functions.getTEEsByType.assert_called_once_with(TEE_TYPE_LLM_PROXY) + + +# --- build_ssl_context_from_der Tests --- + + +class TestBuildSslContextFromDer: + def test_returns_ssl_context(self): + der_cert = _make_self_signed_der() + ctx = build_ssl_context_from_der(der_cert) + + assert isinstance(ctx, ssl.SSLContext) + + def test_hostname_check_disabled(self): + der_cert = _make_self_signed_der() + ctx = build_ssl_context_from_der(der_cert) + + assert ctx.check_hostname is False + + def test_cert_required(self): + der_cert = _make_self_signed_der() + ctx = build_ssl_context_from_der(der_cert) + + assert ctx.verify_mode == ssl.CERT_REQUIRED + + def test_rejects_invalid_der(self): + with pytest.raises(Exception): + build_ssl_context_from_der(b"not-a-valid-cert") From ea0e0bf2354d7e92e9985b99c75b8415483b82ef Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Tue, 10 Mar 2026 18:36:14 -0400 Subject: [PATCH 09/13] use new abi --- src/opengradient/abi/TEERegistry.abi | 40 +++++++++-- src/opengradient/client/tee_registry.py | 77 +++++++++++--------- tests/tee_registry_test.py | 95 +++++++------------------ 3 files changed, 106 insertions(+), 106 deletions(-) diff --git a/src/opengradient/abi/TEERegistry.abi b/src/opengradient/abi/TEERegistry.abi index 9dcf4a5..51fe5b3 100644 --- a/src/opengradient/abi/TEERegistry.abi +++ b/src/opengradient/abi/TEERegistry.abi @@ -1,7 +1,32 @@ [ { - "inputs": [], + "inputs": [{"internalType": "uint8", "name": "teeType", "type": "uint8"}], "name": "getActiveTEEs", + "outputs": [ + { + "components": [ + {"internalType": "address", "name": "owner", "type": "address"}, + {"internalType": "address", "name": "paymentAddress", "type": "address"}, + {"internalType": "string", "name": "endpoint", "type": "string"}, + {"internalType": "bytes", "name": "publicKey", "type": "bytes"}, + {"internalType": "bytes", "name": "tlsCertificate", "type": "bytes"}, + {"internalType": "bytes32", "name": "pcrHash", "type": "bytes32"}, + {"internalType": "uint8", "name": "teeType", "type": "uint8"}, + {"internalType": "bool", "name": "enabled", "type": "bool"}, + {"internalType": "uint256", "name": "registeredAt", "type": "uint256"}, + {"internalType": "uint256", "name": "lastHeartbeatAt", "type": "uint256"} + ], + "internalType": "struct TEERegistry.TEEInfo[]", + "name": "", + "type": "tuple[]" + } + ], + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [{"internalType": "uint8", "name": "teeType", "type": "uint8"}], + "name": "getEnabledTEEs", "outputs": [{"internalType": "bytes32[]", "name": "", "type": "bytes32[]"}], "stateMutability": "view", "type": "function" @@ -26,9 +51,9 @@ {"internalType": "bytes", "name": "tlsCertificate", "type": "bytes"}, {"internalType": "bytes32", "name": "pcrHash", "type": "bytes32"}, {"internalType": "uint8", "name": "teeType", "type": "uint8"}, - {"internalType": "bool", "name": "active", "type": "bool"}, + {"internalType": "bool", "name": "enabled", "type": "bool"}, {"internalType": "uint256", "name": "registeredAt", "type": "uint256"}, - {"internalType": "uint256", "name": "lastUpdatedAt", "type": "uint256"} + {"internalType": "uint256", "name": "lastHeartbeatAt", "type": "uint256"} ], "internalType": "struct TEERegistry.TEEInfo", "name": "", @@ -40,7 +65,14 @@ }, { "inputs": [{"internalType": "bytes32", "name": "teeId", "type": "bytes32"}], - "name": "isActive", + "name": "isTEEActive", + "outputs": [{"internalType": "bool", "name": "", "type": "bool"}], + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [{"internalType": "bytes32", "name": "teeId", "type": "bytes32"}], + "name": "isTEEEnabled", "outputs": [{"internalType": "bool", "name": "", "type": "bool"}], "stateMutability": "view", "type": "function" diff --git a/src/opengradient/client/tee_registry.py b/src/opengradient/client/tee_registry.py index e61b0a4..99294c0 100644 --- a/src/opengradient/client/tee_registry.py +++ b/src/opengradient/client/tee_registry.py @@ -3,7 +3,7 @@ import logging import ssl from dataclasses import dataclass -from typing import List, Optional +from typing import List, NamedTuple, Optional from web3 import Web3 @@ -16,6 +16,21 @@ TEE_TYPE_VALIDATOR = 1 +class TEEInfo(NamedTuple): + """Mirrors the on-chain TEERegistry.TEEInfo struct.""" + + owner: str + payment_address: str + endpoint: str + public_key: bytes + tls_certificate: bytes + pcr_hash: bytes + tee_type: int + enabled: bool + registered_at: int + last_heartbeat_at: int + + @dataclass class TEEEndpoint: """A verified TEE with its endpoint URL and TLS certificate from the registry.""" @@ -53,6 +68,10 @@ def get_active_tees_by_type(self, tee_type: int) -> List[TEEEndpoint]: """ Return all active TEEs of the given type with their endpoints and TLS certs. + Uses the contract's ``getActiveTEEs(teeType)`` which returns only TEEs that + are enabled, have a valid (non-revoked) PCR, and a fresh heartbeat — all in + a single on-chain call. + Args: tee_type: Integer TEE type (0=LLMProxy, 1=Validator). @@ -62,45 +81,35 @@ def get_active_tees_by_type(self, tee_type: int) -> List[TEEEndpoint]: type_label = {TEE_TYPE_LLM_PROXY: "LLMProxy", TEE_TYPE_VALIDATOR: "Validator"}.get(tee_type, str(tee_type)) try: - tee_ids: List[bytes] = self._contract.functions.getTEEsByType(tee_type).call() + tee_infos = self._contract.functions.getActiveTEEs(tee_type).call() except Exception as e: - logger.warning("Failed to fetch TEE IDs from registry (type=%s): %s", type_label, e) + logger.warning("Failed to fetch active TEEs from registry (type=%s): %s", type_label, e) return [] - logger.debug("Registry returned %d TEE ID(s) for type=%s", len(tee_ids), type_label) + logger.debug("Registry returned %d active TEE(s) for type=%s", len(tee_infos), type_label) endpoints: List[TEEEndpoint] = [] - for tee_id in tee_ids: - tee_id_hex = tee_id.hex() - try: - info = self._contract.functions.getTEE(tee_id).call() - # TEEInfo tuple order: owner, paymentAddress, endpoint, publicKey, - # tlsCertificate, pcrHash, teeType, active, - # registeredAt, lastUpdatedAt - owner, payment_address, endpoint, _pub_key, tls_cert_der, _pcr_hash, _tee_type, active, _reg_at, _upd_at = info - if not active: - logger.debug(" teeId=%s status=inactive endpoint=%s (skipped)", tee_id_hex, endpoint) - continue - if not endpoint or not tls_cert_der: - logger.warning(" teeId=%s missing endpoint or TLS cert (skipped)", tee_id_hex) - continue - logger.info( - " teeId=%s endpoint=%s paymentAddress=%s certBytes=%d", - tee_id_hex, - endpoint, - payment_address, - len(tls_cert_der), - ) - endpoints.append( - TEEEndpoint( - tee_id=tee_id_hex, - endpoint=endpoint, - tls_cert_der=bytes(tls_cert_der), - payment_address=payment_address, - ) + for raw in tee_infos: + tee = TEEInfo(*raw) + tee_id_hex = Web3.keccak(tee.public_key).hex() + if not tee.endpoint or not tee.tls_certificate: + logger.warning(" teeId=%s missing endpoint or TLS cert (skipped)", tee_id_hex) + continue + logger.info( + " teeId=%s endpoint=%s paymentAddress=%s certBytes=%d", + tee_id_hex, + tee.endpoint, + tee.payment_address, + len(tee.tls_certificate), + ) + endpoints.append( + TEEEndpoint( + tee_id=tee_id_hex, + endpoint=tee.endpoint, + tls_cert_der=bytes(tee.tls_certificate), + payment_address=tee.payment_address, ) - except Exception as e: - logger.warning("Failed to fetch TEE info for teeId=%s: %s", tee_id_hex, e) + ) logger.info("Discovered %d active %s TEE(s) from registry", len(endpoints), type_label) return endpoints diff --git a/tests/tee_registry_test.py b/tests/tee_registry_test.py index 72469de..8002a9c 100644 --- a/tests/tee_registry_test.py +++ b/tests/tee_registry_test.py @@ -22,21 +22,21 @@ def _make_tee_info( endpoint="https://tee.example.com", payment_address="0xPayment", + pub_key=b"pubkey", tls_cert_der=b"\x01\x02\x03", - active=True, ): - """Build a tuple matching the TEEInfo struct order from the contract.""" + """Build a tuple matching the TEEInfo struct order from the new contract.""" return ( "0xOwner", # owner payment_address, # paymentAddress endpoint, # endpoint - b"pubkey", # publicKey + pub_key, # publicKey tls_cert_der, # tlsCertificate - b"pcrhash", # pcrHash + b"\x00" * 32, # pcrHash 0, # teeType - active, # active + True, # enabled (always True from getActiveTEEs) 1000, # registeredAt - 2000, # lastUpdatedAt + 2000, # lastHeartbeatAt ) @@ -78,6 +78,7 @@ def mock_contract(): mock_web3_cls.return_value = mock_web3 mock_web3_cls.HTTPProvider.return_value = MagicMock() mock_web3_cls.to_checksum_address.side_effect = lambda x: x + mock_web3_cls.keccak.side_effect = lambda data: b"\xaa" * 32 if data == b"pubkey" else b"\xbb" * 32 contract = MagicMock() mock_web3.eth.contract.return_value = contract @@ -93,34 +94,20 @@ class TestGetActiveTeesByType: def test_returns_active_tees(self, mock_contract): registry, contract = mock_contract - tee_id = b"\xaa" * 32 - contract.functions.getTEEsByType.return_value.call.return_value = [tee_id] - contract.functions.getTEE.return_value.call.return_value = _make_tee_info() + contract.functions.getActiveTEEs.return_value.call.return_value = [_make_tee_info()] result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) assert len(result) == 1 - assert result[0].tee_id == tee_id.hex() assert result[0].endpoint == "https://tee.example.com" assert result[0].payment_address == "0xPayment" assert result[0].tls_cert_der == b"\x01\x02\x03" - - def test_skips_inactive_tees(self, mock_contract): - registry, contract = mock_contract - - tee_id = b"\xbb" * 32 - contract.functions.getTEEsByType.return_value.call.return_value = [tee_id] - contract.functions.getTEE.return_value.call.return_value = _make_tee_info(active=False) - - result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) - assert len(result) == 0 + contract.functions.getActiveTEEs.assert_called_once_with(TEE_TYPE_LLM_PROXY) def test_skips_tee_with_empty_endpoint(self, mock_contract): registry, contract = mock_contract - tee_id = b"\xcc" * 32 - contract.functions.getTEEsByType.return_value.call.return_value = [tee_id] - contract.functions.getTEE.return_value.call.return_value = _make_tee_info(endpoint="") + contract.functions.getActiveTEEs.return_value.call.return_value = [_make_tee_info(endpoint="")] result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) assert len(result) == 0 @@ -128,9 +115,7 @@ def test_skips_tee_with_empty_endpoint(self, mock_contract): def test_skips_tee_with_empty_cert(self, mock_contract): registry, contract = mock_contract - tee_id = b"\xdd" * 32 - contract.functions.getTEEsByType.return_value.call.return_value = [tee_id] - contract.functions.getTEE.return_value.call.return_value = _make_tee_info(tls_cert_der=b"") + contract.functions.getActiveTEEs.return_value.call.return_value = [_make_tee_info(tls_cert_der=b"")] result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) assert len(result) == 0 @@ -138,46 +123,19 @@ def test_skips_tee_with_empty_cert(self, mock_contract): def test_returns_empty_on_rpc_failure(self, mock_contract): registry, contract = mock_contract - contract.functions.getTEEsByType.return_value.call.side_effect = Exception("RPC error") + contract.functions.getActiveTEEs.return_value.call.side_effect = Exception("RPC error") result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) assert result == [] - def test_skips_individual_tee_on_lookup_failure(self, mock_contract): - registry, contract = mock_contract - - good_id = b"\xaa" * 32 - bad_id = b"\xbb" * 32 - contract.functions.getTEEsByType.return_value.call.return_value = [bad_id, good_id] - - def get_tee_side_effect(tee_id): - mock = MagicMock() - if tee_id == bad_id: - mock.call.side_effect = Exception("lookup failed") - else: - mock.call.return_value = _make_tee_info() - return mock - - contract.functions.getTEE.side_effect = get_tee_side_effect - - result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) - assert len(result) == 1 - assert result[0].tee_id == good_id.hex() - def test_multiple_active_tees(self, mock_contract): registry, contract = mock_contract - ids = [b"\x01" * 32, b"\x02" * 32, b"\x03" * 32] - contract.functions.getTEEsByType.return_value.call.return_value = ids - - def get_tee_side_effect(tee_id): - mock = MagicMock() - mock.call.return_value = _make_tee_info( - endpoint=f"https://tee-{tee_id.hex()[:4]}.example.com" - ) - return mock - - contract.functions.getTEE.side_effect = get_tee_side_effect + infos = [ + _make_tee_info(endpoint=f"https://tee-{i}.example.com", pub_key=f"pubkey{i}".encode()) + for i in range(3) + ] + contract.functions.getActiveTEEs.return_value.call.return_value = infos result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) assert len(result) == 3 @@ -186,30 +144,31 @@ def test_validator_type_label(self, mock_contract): """Ensure validator type queries work the same way.""" registry, contract = mock_contract - contract.functions.getTEEsByType.return_value.call.return_value = [] + contract.functions.getActiveTEEs.return_value.call.return_value = [] result = registry.get_active_tees_by_type(TEE_TYPE_VALIDATOR) assert result == [] - contract.functions.getTEEsByType.assert_called_once_with(TEE_TYPE_VALIDATOR) + contract.functions.getActiveTEEs.assert_called_once_with(TEE_TYPE_VALIDATOR) class TestGetLlmTee: def test_returns_first_active_tee(self, mock_contract): registry, contract = mock_contract - ids = [b"\x01" * 32, b"\x02" * 32] - contract.functions.getTEEsByType.return_value.call.return_value = ids - contract.functions.getTEE.return_value.call.return_value = _make_tee_info() + contract.functions.getActiveTEEs.return_value.call.return_value = [ + _make_tee_info(endpoint="https://tee-1.example.com"), + _make_tee_info(endpoint="https://tee-2.example.com"), + ] result = registry.get_llm_tee() assert result is not None - assert result.tee_id == ids[0].hex() + assert result.endpoint == "https://tee-1.example.com" def test_returns_none_when_no_tees(self, mock_contract): registry, contract = mock_contract - contract.functions.getTEEsByType.return_value.call.return_value = [] + contract.functions.getActiveTEEs.return_value.call.return_value = [] result = registry.get_llm_tee() assert result is None @@ -217,10 +176,10 @@ def test_returns_none_when_no_tees(self, mock_contract): def test_queries_llm_proxy_type(self, mock_contract): registry, contract = mock_contract - contract.functions.getTEEsByType.return_value.call.return_value = [] + contract.functions.getActiveTEEs.return_value.call.return_value = [] registry.get_llm_tee() - contract.functions.getTEEsByType.assert_called_once_with(TEE_TYPE_LLM_PROXY) + contract.functions.getActiveTEEs.assert_called_once_with(TEE_TYPE_LLM_PROXY) # --- build_ssl_context_from_der Tests --- From 89db68c38448579c9d17bdccb90d88d3d6927449 Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Tue, 10 Mar 2026 18:51:49 -0400 Subject: [PATCH 10/13] fix types --- src/opengradient/cli.py | 8 ++++++-- src/opengradient/client/client.py | 5 +---- src/opengradient/client/llm.py | 10 +++++----- tests/tee_registry_test.py | 25 +++++++++++-------------- 4 files changed, 23 insertions(+), 25 deletions(-) diff --git a/src/opengradient/cli.py b/src/opengradient/cli.py index c65bfab..cfca61b 100644 --- a/src/opengradient/cli.py +++ b/src/opengradient/cli.py @@ -413,7 +413,9 @@ def completion( x402_settlement_mode=x402SettlementModes[x402_settlement_mode], ) - print_llm_completion_result(model_cid, completion_output.transaction_hash, completion_output.completion_output, is_vanilla=False, result=completion_output) + print_llm_completion_result( + model_cid, completion_output.transaction_hash, completion_output.completion_output, is_vanilla=False, result=completion_output + ) except Exception as e: click.echo(f"Error running LLM completion: {str(e)}") @@ -597,7 +599,9 @@ def chat( if stream: print_streaming_chat_result(model_cid, result, is_tee=True) else: - print_llm_chat_result(model_cid, result.transaction_hash, result.finish_reason, result.chat_output, is_vanilla=False, result=result) + print_llm_chat_result( + model_cid, result.transaction_hash, result.finish_reason, result.chat_output, is_vanilla=False, result=result + ) except Exception as e: click.echo(f"Error running LLM chat inference: {str(e)}") diff --git a/src/opengradient/client/client.py b/src/opengradient/client/client.py index 43aab93..bf8c955 100644 --- a/src/opengradient/client/client.py +++ b/src/opengradient/client/client.py @@ -138,10 +138,7 @@ def __init__( llm_tls_cert_der = tee.tls_cert_der logger.info("Using TEE endpoint from registry: %s (teeId=%s)", tee.endpoint, tee.tee_id) else: - raise ValueError( - "No active LLM proxy TEE found in the registry. " - "Pass og_llm_server_url explicitly to override." - ) + raise ValueError("No active LLM proxy TEE found in the registry. Pass og_llm_server_url explicitly to override.") except ValueError: raise except Exception as e: diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index 794862c..1c0e6e2 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -176,7 +176,7 @@ def completion( max_tokens: int = 100, stop_sequence: Optional[List[str]] = None, temperature: float = 0.0, - x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.BATCH_HASHED, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.BATCH_HASHED, ) -> TextGenerationOutput: """ Perform inference on an LLM model using completions via TEE. @@ -218,7 +218,7 @@ def _tee_llm_completion( max_tokens: int = 100, stop_sequence: Optional[List[str]] = None, temperature: float = 0.0, - x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.BATCH_HASHED, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.BATCH_HASHED, ) -> TextGenerationOutput: """ Route completion request to OpenGradient TEE LLM server with x402 payments. @@ -278,7 +278,7 @@ def chat( temperature: float = 0.0, tools: Optional[List[Dict]] = None, tool_choice: Optional[str] = None, - x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.BATCH_HASHED, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.BATCH_HASHED, stream: bool = False, ) -> Union[TextGenerationOutput, TextGenerationStream]: """ @@ -431,8 +431,8 @@ def _tee_llm_chat_tools_as_stream( temperature: float = 0.0, tools: Optional[List[Dict]] = None, tool_choice: Optional[str] = None, - x402_settlement_mode: x402SettlementMode = x402SettlementMode.SETTLE_BATCH, - ): + x402_settlement_mode: x402SettlementMode = x402SettlementMode.BATCH_HASHED, + ) -> TextGenerationStream: """ Transparent non-streaming fallback for tool-call requests with stream=True. diff --git a/tests/tee_registry_test.py b/tests/tee_registry_test.py index 8002a9c..15b15d2 100644 --- a/tests/tee_registry_test.py +++ b/tests/tee_registry_test.py @@ -27,16 +27,16 @@ def _make_tee_info( ): """Build a tuple matching the TEEInfo struct order from the new contract.""" return ( - "0xOwner", # owner - payment_address, # paymentAddress - endpoint, # endpoint - pub_key, # publicKey - tls_cert_der, # tlsCertificate - b"\x00" * 32, # pcrHash - 0, # teeType - True, # enabled (always True from getActiveTEEs) - 1000, # registeredAt - 2000, # lastHeartbeatAt + "0xOwner", # owner + payment_address, # paymentAddress + endpoint, # endpoint + pub_key, # publicKey + tls_cert_der, # tlsCertificate + b"\x00" * 32, # pcrHash + 0, # teeType + True, # enabled (always True from getActiveTEEs) + 1000, # registeredAt + 2000, # lastHeartbeatAt ) @@ -131,10 +131,7 @@ def test_returns_empty_on_rpc_failure(self, mock_contract): def test_multiple_active_tees(self, mock_contract): registry, contract = mock_contract - infos = [ - _make_tee_info(endpoint=f"https://tee-{i}.example.com", pub_key=f"pubkey{i}".encode()) - for i in range(3) - ] + infos = [_make_tee_info(endpoint=f"https://tee-{i}.example.com", pub_key=f"pubkey{i}".encode()) for i in range(3)] contract.functions.getActiveTEEs.return_value.call.return_value = infos result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) From a1c254d072da11186741b21a716edb4387408722 Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Tue, 10 Mar 2026 19:13:34 -0400 Subject: [PATCH 11/13] rm unused --- src/opengradient/client/llm.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index 1c0e6e2..438a1d3 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -2,14 +2,11 @@ import asyncio import json -import logging import ssl import threading from queue import Queue from typing import AsyncGenerator, Dict, List, Optional, Union -logger = logging.getLogger(__name__) - import httpx from eth_account.account import LocalAccount from x402v2 import x402Client as x402Clientv2 From 4e5f43a7726b2f3a2b3ec0c67bd2a13dd7feb5dc Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Tue, 10 Mar 2026 20:00:23 -0400 Subject: [PATCH 12/13] rm tee rpc and update contract --- src/opengradient/client/client.py | 8 ++------ src/opengradient/defaults.py | 3 +-- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/src/opengradient/client/client.py b/src/opengradient/client/client.py index bf8c955..ff2a0e7 100644 --- a/src/opengradient/client/client.py +++ b/src/opengradient/client/client.py @@ -10,7 +10,6 @@ DEFAULT_INFERENCE_CONTRACT_ADDRESS, DEFAULT_RPC_URL, DEFAULT_TEE_REGISTRY_ADDRESS, - DEFAULT_TEE_REGISTRY_RPC_URL, ) from .alpha import Alpha from .llm import LLM @@ -69,7 +68,6 @@ def __init__( og_llm_server_url: Optional[str] = None, og_llm_streaming_server_url: Optional[str] = None, tee_registry_address: str = DEFAULT_TEE_REGISTRY_ADDRESS, - tee_registry_rpc_url: str = DEFAULT_TEE_REGISTRY_RPC_URL, ): """ Initialize the OpenGradient client. @@ -104,8 +102,6 @@ def __init__( Defaults to ``og_llm_server_url`` when that is provided. tee_registry_address: Address of the TEERegistry contract used to discover active LLM proxy endpoints and their verified TLS certs. - tee_registry_rpc_url: RPC endpoint for the chain that hosts the - TEERegistry contract. """ blockchain = Web3(Web3.HTTPProvider(rpc_url)) wallet_account = blockchain.eth.account.from_key(private_key) @@ -128,7 +124,7 @@ def __init__( if og_llm_server_url is None: try: registry = TEERegistry( - rpc_url=tee_registry_rpc_url, + rpc_url=rpc_url, registry_address=tee_registry_address, ) tee = registry.get_llm_tee() @@ -143,7 +139,7 @@ def __init__( raise except Exception as e: raise RuntimeError( - f"Failed to fetch LLM TEE endpoint from registry ({tee_registry_address}): {e}. " + f"Failed to fetch LLM TEE endpoint from registry ({tee_registry_address} on {rpc_url}): {e}. " "Pass og_llm_server_url explicitly to override." ) from e else: diff --git a/src/opengradient/defaults.py b/src/opengradient/defaults.py index 1df851e..ba2bfc5 100644 --- a/src/opengradient/defaults.py +++ b/src/opengradient/defaults.py @@ -8,5 +8,4 @@ DEFAULT_BLOCKCHAIN_EXPLORER = "https://explorer.opengradient.ai/tx/" # TEE Registry contract on the OG EVM chain — used to discover LLM proxy endpoints # and fetch their registry-verified TLS certificates instead of blindly trusting TOFU. -DEFAULT_TEE_REGISTRY_ADDRESS = "0x3d641a2791533b4a0000345ea8d509d01e1ec301" -DEFAULT_TEE_REGISTRY_RPC_URL = "http://13.59.43.94:8545" +DEFAULT_TEE_REGISTRY_ADDRESS = "0x4e72238852f3c918f4E4e57AeC9280dDB0c80248" From 9996bd37acac934a06f116fbf0942af7b6b959b8 Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Tue, 10 Mar 2026 20:04:21 -0400 Subject: [PATCH 13/13] rm _og_llm_streaming_server_url --- src/opengradient/client/client.py | 7 ------- src/opengradient/client/llm.py | 4 +--- tests/client_test.py | 5 +---- 3 files changed, 2 insertions(+), 14 deletions(-) diff --git a/src/opengradient/client/client.py b/src/opengradient/client/client.py index ff2a0e7..08dc9ac 100644 --- a/src/opengradient/client/client.py +++ b/src/opengradient/client/client.py @@ -66,7 +66,6 @@ def __init__( api_url: str = DEFAULT_API_URL, contract_address: str = DEFAULT_INFERENCE_CONTRACT_ADDRESS, og_llm_server_url: Optional[str] = None, - og_llm_streaming_server_url: Optional[str] = None, tee_registry_address: str = DEFAULT_TEE_REGISTRY_ADDRESS, ): """ @@ -98,8 +97,6 @@ def __init__( og_llm_server_url: Override the LLM server URL instead of using the registry-discovered endpoint. When set, the TLS certificate is validated against the system CA bundle rather than the registry. - og_llm_streaming_server_url: Override the LLM streaming server URL. - Defaults to ``og_llm_server_url`` when that is provided. tee_registry_address: Address of the TEERegistry contract used to discover active LLM proxy endpoints and their verified TLS certs. """ @@ -130,7 +127,6 @@ def __init__( tee = registry.get_llm_tee() if tee is not None: og_llm_server_url = tee.endpoint - og_llm_streaming_server_url = og_llm_streaming_server_url or tee.endpoint llm_tls_cert_der = tee.tls_cert_der logger.info("Using TEE endpoint from registry: %s (teeId=%s)", tee.endpoint, tee.tee_id) else: @@ -142,8 +138,6 @@ def __init__( f"Failed to fetch LLM TEE endpoint from registry ({tee_registry_address} on {rpc_url}): {e}. " "Pass og_llm_server_url explicitly to override." ) from e - else: - og_llm_streaming_server_url = og_llm_streaming_server_url or og_llm_server_url # Create namespaces self.model_hub = ModelHub(hub_user=hub_user) @@ -152,7 +146,6 @@ def __init__( self.llm = LLM( wallet_account=wallet_account, og_llm_server_url=og_llm_server_url, - og_llm_streaming_server_url=og_llm_streaming_server_url, tls_cert_der=llm_tls_cert_der, tee_id=tee.tee_id if tee is not None else None, tee_payment_address=tee.payment_address if tee is not None else None, diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index 438a1d3..84868cd 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -65,14 +65,12 @@ def __init__( self, wallet_account: LocalAccount, og_llm_server_url: str, - og_llm_streaming_server_url: str, tls_cert_der: Optional[bytes] = None, tee_id: Optional[str] = None, tee_payment_address: Optional[str] = None, ): self._wallet_account = wallet_account self._og_llm_server_url = og_llm_server_url - self._og_llm_streaming_server_url = og_llm_streaming_server_url # TEE metadata surfaced on every response so callers can verify/audit which # enclave served the request. @@ -610,7 +608,7 @@ async def _parse_sse_response(response) -> AsyncGenerator[StreamChunk, None]: endpoint = "/v1/chat/completions" async with self._stream_client.stream( "POST", - self._og_llm_streaming_server_url + endpoint, + self._og_llm_server_url + endpoint, json=payload, headers=headers, timeout=60, diff --git a/tests/client_test.py b/tests/client_test.py index 5860972..dadb5a0 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -114,9 +114,8 @@ def test_client_initialization_with_auth(self, mock_web3, mock_abi_files): assert client.model_hub._hub_user["idToken"] == "test_token" def test_client_initialization_custom_llm_urls(self, mock_web3, mock_abi_files): - """Test client initialization with custom LLM server URLs.""" + """Test client initialization with custom LLM server URL.""" custom_llm_url = "https://custom.llm.server" - custom_streaming_url = "https://custom.streaming.server" client = Client( private_key="0x" + "a" * 64, @@ -124,11 +123,9 @@ def test_client_initialization_custom_llm_urls(self, mock_web3, mock_abi_files): api_url="https://test.api.url", contract_address="0x" + "b" * 40, og_llm_server_url=custom_llm_url, - og_llm_streaming_server_url=custom_streaming_url, ) assert client.llm._og_llm_server_url == custom_llm_url - assert client.llm._og_llm_streaming_server_url == custom_streaming_url class TestAlphaProperty: