From fef06c0210dab9b69c97aac88ed144091b106c7c Mon Sep 17 00:00:00 2001 From: Ola Adebayo Date: Mon, 16 Mar 2026 19:40:43 +0000 Subject: [PATCH] fix(gateway): buffer response headers to prevent superfluous WriteHeader The inject middleware's responseRecorder was letting upstream handlers write headers directly to the real ResponseWriter. When httputil.ReverseProxy called Flush(), this prematurely committed headers and caused a second WriteHeader call when the middleware later wrote its own status code. Fix by giving the recorder its own http.Header map so upstream headers are buffered until the middleware is ready to write back. Flush() is now a no-op since the middleware must buffer the full response before deciding whether to inject. Headers are copied back via copyHeaders() before WriteHeader. --- .gitignore | 1 + gateway/internal/inject/inject.go | 31 +++++++++++++--- gateway/internal/inject/inject_test.go | 51 ++++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 813bc48..dd3225a 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ _cgo_gotypes.go _cgo_export.* _obj/ _test/ +.gocache/ # ─── Node / TypeScript (sdk, cli) ───────────────────────────────────────────── **/node_modules/ diff --git a/gateway/internal/inject/inject.go b/gateway/internal/inject/inject.go index d88624d..2003dc0 100644 --- a/gateway/internal/inject/inject.go +++ b/gateway/internal/inject/inject.go @@ -65,6 +65,7 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { rec := &responseRecorder{ ResponseWriter: w, body: &bytes.Buffer{}, + header: make(http.Header), statusCode: http.StatusOK, } @@ -72,9 +73,10 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { m.next.ServeHTTP(rec, r) // Check if the response is HTML. - contentType := rec.Header().Get("Content-Type") + contentType := rec.header.Get("Content-Type") if !isHTML(contentType) { // Not HTML — write the response as-is. + copyHeaders(w.Header(), rec.header) w.WriteHeader(rec.statusCode) if _, err := w.Write(rec.body.Bytes()); err != nil { m.logger.Debug("rep.inject.write_error", "path", r.URL.Path, "error", err) @@ -84,7 +86,7 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Decompress the body if the upstream ignored our Accept-Encoding removal. body := rec.body.Bytes() - encoding := rec.Header().Get("Content-Encoding") + encoding := rec.header.Get("Content-Encoding") if encoding != "" { decompressed, err := decompressBody(body, encoding) if err != nil { @@ -93,6 +95,7 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { "path", r.URL.Path, "reason", "unsupported Content-Encoding: "+encoding, ) + copyHeaders(w.Header(), rec.header) w.WriteHeader(rec.statusCode) if _, err := w.Write(body); err != nil { m.logger.Debug("rep.inject.write_error", "path", r.URL.Path, "error", err) @@ -111,6 +114,8 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Inject the REP script tag into the HTML. injected := injectIntoHTML(body, tag) + copyHeaders(w.Header(), rec.header) + // Update Content-Length to reflect the injected content. w.Header().Set("Content-Length", strconv.Itoa(len(injected))) @@ -260,11 +265,16 @@ func isHTML(contentType string) bool { // responseRecorder captures the upstream response for inspection. type responseRecorder struct { http.ResponseWriter + header http.Header body *bytes.Buffer statusCode int wroteHeader bool } +func (r *responseRecorder) Header() http.Header { + return r.header +} + func (r *responseRecorder) WriteHeader(code int) { r.statusCode = code r.wroteHeader = true @@ -277,12 +287,23 @@ func (r *responseRecorder) Write(b []byte) (int, error) { // Flush implements http.Flusher for streaming support. func (r *responseRecorder) Flush() { - if f, ok := r.ResponseWriter.(http.Flusher); ok { - f.Flush() - } + // Intentionally do nothing. The middleware buffers the full upstream response + // before deciding whether to inject, so flushing here would prematurely commit + // headers/body to the client. } // ReadFrom implements io.ReaderFrom for efficient copies. func (r *responseRecorder) ReadFrom(src io.Reader) (int64, error) { return r.body.ReadFrom(src) } + +func copyHeaders(dst, src http.Header) { + for k := range dst { + dst.Del(k) + } + for k, values := range src { + for _, value := range values { + dst.Add(k, value) + } + } +} diff --git a/gateway/internal/inject/inject_test.go b/gateway/internal/inject/inject_test.go index ad6ae8c..45d390a 100644 --- a/gateway/internal/inject/inject_test.go +++ b/gateway/internal/inject/inject_test.go @@ -191,6 +191,57 @@ func TestMiddleware_ContentLengthUpdated(t *testing.T) { _ = expectedLen // The header value is set by the middleware. } +func TestMiddleware_UpstreamFlushDoesNotCommitEarly(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusCreated) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + _, _ = w.Write([]byte(`flushed`)) + }) + + m := New(upstream, testScriptTag, slog.Default()) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + m.ServeHTTP(rec, req) + + if rec.Code != http.StatusCreated { + t.Fatalf("expected status %d, got %d", http.StatusCreated, rec.Code) + } + body := rec.Body.String() + if !strings.Contains(body, testScriptTag) { + t.Fatal("expected injected script tag in flushed HTML response") + } + if !strings.Contains(body, "flushed") { + t.Fatal("expected original body content to be preserved after flush") + } +} + +func TestMiddleware_BuffersHeadersUntilWriteback(t *testing.T) { + upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.Header().Set("X-REP-Test", "buffered") + _, _ = w.Write([]byte(``)) + }) + + m := New(upstream, testScriptTag, slog.Default()) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + m.ServeHTTP(rec, req) + + if got := rec.Header().Get("X-REP-Test"); got != "buffered" { + t.Fatalf("expected buffered header to be copied back, got %q", got) + } + if got := rec.Header().Get("Content-Type"); got != "text/html" { + t.Fatalf("expected content type to survive buffering, got %q", got) + } +} + func TestIsHTML(t *testing.T) { tests := []struct { ct string