diff --git a/docs/2026-04-22-websocket-reverseproxy-design.md b/docs/2026-04-22-websocket-reverseproxy-design.md new file mode 100644 index 0000000..5917bb4 --- /dev/null +++ b/docs/2026-04-22-websocket-reverseproxy-design.md @@ -0,0 +1,134 @@ +# WebSocket Support via ReverseProxy Refactor + +**Date:** 2026-04-22 +**Status:** Approved +**Scope:** `proxy/proxy.go` — `handleConnectWithInterception` + +## Problem + +Gatekeeper's TLS interception path manually reads HTTP requests in a loop (`http.ReadRequest` → `transport.RoundTrip` → `resp.Write`). After a WebSocket upgrade (HTTP 101 Switching Protocols), the client sends binary WebSocket frames which `http.ReadRequest` cannot parse, causing `"malformed HTTP request"` errors and connection drops. + +## Solution + +Replace the manual request loop in `handleConnectWithInterception` with an `http.Server` serving on the client-side TLS connection, using `httputil.ReverseProxy` as the handler. Go 1.25's `ReverseProxy` natively handles WebSocket upgrades — it detects `Upgrade` headers, preserves them through hop-by-hop removal, hijacks both sides on a `101` response, and does bidirectional `io.Copy`. + +## Architecture + +``` +Client ←TLS→ http.Server(tlsClientConn) → ReverseProxy → upstream +``` + +### Flow + +1. CONNECT arrives, proxy hijacks, sends `200 Connection Established` (unchanged) +2. TLS handshake with client using generated cert (unchanged) +3. **New:** Create a single-connection `http.Server` with `httputil.ReverseProxy` as handler +4. `http.Server.Serve()` manages the request loop (replaces manual `for` + `http.ReadRequest`) +5. For normal HTTP: `ReverseProxy` forwards via `Transport.RoundTrip`, credential injection in `Rewrite` +6. For WebSocket: `ReverseProxy` detects `101`, hijacks, bidirectional copy — no custom code needed + +### Feature Mapping + +Every feature in the current manual loop maps to a `ReverseProxy` hook: + +| Feature | Current location | New location | +|---|---|---| +| Network policy check | Loop body | Wrapping handler (before ReverseProxy) | +| Keep HTTP policy | Loop body | Wrapping handler (before ReverseProxy) | +| Credential injection (`injectCredentials`) | Loop body | `Rewrite` on `ProxyRequest.Out` | +| MCP credential injection | Loop body | `Rewrite` | +| Extra headers / remove headers | Loop body | `Rewrite` | +| Token substitution | Loop body | `Rewrite` | +| Request ID generation | Loop body | `Rewrite` | +| Host gateway IP rewrite | Loop body, modifies dial target | `Rewrite` (rewrite URL host) or custom `Transport.DialContext` | +| Proxy-Authorization stripping | Loop body | `Rewrite` (read from `ProxyRequest.In` before hop-by-hop removal) | +| Credential resolver (token-exchange) | Loop body | `Rewrite` (read subject from `In.Header`, resolve, set on `Out`) | +| LLM gateway policy | Loop body, post-response | `ModifyResponse` | +| Response transformers | Loop body, post-response | `ModifyResponse` | +| Body capture for logging | Loop body | `ModifyResponse` (response) and `Rewrite` (request) | +| Canonical log line | Loop body | `ModifyResponse` + `ErrorHandler` | +| OTel span/metrics | Loop body via callbacks | Wrapping handler or `ModifyResponse` | +| Transport error → 502 | Loop body | `ErrorHandler` | +| WebSocket upgrade | **Not supported** | Built-in `ReverseProxy.handleUpgradeResponse` | + +### Key Design Decisions + +**Proxy-Authorization before hop-by-hop removal:** `ReverseProxy` strips hop-by-hop headers (including `Proxy-Authorization`) before calling `Rewrite`. For `subject_from: proxy-auth` token exchange, the subject identity must be extracted from `ProxyRequest.In` (which preserves original headers) rather than `ProxyRequest.Out`. + +**Single-connection http.Server:** The `http.Server` serves on a `net.Listener` wrapping the single TLS connection. When the connection closes, `Serve` returns. This replaces the manual `for` loop and gets HTTP keepalive, pipelining, and protocol upgrade handling from the stdlib. + +**Per-connection transport:** The `http.Transport` is created per-CONNECT connection (same as today). `ForceAttemptHTTP2` remains disabled — the intercepted connection reads HTTP/1.1. + +**No behavioral changes:** All external APIs (`Proxy`, `RunContextData`, config) remain identical. This is purely an internal refactor of one function. + +### Handling Policy Denials in Rewrite + +The current loop writes error responses (407, 403, 502) directly to the TLS connection and continues the loop. With `ReverseProxy`, the `Rewrite` function cannot write responses directly. Two options: + +**Option A — Wrapping handler:** A handler that runs policy checks before delegating to `ReverseProxy`. On denial, it writes the error response itself and does not call `ReverseProxy.ServeHTTP`. This is the cleanest approach. + +**Option B — Rewrite sets a sentinel, ErrorHandler acts on it.** `Rewrite` stores a denial in the request context, `ModifyResponse` or a custom `RoundTripper` wrapper checks for it. More complex, less readable. + +**Decision:** Option A. The wrapping handler pattern is idiomatic and keeps policy logic separate from forwarding logic. + +```go +func (p *Proxy) interceptHandler(host string, rc *RunContextData, transport *http.Transport) http.Handler { + rp := &httputil.ReverseProxy{ + Rewrite: p.rewriteIntercepted(host, rc), + Transport: transport, + ModifyResponse: p.modifyInterceptedResponse(host, rc), + ErrorHandler: p.interceptErrorHandler(host, rc), + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Network policy, Keep HTTP policy checks here + // On denial: write error response, return + // On allow: rp.ServeHTTP(w, r) + }) +} +``` + +## Testing Strategy + +Tests are written first (TDD) against the current code to establish behavioral baselines, then the refactor must keep them passing. + +### New tests to add before refactor + +1. **Normal HTTPS through interception** — credential injection verified on upstream request +2. **WebSocket upgrade through interception** — upgrade succeeds, bidirectional frame exchange works (will fail against current code, pass after refactor) +3. **Multi-request keepalive** — multiple requests over single CONNECT tunnel +4. **Network policy denial on inner request** — 407 returned, connection stays alive +5. **Transport error** — unreachable upstream, 502 returned, canonical log line emitted +6. **Credential resolver via CONNECT** — token-exchange with `subject_from: proxy-auth` +7. **Host gateway through interception** — gateway hostname rewritten to actual IP + +### Existing tests that must keep passing + +All tests in `proxy/proxy_test.go`, particularly: +- `TestProxy_CanonicalLogLine_ConnectTransportError` +- `TestProxy_CanonicalLogLine_ConnectBlocked` +- All credential injection, policy, and logging tests + +## Implementation Plan + +### Phase 1: Test baseline (TDD) +Write the new tests listed above against the current code. All should pass except the WebSocket test. + +### Phase 2: Extract helpers +Extract the inline policy/credential/logging logic from the current loop into named methods that can be called from both the old loop and the new handler. This is a refactor-only step — no behavioral changes. + +### Phase 3: Build the ReverseProxy handler +Implement `interceptHandler` with `Rewrite`, `ModifyResponse`, `ErrorHandler`, and the wrapping handler for policy checks. Wire it into `handleConnectWithInterception` replacing the manual loop. + +### Phase 4: WebSocket test passes +The WebSocket upgrade test should now pass with zero additional code. + +### Phase 5: Verify and clean up +Run full test suite, remove dead code from the old loop, verify OTel instrumentation. + +## Out of Scope + +- Changing the non-interception tunnel path (`handleConnectTunnel`) +- Changing the HTTP relay path (`handleHTTP`) +- Changing the MCP relay handler +- Config schema changes +- New config options for WebSocket-specific behavior diff --git a/proxy/intercept_test.go b/proxy/intercept_test.go new file mode 100644 index 0000000..213c269 --- /dev/null +++ b/proxy/intercept_test.go @@ -0,0 +1,472 @@ +package proxy + +import ( + "bytes" + "crypto/tls" + "crypto/x509" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" +) + +// interceptTestSetup creates a proxy with TLS interception enabled and an HTTPS +// backend server. The proxy is configured to trust the backend's TLS cert and +// the returned client trusts the proxy's interception CA. +type interceptTestSetup struct { + Proxy *Proxy + ProxyServer *httptest.Server + Backend *httptest.Server + Client *http.Client + CA *CA + BackendHost string // hostname only (e.g., 127.0.0.1) — for credential matching + BackendHostPort string // host:port (e.g., 127.0.0.1:12345) — for extra/remove header matching +} + +func newInterceptTestSetup(t *testing.T, backendHandler http.Handler) *interceptTestSetup { + t.Helper() + + ca, err := generateCA() + if err != nil { + t.Fatal(err) + } + + backend := httptest.NewTLSServer(backendHandler) + + // Build a CA pool that trusts the backend's TLS cert. + upstreamCAs := x509.NewCertPool() + upstreamCAs.AddCert(backend.Certificate()) + + p := NewProxy() + p.SetCA(ca) + p.SetUpstreamCAs(upstreamCAs) + + proxyServer := httptest.NewServer(p) + + // Client trusts the interception CA and routes through the proxy. + clientCAs := x509.NewCertPool() + clientCAs.AppendCertsFromPEM(ca.certPEM) + + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(mustParseURL(proxyServer.URL)), + TLSClientConfig: &tls.Config{RootCAs: clientCAs}, + }, + } + + backendHost := mustParseURL(backend.URL).Host // host:port for extra header matching (uses r.Host) + backendHostname := mustParseURL(backend.URL).Hostname() // hostname only for credential matching + + t.Cleanup(func() { + proxyServer.Close() + backend.Close() + }) + + return &interceptTestSetup{ + Proxy: p, + ProxyServer: proxyServer, + Backend: backend, + Client: client, + CA: ca, + BackendHost: backendHostname, + BackendHostPort: backendHost, + } +} + +func TestIntercept_CredentialInjection(t *testing.T) { + var receivedAuth string + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuth = r.Header.Get("Authorization") + w.Write([]byte("ok")) + })) + + setup.Proxy.SetCredentialWithGrant(setup.BackendHost, "Authorization", "Bearer test-token-123", "test-grant") + + resp, err := setup.Client.Get(setup.Backend.URL + "/api/data") + if err != nil { + t.Fatalf("request: %v", err) + } + defer resp.Body.Close() + io.ReadAll(resp.Body) + + if resp.StatusCode != 200 { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + if receivedAuth != "Bearer test-token-123" { + t.Errorf("Authorization = %q, want %q", receivedAuth, "Bearer test-token-123") + } +} + +func TestIntercept_CredentialInjectionCanonicalLog(t *testing.T) { + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("ok")) + })) + + setup.Proxy.SetCredentialWithGrant(setup.BackendHost, "Authorization", "Bearer granted-token", "my-grant") + + var logged RequestLogData + setup.Proxy.SetLogger(func(data RequestLogData) { + logged = data + }) + + resp, err := setup.Client.Get(setup.Backend.URL + "/resource") + if err != nil { + t.Fatalf("request: %v", err) + } + resp.Body.Close() + + if !logged.AuthInjected { + t.Error("expected AuthInjected=true") + } + if len(logged.Grants) == 0 || logged.Grants[0] != "my-grant" { + t.Errorf("Grants = %v, want [my-grant]", logged.Grants) + } + if logged.RequestType != "connect" { + t.Errorf("RequestType = %q, want connect", logged.RequestType) + } + if logged.RequestID == "" { + t.Error("expected non-empty RequestID") + } +} + +func TestIntercept_MultiRequestKeepalive(t *testing.T) { + var requestCount int + var mu sync.Mutex + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + requestCount++ + mu.Unlock() + w.Write([]byte("ok")) + })) + + for i := 0; i < 5; i++ { + resp, err := setup.Client.Get(setup.Backend.URL + "/ping") + if err != nil { + t.Fatalf("request %d: %v", i, err) + } + resp.Body.Close() + if resp.StatusCode != 200 { + t.Fatalf("request %d: status = %d, want 200", i, resp.StatusCode) + } + } + + mu.Lock() + defer mu.Unlock() + if requestCount != 5 { + t.Errorf("requestCount = %d, want 5", requestCount) + } +} + +func TestIntercept_NetworkPolicyDenial(t *testing.T) { + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("backend should not be reached on denied request") + })) + + // Strict policy with no allows — denies everything at the inner request level. + setup.Proxy.SetNetworkPolicy("strict", nil, nil) + + var logged RequestLogData + setup.Proxy.SetLogger(func(data RequestLogData) { + logged = data + }) + + // The CONNECT itself will be denied before TLS interception. + resp, err := setup.Client.Get(setup.Backend.URL + "/blocked") + if err == nil { + resp.Body.Close() + // Under strict policy with no allows, CONNECT is denied with 407. + if resp.StatusCode != http.StatusProxyAuthRequired { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusProxyAuthRequired) + } + } + // The client may get a transport error if CONNECT is blocked. + // Either way, the request should be denied. + if !logged.Denied { + t.Error("expected Denied=true in log") + } +} + +func TestIntercept_TransportError502(t *testing.T) { + ca, err := generateCA() + if err != nil { + t.Fatal(err) + } + + p := NewProxy() + p.SetCA(ca) + + var logged RequestLogData + p.SetLogger(func(data RequestLogData) { + logged = data + }) + + proxyServer := httptest.NewServer(p) + defer proxyServer.Close() + + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(ca.certPEM) + + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(mustParseURL(proxyServer.URL)), + TLSClientConfig: &tls.Config{RootCAs: caCertPool}, + }, + } + + // Connect to a port nothing listens on. + resp, err := client.Get("https://127.0.0.1:1/nope") + if err != nil { + t.Fatalf("request: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusBadGateway { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusBadGateway) + } + if logged.Err == nil { + t.Error("expected error in canonical log") + } + if logged.RequestType != "connect" { + t.Errorf("RequestType = %q, want connect", logged.RequestType) + } +} + +func TestIntercept_CanonicalLogFields(t *testing.T) { + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("hello")) + })) + + setup.Proxy.SetCredentialWithGrant(setup.BackendHost, "Authorization", "Bearer tok", "test-grant") + + var logged RequestLogData + setup.Proxy.SetLogger(func(data RequestLogData) { + logged = data + }) + + resp, err := setup.Client.Get(setup.Backend.URL + "/some/path") + if err != nil { + t.Fatalf("request: %v", err) + } + resp.Body.Close() + + if logged.Method != "GET" { + t.Errorf("Method = %q, want GET", logged.Method) + } + backendHostname := mustParseURL(setup.Backend.URL).Hostname() + if logged.Host != backendHostname { + t.Errorf("Host = %q, want %q", logged.Host, backendHostname) + } + if logged.Path != "/some/path" { + t.Errorf("Path = %q, want /some/path", logged.Path) + } + if logged.StatusCode != 200 { + t.Errorf("StatusCode = %d, want 200", logged.StatusCode) + } + if logged.RequestType != "connect" { + t.Errorf("RequestType = %q, want connect", logged.RequestType) + } + if !logged.AuthInjected { + t.Error("expected AuthInjected=true") + } +} + +func TestIntercept_ExtraHeaders(t *testing.T) { + var receivedHeaders http.Header + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.Write([]byte("ok")) + })) + + setup.Proxy.AddExtraHeader(setup.BackendHost, "X-Custom-Header", "custom-value") + + resp, err := setup.Client.Get(setup.Backend.URL + "/test") + if err != nil { + t.Fatalf("request: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != 200 { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + if receivedHeaders.Get("X-Custom-Header") != "custom-value" { + t.Errorf("X-Custom-Header = %q, want custom-value", receivedHeaders.Get("X-Custom-Header")) + } +} + +func TestIntercept_RemoveHeaders(t *testing.T) { + var receivedAPIKey string + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAPIKey = r.Header.Get("X-Api-Key") + w.Write([]byte("ok")) + })) + + setup.Proxy.RemoveRequestHeader(setup.BackendHost, "X-Api-Key") + + req, _ := http.NewRequest("GET", setup.Backend.URL+"/test", nil) + req.Header.Set("X-Api-Key", "stale-key") + resp, err := setup.Client.Do(req) + if err != nil { + t.Fatalf("request: %v", err) + } + resp.Body.Close() + + if receivedAPIKey != "" { + t.Errorf("X-Api-Key should be removed, got %q", receivedAPIKey) + } +} + +func TestIntercept_RequestBodyForwarded(t *testing.T) { + var receivedBody string + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + receivedBody = string(body) + w.Write([]byte("ok")) + })) + + reqBody := `{"key": "value"}` + resp, err := setup.Client.Post(setup.Backend.URL+"/submit", "application/json", strings.NewReader(reqBody)) + if err != nil { + t.Fatalf("request: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != 200 { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + if receivedBody != reqBody { + t.Errorf("body = %q, want %q", receivedBody, reqBody) + } +} + +func TestIntercept_LargeResponseBody(t *testing.T) { + // 1MB response body to verify streaming works. + largeBody := bytes.Repeat([]byte("x"), 1<<20) + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write(largeBody) + })) + + resp, err := setup.Client.Get(setup.Backend.URL + "/large") + if err != nil { + t.Fatalf("request: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + if len(body) != len(largeBody) { + t.Errorf("body length = %d, want %d", len(body), len(largeBody)) + } +} + +func TestIntercept_ResponseStatusCodes(t *testing.T) { + codes := []int{200, 201, 204, 301, 400, 404, 500} + + for _, code := range codes { + t.Run(http.StatusText(code), func(t *testing.T) { + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(code) + })) + + resp, err := setup.Client.Get(setup.Backend.URL + "/status") + if err != nil { + t.Fatalf("request: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != code { + t.Errorf("status = %d, want %d", resp.StatusCode, code) + } + }) + } +} + +func TestIntercept_ResponseHeaders(t *testing.T) { + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Backend-Header", "backend-value") + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{}`)) + })) + + resp, err := setup.Client.Get(setup.Backend.URL + "/headers") + if err != nil { + t.Fatalf("request: %v", err) + } + resp.Body.Close() + + if resp.Header.Get("X-Backend-Header") != "backend-value" { + t.Errorf("X-Backend-Header = %q, want backend-value", resp.Header.Get("X-Backend-Header")) + } + if resp.Header.Get("Content-Type") != "application/json" { + t.Errorf("Content-Type = %q, want application/json", resp.Header.Get("Content-Type")) + } +} + +func TestIntercept_XRequestIdInjected(t *testing.T) { + var receivedRequestID string + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedRequestID = r.Header.Get("X-Request-Id") + w.Write([]byte("ok")) + })) + + resp, err := setup.Client.Get(setup.Backend.URL + "/rid") + if err != nil { + t.Fatalf("request: %v", err) + } + resp.Body.Close() + + if receivedRequestID == "" { + t.Error("expected X-Request-Id to be injected") + } + if !strings.HasPrefix(receivedRequestID, "req_") { + t.Errorf("X-Request-Id = %q, expected req_ prefix", receivedRequestID) + } +} + +func TestIntercept_ProxyAuthorizationStripped(t *testing.T) { + var receivedProxyAuth string + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedProxyAuth = r.Header.Get("Proxy-Authorization") + w.Write([]byte("ok")) + })) + + resp, err := setup.Client.Get(setup.Backend.URL + "/strip") + if err != nil { + t.Fatalf("request: %v", err) + } + resp.Body.Close() + + // Proxy-Authorization should be stripped before forwarding upstream. + if receivedProxyAuth != "" { + t.Errorf("Proxy-Authorization should be stripped, got %q", receivedProxyAuth) + } +} + +func TestIntercept_HTTPMethods(t *testing.T) { + methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH"} + + for _, method := range methods { + t.Run(method, func(t *testing.T) { + var receivedMethod string + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedMethod = r.Method + w.Write([]byte("ok")) + })) + + req, _ := http.NewRequest(method, setup.Backend.URL+"/method", nil) + resp, err := setup.Client.Do(req) + if err != nil { + t.Fatalf("request: %v", err) + } + resp.Body.Close() + + if receivedMethod != method { + t.Errorf("method = %q, want %q", receivedMethod, method) + } + }) + } +}