Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions packages/orchestrator/pkg/sandbox/sandbox.go
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,15 @@ func (f *Factory) ResumeSandbox(
// This is to prevent race condition of reporting unhealthy sandbox
sbx.Checks = NewChecks(sbx, useClickhouseMetrics)

// Set UFFD failure callback to stop sandbox on memory handler failure
fcUffd.SetOnFailure(func(ctx context.Context, sandboxID string, err error) {
if stopErr := sbx.Stop(ctx); stopErr != nil {
logger.L().Error(ctx, "failed to stop sandbox after uffd failure",
logger.WithSandboxID(sandboxID),
zap.Error(stopErr))
}
})

cleanup.AddPriority(ctx, func(ctx context.Context) error {
// Stop the sandbox first if it is still running, otherwise do nothing
return sbx.Stop(ctx)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
//go:build linux

package testutils

import (
"context"
"testing"

"github.com/e2b-dev/infra/packages/shared/pkg/storage/header"
)

// MockMemfile is a mock implementation of block.ReadonlyDevice for testing.
type MockMemfile struct {
t *testing.T
}

// NewMockMemfile creates a new mock memfile for testing.
func NewMockMemfile(t *testing.T) *MockMemfile {
return &MockMemfile{t: t}
}

// ReadAt implements block.ReadonlyDevice.
func (m *MockMemfile) ReadAt(ctx context.Context, p []byte, off int64) (int, error) {
// Return zeros for testing
for i := range p {
p[i] = 0
}
return len(p), nil
}

// Size implements block.ReadonlyDevice.
func (m *MockMemfile) Size(ctx context.Context) (int64, error) {
return 1024 * 1024 * 1024, nil // 1GB
}

// Close implements io.Closer.
func (m *MockMemfile) Close() error {
return nil
}

// Slice implements block.Slicer.
func (m *MockMemfile) Slice(ctx context.Context, off, length int64) ([]byte, error) {
data := make([]byte, length)
_, err := m.ReadAt(ctx, data, off)
return data, err
}

// BlockSize implements block.ReadonlyDevice.
func (m *MockMemfile) BlockSize() int64 {
return 4096
}

// Header implements block.ReadonlyDevice.
func (m *MockMemfile) Header() *header.Header {
return &header.Header{}
}

// SwapHeader implements block.ReadonlyDevice.
func (m *MockMemfile) SwapHeader(h *header.Header) {
// No-op for mock
}
40 changes: 39 additions & 1 deletion packages/orchestrator/pkg/sandbox/uffd/uffd.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ type Uffd struct {
memfile block.ReadonlyDevice
handler utils.SetOnce[*userfaultfd.Userfaultfd]
fdExit utils.SetOnce[*fdexit.FdExit]
mu sync.Mutex
onFailure func(context.Context, string, error) // callback when handle fails
failedErr error
failedSbx string
failedCtx context.Context
}
Comment thread
AdaAibaby marked this conversation as resolved.

var _ MemoryBackend = (*Uffd)(nil)
Expand All @@ -55,9 +60,30 @@ func New(memfile block.ReadonlyDevice, socketPath string) *Uffd {
memfile: memfile,
handler: *utils.NewSetOnce[*userfaultfd.Userfaultfd](),
fdExit: *utils.NewSetOnce[*fdexit.FdExit](),
onFailure: nil,
}
}

// SetOnFailure sets a callback to be invoked when handle fails.
// The callback receives the context, sandbox ID, and error.
// If a failure has already occurred before this is called, the callback is
// invoked immediately in a new goroutine with the stored failure details.
func (u *Uffd) SetOnFailure(fn func(context.Context, string, error)) {
u.mu.Lock()
u.onFailure = fn
Comment thread
AdaAibaby marked this conversation as resolved.
if u.failedErr != nil && fn != nil {
go fn(u.failedCtx, u.failedSbx, u.failedErr)
}
u.mu.Unlock()
}

// GetOnFailure returns the currently set failure callback (for testing).
func (u *Uffd) GetOnFailure() func(context.Context, string, error) {
u.mu.Lock()
defer u.mu.Unlock()
return u.onFailure
}
Comment thread
AdaAibaby marked this conversation as resolved.

func (u *Uffd) Prefault(ctx context.Context, offset int64, data []byte) error {
handler, err := u.handler.WaitWithContext(ctx)
if err != nil {
Expand Down Expand Up @@ -95,13 +121,25 @@ func (u *Uffd) Start(ctx context.Context, sandboxId string) error {
ctx, span := tracer.Start(ctx, "serve uffd")
defer span.End()

// TODO: If the handle function fails, we should kill the sandbox
handleErr := u.handle(ctx, sandboxId, fdExit)

// If handle failed before setting the handler value, set an error to unblock
// any waiters (e.g., prefetcher goroutines waiting on Prefault).
if handleErr != nil {
u.handler.SetError(handleErr)
logger.L().Error(ctx, "uffd handle failed",
logger.WithSandboxID(sandboxId),
zap.Error(handleErr))
// Persist failure state and invoke callback under lock.
// Storing the state ensures late-registered callbacks still receive the event.
u.mu.Lock()
u.failedErr = handleErr
u.failedSbx = sandboxId
u.failedCtx = ctx
if u.onFailure != nil {
go u.onFailure(ctx, sandboxId, handleErr)
}
u.mu.Unlock()
}

closeErr := u.lis.Close()
Expand Down
236 changes: 236 additions & 0 deletions packages/orchestrator/pkg/sandbox/uffd/uffd_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
//go:build linux

package uffd

import (
"context"
"path/filepath"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/testutils"
)

// TestSetOnFailure verifies that the failure callback is properly set and can be retrieved.
func TestSetOnFailure(t *testing.T) {
t.Parallel()

memfile := testutils.NewMockMemfile(t)
socketPath := filepath.Join(t.TempDir(), "test.sock")

u := New(memfile, socketPath)

assert.Nil(t, u.GetOnFailure())

u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) {})

assert.NotNil(t, u.GetOnFailure())
}

// TestOnFailureNotInvokedWhenCallbackNil verifies that a nil callback doesn't crash.
func TestOnFailureNotInvokedWhenCallbackNil(t *testing.T) {
t.Parallel()

memfile := testutils.NewMockMemfile(t)
socketPath := filepath.Join(t.TempDir(), "test.sock")

u := New(memfile, socketPath)

ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()

err := u.Start(ctx, "test-sandbox")
require.NoError(t, err)

time.Sleep(100 * time.Millisecond)

stopErr := u.Stop()
assert.NoError(t, stopErr)
}

// TestMultipleCallbackSets verifies that SetOnFailure can be called multiple times,
// with each call replacing the previous callback.
func TestMultipleCallbackSets(t *testing.T) {
t.Parallel()

memfile := testutils.NewMockMemfile(t)
socketPath := filepath.Join(t.TempDir(), "test.sock")

u := New(memfile, socketPath)

u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) {})
assert.NotNil(t, u.GetOnFailure())

u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) {})
assert.NotNil(t, u.GetOnFailure())

u.SetOnFailure(nil)
assert.Nil(t, u.GetOnFailure())
}

// TestUffdStopAfterFailure verifies that UFFD can be stopped cleanly after a failure.
func TestUffdStopAfterFailure(t *testing.T) {
t.Parallel()

memfile := testutils.NewMockMemfile(t)
socketPath := filepath.Join(t.TempDir(), "test.sock")

u := New(memfile, socketPath)
u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) {})

ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()

err := u.Start(ctx, "test-sandbox")
require.NoError(t, err)

time.Sleep(100 * time.Millisecond)

stopErr := u.Stop()
assert.NoError(t, stopErr)
}

// TestCallbackNotInvokedOnCleanStop verifies that the callback is NOT invoked
// when UFFD is stopped cleanly (no failure).
func TestCallbackNotInvokedOnCleanStop(t *testing.T) {
t.Parallel()

memfile := testutils.NewMockMemfile(t)
socketPath := filepath.Join(t.TempDir(), "test.sock")

u := New(memfile, socketPath)

invoked := atomic.Bool{}
u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) {
invoked.Store(true)
})

ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()

err := u.Start(ctx, "test-sandbox")
require.NoError(t, err)

_ = u.Stop()

time.Sleep(50 * time.Millisecond)

assert.False(t, invoked.Load(), "callback should not be invoked on clean stop")
}

// TestOnFailureInvokedOnHandleError verifies that the failure callback is invoked
// when handle fails (socket timeout). Long-running: waits ~10s for socket deadline.
func TestOnFailureInvokedOnHandleError(t *testing.T) {
if testing.Short() {
t.Skip("skipping long-running test in short mode")
}

t.Parallel()

memfile := testutils.NewMockMemfile(t)
socketPath := filepath.Join(t.TempDir(), "test.sock")

u := New(memfile, socketPath)

var (
wg sync.WaitGroup
callbackSandbox string
callbackErr error
mu sync.Mutex
)

wg.Add(1)
sandboxID := "test-sandbox-123"
u.SetOnFailure(func(ctx context.Context, sbxID string, err error) {
defer wg.Done()
mu.Lock()
defer mu.Unlock()
callbackSandbox = sbxID
callbackErr = err
})

err := u.Start(context.Background(), sandboxID)
require.NoError(t, err)

done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()

select {
case <-done:
case <-time.After(15 * time.Second):
t.Fatal("callback was not invoked within timeout")
}

mu.Lock()
defer mu.Unlock()
assert.Equal(t, sandboxID, callbackSandbox)
assert.NotNil(t, callbackErr)
}

// TestLateCallbackRegistrationAfterFailure verifies the key correctness guarantee
// from the code review: a callback registered AFTER a failure has already occurred
// is still invoked immediately with the stored failure details.
func TestLateCallbackRegistrationAfterFailure(t *testing.T) {
if testing.Short() {
t.Skip("skipping long-running test in short mode")
}

t.Parallel()

memfile := testutils.NewMockMemfile(t)
socketPath := filepath.Join(t.TempDir(), "test.sock")

u := New(memfile, socketPath)

sandboxID := "test-sandbox-late"

// Start without a callback - UFFD will fail after socket timeout (10 seconds)
err := u.Start(context.Background(), sandboxID)
require.NoError(t, err)

// Wait for the internal goroutine to finish (socket timeout = 10 seconds)
<-u.readyCh

// Register callback AFTER failure has already occurred.
// It must be invoked immediately in a goroutine with the stored failure details.
var (
wg sync.WaitGroup
receivedSandbox string
receivedErr error
mu sync.Mutex
)

wg.Add(1)
u.SetOnFailure(func(ctx context.Context, sbxID string, cbErr error) {
defer wg.Done()
mu.Lock()
defer mu.Unlock()
receivedSandbox = sbxID
receivedErr = cbErr
})

done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()

select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("late-registered callback was not invoked")
}

mu.Lock()
defer mu.Unlock()
assert.Equal(t, sandboxID, receivedSandbox)
assert.NotNil(t, receivedErr)
}