-
Notifications
You must be signed in to change notification settings - Fork 19
fix(mcp): isolate stdout to prevent stray output from corrupting JSON-RPC #383
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,7 @@ | |
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import logging | ||
| import time | ||
| from dataclasses import dataclass | ||
| from datetime import datetime | ||
|
|
@@ -18,7 +19,12 @@ | |
| list_drivers, | ||
| walk_click_tree, | ||
| ) | ||
| from jumpstarter_mcp.server import TOKEN_REFRESH_THRESHOLD_SECONDS, _ensure_fresh_token, create_server | ||
| from jumpstarter_mcp.server import ( | ||
| TOKEN_REFRESH_THRESHOLD_SECONDS, | ||
| _ensure_fresh_token, | ||
| _setup_logging, | ||
| create_server, | ||
| ) | ||
| from jumpstarter_mcp.tools.leases import _lease_status | ||
|
|
||
| # --------------------------------------------------------------------------- | ||
|
|
@@ -379,6 +385,26 @@ async def test_j_not_found(self, manager_with_conn): | |
| assert "error" in result | ||
| assert "not found" in result["error"] | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_subprocess_stdin_is_devnull(self, manager_with_conn): | ||
| """Subprocess must not inherit MCP's stdin (would consume JSON-RPC input).""" | ||
| from jumpstarter_mcp.tools.commands import run_command | ||
|
|
||
| manager, conn_id = manager_with_conn | ||
|
|
||
| mock_proc = AsyncMock() | ||
| mock_proc.communicate = AsyncMock(return_value=(b"ok\n", b"")) | ||
| mock_proc.returncode = 0 | ||
|
|
||
| with ( | ||
| patch("shutil.which", return_value="/usr/bin/j"), | ||
| patch("asyncio.create_subprocess_exec", return_value=mock_proc) as mock_exec, | ||
| ): | ||
| await run_command(manager, conn_id, ["power", "on"]) | ||
|
|
||
| _, kwargs = mock_exec.call_args | ||
| assert kwargs["stdin"] == asyncio.subprocess.DEVNULL | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Server creation | ||
|
|
@@ -536,3 +562,126 @@ async def test_token_without_exp_claim_skips_refresh(self): | |
|
|
||
| assert result is config | ||
| mock_cls.save.assert_not_called() | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # _setup_logging | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| class TestSetupLogging: | ||
| def test_adds_file_handler_to_root_logger(self, tmp_path): | ||
| root = logging.getLogger() | ||
| handlers_before = list(root.handlers) | ||
|
|
||
| try: | ||
| with patch("pathlib.Path.home", return_value=tmp_path): | ||
| _setup_logging() | ||
|
|
||
| new_file_handlers = [ | ||
| h | ||
| for h in root.handlers | ||
| if isinstance(h, logging.FileHandler) and h not in handlers_before | ||
| ] | ||
| assert len(new_file_handlers) == 1 | ||
| assert "mcp-server.log" in new_file_handlers[0].baseFilename | ||
| finally: | ||
| root.handlers = handlers_before | ||
|
|
||
| def test_caps_stream_handlers_to_warning(self, tmp_path): | ||
| root = logging.getLogger() | ||
| handlers_before = list(root.handlers) | ||
|
|
||
| stream_handler = logging.StreamHandler() | ||
| stream_handler.setLevel(logging.DEBUG) | ||
| root.addHandler(stream_handler) | ||
|
|
||
| try: | ||
| with patch("pathlib.Path.home", return_value=tmp_path): | ||
| _setup_logging() | ||
| assert stream_handler.level == logging.WARNING | ||
| finally: | ||
| root.handlers = handlers_before | ||
|
|
||
| def test_sets_logger_levels(self, tmp_path): | ||
| root = logging.getLogger() | ||
| handlers_before = list(root.handlers) | ||
|
|
||
| loggers = { | ||
| "jumpstarter_mcp": logging.getLogger("jumpstarter_mcp"), | ||
| "jumpstarter": logging.getLogger("jumpstarter"), | ||
| "mcp": logging.getLogger("mcp"), | ||
| } | ||
| saved_levels = {name: lg.level for name, lg in loggers.items()} | ||
|
|
||
| try: | ||
| with patch("pathlib.Path.home", return_value=tmp_path): | ||
| _setup_logging() | ||
| assert loggers["jumpstarter_mcp"].level == logging.DEBUG | ||
| assert loggers["jumpstarter"].level == logging.DEBUG | ||
| assert loggers["mcp"].level == logging.WARNING | ||
| finally: | ||
| root.handlers = handlers_before | ||
| for name, level in saved_levels.items(): | ||
| loggers[name].setLevel(level) | ||
|
Comment on lines
+573
to
+626
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Restore the logging levels these tests mutate. Line 589, Line 604, and Line 624 only replace 🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Stdout isolation (the fd-redirect pattern used by run_server) | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| class TestStdoutIsolation: | ||
| def test_stray_writes_do_not_reach_saved_stdout(self): | ||
| """After the dup/dup2 redirect, writes to sys.stdout (fd 1) must go | ||
| to stderr, while only the saved fd still reaches the real stdout. | ||
| This is the behavioral property that protects MCP JSON-RPC.""" | ||
| import os | ||
| import sys | ||
|
|
||
| r_out, w_out = os.pipe() | ||
| r_err, w_err = os.pipe() | ||
|
|
||
| saved_real_stdout_fd = os.dup(1) | ||
| saved_real_stderr_fd = os.dup(2) | ||
| saved_sys_stdout = sys.stdout | ||
| saved_sys_stderr = sys.stderr | ||
|
|
||
| try: | ||
| os.dup2(w_out, 1) | ||
| os.dup2(w_err, 2) | ||
| sys.stdout = os.fdopen(1, "w", closefd=False) | ||
| sys.stderr = os.fdopen(2, "w", closefd=False) | ||
|
|
||
| # Apply the same redirect pattern as run_server(): | ||
| sys.stdout.flush() | ||
| mcp_fd = os.dup(sys.stdout.fileno()) # save "real stdout" (pipe) | ||
| os.dup2(sys.stderr.fileno(), sys.stdout.fileno()) # fd 1 -> stderr | ||
| sys.stdout = sys.stderr | ||
|
|
||
| # Stray write via sys.stdout -- should land in the stderr pipe | ||
| sys.stdout.write("stray\n") | ||
| sys.stdout.flush() | ||
|
|
||
| # MCP-only write via the saved fd -- should land in the stdout pipe | ||
| mcp_file = os.fdopen(mcp_fd, "w", closefd=True) | ||
| mcp_file.write("mcp-json\n") | ||
| mcp_file.flush() | ||
|
|
||
| os.close(w_out) | ||
| os.close(w_err) | ||
| stdout_data = os.read(r_out, 4096).decode() | ||
| stderr_data = os.read(r_err, 4096).decode() | ||
|
|
||
| assert "mcp-json" in stdout_data | ||
| assert "stray" not in stdout_data | ||
| assert "stray" in stderr_data | ||
| finally: | ||
| os.dup2(saved_real_stdout_fd, 1) | ||
| os.dup2(saved_real_stderr_fd, 2) | ||
| os.close(saved_real_stdout_fd) | ||
| os.close(saved_real_stderr_fd) | ||
| os.close(r_out) | ||
| os.close(r_err) | ||
| sys.stdout = saved_sys_stdout | ||
| sys.stderr = saved_sys_stderr | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is meh, but trying to do the right thing here: modelcontextprotocol/python-sdk#2343 for the long term
They don't let you wire stdio/stdout in the async streams