diff --git a/packages/shared/pkg/storage/compress_upload.go b/packages/shared/pkg/storage/compress_upload.go index c4022c8fcb..ab779df40e 100644 --- a/packages/shared/pkg/storage/compress_upload.go +++ b/packages/shared/pkg/storage/compress_upload.go @@ -62,6 +62,25 @@ func (m *memPartUploader) Assemble() []byte { return buf.Bytes() } +// inputBufPool is shared across all uploads so frame-sized buffers (almost +// always DefaultCompressFrameSize) are reused between streams instead of being +// reallocated per call. The size guard keeps it correct for any frame size. +var inputBufPool sync.Pool + +func getInputBuf(size int) *[]byte { + if v := inputBufPool.Get(); v != nil { + bufPtr := v.(*[]byte) + if cap(*bufPtr) >= size { + *bufPtr = (*bufPtr)[:size] + + return bufPtr + } + } + buf := make([]byte, size) + + return &buf +} + type frame struct { uncompressedSize int compressed []byte @@ -83,11 +102,13 @@ func newPart(index int, parentCtx context.Context, workers int) (*part, context. return p, ctx } -func (p *part) addFrame(ctx context.Context, uncompressedData []byte, pool *sync.Pool) { - frameInPart := &frame{uncompressedSize: len(uncompressedData)} +func (p *part) addFrame(ctx context.Context, bufPtr *[]byte, n int, pool *sync.Pool) { + frameInPart := &frame{uncompressedSize: n} p.frames = append(p.frames, frameInPart) + uncompressedData := (*bufPtr)[:n] p.compress.Go(func() error { + defer inputBufPool.Put(bufPtr) if err := ctx.Err(); err != nil { return err } @@ -193,17 +214,22 @@ func readLoop(ctx context.Context, in io.Reader, cfg CompressConfig, hasher io.W return err } - buf := make([]byte, frameSize) + bufPtr := getInputBuf(frameSize) + buf := *bufPtr n, err := io.ReadFull(in, buf) eof := errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) if err != nil && !eof { + inputBufPool.Put(bufPtr) + return fmt.Errorf("read frame: %w", err) } if n > 0 { hasher.Write(buf[:n]) - p.addFrame(compressCtx, buf[:n], compressors) + p.addFrame(compressCtx, bufPtr, n, compressors) + } else { + inputBufPool.Put(bufPtr) } if eof { diff --git a/packages/shared/pkg/storage/gcp_multipart.go b/packages/shared/pkg/storage/gcp_multipart.go index 18d407e22d..d9945be96a 100644 --- a/packages/shared/pkg/storage/gcp_multipart.go +++ b/packages/shared/pkg/storage/gcp_multipart.go @@ -254,11 +254,6 @@ func (m *MultipartUploader) initiateUpload(ctx context.Context) (string, error) } func (m *MultipartUploader) uploadPart(ctx context.Context, uploadID string, partNumber int, data []byte) (string, error) { - // Calculate MD5 for data integrity - hasher := md5.New() - hasher.Write(data) - md5Sum := base64.StdEncoding.EncodeToString(hasher.Sum(nil)) - url := fmt.Sprintf("%s/%s?partNumber=%d&uploadId=%s", m.baseURL, m.objectName, partNumber, uploadID) @@ -269,7 +264,8 @@ func (m *MultipartUploader) uploadPart(ctx context.Context, uploadID string, par req.Header.Set("Authorization", "Bearer "+m.token) req.Header.Set("Content-Length", fmt.Sprintf("%d", len(data))) - req.Header.Set("Content-MD5", md5Sum) + sum := md5.Sum(data) //nolint:gosec // GCS multipart uses Content-MD5 for transport integrity. + req.Header.Set("Content-MD5", base64.StdEncoding.EncodeToString(sum[:])) resp, err := m.client.Do(req) if err != nil { @@ -291,29 +287,57 @@ func (m *MultipartUploader) uploadPart(ctx context.Context, uploadID string, par return etag, nil } +type multiSliceReader struct { + slices [][]byte + idx int + off int +} + +func (r *multiSliceReader) Read(p []byte) (int, error) { + if len(p) == 0 { + if r.idx >= len(r.slices) { + return 0, io.EOF + } + + return 0, nil + } + + var n int + for len(p) > 0 && r.idx < len(r.slices) { + current := r.slices[r.idx] + if r.off >= len(current) { + r.idx++ + r.off = 0 + + continue + } + + copied := copy(p, current[r.off:]) + n += copied + r.off += copied + p = p[copied:] + } + + if n > 0 { + return n, nil + } + + return 0, io.EOF +} + // uploadPartSlices uploads a part from multiple byte slices without concatenating them. -// It computes MD5 by hashing each slice and uses a ReaderFunc for retryable reads. func (m *MultipartUploader) uploadPartSlices(ctx context.Context, uploadID string, partNumber int, slices [][]byte) (string, error) { - // Compute MD5 and total length without copying - hasher := md5.New() totalLen := 0 for _, s := range slices { - hasher.Write(s) totalLen += len(s) } - md5Sum := base64.StdEncoding.EncodeToString(hasher.Sum(nil)) url := fmt.Sprintf("%s/%s?partNumber=%d&uploadId=%s", m.baseURL, m.objectName, partNumber, uploadID) // Use a ReaderFunc so the retryable client can replay the body on retries bodyFn := func() (io.Reader, error) { - readers := make([]io.Reader, len(slices)) - for i, s := range slices { - readers[i] = bytes.NewReader(s) - } - - return io.MultiReader(readers...), nil + return &multiSliceReader{slices: slices}, nil } req, err := retryablehttp.NewRequestWithContext(ctx, "PUT", url, retryablehttp.ReaderFunc(bodyFn)) @@ -323,7 +347,11 @@ func (m *MultipartUploader) uploadPartSlices(ctx context.Context, uploadID strin req.Header.Set("Authorization", "Bearer "+m.token) req.Header.Set("Content-Length", fmt.Sprintf("%d", totalLen)) - req.Header.Set("Content-MD5", md5Sum) + h := md5.New() //nolint:gosec // GCS multipart uses Content-MD5 for transport integrity. + for _, s := range slices { + _, _ = h.Write(s) + } + req.Header.Set("Content-MD5", base64.StdEncoding.EncodeToString(h.Sum(nil))) resp, err := m.client.Do(req) if err != nil { diff --git a/packages/shared/pkg/storage/gcp_multipart_test.go b/packages/shared/pkg/storage/gcp_multipart_test.go index c30bb8c117..6479b45a0b 100644 --- a/packages/shared/pkg/storage/gcp_multipart_test.go +++ b/packages/shared/pkg/storage/gcp_multipart_test.go @@ -104,6 +104,8 @@ func TestMultipartUploader_UploadPart_Success(t *testing.T) { body, err := io.ReadAll(r.Body) assert.NoError(t, err) assert.Equal(t, testData, body) + sum := md5.Sum(testData) //nolint:gosec // verifying GCS Content-MD5 header. + assert.Equal(t, base64.StdEncoding.EncodeToString(sum[:]), r.Header.Get("Content-MD5")) w.Header().Set("ETag", expectedETag) w.WriteHeader(http.StatusOK) @@ -121,25 +123,17 @@ func TestMultipartUploader_UploadPartSlices_Success(t *testing.T) { expectedETag := `"slice-etag"` slices := [][]byte{[]byte("hello "), []byte("world"), []byte("!")} - // Compute expected MD5 over all slices. - h := md5.New() - for _, s := range slices { - h.Write(s) - } - expectedMD5 := base64.StdEncoding.EncodeToString(h.Sum(nil)) - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "PUT", r.Method) assert.Contains(t, r.URL.RawQuery, "partNumber=3") assert.Contains(t, r.URL.RawQuery, "uploadId=test-upload-id") - - // Verify MD5 matches the expected hash of all slices. - assert.Equal(t, expectedMD5, r.Header.Get("Content-MD5")) - // Verify body is the concatenation of all slices. body, err := io.ReadAll(r.Body) assert.NoError(t, err) - assert.Equal(t, []byte("hello world!"), body) + expected := []byte("hello world!") + assert.Equal(t, expected, body) + sum := md5.Sum(expected) //nolint:gosec // verifying GCS Content-MD5 header. + assert.Equal(t, base64.StdEncoding.EncodeToString(sum[:]), r.Header.Get("Content-MD5")) w.Header().Set("ETag", expectedETag) w.WriteHeader(http.StatusOK)