Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions packages/shared/pkg/storage/compress_upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,25 @@ func (m *memPartUploader) Assemble() []byte {
return buf.Bytes()
}

// inputBufPool is shared across all uploads so frame-sized buffers (almost
// always DefaultCompressFrameSize) are reused between streams instead of being
// reallocated per call. The size guard keeps it correct for any frame size.
var inputBufPool sync.Pool

func getInputBuf(size int) *[]byte {
if v := inputBufPool.Get(); v != nil {
bufPtr := v.(*[]byte)
if cap(*bufPtr) >= size {
*bufPtr = (*bufPtr)[:size]

return bufPtr
}
}
buf := make([]byte, size)

return &buf
}

type frame struct {
uncompressedSize int
compressed []byte
Expand All @@ -83,11 +102,13 @@ func newPart(index int, parentCtx context.Context, workers int) (*part, context.
return p, ctx
}

func (p *part) addFrame(ctx context.Context, uncompressedData []byte, pool *sync.Pool) {
frameInPart := &frame{uncompressedSize: len(uncompressedData)}
func (p *part) addFrame(ctx context.Context, bufPtr *[]byte, n int, pool *sync.Pool) {
frameInPart := &frame{uncompressedSize: n}
p.frames = append(p.frames, frameInPart)
uncompressedData := (*bufPtr)[:n]

p.compress.Go(func() error {
defer inputBufPool.Put(bufPtr)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets not access the inputBufPool from here at all - keep the responsibility to the place where gets are happening. Keeping get - put together

If we need to do async callback, lets do that instead, or move the goroutine schedule to the readLoop function (if possible)

This comment was marked as outdated.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if wrapping in callback make sense. This is fundamentally just a free for the memory, if we don't call it and because we need it to be async putting the callback there does not feel like it clears up the lifecycle. It is the same method, but not it is also closured.

Copy link
Copy Markdown
Contributor

@dobrac dobrac May 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal is to keep the Get/Put combination next to each other - in one place. That way it's possible to validate the buffer free when it's stopped being used.

Now its across two functions which makes it difficult to track

if err := ctx.Err(); err != nil {
return err
}
Expand Down Expand Up @@ -193,17 +214,22 @@ func readLoop(ctx context.Context, in io.Reader, cfg CompressConfig, hasher io.W
return err
}

buf := make([]byte, frameSize)
bufPtr := getInputBuf(frameSize)
buf := *bufPtr
n, err := io.ReadFull(in, buf)

eof := errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)
if err != nil && !eof {
inputBufPool.Put(bufPtr)

return fmt.Errorf("read frame: %w", err)
}

if n > 0 {
hasher.Write(buf[:n])
p.addFrame(compressCtx, buf[:n], compressors)
p.addFrame(compressCtx, bufPtr, n, compressors)
} else {
inputBufPool.Put(bufPtr)
}

if eof {
Expand Down
64 changes: 46 additions & 18 deletions packages/shared/pkg/storage/gcp_multipart.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,6 @@ func (m *MultipartUploader) initiateUpload(ctx context.Context) (string, error)
}

func (m *MultipartUploader) uploadPart(ctx context.Context, uploadID string, partNumber int, data []byte) (string, error) {
// Calculate MD5 for data integrity
hasher := md5.New()
hasher.Write(data)
md5Sum := base64.StdEncoding.EncodeToString(hasher.Sum(nil))

url := fmt.Sprintf("%s/%s?partNumber=%d&uploadId=%s",
m.baseURL, m.objectName, partNumber, uploadID)

Expand All @@ -269,7 +264,8 @@ func (m *MultipartUploader) uploadPart(ctx context.Context, uploadID string, par

req.Header.Set("Authorization", "Bearer "+m.token)
req.Header.Set("Content-Length", fmt.Sprintf("%d", len(data)))
req.Header.Set("Content-MD5", md5Sum)
sum := md5.Sum(data) //nolint:gosec // GCS multipart uses Content-MD5 for transport integrity.
req.Header.Set("Content-MD5", base64.StdEncoding.EncodeToString(sum[:]))

resp, err := m.client.Do(req)
if err != nil {
Expand All @@ -291,29 +287,57 @@ func (m *MultipartUploader) uploadPart(ctx context.Context, uploadID string, par
return etag, nil
}

type multiSliceReader struct {
slices [][]byte
idx int
off int
}

func (r *multiSliceReader) Read(p []byte) (int, error) {
if len(p) == 0 {
if r.idx >= len(r.slices) {
return 0, io.EOF
}

return 0, nil
}

var n int
Comment thread
ValentaTomas marked this conversation as resolved.
for len(p) > 0 && r.idx < len(r.slices) {
current := r.slices[r.idx]
if r.off >= len(current) {
r.idx++
r.off = 0

continue
}

copied := copy(p, current[r.off:])
n += copied
r.off += copied
p = p[copied:]
}

if n > 0 {
return n, nil
}

return 0, io.EOF
}

// uploadPartSlices uploads a part from multiple byte slices without concatenating them.
// It computes MD5 by hashing each slice and uses a ReaderFunc for retryable reads.
func (m *MultipartUploader) uploadPartSlices(ctx context.Context, uploadID string, partNumber int, slices [][]byte) (string, error) {
// Compute MD5 and total length without copying
hasher := md5.New()
totalLen := 0
for _, s := range slices {
hasher.Write(s)
totalLen += len(s)
}
md5Sum := base64.StdEncoding.EncodeToString(hasher.Sum(nil))

url := fmt.Sprintf("%s/%s?partNumber=%d&uploadId=%s",
m.baseURL, m.objectName, partNumber, uploadID)

// Use a ReaderFunc so the retryable client can replay the body on retries
bodyFn := func() (io.Reader, error) {
readers := make([]io.Reader, len(slices))
for i, s := range slices {
readers[i] = bytes.NewReader(s)
}

return io.MultiReader(readers...), nil
return &multiSliceReader{slices: slices}, nil
}

req, err := retryablehttp.NewRequestWithContext(ctx, "PUT", url, retryablehttp.ReaderFunc(bodyFn))
Expand All @@ -323,7 +347,11 @@ func (m *MultipartUploader) uploadPartSlices(ctx context.Context, uploadID strin

req.Header.Set("Authorization", "Bearer "+m.token)
req.Header.Set("Content-Length", fmt.Sprintf("%d", totalLen))
Comment thread
ValentaTomas marked this conversation as resolved.
req.Header.Set("Content-MD5", md5Sum)
h := md5.New() //nolint:gosec // GCS multipart uses Content-MD5 for transport integrity.
for _, s := range slices {
_, _ = h.Write(s)
}
req.Header.Set("Content-MD5", base64.StdEncoding.EncodeToString(h.Sum(nil)))

resp, err := m.client.Do(req)
if err != nil {
Expand Down
18 changes: 6 additions & 12 deletions packages/shared/pkg/storage/gcp_multipart_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ func TestMultipartUploader_UploadPart_Success(t *testing.T) {
body, err := io.ReadAll(r.Body)
assert.NoError(t, err)
assert.Equal(t, testData, body)
sum := md5.Sum(testData) //nolint:gosec // verifying GCS Content-MD5 header.
assert.Equal(t, base64.StdEncoding.EncodeToString(sum[:]), r.Header.Get("Content-MD5"))

w.Header().Set("ETag", expectedETag)
w.WriteHeader(http.StatusOK)
Expand All @@ -121,25 +123,17 @@ func TestMultipartUploader_UploadPartSlices_Success(t *testing.T) {
expectedETag := `"slice-etag"`
slices := [][]byte{[]byte("hello "), []byte("world"), []byte("!")}

// Compute expected MD5 over all slices.
h := md5.New()
for _, s := range slices {
h.Write(s)
}
expectedMD5 := base64.StdEncoding.EncodeToString(h.Sum(nil))

handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
assert.Contains(t, r.URL.RawQuery, "partNumber=3")
assert.Contains(t, r.URL.RawQuery, "uploadId=test-upload-id")

// Verify MD5 matches the expected hash of all slices.
assert.Equal(t, expectedMD5, r.Header.Get("Content-MD5"))

// Verify body is the concatenation of all slices.
body, err := io.ReadAll(r.Body)
assert.NoError(t, err)
assert.Equal(t, []byte("hello world!"), body)
expected := []byte("hello world!")
assert.Equal(t, expected, body)
sum := md5.Sum(expected) //nolint:gosec // verifying GCS Content-MD5 header.
assert.Equal(t, base64.StdEncoding.EncodeToString(sum[:]), r.Header.Get("Content-MD5"))

w.Header().Set("ETag", expectedETag)
w.WriteHeader(http.StatusOK)
Expand Down
Loading