diff --git a/python/packages/jumpstarter-mcp/jumpstarter_mcp/cli.py b/python/packages/jumpstarter-mcp/jumpstarter_mcp/cli.py index 849201e6..72b7b718 100644 --- a/python/packages/jumpstarter-mcp/jumpstarter_mcp/cli.py +++ b/python/packages/jumpstarter-mcp/jumpstarter_mcp/cli.py @@ -18,6 +18,14 @@ def serve(): This is meant to be invoked by an MCP-compatible host (e.g. Cursor) as a subprocess. All communication happens over stdin/stdout. """ + import sys + + # Redirect stdout to stderr early, before importing the server module. + # Module-level imports in server.py (gRPC, jumpstarter, etc.) can + # trigger logging or print output that would corrupt MCP JSON-RPC. + # run_server() does a more thorough fd-level redirect later. + sys.stdout = sys.stderr + from jumpstarter_mcp.server import run_server anyio.run(run_server) diff --git a/python/packages/jumpstarter-mcp/jumpstarter_mcp/server.py b/python/packages/jumpstarter-mcp/jumpstarter_mcp/server.py index 68e650de..5fefb2aa 100644 --- a/python/packages/jumpstarter-mcp/jumpstarter_mcp/server.py +++ b/python/packages/jumpstarter-mcp/jumpstarter_mcp/server.py @@ -5,9 +5,14 @@ import asyncio import json import logging +import os +import sys +from io import TextIOWrapper +import anyio from anyio import ClosedResourceError from mcp.server.fastmcp import FastMCP +from mcp.server.stdio import stdio_server from jumpstarter_mcp.connections import ConnectionManager from jumpstarter_mcp.tools import commands as cmd_tools @@ -462,14 +467,36 @@ def _is_closed_resource_error(exc_group: BaseExceptionGroup) -> bool: async def run_server(): """Run the MCP server with stdio transport.""" + # MCP communicates via JSON-RPC over stdin/stdout. Any stray output to + # stdout (from logging, gRPC C extensions, print() in dependencies, etc.) + # corrupts the protocol and kills the session. We duplicate the real stdout + # fd for exclusive MCP use, then redirect fd 1 (and sys.stdout) to stderr + # so all other output is harmless. + # + # Use literal POSIX fd numbers (1=stdout, 2=stderr) because cli.py may + # have already set sys.stdout = sys.stderr before we get here. + real_stdout_fd = os.dup(1) + os.dup2(2, 1) + sys.stdout = sys.stderr + _setup_logging() logger.info("Jumpstarter MCP server starting") mcp, manager = create_server() try: async with manager.running(): - await mcp.run_stdio_async() + mcp_stdout = anyio.wrap_file( + TextIOWrapper(os.fdopen(real_stdout_fd, "wb"), encoding="utf-8") + ) + async with stdio_server(stdout=mcp_stdout) as ( + read_stream, + write_stream, + ): + await mcp._mcp_server.run( + read_stream, + write_stream, + mcp._mcp_server.create_initialization_options(), + ) except asyncio.CancelledError: - # Normal when the MCP host closes stdin or cancels the task; not a bug. logger.info("MCP stdio session ended (cancelled)") except BaseException as exc: if isinstance(exc, ClosedResourceError): diff --git a/python/packages/jumpstarter-mcp/jumpstarter_mcp/server_test.py b/python/packages/jumpstarter-mcp/jumpstarter_mcp/server_test.py index 8075f11f..ebea459c 100644 --- a/python/packages/jumpstarter-mcp/jumpstarter_mcp/server_test.py +++ b/python/packages/jumpstarter-mcp/jumpstarter_mcp/server_test.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import logging import time from dataclasses import dataclass from datetime import datetime @@ -18,7 +19,12 @@ list_drivers, walk_click_tree, ) -from jumpstarter_mcp.server import TOKEN_REFRESH_THRESHOLD_SECONDS, _ensure_fresh_token, create_server +from jumpstarter_mcp.server import ( + TOKEN_REFRESH_THRESHOLD_SECONDS, + _ensure_fresh_token, + _setup_logging, + create_server, +) from jumpstarter_mcp.tools.leases import _lease_status # --------------------------------------------------------------------------- @@ -379,6 +385,26 @@ async def test_j_not_found(self, manager_with_conn): assert "error" in result assert "not found" in result["error"] + @pytest.mark.asyncio + async def test_subprocess_stdin_is_devnull(self, manager_with_conn): + """Subprocess must not inherit MCP's stdin (would consume JSON-RPC input).""" + from jumpstarter_mcp.tools.commands import run_command + + manager, conn_id = manager_with_conn + + mock_proc = AsyncMock() + mock_proc.communicate = AsyncMock(return_value=(b"ok\n", b"")) + mock_proc.returncode = 0 + + with ( + patch("shutil.which", return_value="/usr/bin/j"), + patch("asyncio.create_subprocess_exec", return_value=mock_proc) as mock_exec, + ): + await run_command(manager, conn_id, ["power", "on"]) + + _, kwargs = mock_exec.call_args + assert kwargs["stdin"] == asyncio.subprocess.DEVNULL + # --------------------------------------------------------------------------- # Server creation @@ -536,3 +562,126 @@ async def test_token_without_exp_claim_skips_refresh(self): assert result is config mock_cls.save.assert_not_called() + + +# --------------------------------------------------------------------------- +# _setup_logging +# --------------------------------------------------------------------------- + + +class TestSetupLogging: + def test_adds_file_handler_to_root_logger(self, tmp_path): + root = logging.getLogger() + handlers_before = list(root.handlers) + + try: + with patch("pathlib.Path.home", return_value=tmp_path): + _setup_logging() + + new_file_handlers = [ + h + for h in root.handlers + if isinstance(h, logging.FileHandler) and h not in handlers_before + ] + assert len(new_file_handlers) == 1 + assert "mcp-server.log" in new_file_handlers[0].baseFilename + finally: + root.handlers = handlers_before + + def test_caps_stream_handlers_to_warning(self, tmp_path): + root = logging.getLogger() + handlers_before = list(root.handlers) + + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.DEBUG) + root.addHandler(stream_handler) + + try: + with patch("pathlib.Path.home", return_value=tmp_path): + _setup_logging() + assert stream_handler.level == logging.WARNING + finally: + root.handlers = handlers_before + + def test_sets_logger_levels(self, tmp_path): + root = logging.getLogger() + handlers_before = list(root.handlers) + + loggers = { + "jumpstarter_mcp": logging.getLogger("jumpstarter_mcp"), + "jumpstarter": logging.getLogger("jumpstarter"), + "mcp": logging.getLogger("mcp"), + } + saved_levels = {name: lg.level for name, lg in loggers.items()} + + try: + with patch("pathlib.Path.home", return_value=tmp_path): + _setup_logging() + assert loggers["jumpstarter_mcp"].level == logging.DEBUG + assert loggers["jumpstarter"].level == logging.DEBUG + assert loggers["mcp"].level == logging.WARNING + finally: + root.handlers = handlers_before + for name, level in saved_levels.items(): + loggers[name].setLevel(level) + + +# --------------------------------------------------------------------------- +# Stdout isolation (the fd-redirect pattern used by run_server) +# --------------------------------------------------------------------------- + + +class TestStdoutIsolation: + def test_stray_writes_do_not_reach_saved_stdout(self): + """After the dup/dup2 redirect, writes to sys.stdout (fd 1) must go + to stderr, while only the saved fd still reaches the real stdout. + This is the behavioral property that protects MCP JSON-RPC.""" + import os + import sys + + r_out, w_out = os.pipe() + r_err, w_err = os.pipe() + + saved_real_stdout_fd = os.dup(1) + saved_real_stderr_fd = os.dup(2) + saved_sys_stdout = sys.stdout + saved_sys_stderr = sys.stderr + + try: + os.dup2(w_out, 1) + os.dup2(w_err, 2) + sys.stdout = os.fdopen(1, "w", closefd=False) + sys.stderr = os.fdopen(2, "w", closefd=False) + + # Apply the same redirect pattern as run_server(): + sys.stdout.flush() + mcp_fd = os.dup(sys.stdout.fileno()) # save "real stdout" (pipe) + os.dup2(sys.stderr.fileno(), sys.stdout.fileno()) # fd 1 -> stderr + sys.stdout = sys.stderr + + # Stray write via sys.stdout -- should land in the stderr pipe + sys.stdout.write("stray\n") + sys.stdout.flush() + + # MCP-only write via the saved fd -- should land in the stdout pipe + mcp_file = os.fdopen(mcp_fd, "w", closefd=True) + mcp_file.write("mcp-json\n") + mcp_file.flush() + + os.close(w_out) + os.close(w_err) + stdout_data = os.read(r_out, 4096).decode() + stderr_data = os.read(r_err, 4096).decode() + + assert "mcp-json" in stdout_data + assert "stray" not in stdout_data + assert "stray" in stderr_data + finally: + os.dup2(saved_real_stdout_fd, 1) + os.dup2(saved_real_stderr_fd, 2) + os.close(saved_real_stdout_fd) + os.close(saved_real_stderr_fd) + os.close(r_out) + os.close(r_err) + sys.stdout = saved_sys_stdout + sys.stderr = saved_sys_stderr diff --git a/python/packages/jumpstarter-mcp/jumpstarter_mcp/tools/commands.py b/python/packages/jumpstarter-mcp/jumpstarter_mcp/tools/commands.py index df9a606f..79e7b1b1 100644 --- a/python/packages/jumpstarter-mcp/jumpstarter_mcp/tools/commands.py +++ b/python/packages/jumpstarter-mcp/jumpstarter_mcp/tools/commands.py @@ -46,6 +46,7 @@ async def run_command( proc = await asyncio.create_subprocess_exec( j_path, *command, + stdin=asyncio.subprocess.DEVNULL, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, env=full_env,