Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/packages/jumpstarter-mcp/jumpstarter_mcp/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ def serve():
This is meant to be invoked by an MCP-compatible host (e.g. Cursor)
as a subprocess. All communication happens over stdin/stdout.
"""
import sys

# Redirect stdout to stderr early, before importing the server module.
# Module-level imports in server.py (gRPC, jumpstarter, etc.) can
# trigger logging or print output that would corrupt MCP JSON-RPC.
# run_server() does a more thorough fd-level redirect later.
sys.stdout = sys.stderr

from jumpstarter_mcp.server import run_server

anyio.run(run_server)
31 changes: 29 additions & 2 deletions python/packages/jumpstarter-mcp/jumpstarter_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@
import asyncio
import json
import logging
import os
import sys
from io import TextIOWrapper

import anyio
from anyio import ClosedResourceError
from mcp.server.fastmcp import FastMCP
from mcp.server.stdio import stdio_server

from jumpstarter_mcp.connections import ConnectionManager
from jumpstarter_mcp.tools import commands as cmd_tools
Expand Down Expand Up @@ -462,14 +467,36 @@ def _is_closed_resource_error(exc_group: BaseExceptionGroup) -> bool:

async def run_server():
"""Run the MCP server with stdio transport."""
# MCP communicates via JSON-RPC over stdin/stdout. Any stray output to
# stdout (from logging, gRPC C extensions, print() in dependencies, etc.)
# corrupts the protocol and kills the session. We duplicate the real stdout
# fd for exclusive MCP use, then redirect fd 1 (and sys.stdout) to stderr
# so all other output is harmless.
#
# Use literal POSIX fd numbers (1=stdout, 2=stderr) because cli.py may
# have already set sys.stdout = sys.stderr before we get here.
real_stdout_fd = os.dup(1)
os.dup2(2, 1)
sys.stdout = sys.stderr

_setup_logging()
logger.info("Jumpstarter MCP server starting")
mcp, manager = create_server()
try:
async with manager.running():
await mcp.run_stdio_async()
mcp_stdout = anyio.wrap_file(
TextIOWrapper(os.fdopen(real_stdout_fd, "wb"), encoding="utf-8")
)
async with stdio_server(stdout=mcp_stdout) as (
read_stream,
write_stream,
):
await mcp._mcp_server.run(
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is meh, but trying to do the right thing here: modelcontextprotocol/python-sdk#2343 for the long term

They don't let you wire stdio/stdout in the async streams

read_stream,
write_stream,
mcp._mcp_server.create_initialization_options(),
)
except asyncio.CancelledError:
# Normal when the MCP host closes stdin or cancels the task; not a bug.
logger.info("MCP stdio session ended (cancelled)")
except BaseException as exc:
if isinstance(exc, ClosedResourceError):
Expand Down
151 changes: 150 additions & 1 deletion python/packages/jumpstarter-mcp/jumpstarter_mcp/server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import asyncio
import logging
import time
from dataclasses import dataclass
from datetime import datetime
Expand All @@ -18,7 +19,12 @@
list_drivers,
walk_click_tree,
)
from jumpstarter_mcp.server import TOKEN_REFRESH_THRESHOLD_SECONDS, _ensure_fresh_token, create_server
from jumpstarter_mcp.server import (
TOKEN_REFRESH_THRESHOLD_SECONDS,
_ensure_fresh_token,
_setup_logging,
create_server,
)
from jumpstarter_mcp.tools.leases import _lease_status

# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -379,6 +385,26 @@ async def test_j_not_found(self, manager_with_conn):
assert "error" in result
assert "not found" in result["error"]

@pytest.mark.asyncio
async def test_subprocess_stdin_is_devnull(self, manager_with_conn):
"""Subprocess must not inherit MCP's stdin (would consume JSON-RPC input)."""
from jumpstarter_mcp.tools.commands import run_command

manager, conn_id = manager_with_conn

mock_proc = AsyncMock()
mock_proc.communicate = AsyncMock(return_value=(b"ok\n", b""))
mock_proc.returncode = 0

with (
patch("shutil.which", return_value="/usr/bin/j"),
patch("asyncio.create_subprocess_exec", return_value=mock_proc) as mock_exec,
):
await run_command(manager, conn_id, ["power", "on"])

_, kwargs = mock_exec.call_args
assert kwargs["stdin"] == asyncio.subprocess.DEVNULL


# ---------------------------------------------------------------------------
# Server creation
Expand Down Expand Up @@ -536,3 +562,126 @@ async def test_token_without_exp_claim_skips_refresh(self):

assert result is config
mock_cls.save.assert_not_called()


# ---------------------------------------------------------------------------
# _setup_logging
# ---------------------------------------------------------------------------


class TestSetupLogging:
def test_adds_file_handler_to_root_logger(self, tmp_path):
root = logging.getLogger()
handlers_before = list(root.handlers)

try:
with patch("pathlib.Path.home", return_value=tmp_path):
_setup_logging()

new_file_handlers = [
h
for h in root.handlers
if isinstance(h, logging.FileHandler) and h not in handlers_before
]
assert len(new_file_handlers) == 1
assert "mcp-server.log" in new_file_handlers[0].baseFilename
finally:
root.handlers = handlers_before

def test_caps_stream_handlers_to_warning(self, tmp_path):
root = logging.getLogger()
handlers_before = list(root.handlers)

stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.DEBUG)
root.addHandler(stream_handler)

try:
with patch("pathlib.Path.home", return_value=tmp_path):
_setup_logging()
assert stream_handler.level == logging.WARNING
finally:
root.handlers = handlers_before

def test_sets_logger_levels(self, tmp_path):
root = logging.getLogger()
handlers_before = list(root.handlers)

loggers = {
"jumpstarter_mcp": logging.getLogger("jumpstarter_mcp"),
"jumpstarter": logging.getLogger("jumpstarter"),
"mcp": logging.getLogger("mcp"),
}
saved_levels = {name: lg.level for name, lg in loggers.items()}

try:
with patch("pathlib.Path.home", return_value=tmp_path):
_setup_logging()
assert loggers["jumpstarter_mcp"].level == logging.DEBUG
assert loggers["jumpstarter"].level == logging.DEBUG
assert loggers["mcp"].level == logging.WARNING
finally:
root.handlers = handlers_before
for name, level in saved_levels.items():
loggers[name].setLevel(level)
Comment on lines +573 to +626
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Restore the logging levels these tests mutate.

Line 589, Line 604, and Line 624 only replace root.handlers, but _setup_logging() also changes every existing root StreamHandler level and the jumpstarter_mcp/jumpstarter/mcp logger levels. Those objects survive the list swap, so later tests can inherit altered logging behavior depending on execution order.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/packages/jumpstarter-mcp/jumpstarter_mcp/server_test.py` around lines
573 - 626, These tests call _setup_logging() which mutates global logger and
handler state; capture and restore any mutated levels: record pre-test levels of
all root StreamHandler instances and the per-name logger levels (e.g., loggers
"jumpstarter_mcp", "jumpstarter", "mcp") before calling _setup_logging(), then
after the test restore root.handlers, reset each StreamHandler.level to its
saved value, and reset each logger level using the saved_levels mapping; update
test_adds_file_handler_to_root_logger, test_caps_stream_handlers_to_warning, and
test_sets_logger_levels to save and restore these handler and logger levels (use
variables like handlers_before, stream_handler, saved_levels, and the loggers
dict to locate code).



# ---------------------------------------------------------------------------
# Stdout isolation (the fd-redirect pattern used by run_server)
# ---------------------------------------------------------------------------


class TestStdoutIsolation:
def test_stray_writes_do_not_reach_saved_stdout(self):
"""After the dup/dup2 redirect, writes to sys.stdout (fd 1) must go
to stderr, while only the saved fd still reaches the real stdout.
This is the behavioral property that protects MCP JSON-RPC."""
import os
import sys

r_out, w_out = os.pipe()
r_err, w_err = os.pipe()

saved_real_stdout_fd = os.dup(1)
saved_real_stderr_fd = os.dup(2)
saved_sys_stdout = sys.stdout
saved_sys_stderr = sys.stderr

try:
os.dup2(w_out, 1)
os.dup2(w_err, 2)
sys.stdout = os.fdopen(1, "w", closefd=False)
sys.stderr = os.fdopen(2, "w", closefd=False)

# Apply the same redirect pattern as run_server():
sys.stdout.flush()
mcp_fd = os.dup(sys.stdout.fileno()) # save "real stdout" (pipe)
os.dup2(sys.stderr.fileno(), sys.stdout.fileno()) # fd 1 -> stderr
sys.stdout = sys.stderr

# Stray write via sys.stdout -- should land in the stderr pipe
sys.stdout.write("stray\n")
sys.stdout.flush()

# MCP-only write via the saved fd -- should land in the stdout pipe
mcp_file = os.fdopen(mcp_fd, "w", closefd=True)
mcp_file.write("mcp-json\n")
mcp_file.flush()

os.close(w_out)
os.close(w_err)
stdout_data = os.read(r_out, 4096).decode()
stderr_data = os.read(r_err, 4096).decode()

assert "mcp-json" in stdout_data
assert "stray" not in stdout_data
assert "stray" in stderr_data
finally:
os.dup2(saved_real_stdout_fd, 1)
os.dup2(saved_real_stderr_fd, 2)
os.close(saved_real_stdout_fd)
os.close(saved_real_stderr_fd)
os.close(r_out)
os.close(r_err)
sys.stdout = saved_sys_stdout
sys.stderr = saved_sys_stderr
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ async def run_command(
proc = await asyncio.create_subprocess_exec(
j_path,
*command,
stdin=asyncio.subprocess.DEVNULL,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=full_env,
Expand Down
Loading