Skip to content

Commit 2fa9693

Browse files
authored
feat: Record requests from the same test into one file (#17)
Users can group replay requests & response into the same file by setting the request header `Test-Name`
1 parent cd965dc commit 2fa9693

4 files changed

Lines changed: 116 additions & 34 deletions

File tree

internal/record/recording_https_proxy.go

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import (
3333

3434
type RecordingHTTPSProxy struct {
3535
prevRequestSHA string
36+
seenFiles map[string]struct{}
3637
config *config.EndpointConfig
3738
recordingDir string
3839
redactor *redact.Redact
@@ -41,6 +42,7 @@ type RecordingHTTPSProxy struct {
4142
func NewRecordingHTTPSProxy(cfg *config.EndpointConfig, recordingDir string, redactor *redact.Redact) *RecordingHTTPSProxy {
4243
return &RecordingHTTPSProxy{
4344
prevRequestSHA: store.HeadSHA,
45+
seenFiles: make(map[string]struct{}),
4446
config: cfg,
4547
recordingDir: recordingDir,
4648
redactor: redactor,
@@ -70,7 +72,7 @@ func (r *RecordingHTTPSProxy) handleRequest(w http.ResponseWriter, req *http.Req
7072
}
7173
fmt.Printf("Recording request: %s %s\n", req.Method, req.URL.String())
7274

73-
recReq, err := r.recordRequest(req)
75+
recReq, err := r.redactRequest(req)
7476
if err != nil {
7577
fmt.Printf("Error recording request: %v\n", err)
7678
http.Error(w, fmt.Sprintf("Error recording request: %v", err), http.StatusInternalServerError)
@@ -82,6 +84,10 @@ func (r *RecordingHTTPSProxy) handleRequest(w http.ResponseWriter, req *http.Req
8284
http.Error(w, fmt.Sprintf("Invalid recording file name: %v", err), http.StatusInternalServerError)
8385
return
8486
}
87+
if _, ok := r.seenFiles[fileName]; !ok {
88+
// Reset to HeadSHA when first time seen a request from the given file.
89+
recReq.PreviousRequest=store.HeadSHA
90+
}
8591

8692
if req.Header.Get("Upgrade") == "websocket" {
8793
fmt.Printf("Upgrading connection to websocket...\n")
@@ -95,17 +101,20 @@ func (r *RecordingHTTPSProxy) handleRequest(w http.ResponseWriter, req *http.Req
95101
http.Error(w, fmt.Sprintf("Error proxying request: %v", err), http.StatusInternalServerError)
96102
return
97103
}
98-
99-
err = r.recordResponse(resp, fileName, respBody)
100-
104+
shaSum := recReq.ComputeSum()
105+
err = r.recordResponse(recReq, resp, fileName, shaSum, respBody)
101106
if err != nil {
102107
fmt.Printf("Error recording response: %v\n", err)
103108
http.Error(w, fmt.Sprintf("Error recording response: %v", err), http.StatusInternalServerError)
104109
return
105110
}
111+
if (fileName != shaSum) {
112+
r.prevRequestSHA = shaSum
113+
}
114+
r.seenFiles[fileName] = struct{}{}
106115
}
107116

108-
func (r *RecordingHTTPSProxy) recordRequest(req *http.Request) (*store.RecordedRequest, error) {
117+
func (r *RecordingHTTPSProxy) redactRequest(req *http.Request) (*store.RecordedRequest, error) {
109118
recordedRequest, err := store.NewRecordedRequest(req, r.prevRequestSHA, *r.config)
110119
if err != nil {
111120
return recordedRequest, err
@@ -117,17 +126,6 @@ func (r *RecordingHTTPSProxy) recordRequest(req *http.Request) (*store.RecordedR
117126
r.redactor.Headers(recordedRequest.Header)
118127
recordedRequest.Request = r.redactor.String(recordedRequest.Request)
119128
recordedRequest.Body = r.redactor.Bytes(recordedRequest.Body)
120-
121-
fileName, err := recordedRequest.GetRecordingFileName()
122-
if err != nil {
123-
fmt.Printf("Invalid recording file name: %v\n", err)
124-
return recordedRequest, err
125-
}
126-
recordPath := filepath.Join(r.recordingDir, fileName+".req")
127-
err = os.WriteFile(recordPath, []byte(recordedRequest.Serialize()), 0644)
128-
if err != nil {
129-
return recordedRequest, err
130-
}
131129
return recordedRequest, nil
132130
}
133131

@@ -178,21 +176,47 @@ func (r *RecordingHTTPSProxy) proxyRequest(w http.ResponseWriter, req *http.Requ
178176
return resp, respBodyBytes, nil
179177
}
180178

181-
func (r *RecordingHTTPSProxy) recordResponse(resp *http.Response, fileName string, body []byte) error {
179+
func (r *RecordingHTTPSProxy) recordResponse(recReq *store.RecordedRequest, resp *http.Response, fileName string, shaSum string, body []byte) error {
182180
recordedResponse, err := store.NewRecordedResponse(resp, body)
183181
if err != nil {
184182
return err
185183
}
184+
recordPath := filepath.Join(r.recordingDir, fileName+".http.log")
186185

187-
recordedResponse.Body = r.redactor.Bytes(recordedResponse.Body)
186+
// Default to overwriting the file, assuming it's a new file to record.
187+
fileMode := os.O_TRUNC
188+
// If we've seen requests with the same file name before, change the mode to append.
189+
if _, ok := r.seenFiles[fileName]; ok {
190+
fileMode = os.O_APPEND
191+
}
192+
file, err := os.OpenFile(recordPath, fileMode|os.O_CREATE|os.O_WRONLY , 0644)
193+
if err != nil {
194+
return err
195+
}
196+
defer file.Close()
188197

189-
recordPath := filepath.Join(r.recordingDir, fileName+".resp")
190-
fmt.Printf("Writing response to: %s\n", recordPath)
191-
err = os.WriteFile(recordPath, []byte(recordedResponse.Serialize()), 0644)
198+
fmt.Printf("Writing request to: %s\n", recordPath)
199+
serializedReq := recReq.Serialize()
200+
_, err = file.WriteString(fmt.Sprintf("%s.req %d\n", shaSum, len(serializedReq)))
201+
if err != nil {
202+
return err
203+
}
204+
_, err = file.WriteString(serializedReq)
192205
if err != nil {
193206
return err
194207
}
195208

209+
fmt.Printf("Writing response to: %s\n", recordPath)
210+
recordedResponse.Body = r.redactor.Bytes(recordedResponse.Body)
211+
serializedResp := recordedResponse.Serialize()
212+
_, err = file.WriteString(fmt.Sprintf("\n%s.resp %d\n", shaSum, len(serializedResp)))
213+
if err != nil {
214+
return err
215+
}
216+
_, err = file.WriteString(serializedResp)
217+
if err != nil {
218+
return err
219+
}
196220
return nil
197221
}
198222

@@ -230,7 +254,7 @@ func (r *RecordingHTTPSProxy) proxyWebsocket(w http.ResponseWriter, req *http.Re
230254
go pumpWebsocket(clientConn, conn, c, quit, ">")
231255
go pumpWebsocket(conn, clientConn, c, quit, "<")
232256

233-
recordPath := filepath.Join(r.recordingDir, fileName+".websocket")
257+
recordPath := filepath.Join(r.recordingDir, fileName+".websocket.log")
234258
f, err := os.Create(recordPath)
235259
if err != nil {
236260
fmt.Printf("Error creating websocket recording file: %v\n", err)

internal/replay/replay_http_server.go

Lines changed: 67 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ import (
2424
"strconv"
2525
"strings"
2626
"unicode"
27+
"bufio"
28+
"io"
2729

2830
"github.com/google/test-server/internal/config"
2931
"github.com/google/test-server/internal/redact"
@@ -33,6 +35,7 @@ import (
3335

3436
type ReplayHTTPServer struct {
3537
prevRequestSHA string
38+
seenFiles map[string]struct{}
3639
config *config.EndpointConfig
3740
recordingDir string
3841
redactor *redact.Redact
@@ -41,6 +44,7 @@ type ReplayHTTPServer struct {
4144
func NewReplayHTTPServer(cfg *config.EndpointConfig, recordingDir string, redactor *redact.Redact) *ReplayHTTPServer {
4245
return &ReplayHTTPServer{
4346
prevRequestSHA: store.HeadSHA,
47+
seenFiles: make(map[string]struct{}),
4448
config: cfg,
4549
recordingDir: recordingDir,
4650
redactor: redactor,
@@ -78,6 +82,10 @@ func (r *ReplayHTTPServer) handleRequest(w http.ResponseWriter, req *http.Reques
7882
http.Error(w, fmt.Sprintf("Invalid recording file name: %v", err), http.StatusInternalServerError)
7983
return
8084
}
85+
if _, ok := r.seenFiles[fileName]; !ok {
86+
// Reset to HeadSHA when first time seen request from the given file.
87+
redactedReq.PreviousRequest=store.HeadSHA
88+
}
8189
if req.Header.Get("Upgrade") == "websocket" {
8290
fmt.Printf("Upgrading connection to websocket...\n")
8391

@@ -92,7 +100,8 @@ func (r *ReplayHTTPServer) handleRequest(w http.ResponseWriter, req *http.Reques
92100
return
93101
}
94102
fmt.Printf("Replaying http request: %s\n", redactedReq.Request)
95-
resp, err := r.loadResponse(fileName)
103+
shaSum := redactedReq.ComputeSum()
104+
resp, err := r.loadResponse(fileName, shaSum)
96105
if err != nil {
97106
fmt.Printf("Error loading response: %v\n", err)
98107
http.Error(w, fmt.Sprintf("Error loading response: %v", err), http.StatusInternalServerError)
@@ -104,6 +113,10 @@ func (r *ReplayHTTPServer) handleRequest(w http.ResponseWriter, req *http.Reques
104113
fmt.Printf("Error writing response: %v\n", err)
105114
panic(err)
106115
}
116+
if (fileName != shaSum) {
117+
r.prevRequestSHA = shaSum
118+
}
119+
r.seenFiles[fileName] = struct{}{}
107120
}
108121

109122
func (r *ReplayHTTPServer) createRedactedRequest(req *http.Request) (*store.RecordedRequest, error) {
@@ -122,14 +135,59 @@ func (r *ReplayHTTPServer) createRedactedRequest(req *http.Request) (*store.Reco
122135
return recordedRequest, nil
123136
}
124137

125-
func (r *ReplayHTTPServer) loadResponse(fileName string) (*store.RecordedResponse, error) {
126-
responseFile := filepath.Join(r.recordingDir, fileName+".resp")
127-
fmt.Printf("loading response from : %s\n", responseFile)
128-
responseData, err := os.ReadFile(responseFile)
138+
func (r *ReplayHTTPServer) loadResponse(fileName string, shaSum string) (*store.RecordedResponse, error) {
139+
// 1. Open the replay log file for reading.
140+
filePath := filepath.Join(r.recordingDir, fileName+".http.log")
141+
fmt.Printf("loading response from : %s with shaSum: %s\n", filePath, shaSum)
142+
file, err := os.Open(filePath)
129143
if err != nil {
130-
return nil, err
144+
return nil, fmt.Errorf("could not open file %s: %w", filePath, err)
145+
}
146+
defer file.Close()
147+
148+
reader := bufio.NewReader(file)
149+
expectedKey := shaSum + ".resp"
150+
// 2. Scan the file line by line using the reader directly.
151+
for {
152+
// Read one line, including the newline character.
153+
line, err := reader.ReadString('\n')
154+
if err != nil {
155+
if err == io.EOF {
156+
return nil, fmt.Errorf("response with shaSum %s not found in file", shaSum)
157+
}
158+
return nil, fmt.Errorf("error while reading file: %w", err)
159+
}
160+
trimmedLine := strings.TrimSpace(line)
161+
parts := strings.Fields(trimmedLine)
162+
if len(parts) != 2 {
163+
continue
164+
}
165+
166+
fileKey := parts[0]
167+
sizeStr := parts[1]
168+
169+
size, err := strconv.Atoi(sizeStr)
170+
if err != nil {
171+
return nil, fmt.Errorf("invalid size format on delimiter line: '%s'", trimmedLine)
172+
}
173+
fmt.Printf("Bytes to load: %d\n", size)
174+
if size < 0 {
175+
return nil, fmt.Errorf("invalid negative size on delimiter line: '%s'", trimmedLine)
176+
}
177+
178+
// 3. Read the exact number of bytes for the payload.
179+
data := make([]byte, size)
180+
if _, err := io.ReadFull(reader, data); err != nil {
181+
return nil, fmt.Errorf("failed to read %d bytes after delimiter: %w", size, err)
182+
}
183+
184+
// 4. Return the response when it matches our target shaSum.
185+
if fileKey == expectedKey {
186+
return store.DeserializeResponse(data)
187+
} else {
188+
continue
189+
}
131190
}
132-
return store.DeserializeResponse(responseData)
133191
}
134192

135193
func (r *ReplayHTTPServer) writeResponse(w http.ResponseWriter, resp *store.RecordedResponse) error {
@@ -175,8 +233,8 @@ func (r *ReplayHTTPServer) proxyWebsocket(w http.ResponseWriter, req *http.Reque
175233
replayWebsocket(clientConn, chunks)
176234
}
177235

178-
func (r *ReplayHTTPServer) loadWebsocketChunks(sha string) ([]string, error) {
179-
responseFile := filepath.Join(r.recordingDir, sha+".websocket")
236+
func (r *ReplayHTTPServer) loadWebsocketChunks(fileName string) ([]string, error) {
237+
responseFile := filepath.Join(r.recordingDir, fileName+".websocket.log")
180238
fmt.Printf("loading websocket response from : %s\n", responseFile)
181239
bytes, err := os.ReadFile(responseFile)
182240
var chunks = make([]string, 0)

internal/store/store.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ import (
3030
"github.com/google/test-server/internal/config"
3131
)
3232

33-
// A sha of an invalid RecordRequest to be used as the head of all chains.
3433
const HeadSHA = "b4d6e60a9b97e7b98c63df9308728c5c88c0b40c398046772c63447b94608b4d"
3534

3635
type RecordedRequest struct {
@@ -102,7 +101,8 @@ func (r *RecordedRequest) GetRecordingFileName() (string, error) {
102101
return "", fmt.Errorf("test name: %s contains illegal sequence '../'", testName)
103102
}
104103
if testName != "" {
105-
return testName, nil
104+
fileName := strings.ReplaceAll(testName, " ", "_")
105+
return fileName, nil
106106
}
107107
return r.ComputeSum(), nil
108108
}

internal/store/store_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ func TestRecordedRequest_GetRecordFileName(t *testing.T) {
349349
Port: 0,
350350
Protocol: "",
351351
},
352-
expected: "random test name",
352+
expected: "random_test_name",
353353
expectedErr: false,
354354
},
355355
{

0 commit comments

Comments
 (0)