From 54412931b44aba3528b20902e4240eee6eef8382 Mon Sep 17 00:00:00 2001 From: air17 Date: Fri, 12 Jun 2026 17:50:56 +0000 Subject: [PATCH] Refine nexus bridge typing --- Editor/nexus_bridge/_transport.py | 13 +- Editor/nexus_bridge/_types.py | 76 +++++++++ Editor/nexus_bridge/routing.py | 252 +++++++++++++++++++----------- Editor/nexus_bridge/schemas.py | 8 +- 4 files changed, 250 insertions(+), 99 deletions(-) create mode 100644 Editor/nexus_bridge/_types.py diff --git a/Editor/nexus_bridge/_transport.py b/Editor/nexus_bridge/_transport.py index b9fabdf..efc7cb2 100644 --- a/Editor/nexus_bridge/_transport.py +++ b/Editor/nexus_bridge/_transport.py @@ -5,7 +5,8 @@ import os import sys import urllib.request -from typing import Any + +from ._types import JsonObject, JsonRpcError, JsonRpcRequest, JsonRpcResponse DEFAULT_PORT: int = 8081 @@ -44,8 +45,8 @@ def _read_timeout() -> float: UNITY_TIMEOUT_SECONDS: float = _read_timeout() -def call_unity(method: str, params: dict[str, Any] | None = None) -> dict[str, Any]: - payload: dict[str, Any] = {"jsonrpc": "2.0", "method": method, "params": params or {}, "id": 1} +def call_unity(method: str, params: JsonObject | None = None) -> JsonRpcResponse: + payload: JsonRpcRequest = {"jsonrpc": "2.0", "method": method, "params": params or {}, "id": 1} data: bytes = json.dumps(payload).encode("utf-8") req: urllib.request.Request = urllib.request.Request( UNITY_URL, @@ -56,4 +57,8 @@ def call_unity(method: str, params: dict[str, Any] | None = None) -> dict[str, A with urllib.request.urlopen(req, timeout=UNITY_TIMEOUT_SECONDS) as response: return json.loads(response.read().decode("utf-8")) except Exception as error: - return {"error": {"code": -32000, "message": f"Unity Server unreachable. Error: {error}"}} + error_payload: JsonRpcError = { + "code": -32000, + "message": f"Unity Server unreachable. Error: {error}", + } + return {"error": error_payload} diff --git a/Editor/nexus_bridge/_types.py b/Editor/nexus_bridge/_types.py new file mode 100644 index 0000000..bae9566 --- /dev/null +++ b/Editor/nexus_bridge/_types.py @@ -0,0 +1,76 @@ +"""Private type definitions for the NexusUnity Python bridge.""" +from __future__ import annotations + +from typing import Any, TypeAlias, TypedDict + +JsonObject: TypeAlias = dict[str, Any] + + +class JsonRpcError(TypedDict): + code: int + message: str + + +class JsonRpcRequest(TypedDict): + jsonrpc: str + method: str + params: JsonObject + id: int + + +class ToolDefinition(TypedDict): + name: str + description: str + inputSchema: JsonObject + + +class ResourceDefinition(TypedDict): + uri: str + name: str + mimeType: str + + +class JsonRpcResponse(TypedDict, total=False): + result: JsonObject + error: JsonRpcError + + +class TransformArguments(TypedDict, total=False): + instance_id: int + position: JsonObject + rotation: JsonObject + scale: JsonObject + eulerAngles: JsonObject + localScale: JsonObject + + +class WriteFileSpec(TypedDict): + path: str + content: str + + +class WriteError(TypedDict): + path: str + error: JsonRpcError + + +class WaitResultPayload(TypedDict): + status: str + time_waited_seconds: float + + +class TestResultsPayload(WaitResultPayload, total=False): + timestamp_utc: str + message: str + result_path: str + trigger: JsonObject + + +class WriteAndCompileSuccessPayload(WaitResultPayload): + compiler_errors: list[JsonObject] + + +class WriteAndCompileFailurePayload(TypedDict): + status: str + message: str + errors: list[WriteError] diff --git a/Editor/nexus_bridge/routing.py b/Editor/nexus_bridge/routing.py index aaddea6..e9db828 100644 --- a/Editor/nexus_bridge/routing.py +++ b/Editor/nexus_bridge/routing.py @@ -7,106 +7,149 @@ from __future__ import annotations import time -from typing import Any +from typing import Any, Mapping, Sequence, cast from ._logging import logger from ._transport import call_unity from .schemas import STATIC_TOOLS - - -def _compact(params: dict[str, Any]) -> dict[str, Any]: +from ._types import ( + JsonObject, + JsonRpcError, + JsonRpcResponse, + TestResultsPayload, + TransformArguments, + WaitResultPayload, + WriteAndCompileFailurePayload, + WriteAndCompileSuccessPayload, + WriteError, + WriteFileSpec, +) + + +def _compact(params: JsonObject) -> JsonObject: return {key: value for key, value in params.items() if value is not None} -def _alias(action: str | None, aliases: dict[str, str]) -> str | None: - return aliases.get(action, action) # type: ignore[arg-type] +def _alias(action_name: str | None, aliases: Mapping[str, str]) -> str | None: + if action_name is None: + return None + return aliases.get(action_name, action_name) -def _invalid_action(action: str | None, valid_actions: list[str]) -> dict[str, Any]: +def _invalid_action(action_name: str | None, valid_actions: Sequence[str]) -> JsonRpcResponse: valid = ", ".join(valid_actions) - return {"error": {"code": -32602, "message": f"Invalid action: {action}. Valid actions: {valid}"}} + error_payload: JsonRpcError = { + "code": -32602, + "message": f"Invalid action: {action_name}. Valid actions: {valid}", + } + return {"error": error_payload} + + +def _result_object(response: JsonRpcResponse | None) -> JsonObject: + if not response: + return {} + result_payload = response.get("result") + return result_payload if isinstance(result_payload, dict) else {} -def _transform_params(args: dict[str, Any], instance_id: int | None = None) -> dict[str, Any]: - params: dict[str, Any] = {"instance_id": instance_id if instance_id is not None else args.get("instance_id")} +def _error_object(response: JsonRpcResponse | None) -> JsonRpcError | None: + if not response: + return None + return response.get("error") + + +def _transform_params(args: JsonObject, instance_id: int | None = None) -> JsonObject: + params: TransformArguments = { + "instance_id": instance_id if instance_id is not None else args.get("instance_id") + } for key in ["position", "rotation", "scale", "eulerAngles", "localScale"]: params[key] = args.get(key) return _compact(params) -def _extract_created_instance_id(response: dict[str, Any]) -> int | None: - if not isinstance(response, dict) or "error" in response: +def _extract_created_instance_id(response: JsonRpcResponse) -> int | None: + if "error" in response: return None - result = response.get("result", {}) - data = result.get("data", {}) if isinstance(result, dict) else {} + result_payload = _result_object(response) + data = result_payload.get("data", {}) return data.get("instance_id") if isinstance(data, dict) else None -def _apply_created_transform(response: dict[str, Any], args: dict[str, Any]) -> dict[str, Any]: +def _apply_created_transform(response: JsonRpcResponse, args: JsonObject) -> JsonRpcResponse: instance_id = _extract_created_instance_id(response) if not instance_id: return response params = _transform_params(args, instance_id) if len(params) <= 1: return response - transform = call_unity("set_transform", params) - if transform and "error" in transform: - return transform + transform_response = call_unity("set_transform", params) + if transform_response and "error" in transform_response: + return transform_response return response -def _run_tests_wait(args: dict[str, Any]) -> dict[str, Any]: +def _run_tests_wait(args: JsonObject) -> JsonRpcResponse: timeout = args.get("timeout_seconds", 180) poll_interval = args.get("poll_interval_seconds", 1.0) start_time = time.time() - before = call_unity("get_test_results") - before_result = before.get("result", {}) if isinstance(before, dict) else {} - before_timestamp = before_result.get("timestamp_utc") if before_result.get("status") == "Success" else None + previous_results_response = call_unity("get_test_results") + previous_results_payload = _result_object(previous_results_response) + previous_timestamp = ( + previous_results_payload.get("timestamp_utc") + if previous_results_payload.get("status") == "Success" + else None + ) run_params = _compact({ "mode": args.get("mode", "EditMode"), "filter": args.get("filter"), }) - trigger = call_unity("run_tests", run_params) - if trigger and "error" in trigger: - return trigger + trigger_response = call_unity("run_tests", run_params) + if trigger_response and "error" in trigger_response: + return trigger_response - trigger_result = trigger.get("result", {}) if isinstance(trigger, dict) else {} - result_path = trigger_result.get("result_path") + trigger_payload = _result_object(trigger_response) + result_path = trigger_payload.get("result_path") while time.time() - start_time < timeout: params = {"result_path": result_path} if result_path else {} - current = call_unity("get_test_results", params) - if current and "error" in current: - return current - - result = current.get("result", {}) if isinstance(current, dict) else {} - if result.get("status") == "Success" and result.get("timestamp_utc") != before_timestamp: - result["time_waited_seconds"] = round(time.time() - start_time, 2) - return {"result": result} + current_results_response = call_unity("get_test_results", params) + if current_results_response and "error" in current_results_response: + return current_results_response + + current_results_payload = _result_object(current_results_response) + if ( + current_results_payload.get("status") == "Success" + and current_results_payload.get("timestamp_utc") != previous_timestamp + ): + test_results = cast(TestResultsPayload, dict(current_results_payload)) + test_results["time_waited_seconds"] = round(time.time() - start_time, 2) + return {"result": test_results} time.sleep(poll_interval) + timeout_result: TestResultsPayload = { + "status": "Timeout", + "message": "Timed out waiting for a new Unity TestResults XML file.", + "time_waited_seconds": round(time.time() - start_time, 2), + "trigger": trigger_payload, + } + if isinstance(result_path, str): + timeout_result["result_path"] = result_path return { - "result": { - "status": "Timeout", - "message": "Timed out waiting for a new Unity TestResults XML file.", - "time_waited_seconds": round(time.time() - start_time, 2), - "result_path": result_path, - "trigger": trigger_result, - } + "result": timeout_result } -def _wait_for_compilation(timeout: float, start_time: float | None = None) -> dict[str, Any]: +def _wait_for_compilation(timeout: float, start_time: float | None = None) -> JsonRpcResponse: start_time = time.time() if start_time is None else start_time status: str = "Ready" reload_started: bool = False while time.time() - start_time < 20: - res: dict[str, Any] = call_unity("initialize") - if res is None or "error" in res: + initialize_response = call_unity("initialize") + if initialize_response is None or "error" in initialize_response: reload_started = True break time.sleep(0.5) @@ -115,61 +158,75 @@ def _wait_for_compilation(timeout: float, start_time: float | None = None) -> di call_unity("refresh_asset_database") while time.time() - start_time < timeout: - res = call_unity("initialize") - if res and "result" in res: + initialize_response = call_unity("initialize") + if initialize_response and "result" in initialize_response: time.sleep(2.0) - state: dict[str, Any] = call_unity("get_editor_state") - if state and "result" in state: - if not state["result"].get("is_compiling") and not state["result"].get("is_updating"): + editor_state_response = call_unity("get_editor_state") + editor_state = _result_object(editor_state_response) + if editor_state_response and "result" in editor_state_response: + if not editor_state.get("is_compiling") and not editor_state.get("is_updating"): break time.sleep(1.0) else: status = "Timeout" + wait_result: WaitResultPayload = { + "status": status, + "time_waited_seconds": round(time.time() - start_time, 2), + } return { - "result": { - "status": status, - "time_waited_seconds": round(time.time() - start_time, 2), - } + "result": wait_result } -def route_tool(name: str, args: dict[str, Any]) -> dict[str, Any]: +def route_tool(name: str, args: JsonObject) -> JsonRpcResponse: if name in ["tools/list", "list_tools", "listTools"]: return {"result": {"tools": STATIC_TOOLS}} if name == "write_and_compile": - files: list[dict[str, Any]] = args.get("files", []) + files: list[WriteFileSpec] = args.get("files", []) start_time: float = time.time() call_unity("clear_logs") - write_errors: list[dict[str, Any]] = [] + write_errors: list[WriteError] = [] for file_info in files: - res = call_unity("write_file", {"path": file_info["path"], "content": file_info["content"]}) - if res and "error" in res: - write_errors.append({"path": file_info["path"], "error": res["error"]}) + write_file_response = call_unity( + "write_file", + {"path": file_info["path"], "content": file_info["content"]}, + ) + write_error = _error_object(write_file_response) + if write_error is not None: + write_errors.append({"path": file_info["path"], "error": write_error}) if write_errors: - return {"result": {"status": "Failed", "message": "Failed to write some files", "errors": write_errors}} + failure_result: WriteAndCompileFailurePayload = { + "status": "Failed", + "message": "Failed to write some files", + "errors": write_errors, + } + return {"result": failure_result} else: - wait_result: dict[str, Any] = _wait_for_compilation(timeout=90, start_time=start_time) - wait_status: str = wait_result["result"]["status"] - time_waited_seconds: float = wait_result["result"]["time_waited_seconds"] + wait_response = _wait_for_compilation(timeout=90, start_time=start_time) + wait_result = _result_object(wait_response) + wait_status: str = wait_result["status"] + time_waited_seconds: float = wait_result["time_waited_seconds"] - compiler_errors: list[dict[str, Any]] = [] + compiler_errors: list[JsonObject] = [] if wait_status == "Ready": - log_res = call_unity("read_logs", {"count": 200}) - if log_res and "result" in log_res: - for log_entry in log_res["result"].get("logs", []): + log_response = call_unity("read_logs", {"count": 200}) + log_payload = _result_object(log_response) + if log_response and "result" in log_response: + for log_entry in log_payload.get("logs", []): if log_entry.get("Type") in ["Error", "Exception", "Assert"]: compiler_errors.append(log_entry) + success_result: WriteAndCompileSuccessPayload = { + "status": "Failed" if compiler_errors else wait_status, + "time_waited_seconds": time_waited_seconds, + "compiler_errors": compiler_errors, + } return { - "result": { - "status": "Failed" if compiler_errors else wait_status, - "time_waited_seconds": time_waited_seconds, - "compiler_errors": compiler_errors - } + "result": success_result } elif name == "scene_manager": @@ -302,37 +359,50 @@ def route_tool(name: str, args: dict[str, Any]) -> dict[str, Any]: else: return _invalid_action(action, ["get", "set", "delete", "list"]) elif name == "wait": - cond: Any = args.get("condition") + condition: Any = args.get("condition") timeout: float = args.get("timeout_seconds", 60) start_time: float = time.time() status: str = "Ready" - if cond == "compilation": + if condition == "compilation": return _wait_for_compilation(timeout=timeout, start_time=start_time) - elif cond == "play_mode": + elif condition == "play_mode": target_state = args.get("state", True) while time.time() - start_time < timeout: - state_res = call_unity("get_editor_state") - if state_res and "result" in state_res: - if state_res["result"].get("is_playing") == target_state: break + editor_state_response = call_unity("get_editor_state") + editor_state = _result_object(editor_state_response) + if editor_state_response and "result" in editor_state_response: + if editor_state.get("is_playing") == target_state: + break time.sleep(1.0) - else: status = "Timeout" - elif cond == "import": + else: + status = "Timeout" + elif condition == "import": while time.time() - start_time < timeout: - res = call_unity("is_asset_import_idle") - if res and "result" in res: - if res["result"].get("is_idle"): break + import_idle_response = call_unity("is_asset_import_idle") + import_idle_state = _result_object(import_idle_response) + if import_idle_response and "result" in import_idle_response: + if import_idle_state.get("is_idle"): + break time.sleep(1.0) - else: status = "Timeout" - elif cond == "editor_idle": + else: + status = "Timeout" + elif condition == "editor_idle": while time.time() - start_time < timeout: - res = call_unity("is_editor_idle") - if res and "result" in res: - if res["result"].get("is_idle"): break + editor_idle_response = call_unity("is_editor_idle") + editor_idle_state = _result_object(editor_idle_response) + if editor_idle_response and "result" in editor_idle_response: + if editor_idle_state.get("is_idle"): + break time.sleep(1.0) - else: status = "Timeout" + else: + status = "Timeout" - return {"result": {"status": status, "time_waited_seconds": round(time.time() - start_time, 2)}} + wait_result: WaitResultPayload = { + "status": status, + "time_waited_seconds": round(time.time() - start_time, 2), + } + return {"result": wait_result} else: return call_unity(name, args) diff --git a/Editor/nexus_bridge/schemas.py b/Editor/nexus_bridge/schemas.py index 8739458..fdb100c 100644 --- a/Editor/nexus_bridge/schemas.py +++ b/Editor/nexus_bridge/schemas.py @@ -8,10 +8,10 @@ """ from __future__ import annotations -from typing import Any +from ._types import JsonObject, ResourceDefinition, ToolDefinition # --- Shared sub-schemas --- -VECTOR3_SCHEMA: dict[str, Any] = { +VECTOR3_SCHEMA: JsonObject = { "type": "object", "properties": { "x": {"type": "number"}, @@ -20,7 +20,7 @@ }, } -STATIC_TOOLS: list[dict[str, Any]] = [ +STATIC_TOOLS: list[ToolDefinition] = [ # --- Consolidated Core Managers --- { "name": "unity_scene_manager", @@ -185,7 +185,7 @@ {"name": "unity_lint_project", "description": "Run Roslyn-based C# audit of the entire project", "inputSchema": {"type": "object", "properties": {}}} ] -STATIC_RESOURCES: list[dict[str, Any]] = [ +STATIC_RESOURCES: list[ResourceDefinition] = [ { "uri": "unity://docs/api-reference", "name": "API Reference",