diff --git a/pkg/tools/mcp/remote.go b/pkg/tools/mcp/remote.go index 0e2600323..190a308aa 100644 --- a/pkg/tools/mcp/remote.go +++ b/pkg/tools/mcp/remote.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" "net/http" + "net/http/cookiejar" gomcp "github.com/modelcontextprotocol/go-sdk/mcp" @@ -57,7 +58,10 @@ func (c *remoteMCPClient) Initialize(ctx context.Context, _ *gomcp.InitializeReq // oauthTransport so we can enrich Connect errors with the server's own // explanation — without this, a plain `Bad Request` bubbles up and the // user has no idea that, say, the Slack app hasn't been enabled for MCP. - httpClient, oauthT := c.createHTTPClient() + httpClient, oauthT, err := c.createHTTPClient() + if err != nil { + return nil, fmt.Errorf("creating HTTP client: %w", err) + } var transport gomcp.Transport @@ -146,7 +150,7 @@ func (c *remoteMCPClient) SetManagedOAuth(managed bool) { // The oauthTransport is returned alongside the client so callers can inspect // the most recent server-side failure (via lastServerError) when Connect() // returns a bare HTTP-status error and we need to surface the actual cause. -func (c *remoteMCPClient) createHTTPClient() (*http.Client, *oauthTransport) { +func (c *remoteMCPClient) createHTTPClient() (*http.Client, *oauthTransport, error) { base := c.headerTransport() // Then wrap with OAuth support @@ -160,7 +164,13 @@ func (c *remoteMCPClient) createHTTPClient() (*http.Client, *oauthTransport) { oauthHTTPClient: oauthHTTPClientForAllowPrivateIPs(c.allowPrivateIPs), } - return &http.Client{Transport: oauthT}, oauthT + // Persist cookies across requests + // So sticky sessions work if implemented by the server (e.g. in a multiple replica setup) + jar, err := cookiejar.New(nil) + if err != nil { + return nil, nil, fmt.Errorf("creating cookie jar: %w", err) + } + return &http.Client{Transport: oauthT, Jar: jar}, oauthT, nil } func (c *remoteMCPClient) headerTransport() http.RoundTripper { diff --git a/pkg/tools/mcp/remote_test.go b/pkg/tools/mcp/remote_test.go index 5baf61232..be8391c4d 100644 --- a/pkg/tools/mcp/remote_test.go +++ b/pkg/tools/mcp/remote_test.go @@ -5,6 +5,8 @@ import ( "fmt" "net/http" "net/http/httptest" + "strings" + "sync/atomic" "testing" "time" @@ -343,6 +345,53 @@ func TestInitialize_OAuthDefersWhenElicitationBridgeNotReady(t *testing.T) { } } +// TestCreateHTTPClient_PersistsCookies verifies that the *http.Client returned +// by createHTTPClient has a cookie jar, so sticky-session cookies set by the +// remote MCP ingress are echoed back on subsequent requests. +func TestCreateHTTPClient_PersistsCookies(t *testing.T) { + t.Parallel() + + var requestCount atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := requestCount.Add(1) + switch n { + case 1: + if _, err := r.Cookie("mcp_session"); err == nil { + t.Errorf("first request must not carry mcp_session cookie, got one") + } + w.Header().Set("Set-Cookie", "mcp_session=abc123; Path=/") + w.WriteHeader(http.StatusOK) + default: + cookie := r.Header.Get("Cookie") + if !strings.Contains(cookie, "mcp_session=abc123") { + t.Errorf("subsequent request must carry mcp_session=abc123, got Cookie=%q", cookie) + } + w.WriteHeader(http.StatusOK) + } + })) + defer server.Close() + + client := newRemoteClient(server.URL, "streamable", nil, NewInMemoryTokenStore(), nil, false) + httpClient, _, err := client.createHTTPClient() + require.NoError(t, err) + require.NotNil(t, httpClient.Jar, "createHTTPClient must attach a cookie jar so sticky sessions stick") + + req1, err := http.NewRequestWithContext(t.Context(), http.MethodGet, server.URL, http.NoBody) + require.NoError(t, err) + resp1, err := httpClient.Do(req1) + require.NoError(t, err) + _ = resp1.Body.Close() + + req2, err := http.NewRequestWithContext(t.Context(), http.MethodGet, server.URL, http.NoBody) + require.NoError(t, err) + resp2, err := httpClient.Do(req2) + require.NoError(t, err) + _ = resp2.Body.Close() + + require.Equal(t, int32(2), requestCount.Load(), "handler should have served both requests") +} + func TestNewRemoteToolsetWithAllowPrivateIPsPropagatesToClient(t *testing.T) { t.Parallel()