diff --git a/README.md b/README.md index 6ac29be..43c7ed2 100644 --- a/README.md +++ b/README.md @@ -293,6 +293,14 @@ Set `stateless: true` in `MCP::Server::Transports::StreamableHTTPTransport.new` transport = MCP::Server::Transports::StreamableHTTPTransport.new(server, stateless: true) ``` +By default, sessions do not expire. To mitigate session hijacking risks, you can set a `session_idle_timeout` (in seconds). +When configured, sessions that receive no HTTP requests for this duration are automatically expired and cleaned up: + +```ruby +# Session timeout of 30 minutes +transport = MCP::Server::Transports::StreamableHTTPTransport.new(server, session_idle_timeout: 1800) +``` + ### Unsupported Features (to be implemented in future versions) - Resource subscriptions diff --git a/lib/mcp/server/transports/streamable_http_transport.rb b/lib/mcp/server/transports/streamable_http_transport.rb index f1a5d9d..21f1796 100644 --- a/lib/mcp/server/transports/streamable_http_transport.rb +++ b/lib/mcp/server/transports/streamable_http_transport.rb @@ -8,18 +8,30 @@ module MCP class Server module Transports class StreamableHTTPTransport < Transport - def initialize(server, stateless: false) + def initialize(server, stateless: false, session_idle_timeout: nil) super(server) - # { session_id => { stream: stream_object } + # Session data structure: `{ session_id => { stream: stream_object, last_active_at: float_from_monotonic_clock } }`. @sessions = {} @mutex = Mutex.new @stateless = stateless + @session_idle_timeout = session_idle_timeout + + if @session_idle_timeout + if @stateless + raise ArgumentError, "session_idle_timeout is not supported in stateless mode." + elsif @session_idle_timeout <= 0 + raise ArgumentError, "session_idle_timeout must be a positive number." + end + end + + start_reaper_thread if @session_idle_timeout end REQUIRED_POST_ACCEPT_TYPES = ["application/json", "text/event-stream"].freeze REQUIRED_GET_ACCEPT_TYPES = ["text/event-stream"].freeze STREAM_WRITE_ERRORS = [IOError, Errno::EPIPE, Errno::ECONNRESET].freeze + SESSION_REAP_INTERVAL = 60 def handle_request(request) case request.env["REQUEST_METHOD"] @@ -35,6 +47,9 @@ def handle_request(request) end def close + @reaper_thread&.kill + @reaper_thread = nil + @mutex.synchronize do @sessions.each_key { |session_id| cleanup_session_unsafe(session_id) } end @@ -56,6 +71,11 @@ def send_notification(method, params = nil, session_id: nil) session = @sessions[session_id] return false unless session && session[:stream] + if session_expired?(session) + cleanup_session_unsafe(session_id) + return false + end + begin send_to_stream(session[:stream], notification) true @@ -75,6 +95,11 @@ def send_notification(method, params = nil, session_id: nil) @sessions.each do |sid, session| next unless session[:stream] + if session_expired?(session) + failed_sessions << sid + next + end + begin send_to_stream(session[:stream], notification) sent_count += 1 @@ -97,6 +122,39 @@ def send_notification(method, params = nil, session_id: nil) private + def start_reaper_thread + @reaper_thread = Thread.new do + loop do + sleep(SESSION_REAP_INTERVAL) + reap_expired_sessions + rescue StandardError => e + MCP.configuration.exception_reporter.call(e, error: "Session reaper error") + end + end + end + + def reap_expired_sessions + return unless @session_idle_timeout + + expired_streams = @mutex.synchronize do + @sessions.each_with_object([]) do |(session_id, session), streams| + next unless session_expired?(session) + + streams << session[:stream] if session[:stream] + @sessions.delete(session_id) + end + end + + expired_streams.each do |stream| + # Closing outside the mutex is safe because expired sessions are already + # removed from `@sessions` above, so other threads will not find them + # and will not attempt to close the same stream. + stream.close + rescue + nil + end + end + def send_to_stream(stream, data) message = data.is_a?(String) ? data : data.to_json stream.write("data: #{message}\n\n") @@ -141,7 +199,9 @@ def handle_get(request) session_id = extract_session_id(request) return missing_session_id_response unless session_id - return session_not_found_response unless session_exists?(session_id) + + error_response = validate_and_touch_session(session_id) + return error_response if error_response return session_already_connected_response if get_session_stream(session_id) setup_sse_stream(session_id) @@ -235,6 +295,7 @@ def handle_initialization(body_string, body) @mutex.synchronize do @sessions[session_id] = { stream: nil, + last_active_at: Process.clock_gettime(Process::CLOCK_MONOTONIC), } end end @@ -256,8 +317,9 @@ def handle_accepted def handle_regular_request(body_string, session_id) unless @stateless - if session_id && !session_exists?(session_id) - return session_not_found_response + if session_id + error_response = validate_and_touch_session(session_id) + return error_response if error_response end end @@ -273,6 +335,22 @@ def handle_regular_request(body_string, session_id) end end + def validate_and_touch_session(session_id) + @mutex.synchronize do + return session_not_found_response unless (session = @sessions[session_id]) + return unless @session_idle_timeout + + if session_expired?(session) + cleanup_session_unsafe(session_id) + return session_not_found_response + end + + session[:last_active_at] = Process.clock_gettime(Process::CLOCK_MONOTONIC) + end + + nil + end + def get_session_stream(session_id) @mutex.synchronize { @sessions[session_id]&.fetch(:stream, nil) } end @@ -378,6 +456,12 @@ def send_keepalive_ping(session_id) ) raise # Re-raise to exit the keepalive loop end + + def session_expired?(session) + return false unless @session_idle_timeout + + Process.clock_gettime(Process::CLOCK_MONOTONIC) - session[:last_active_at] > @session_idle_timeout + end end end end diff --git a/test/mcp/server/transports/streamable_http_transport_test.rb b/test/mcp/server/transports/streamable_http_transport_test.rb index 4b346c4..0d5b3f4 100644 --- a/test/mcp/server/transports/streamable_http_transport_test.rb +++ b/test/mcp/server/transports/streamable_http_transport_test.rb @@ -17,6 +17,10 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase @transport = StreamableHTTPTransport.new(@server) end + teardown do + @transport.close + end + test "handles POST request with valid JSON-RPC message" do # First create a session init_request = create_rack_request( @@ -1331,16 +1335,53 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase assert_equal([], response[2]) end - test "handles POST request with body including JSON-RPC response object and returns with no body" do + test "expired session returns 404 on GET request" do + transport = StreamableHTTPTransport.new(@server, session_idle_timeout: 0.01) + + # Create a session + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "123" }.to_json, + ) + init_response = transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + assert(session_id) + + # Session should now be expired (timeout is 0) + sleep(0.01) + + get_request = create_rack_request( + "GET", + "/", + { "HTTP_MCP_SESSION_ID" => session_id }, + ) + response = transport.handle_request(get_request) + assert_equal(404, response[0]) + + body = JSON.parse(response[2][0]) + assert_equal("Session not found", body["error"]) + ensure + transport.close + end + + test "expired session returns 404 on POST request" do + transport = StreamableHTTPTransport.new(@server, session_idle_timeout: 0.01) + + # Create a session init_request = create_rack_request( "POST", "/", { "CONTENT_TYPE" => "application/json" }, { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, ) - init_response = @transport.handle_request(init_request) + init_response = transport.handle_request(init_request) session_id = init_response[1]["Mcp-Session-Id"] + # Session should now be expired (timeout is 0) + sleep(0.01) + request = create_rack_request( "POST", "/", @@ -1348,26 +1389,58 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase "CONTENT_TYPE" => "application/json", "HTTP_MCP_SESSION_ID" => session_id, }, - { jsonrpc: "2.0", result: "success", id: "123" }.to_json, + { jsonrpc: "2.0", method: "ping", id: "456" }.to_json, ) - response = @transport.handle_request(request) - assert_equal 202, response[0] - assert_equal({}, response[1]) - assert_equal([], response[2]) + response = transport.handle_request(request) + assert_equal(404, response[0]) + + body = JSON.parse(response[2][0]) + assert_equal("Session not found", body["error"]) + ensure + transport.close end - test "handle_regular_request checks session existence under mutex" do + test "session_idle_timeout: nil disables session expiry" do + transport = StreamableHTTPTransport.new(@server, session_idle_timeout: nil) + init_request = create_rack_request( "POST", "/", { "CONTENT_TYPE" => "application/json" }, { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, ) - init_response = @transport.handle_request(init_request) + init_response = transport.handle_request(init_request) session_id = init_response[1]["Mcp-Session-Id"] - @transport.expects(:session_exists?).with(session_id).returns(true) + # Make a request - session should still be valid + request = create_rack_request( + "POST", + "/", + { + "CONTENT_TYPE" => "application/json", + "HTTP_MCP_SESSION_ID" => session_id, + }, + { jsonrpc: "2.0", method: "ping", id: "456" }.to_json, + ) + + response = transport.handle_request(request) + assert_equal(200, response[0]) + ensure + transport.close + end + + test "session within timeout period remains valid" do + transport = StreamableHTTPTransport.new(@server, session_idle_timeout: 3600) + + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] request = create_rack_request( "POST", @@ -1378,8 +1451,252 @@ class StreamableHTTPTransportTest < ActiveSupport::TestCase }, { jsonrpc: "2.0", method: "ping", id: "456" }.to_json, ) - response = @transport.handle_request(request) + + response = transport.handle_request(request) + assert_equal(200, response[0]) + ensure + transport.close + end + + test "session activity resets the idle timeout" do + transport = StreamableHTTPTransport.new(@server, session_idle_timeout: 0.5) + + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + + # Send requests every 0.2s to keep the session alive. + # Total elapsed time (~0.6s) exceeds timeout (0.5s), but each request + # resets the idle timer so the session remains valid. + 3.times do + sleep(0.2) + request = create_rack_request( + "POST", + "/", + { + "CONTENT_TYPE" => "application/json", + "HTTP_MCP_SESSION_ID" => session_id, + }, + { jsonrpc: "2.0", method: "ping", id: "456" }.to_json, + ) + response = transport.handle_request(request) + assert_equal(200, response[0]) + end + ensure + transport.close + end + + test "reaper thread cleans up expired sessions" do + transport = StreamableHTTPTransport.new(@server, session_idle_timeout: 0.01) + + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + assert(session_id) + + # Wait for session to expire + sleep(0.02) + + # Manually trigger reaper since the background thread runs on 60s interval + transport.send(:reap_expired_sessions) + + # Session should have been reaped + get_request = create_rack_request( + "GET", + "/", + { "HTTP_MCP_SESSION_ID" => session_id }, + ) + response = transport.handle_request(get_request) + assert_equal(404, response[0]) + ensure + transport.close + end + + test "reaper thread cleans up expired sessions and POST returns 404" do + transport = StreamableHTTPTransport.new(@server, session_idle_timeout: 0.01) + + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + + # Wait for the session to exceed the idle timeout (0.01s) + sleep(0.02) + transport.send(:reap_expired_sessions) + + # POST to a reaped session should also return 404 + request = create_rack_request( + "POST", + "/", + { + "CONTENT_TYPE" => "application/json", + "HTTP_MCP_SESSION_ID" => session_id, + }, + { jsonrpc: "2.0", method: "ping", id: "456" }.to_json, + ) + response = transport.handle_request(request) + assert_equal(404, response[0]) + + body = JSON.parse(response[2][0]) + assert_equal("Session not found", body["error"]) + ensure + transport.close + end + + test "close stops the reaper thread" do + transport = StreamableHTTPTransport.new(@server, session_idle_timeout: 3600) + reaper_thread = transport.instance_variable_get(:@reaper_thread) + assert reaper_thread + assert reaper_thread.alive? + + transport.close + + sleep(0.01) + refute reaper_thread.alive? + assert_nil transport.instance_variable_get(:@reaper_thread) + end + + test "reaper thread is not started when session_idle_timeout is nil" do + transport = StreamableHTTPTransport.new(@server, session_idle_timeout: nil) + assert_nil(transport.instance_variable_get(:@reaper_thread)) + ensure + transport.close + end + + test "default session_idle_timeout is nil and sessions do not expire" do + transport = StreamableHTTPTransport.new(@server) + assert_nil(transport.instance_variable_get(:@reaper_thread)) + + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + + request = create_rack_request( + "POST", + "/", + { + "CONTENT_TYPE" => "application/json", + "HTTP_MCP_SESSION_ID" => session_id, + }, + { jsonrpc: "2.0", method: "ping", id: "456" }.to_json, + ) + + response = transport.handle_request(request) assert_equal(200, response[0]) + ensure + transport.close + end + + test "raises ArgumentError when session_idle_timeout is zero" do + error = assert_raises(ArgumentError) do + StreamableHTTPTransport.new(@server, session_idle_timeout: 0) + end + assert_equal("session_idle_timeout must be a positive number.", error.message) + end + + test "raises ArgumentError when session_idle_timeout is negative" do + error = assert_raises(ArgumentError) do + StreamableHTTPTransport.new(@server, session_idle_timeout: -1) + end + assert_equal("session_idle_timeout must be a positive number.", error.message) + end + + test "raises ArgumentError when session_idle_timeout is used with stateless mode" do + error = assert_raises(ArgumentError) do + StreamableHTTPTransport.new(@server, stateless: true, session_idle_timeout: 3600) + end + assert_equal("session_idle_timeout is not supported in stateless mode.", error.message) + end + + test "expired session does not receive targeted notification" do + transport = StreamableHTTPTransport.new(@server, session_idle_timeout: 0.01) + + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + + # Wait for the session to exceed the idle timeout (0.01s) + sleep(0.02) + + result = transport.send_notification("test/notify", { message: "hello" }, session_id: session_id) + refute(result) + ensure + transport.close + end + + test "expired session is skipped during broadcast notification" do + transport = StreamableHTTPTransport.new(@server, session_idle_timeout: 0.01) + + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + + # Attach a mock stream to the session + stream = StringIO.new + transport.instance_variable_get(:@sessions)[session_id][:stream] = stream + + # Wait for the session to exceed the idle timeout (0.01s) + sleep(0.02) + + sent_count = transport.send_notification("test/notify", { message: "hello" }, **{}) + assert_equal(0, sent_count) + ensure + transport.close + end + + test "handles POST request with body including JSON-RPC response object and returns with no body" do + init_request = create_rack_request( + "POST", + "/", + { "CONTENT_TYPE" => "application/json" }, + { jsonrpc: "2.0", method: "initialize", id: "init" }.to_json, + ) + init_response = @transport.handle_request(init_request) + session_id = init_response[1]["Mcp-Session-Id"] + + request = create_rack_request( + "POST", + "/", + { + "CONTENT_TYPE" => "application/json", + "HTTP_MCP_SESSION_ID" => session_id, + }, + { jsonrpc: "2.0", result: "success", id: "123" }.to_json, + ) + + response = @transport.handle_request(request) + assert_equal 202, response[0] + assert_equal({}, response[1]) + assert_equal([], response[2]) end private