From f0a72d4be3ea7b5fa1a7da9bee0b26f1b81ec288 Mon Sep 17 00:00:00 2001 From: harmlessii <919320834@qq.com> Date: Sat, 30 May 2026 23:54:17 +0800 Subject: [PATCH] feat(orchestrator): add ublk block device provider as alternative to NBD --- packages/orchestrator/go.mod | 1 + packages/orchestrator/go.sum | 2 + packages/orchestrator/pkg/factories/run.go | 10 +- .../orchestrator/pkg/sandbox/rootfs/ublk.go | 180 ++++++++ .../pkg/sandbox/rootfs/ublk_backend.go | 173 ++++++++ .../pkg/sandbox/rootfs/ublk_test.go | 414 ++++++++++++++++++ packages/orchestrator/pkg/sandbox/sandbox.go | 52 ++- .../orchestrator/pkg/sandbox/ublk/pool.go | 151 +++++++ 8 files changed, 968 insertions(+), 15 deletions(-) create mode 100644 packages/orchestrator/pkg/sandbox/rootfs/ublk.go create mode 100644 packages/orchestrator/pkg/sandbox/rootfs/ublk_backend.go create mode 100644 packages/orchestrator/pkg/sandbox/rootfs/ublk_test.go create mode 100644 packages/orchestrator/pkg/sandbox/ublk/pool.go diff --git a/packages/orchestrator/go.mod b/packages/orchestrator/go.mod index cd5f85fa96..c184832c5b 100644 --- a/packages/orchestrator/go.mod +++ b/packages/orchestrator/go.mod @@ -30,6 +30,7 @@ require ( github.com/dustin/go-humanize v1.0.1 github.com/e2b-dev/infra/packages/clickhouse v0.0.0 github.com/e2b-dev/infra/packages/shared v0.0.0 + github.com/e2b-dev/ublk-go v0.1.3 github.com/edsrzf/mmap-go v1.2.1-0.20241212181136-fad1cd13edbd github.com/firecracker-microvm/firecracker-go-sdk v1.0.0 github.com/getkin/kin-openapi v0.137.0 diff --git a/packages/orchestrator/go.sum b/packages/orchestrator/go.sum index 6528e26dc2..58b4a29a71 100644 --- a/packages/orchestrator/go.sum +++ b/packages/orchestrator/go.sum @@ -415,6 +415,8 @@ github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+m github.com/dvyukov/go-fuzz v0.0.0-20210914135545-4980593459a1/go.mod h1:11Gm+ccJnvAhCNLlf5+cS9KjtbaD5I5zaZpFMsTHWTw= github.com/e2b-dev/go-nfs v0.0.0-20260318224420-f59b77ca8555 h1:cUImAB4xt5kewJV1cd0gJMspWgmuMgbpj3Q8qIOUWrY= github.com/e2b-dev/go-nfs v0.0.0-20260318224420-f59b77ca8555/go.mod h1:VhNccO67Oug787VNXcyx9JDI3ZoSpqoKMT/lWMhUIDg= +github.com/e2b-dev/ublk-go v0.1.3 h1:N2UT3KPI46UfeGhdo+7o4rSxVd/fIsSAtIiAUvKjSyY= +github.com/e2b-dev/ublk-go v0.1.3/go.mod h1:OKUhrFueUGZ2ygXv2EWX3RbJLQxf9VHOnRDxnzDnSF0= github.com/ebitengine/purego v0.10.0 h1:QIw4xfpWT6GWTzaW5XEKy3HXoqrJGx1ijYHzTF0/ISU= github.com/ebitengine/purego v0.10.0/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= github.com/edsrzf/mmap-go v1.2.1-0.20241212181136-fad1cd13edbd h1:I4PrRZuNMeDP3VbFrak4QsqwO5tWkQf0tqrrr1L2DsU= diff --git a/packages/orchestrator/pkg/factories/run.go b/packages/orchestrator/pkg/factories/run.go index 0a963000df..65b522d559 100644 --- a/packages/orchestrator/pkg/factories/run.go +++ b/packages/orchestrator/pkg/factories/run.go @@ -45,6 +45,7 @@ import ( blockmetrics "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block/metrics" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/cgroup" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/nbd" + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/ublk" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/network" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/template" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/template/peerclient" @@ -530,6 +531,13 @@ func run(config cfg.Config, opts Options) (success bool) { }) closers = append(closers, closer{"device pool", devicePool.Close}) + // ublk device pool + ublkPool, err := ublk.NewDevicePool(0) + if err != nil { + logger.L().Fatal(ctx, "failed to create ublk pool", zap.Error(err)) + } + closers = append(closers, closer{"ublk pool", ublkPool.Shutdown}) + // network pool slotStorage, err := newStorage(ctx, nodeID, config.NetworkConfig, egressSetup.Proxy) if err != nil { @@ -544,7 +552,7 @@ func run(config cfg.Config, opts Options) (success bool) { closers = append(closers, closer{"network pool", networkPool.Close}) // sandbox factory - sandboxFactory := sandbox.NewFactory(config.BuilderConfig, networkPool, devicePool, featureFlags, hostStatsDelivery, cgroupManager, egressSetup.Proxy, sandboxes) + sandboxFactory := sandbox.NewFactory(config.BuilderConfig, networkPool, devicePool, ublkPool, featureFlags, hostStatsDelivery, cgroupManager, egressSetup.Proxy, sandboxes) // isolated filesystems cache (for nfs proxy) builder := chrooted.NewBuilder(config) diff --git a/packages/orchestrator/pkg/sandbox/rootfs/ublk.go b/packages/orchestrator/pkg/sandbox/rootfs/ublk.go new file mode 100644 index 0000000000..96be04a5ab --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/rootfs/ublk.go @@ -0,0 +1,180 @@ +package rootfs + +import ( + "context" + "errors" + "fmt" + "go.uber.org/zap" + "os" + + "github.com/e2b-dev/ublk-go/ublk" + "golang.org/x/sys/unix" + + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block" + ublkpool "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/ublk" + "github.com/e2b-dev/infra/packages/shared/pkg/logger" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" + "github.com/e2b-dev/infra/packages/shared/pkg/telemetry" + "github.com/e2b-dev/infra/packages/shared/pkg/utils" +) + +type UblkProvider struct { + ctx context.Context + cancel context.CancelFunc + overlay *block.Overlay + backend *ublkBackend + dev *ublk.Device + pool *ublkpool.DevicePool + + ready *utils.SetOnce[string] + finishedOperations chan struct{} + blockSize int64 +} + +func NewUblkProvider( + ctx context.Context, + rootfs block.ReadonlyDevice, + cachePath string, + pool *ublkpool.DevicePool, +) (Provider, error) { + size, err := rootfs.Size(ctx) + if err != nil { + return nil, fmt.Errorf("error getting device size: %w", err) + } + + blockSize := rootfs.BlockSize() + + cache, err := block.NewCache(size, blockSize, cachePath, false) + if err != nil { + return nil, fmt.Errorf("error creating cache: %w", err) + } + + overlay := block.NewOverlay(rootfs, cache) + + // Use a background context so the ublk backend outlives the CreateSandbox + // request context. Only cancelled explicitly in Close(). + runCtx, cancel := context.WithCancel(context.Background()) + return &UblkProvider{ + ctx: runCtx, + cancel: cancel, + overlay: overlay, + backend: newUblkBackend(runCtx, overlay), + pool: pool, + ready: utils.NewSetOnce[string](), + finishedOperations: make(chan struct{}, 1), + blockSize: blockSize, + }, nil +} + +func (u *UblkProvider) Start(ctx context.Context) error { + size, err := u.overlay.Size(ctx) + if err != nil { + return u.ready.SetError(err) + } + + telemetry.ReportEvent(ctx, "creating ublk device") + + dev, err := u.pool.New(ctx, u.backend, uint64(size)) + if err != nil { + return u.ready.SetError(fmt.Errorf("ublk.New: %w", err)) + } + u.dev = dev + telemetry.ReportEvent(ctx, "ublk device created") + return u.ready.SetValue(dev.Path()) +} + +func (u *UblkProvider) Path() (string, error) { return u.ready.Wait() } + +func (u *UblkProvider) Close(ctx context.Context) error { + ctx, span := tracer.Start(ctx, "ublk-close") + defer span.End() + + var errs []error + + err := u.sync(ctx) + if err != nil { + errs = append(errs, fmt.Errorf("ublk flush: %w", err)) + } + + if u.dev != nil { + err = u.pool.Close(ctx, u.dev) + if err != nil { + errs = append(errs, fmt.Errorf("ublk close: %w", err)) + } + } + u.cancel() + + u.finishedOperations <- struct{}{} + + err = u.overlay.Close() + if err != nil { + errs = append(errs, fmt.Errorf("overlay close: %w", err)) + } + logger.L().Info(ctx, "ublk overlay device released") + return errors.Join(errs...) +} + +func (u *UblkProvider) ExportDiff( + ctx context.Context, out *os.File, + closeSandbox func(context.Context) error, +) (*header.DiffMetadata, error) { + ctx, span := tracer.Start(ctx, "ublk-export") + defer span.End() + + cache, err := u.overlay.EjectCache() + if err != nil { + return nil, fmt.Errorf("eject cache: %w", err) + } + + go func() { + err := closeSandbox(ctx) + if err != nil { + logger.L().Error(ctx, "stop sandbox on cow export", zap.Error(err)) + } + }() + + select { + case <-u.finishedOperations: + case <-ctx.Done(): + return nil, fmt.Errorf("timeout waiting for ublk device released") + } + telemetry.ReportEvent(ctx, "sandbox stopped") + + m, err := cache.ExportToDiff(ctx, out) + if err != nil { + return nil, fmt.Errorf("error exporting cache: %w", err) + } + telemetry.ReportEvent(ctx, "cache exported") + + err = cache.Close() + if err != nil { + return nil, fmt.Errorf("error closing cache: %w", err) + } + return m, nil +} + +func (u *UblkProvider) sync(ctx context.Context) error { + ctx, span := tracer.Start(ctx, "ublk-sync") + defer span.End() + + path, err := u.Path() + if err != nil { + return fmt.Errorf("failed to get cow path: %w", err) + } + + file, err := os.Open(path) + if err != nil { + return fmt.Errorf("open %s: %w", path, err) + } + defer func() { + err := file.Close() + if err != nil { + logger.L().Error(ctx, "failed to close ublk file", zap.Error(err)) + } + }() + + if err := unix.IoctlSetInt(int(file.Fd()), unix.BLKFLSBUF, 0); err != nil { + return fmt.Errorf("ioctl BLKFLSBUF: %w", err) + } + return flush(ctx, path) +} diff --git a/packages/orchestrator/pkg/sandbox/rootfs/ublk_backend.go b/packages/orchestrator/pkg/sandbox/rootfs/ublk_backend.go new file mode 100644 index 0000000000..aee654c9dc --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/rootfs/ublk_backend.go @@ -0,0 +1,173 @@ +package rootfs + +import ( + "context" + "fmt" + "sort" + "sync" + + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block" +) + +const ublkBackendLockStripes = 256 + +// ublkBackend Adapt block.Device to ublk-go's io.ReaderAt/WriterAt. +type ublkBackend struct { + ctx context.Context + dev block.Device + blockSize int64 + bufPool sync.Pool + locks []sync.Mutex +} + +func newUblkBackend(ctx context.Context, dev block.Device) *ublkBackend { + bs := dev.BlockSize() + return &ublkBackend{ + ctx: ctx, + dev: dev, + blockSize: bs, + locks: make([]sync.Mutex, ublkBackendLockStripes), + bufPool: sync.Pool{ + New: func() any { + b := make([]byte, bs) + return &b + }, + }, + } +} + +func (b *ublkBackend) ReadAt(p []byte, off int64) (int, error) { + if len(p) == 0 { + return 0, nil + } + + if b.isAligned(off, len(p)) { + tmp := make([]byte, len(p)) + n, err := b.dev.ReadAt(b.ctx, tmp, off) + if n > 0 { + copied := n + if copied > len(p) { + copied = len(p) + } + copy(p, tmp[:copied]) + } + + if err != nil { + if n > len(p) { + n = len(p) + } + + return n, err + } + + return len(p), nil + } + + alignedOff, alignedLen := b.alignedRange(off, len(p)) + tmp := make([]byte, alignedLen) + + n, err := b.dev.ReadAt(b.ctx, tmp, alignedOff) + if err != nil { + return n, err + } + + start := int(off - alignedOff) + copy(p, tmp[start:start+len(p)]) + + return len(p), nil +} + +func (b *ublkBackend) WriteAt(p []byte, off int64) (int, error) { + if len(p) == 0 { + return 0, nil + } + + alignedOff, alignedLen := b.alignedRange(off, len(p)) + unlock := b.lockRange(alignedOff, alignedLen) + defer unlock() + + if b.isAligned(off, len(p)) { + return b.dev.WriteAt(p, off) + } + + tmp := make([]byte, alignedLen) + requestEnd := off + int64(len(p)) + + for blockOff := alignedOff; blockOff < alignedOff+int64(alignedLen); blockOff += b.blockSize { + blockStart := blockOff + blockEnd := blockOff + b.blockSize + tmpStart := int(blockStart - alignedOff) + tmpEnd := tmpStart + int(b.blockSize) + blockBuf := tmp[tmpStart:tmpEnd] + + writeStart := max(blockStart, off) + writeEnd := min(blockEnd, requestEnd) + + if writeStart > blockStart || writeEnd < blockEnd { + n, err := b.dev.ReadAt(b.ctx, blockBuf, blockStart) + if err != nil { + return n, err + } + } + + copyStart := int(writeStart - off) + copyEnd := int(writeEnd - off) + blockCopyStart := int(writeStart - blockStart) + copy(blockBuf[blockCopyStart:blockCopyStart+(copyEnd-copyStart)], p[copyStart:copyEnd]) + } + + n, err := b.dev.WriteAt(tmp, alignedOff) + if err != nil { + return n, err + } + + if n != alignedLen { + return n, fmt.Errorf("short aligned write: wrote %d want %d", n, alignedLen) + } + + return len(p), nil +} + +func (b *ublkBackend) isAligned(off int64, length int) bool { + return off%b.blockSize == 0 && int64(length)%b.blockSize == 0 +} + +func (b *ublkBackend) alignedRange(off int64, length int) (int64, int) { + alignedOff := (off / b.blockSize) * b.blockSize + end := off + int64(length) + alignedEnd := ((end + b.blockSize - 1) / b.blockSize) * b.blockSize + + return alignedOff, int(alignedEnd - alignedOff) +} + +func (b *ublkBackend) lockRange(off int64, length int) func() { + if length <= 0 { + return func() {} + } + + startBlock := off / b.blockSize + endBlock := (off + int64(length) - 1) / b.blockSize + stripes := make([]int, 0, endBlock-startBlock+1) + seen := make(map[int]struct{}, endBlock-startBlock+1) + + for blockIdx := startBlock; blockIdx <= endBlock; blockIdx++ { + stripe := int(blockIdx % int64(len(b.locks))) + if _, ok := seen[stripe]; ok { + continue + } + + seen[stripe] = struct{}{} + stripes = append(stripes, stripe) + } + + sort.Ints(stripes) + for _, stripe := range stripes { + b.locks[stripe].Lock() + } + + return func() { + for i := len(stripes) - 1; i >= 0; i-- { + b.locks[stripes[i]].Unlock() + } + } +} diff --git a/packages/orchestrator/pkg/sandbox/rootfs/ublk_test.go b/packages/orchestrator/pkg/sandbox/rootfs/ublk_test.go new file mode 100644 index 0000000000..7f4e775609 --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/rootfs/ublk_test.go @@ -0,0 +1,414 @@ +package rootfs + +import ( + "bytes" + "context" + "crypto/rand" + "fmt" + "os" + "os/exec" + "path/filepath" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "golang.org/x/sys/unix" + + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block" + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/nbd/testutils" + ublkpool "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/ublk" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +func TestUblk_Write(t *testing.T) { + t.Parallel() + + size := int64(5 * 1024 * 1024) + deviceFile := setupUblkDevice(t, size, os.O_RDWR) + + const writeSize = 1024 * 1024 + testData := make([]byte, writeSize) + _, err := rand.Read(testData) + require.NoError(t, err, "failed to generate random data") + + n, err := deviceFile.WriteAt(testData, 0) + require.NoError(t, err, "failed to write data to device") + require.Equal(t, len(testData), n, "partial write") + + readData := make([]byte, writeSize) + n, err = deviceFile.ReadAt(readData, 0) + require.NoError(t, err, "failed to read data from device") + require.Equal(t, len(readData), n, "partial read") + require.Equal(t, testData, readData, "data mismatch") +} + +func TestUblk_WriteAtOffset(t *testing.T) { + t.Parallel() + + size := int64(5 * 1024 * 1024) + deviceFile := setupUblkDevice(t, size, os.O_RDWR) + + const writeSize = 512 * 1024 + const writeOffset = 512 * 1024 + testData := make([]byte, writeSize) + _, err := rand.Read(testData) + require.NoError(t, err, "failed to generate random data") + + n, err := deviceFile.WriteAt(testData, writeOffset) + require.NoError(t, err, "failed to write data to device") + require.Equal(t, len(testData), n, "partial write") + + readData := make([]byte, writeSize) + n, err = deviceFile.ReadAt(readData, writeOffset) + require.NoError(t, err, "failed to read data from device") + require.Equal(t, len(readData), n, "partial read") + require.Equal(t, testData, readData, "data mismatch") +} + +func TestUblk_Direct4MBWrite(t *testing.T) { + t.Parallel() + + size := int64(10 * 1024 * 1024) + deviceFile := setupUblkDevice(t, size, unix.O_DIRECT|unix.O_RDWR) + + const bs = 4 * 1024 * 1024 + buf, err := unix.Mmap(-1, 0, bs, unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED|unix.MAP_ANON) + require.NoError(t, err, "failed to mmap") + t.Cleanup(func() { + unix.Munmap(buf) + }) + + n, err := deviceFile.WriteAt(buf, 0) + require.NoError(t, err, "failed to write to device") + require.Equal(t, len(buf), n, "partial write") + + readData := make([]byte, bs) + n, err = deviceFile.ReadAt(readData, 0) + require.NoError(t, err, "failed to read from device") + require.Equal(t, len(readData), n, "partial read") + require.Equal(t, buf, readData, "data mismatch") +} + +func TestUblk_Direct32MBWrite(t *testing.T) { + t.Parallel() + + size := int64(256 * 1024 * 1024) + deviceFile := setupUblkDevice(t, size, unix.O_DIRECT|unix.O_RDWR) + + const bs = 32 * 1024 * 1024 + buf, err := unix.Mmap(-1, 0, bs, unix.PROT_READ|unix.PROT_WRITE, unix.MAP_SHARED|unix.MAP_ANON) + require.NoError(t, err, "failed to mmap") + t.Cleanup(func() { + unix.Munmap(buf) + }) + + n, err := deviceFile.WriteAt(buf, 0) + require.NoError(t, err, "failed to write to device") + require.Equal(t, len(buf), n, "partial write") + + readData := make([]byte, bs) + n, err = deviceFile.ReadAt(readData, 0) + require.NoError(t, err, "failed to read from device") + require.Equal(t, len(readData), n, "partial read") + require.Equal(t, buf, readData, "data mismatch") +} + +func TestUblk_LargeWrite(t *testing.T) { + t.Parallel() + + size := int64(1200 * 1024 * 1024) + deviceFile := setupUblkDevice(t, size, os.O_RDWR) + + time.Sleep(1 * time.Second) + cmd := exec.CommandContext(t.Context(), "dd", "if=/dev/zero", "of="+deviceFile.Name(), "bs=1G", "count=1") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + err := cmd.Run() + require.NoError(t, err, "failed to execute dd command") +} + +func TestUblk_LargeRead(t *testing.T) { + t.Parallel() + + size := int64(1200 * 1024 * 1024) + deviceFile := setupUblkDevice(t, size, os.O_RDONLY) + + time.Sleep(1 * time.Second) + cmd := exec.CommandContext(t.Context(), "dd", "if="+deviceFile.Name(), "of=/dev/null", "bs=1G", "count=1") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + err := cmd.Run() + require.NoError(t, err, "failed to execute dd command") +} + +func TestUblkBackend_UnalignedWritePreservesBlockData(t *testing.T) { + t.Parallel() + + const blockSize = header.RootfsBlockSize + const size = 2 * blockSize + + backend, overlay := setupUblkBackendForTest(t, size) + + base := bytes.Repeat([]byte{0xAA}, int(blockSize)) + n, err := overlay.WriteAt(base, 0) + require.NoError(t, err) + require.Equal(t, len(base), n) + + patch := bytes.Repeat([]byte{0xBB}, 512) + n, err = backend.WriteAt(patch, 512) + require.NoError(t, err) + require.Equal(t, len(patch), n) + + got := make([]byte, blockSize) + n, err = overlay.ReadAt(context.Background(), got, 0) + require.NoError(t, err) + require.Equal(t, len(got), n) + + require.Equal(t, base[:512], got[:512]) + require.Equal(t, patch, got[512:1024]) + require.Equal(t, base[1024:], got[1024:]) +} + +func TestUblkBackend_UnalignedReadAcrossBlockBoundary(t *testing.T) { + t.Parallel() + + const blockSize = header.RootfsBlockSize + const size = 2 * blockSize + + backend, overlay := setupUblkBackendForTest(t, size) + + first := bytes.Repeat([]byte{0x11}, int(blockSize)) + second := bytes.Repeat([]byte{0x22}, int(blockSize)) + + n, err := overlay.WriteAt(first, 0) + require.NoError(t, err) + require.Equal(t, len(first), n) + + n, err = overlay.WriteAt(second, blockSize) + require.NoError(t, err) + require.Equal(t, len(second), n) + + buf := make([]byte, 1024) + n, err = backend.ReadAt(buf, blockSize-512) + require.NoError(t, err) + require.Equal(t, len(buf), n) + + require.Equal(t, bytes.Repeat([]byte{0x11}, 512), buf[:512]) + require.Equal(t, bytes.Repeat([]byte{0x22}, 512), buf[512:]) +} + +func TestUblkBackend_SerializesRMWOnSameBlock(t *testing.T) { + t.Parallel() + + const blockSize = header.RootfsBlockSize + const size = 2 * blockSize + + backend, overlay := setupUblkBackendForTest(t, size) + guard := newSerializingTestDevice(overlay, blockSize) + backend.dev = guard + + base := bytes.Repeat([]byte{0xAA}, int(blockSize)) + n, err := overlay.WriteAt(base, 0) + require.NoError(t, err) + require.Equal(t, len(base), n) + + firstDone := make(chan error, 1) + go func() { + _, err := backend.WriteAt(bytes.Repeat([]byte{0x11}, 512), 0) + firstDone <- err + }() + + guard.waitFirstReadEntered(t) + + secondDone := make(chan error, 1) + go func() { + _, err := backend.WriteAt(bytes.Repeat([]byte{0x22}, 512), 512) + secondDone <- err + }() + + guard.assertNoSecondReadEntry(t) + guard.releaseFirstRead() + + require.NoError(t, <-firstDone) + require.NoError(t, <-secondDone) + + got := make([]byte, blockSize) + n, err = overlay.ReadAt(context.Background(), got, 0) + require.NoError(t, err) + require.Equal(t, len(got), n) + require.Equal(t, bytes.Repeat([]byte{0x11}, 512), got[:512]) + require.Equal(t, bytes.Repeat([]byte{0x22}, 512), got[512:1024]) + require.Equal(t, base[1024:], got[1024:]) + require.Equal(t, int32(2), guard.readCalls.Load()) + guard.assertReadsNotConcurrent(t) +} + +func setupUblkDevice(t *testing.T, size int64, flags int) *os.File { + t.Helper() + + const blockSize = header.RootfsBlockSize + + require.Equal(t, 0, os.Geteuid(), "ublk requires root privileges to run") + + emptyDevice, err := testutils.NewZeroDevice(size, blockSize) + require.NoError(t, err, "failed to create zero device") + + cowCachePath := filepath.Join(os.TempDir(), fmt.Sprintf("test-rootfs.ext4.cow.cache-%s", uuid.New().String())) + t.Cleanup(func() { + os.RemoveAll(cowCachePath) + }) + + cache, err := block.NewCache( + size, + blockSize, + cowCachePath, + false, + ) + require.NoError(t, err, "failed to create cache") + + overlay := block.NewOverlay(emptyDevice, cache) + t.Cleanup(func() { + overlay.Close() + }) + + ctx := context.Background() + + pool, err := ublkpool.NewDevicePool(0) + require.NoError(t, err, "failed to create ublk pool") + t.Cleanup(func() { + pool.Shutdown(context.Background()) + }) + + backend := newUblkBackend(ctx, overlay) + + dev, err := pool.New(ctx, backend, uint64(size)) + require.NoError(t, err, "failed to create ublk device") + t.Cleanup(func() { + pool.Close(context.Background(), dev) + }) + + devicePath := dev.Path() + t.Logf("ublk device path: %s", devicePath) + + deviceFile, err := os.OpenFile(devicePath, flags, 0) + require.NoError(t, err, "failed to open device") + t.Cleanup(func() { + deviceFile.Close() + }) + + return deviceFile +} + +func setupUblkBackendForTest(t *testing.T, size int64) (*ublkBackend, *block.Overlay) { + t.Helper() + + const blockSize = header.RootfsBlockSize + + zeroDevice, err := testutils.NewZeroDevice(size, blockSize) + require.NoError(t, err) + + cachePath := filepath.Join(t.TempDir(), fmt.Sprintf("test-ublk-backend-%s.cache", uuid.New().String())) + cache, err := block.NewCache(size, blockSize, cachePath, false) + require.NoError(t, err) + + overlay := block.NewOverlay(zeroDevice, cache) + t.Cleanup(func() { + require.NoError(t, overlay.Close()) + }) + + return newUblkBackend(context.Background(), overlay), overlay +} + +type serializingTestDevice struct { + block.Device + blockSize int64 + + firstReadEntered chan struct{} + releaseFirstReadCh chan struct{} + secondReadEntered chan struct{} + + readCalls atomic.Int32 + readsInFlight atomic.Int32 + maxReadsInFlight atomic.Int32 + + once sync.Once +} + +func newSerializingTestDevice(dev block.Device, blockSize int64) *serializingTestDevice { + return &serializingTestDevice{ + Device: dev, + blockSize: blockSize, + firstReadEntered: make(chan struct{}), + releaseFirstReadCh: make(chan struct{}), + secondReadEntered: make(chan struct{}, 1), + } +} + +func (d *serializingTestDevice) ReadAt(ctx context.Context, p []byte, off int64) (int, error) { + if len(p) == 0 || int64(len(p)) != d.blockSize || off%d.blockSize != 0 { + return d.Device.ReadAt(ctx, p, off) + } + + readNum := d.readCalls.Add(1) + inFlight := d.readsInFlight.Add(1) + d.recordMaxInFlight(inFlight) + defer d.readsInFlight.Add(-1) + + if readNum == 1 { + d.once.Do(func() { close(d.firstReadEntered) }) + <-d.releaseFirstReadCh + } else if readNum == 2 { + select { + case d.secondReadEntered <- struct{}{}: + default: + } + } + + return d.Device.ReadAt(ctx, p, off) +} + +func (d *serializingTestDevice) waitFirstReadEntered(t *testing.T) { + t.Helper() + select { + case <-d.firstReadEntered: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for first read to enter") + } +} + +func (d *serializingTestDevice) releaseFirstRead() { + close(d.releaseFirstReadCh) +} + +func (d *serializingTestDevice) assertNoSecondReadEntry(t *testing.T) { + t.Helper() + select { + case <-d.secondReadEntered: + t.Fatal("second RMW entered device.ReadAt for the same block before first completed") + case <-time.After(150 * time.Millisecond): + } +} + +func (d *serializingTestDevice) assertReadsNotConcurrent(t *testing.T) { + t.Helper() + if d.maxReadsInFlight.Load() > 1 { + t.Fatalf("expected block reads to be serialized, max in-flight reads = %d", d.maxReadsInFlight.Load()) + } +} + +func (d *serializingTestDevice) recordMaxInFlight(v int32) { + for { + current := d.maxReadsInFlight.Load() + if v <= current { + return + } + if d.maxReadsInFlight.CompareAndSwap(current, v) { + return + } + } +} diff --git a/packages/orchestrator/pkg/sandbox/sandbox.go b/packages/orchestrator/pkg/sandbox/sandbox.go index f57fa26bd4..ab4380f85a 100644 --- a/packages/orchestrator/pkg/sandbox/sandbox.go +++ b/packages/orchestrator/pkg/sandbox/sandbox.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "net/http" + "os" "sync" "time" @@ -30,6 +31,7 @@ import ( "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/network" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/rootfs" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/template" + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/ublk" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/prefetch" "github.com/e2b-dev/infra/packages/orchestrator/pkg/template/metadata" @@ -286,6 +288,7 @@ type Factory struct { config cfg.BuilderConfig networkPool *network.Pool devicePool *nbd.DevicePool + ublkPool *ublk.DevicePool featureFlags *featureflags.Client hostStatsDelivery hoststats.Delivery cgroupManager cgroup.Manager @@ -296,6 +299,7 @@ func NewFactory( config cfg.BuilderConfig, networkPool *network.Pool, devicePool *nbd.DevicePool, + ublkPool *ublk.DevicePool, featureFlags *featureflags.Client, hostStatsDelivery hoststats.Delivery, cgroupManager cgroup.Manager, @@ -307,6 +311,7 @@ func NewFactory( config: config, networkPool: networkPool, devicePool: devicePool, + ublkPool: ublkPool, featureFlags: featureFlags, hostStatsDelivery: hostStatsDelivery, cgroupManager: cgroupManager, @@ -365,13 +370,22 @@ func (f *Factory) CreateSandbox( var rootfsProvider rootfs.Provider if rootfsCachePath == "" { - rootfsProvider, err = rootfs.NewNBDProvider( - ctx, - rootFS, - sandboxFiles.SandboxCacheRootfsPath(f.config.StorageConfig), - f.devicePool, - f.featureFlags, - ) + if os.Getenv("USE_UBLK_PROVIDER") == "true" { + rootfsProvider, err = rootfs.NewUblkProvider( + ctx, + rootFS, + sandboxFiles.SandboxCacheRootfsPath(f.config.StorageConfig), + f.ublkPool, + ) + } else { + rootfsProvider, err = rootfs.NewNBDProvider( + ctx, + rootFS, + sandboxFiles.SandboxCacheRootfsPath(f.config.StorageConfig), + f.devicePool, + f.featureFlags, + ) + } } else { rootfsProvider, err = rootfs.NewDirectProvider( ctx, @@ -675,13 +689,23 @@ func (f *Factory) ResumeSandbox( telemetry.ReportEvent(ctx, "got template rootfs") - overlay, err := rootfs.NewNBDProvider( - ctx, - readonlyRootfs, - sandboxFiles.SandboxCacheRootfsPath(f.config.StorageConfig), - f.devicePool, - f.featureFlags, - ) + var overlay rootfs.Provider + if os.Getenv("USE_UBLK_PROVIDER") == "true" { + overlay, err = rootfs.NewUblkProvider( + ctx, + readonlyRootfs, + sandboxFiles.SandboxCacheRootfsPath(f.config.StorageConfig), + f.ublkPool, + ) + } else { + overlay, err = rootfs.NewNBDProvider( + ctx, + readonlyRootfs, + sandboxFiles.SandboxCacheRootfsPath(f.config.StorageConfig), + f.devicePool, + f.featureFlags, + ) + } if err != nil { return nil, fmt.Errorf("failed to create rootfs overlay: %w", err) } diff --git a/packages/orchestrator/pkg/sandbox/ublk/pool.go b/packages/orchestrator/pkg/sandbox/ublk/pool.go new file mode 100644 index 0000000000..346974c039 --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/ublk/pool.go @@ -0,0 +1,151 @@ +package ublk + +import ( + "context" + "fmt" + "go.uber.org/zap" + "sync" + "time" + + "github.com/e2b-dev/ublk-go/ublk" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/metric" + + "github.com/e2b-dev/infra/packages/shared/pkg/logger" + "github.com/e2b-dev/infra/packages/shared/pkg/utils" +) + +const defaultMaxDevices = 4096 + +var ( + meter = otel.Meter("github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/ublk") + inUseCounter = utils.Must(meter.Int64UpDownCounter("orchestrator.ublk.devices_in_use", + metric.WithDescription("Number of ublk devices currently in use."), + metric.WithUnit("{device}"), + )) + acquiredCounter = utils.Must(meter.Int64Counter("orchestrator.ublk.devices_acquired", + metric.WithDescription("Total number of ublk devices acquired."), + metric.WithUnit("{device}"), + )) + releasedCounter = utils.Must(meter.Int64Counter("orchestrator.ublk.devices_released", + metric.WithDescription("Total number of ublk devices released."), + metric.WithUnit("{device}"), + )) + newLatency = utils.Must(meter.Int64Histogram("orchestrator.ublk.new_latency_ms", + metric.WithDescription("ublk.New latency in ms."), + metric.WithUnit("ms"), + )) + closeLatency = utils.Must(meter.Int64Histogram("orchestrator.ublk.close_latency_ms", + metric.WithDescription("ublk.Device.Close latency in ms."), + metric.WithUnit("ms"), + )) +) + +type DevicePool struct { + mu sync.Mutex + devices map[*ublk.Device]struct{} // Hold devices, used for cleanup when closing + + sem chan struct{} // Concurrent limit, default 4096 + closed chan struct{} +} + +func NewDevicePool(maxDevices int) (*DevicePool, error) { + if maxDevices <= 0 { + maxDevices = defaultMaxDevices + } + return &DevicePool{ + devices: make(map[*ublk.Device]struct{}, maxDevices), + sem: make(chan struct{}, maxDevices), + closed: make(chan struct{}), + }, nil +} + +// New Create a ublk device. Support context cancellation and automatically release the semaphore on failure. +func (p *DevicePool) New(ctx context.Context, backend ublk.Backend, size uint64) (*ublk.Device, error) { + select { + case <-p.closed: + return nil, fmt.Errorf("ublk pool closed") + case p.sem <- struct{}{}: + case <-ctx.Done(): + return nil, ctx.Err() + } + + t0 := time.Now() + dev, err := ublk.New(backend, size) + newLatency.Record(ctx, time.Since(t0).Milliseconds()) + if err != nil { + <-p.sem + return nil, err + } + + p.mu.Lock() + p.devices[dev] = struct{}{} + p.mu.Unlock() + + inUseCounter.Add(ctx, 1) + acquiredCounter.Add(ctx, 1) + + logger.L().Debug(ctx, "ublk device created", + zap.String("path", dev.Path()), + ) + return dev, nil +} + +// Close Close a single device. Multiple internal calls are safe, and the semaphore will be released after closure. +func (p *DevicePool) Close(ctx context.Context, dev *ublk.Device) error { + p.mu.Lock() + if _, ok := p.devices[dev]; !ok { + p.mu.Unlock() + return nil + } + delete(p.devices, dev) + p.mu.Unlock() + + t0 := time.Now() + err := dev.Close() + closeLatency.Record(ctx, time.Since(t0).Milliseconds()) + if err != nil { + logger.L().Error(ctx, "ublk device close error", + zap.String("path", dev.Path()), + zap.Error(err), + ) + } + + inUseCounter.Add(ctx, -1) + releasedCounter.Add(ctx, 1) + <-p.sem + return err +} + +// Shutdown Called when the process exits, closing all unclosed devices in parallel. +func (p *DevicePool) Shutdown(ctx context.Context) error { + close(p.closed) + p.mu.Lock() + devs := make([]*ublk.Device, 0, len(p.devices)) + for d := range p.devices { + devs = append(devs, d) + } + p.mu.Unlock() + + if len(devs) == 0 { + return nil + } + logger.L().Info(ctx, "shutting down ublk pool", zap.Int("remaining", len(devs))) + + var wg sync.WaitGroup + for _, d := range devs { + wg.Add(1) + go func(d *ublk.Device) { + defer wg.Done() + if err := d.Close(); err != nil { + logger.L().Error(ctx, "ublk shutdown: device close error", + zap.String("path", d.Path()), + zap.Error(err), + ) + } + }(d) + } + wg.Wait() + logger.L().Info(ctx, "ublk pool shutdown complete") + return nil +}