diff --git a/packages/orchestrator/pkg/sandbox/sandbox.go b/packages/orchestrator/pkg/sandbox/sandbox.go index ae8f4ef719..9b2f936ca0 100644 --- a/packages/orchestrator/pkg/sandbox/sandbox.go +++ b/packages/orchestrator/pkg/sandbox/sandbox.go @@ -836,6 +836,15 @@ func (f *Factory) ResumeSandbox( // This is to prevent race condition of reporting unhealthy sandbox sbx.Checks = NewChecks(sbx, useClickhouseMetrics) + // Set UFFD failure callback to stop sandbox on memory handler failure + fcUffd.SetOnFailure(func(ctx context.Context, sandboxID string, err error) { + if stopErr := sbx.Stop(ctx); stopErr != nil { + logger.L().Error(ctx, "failed to stop sandbox after uffd failure", + logger.WithSandboxID(sandboxID), + zap.Error(stopErr)) + } + }) + cleanup.AddPriority(ctx, func(ctx context.Context) error { // Stop the sandbox first if it is still running, otherwise do nothing return sbx.Stop(ctx) diff --git a/packages/orchestrator/pkg/sandbox/uffd/testutils/mock_memfile.go b/packages/orchestrator/pkg/sandbox/uffd/testutils/mock_memfile.go new file mode 100644 index 0000000000..ddf79b4f43 --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/uffd/testutils/mock_memfile.go @@ -0,0 +1,61 @@ +//go:build linux + +package testutils + +import ( + "context" + "testing" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +// MockMemfile is a mock implementation of block.ReadonlyDevice for testing. +type MockMemfile struct { + t *testing.T +} + +// NewMockMemfile creates a new mock memfile for testing. +func NewMockMemfile(t *testing.T) *MockMemfile { + return &MockMemfile{t: t} +} + +// ReadAt implements block.ReadonlyDevice. +func (m *MockMemfile) ReadAt(ctx context.Context, p []byte, off int64) (int, error) { + // Return zeros for testing + for i := range p { + p[i] = 0 + } + return len(p), nil +} + +// Size implements block.ReadonlyDevice. +func (m *MockMemfile) Size(ctx context.Context) (int64, error) { + return 1024 * 1024 * 1024, nil // 1GB +} + +// Close implements io.Closer. +func (m *MockMemfile) Close() error { + return nil +} + +// Slice implements block.Slicer. +func (m *MockMemfile) Slice(ctx context.Context, off, length int64) ([]byte, error) { + data := make([]byte, length) + _, err := m.ReadAt(ctx, data, off) + return data, err +} + +// BlockSize implements block.ReadonlyDevice. +func (m *MockMemfile) BlockSize() int64 { + return 4096 +} + +// Header implements block.ReadonlyDevice. +func (m *MockMemfile) Header() *header.Header { + return &header.Header{} +} + +// SwapHeader implements block.ReadonlyDevice. +func (m *MockMemfile) SwapHeader(h *header.Header) { + // No-op for mock +} diff --git a/packages/orchestrator/pkg/sandbox/uffd/uffd.go b/packages/orchestrator/pkg/sandbox/uffd/uffd.go index 4dd2fc1760..3a264d92f8 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/uffd.go +++ b/packages/orchestrator/pkg/sandbox/uffd/uffd.go @@ -43,6 +43,11 @@ type Uffd struct { memfile block.ReadonlyDevice handler utils.SetOnce[*userfaultfd.Userfaultfd] fdExit utils.SetOnce[*fdexit.FdExit] + mu sync.Mutex + onFailure func(context.Context, string, error) // callback when handle fails + failedErr error + failedSbx string + failedCtx context.Context } var _ MemoryBackend = (*Uffd)(nil) @@ -55,9 +60,30 @@ func New(memfile block.ReadonlyDevice, socketPath string) *Uffd { memfile: memfile, handler: *utils.NewSetOnce[*userfaultfd.Userfaultfd](), fdExit: *utils.NewSetOnce[*fdexit.FdExit](), + onFailure: nil, } } +// SetOnFailure sets a callback to be invoked when handle fails. +// The callback receives the context, sandbox ID, and error. +// If a failure has already occurred before this is called, the callback is +// invoked immediately in a new goroutine with the stored failure details. +func (u *Uffd) SetOnFailure(fn func(context.Context, string, error)) { + u.mu.Lock() + u.onFailure = fn + if u.failedErr != nil && fn != nil { + go fn(u.failedCtx, u.failedSbx, u.failedErr) + } + u.mu.Unlock() +} + +// GetOnFailure returns the currently set failure callback (for testing). +func (u *Uffd) GetOnFailure() func(context.Context, string, error) { + u.mu.Lock() + defer u.mu.Unlock() + return u.onFailure +} + func (u *Uffd) Prefault(ctx context.Context, offset int64, data []byte) error { handler, err := u.handler.WaitWithContext(ctx) if err != nil { @@ -95,13 +121,25 @@ func (u *Uffd) Start(ctx context.Context, sandboxId string) error { ctx, span := tracer.Start(ctx, "serve uffd") defer span.End() - // TODO: If the handle function fails, we should kill the sandbox handleErr := u.handle(ctx, sandboxId, fdExit) // If handle failed before setting the handler value, set an error to unblock // any waiters (e.g., prefetcher goroutines waiting on Prefault). if handleErr != nil { u.handler.SetError(handleErr) + logger.L().Error(ctx, "uffd handle failed", + logger.WithSandboxID(sandboxId), + zap.Error(handleErr)) + // Persist failure state and invoke callback under lock. + // Storing the state ensures late-registered callbacks still receive the event. + u.mu.Lock() + u.failedErr = handleErr + u.failedSbx = sandboxId + u.failedCtx = ctx + if u.onFailure != nil { + go u.onFailure(ctx, sandboxId, handleErr) + } + u.mu.Unlock() } closeErr := u.lis.Close() diff --git a/packages/orchestrator/pkg/sandbox/uffd/uffd_test.go b/packages/orchestrator/pkg/sandbox/uffd/uffd_test.go new file mode 100644 index 0000000000..918a01d9bd --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/uffd/uffd_test.go @@ -0,0 +1,236 @@ +//go:build linux + +package uffd + +import ( + "context" + "path/filepath" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/testutils" +) + +// TestSetOnFailure verifies that the failure callback is properly set and can be retrieved. +func TestSetOnFailure(t *testing.T) { + t.Parallel() + + memfile := testutils.NewMockMemfile(t) + socketPath := filepath.Join(t.TempDir(), "test.sock") + + u := New(memfile, socketPath) + + assert.Nil(t, u.GetOnFailure()) + + u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) {}) + + assert.NotNil(t, u.GetOnFailure()) +} + +// TestOnFailureNotInvokedWhenCallbackNil verifies that a nil callback doesn't crash. +func TestOnFailureNotInvokedWhenCallbackNil(t *testing.T) { + t.Parallel() + + memfile := testutils.NewMockMemfile(t) + socketPath := filepath.Join(t.TempDir(), "test.sock") + + u := New(memfile, socketPath) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + err := u.Start(ctx, "test-sandbox") + require.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + stopErr := u.Stop() + assert.NoError(t, stopErr) +} + +// TestMultipleCallbackSets verifies that SetOnFailure can be called multiple times, +// with each call replacing the previous callback. +func TestMultipleCallbackSets(t *testing.T) { + t.Parallel() + + memfile := testutils.NewMockMemfile(t) + socketPath := filepath.Join(t.TempDir(), "test.sock") + + u := New(memfile, socketPath) + + u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) {}) + assert.NotNil(t, u.GetOnFailure()) + + u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) {}) + assert.NotNil(t, u.GetOnFailure()) + + u.SetOnFailure(nil) + assert.Nil(t, u.GetOnFailure()) +} + +// TestUffdStopAfterFailure verifies that UFFD can be stopped cleanly after a failure. +func TestUffdStopAfterFailure(t *testing.T) { + t.Parallel() + + memfile := testutils.NewMockMemfile(t) + socketPath := filepath.Join(t.TempDir(), "test.sock") + + u := New(memfile, socketPath) + u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) {}) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + err := u.Start(ctx, "test-sandbox") + require.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + stopErr := u.Stop() + assert.NoError(t, stopErr) +} + +// TestCallbackNotInvokedOnCleanStop verifies that the callback is NOT invoked +// when UFFD is stopped cleanly (no failure). +func TestCallbackNotInvokedOnCleanStop(t *testing.T) { + t.Parallel() + + memfile := testutils.NewMockMemfile(t) + socketPath := filepath.Join(t.TempDir(), "test.sock") + + u := New(memfile, socketPath) + + invoked := atomic.Bool{} + u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) { + invoked.Store(true) + }) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + err := u.Start(ctx, "test-sandbox") + require.NoError(t, err) + + _ = u.Stop() + + time.Sleep(50 * time.Millisecond) + + assert.False(t, invoked.Load(), "callback should not be invoked on clean stop") +} + +// TestOnFailureInvokedOnHandleError verifies that the failure callback is invoked +// when handle fails (socket timeout). Long-running: waits ~10s for socket deadline. +func TestOnFailureInvokedOnHandleError(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + t.Parallel() + + memfile := testutils.NewMockMemfile(t) + socketPath := filepath.Join(t.TempDir(), "test.sock") + + u := New(memfile, socketPath) + + var ( + wg sync.WaitGroup + callbackSandbox string + callbackErr error + mu sync.Mutex + ) + + wg.Add(1) + sandboxID := "test-sandbox-123" + u.SetOnFailure(func(ctx context.Context, sbxID string, err error) { + defer wg.Done() + mu.Lock() + defer mu.Unlock() + callbackSandbox = sbxID + callbackErr = err + }) + + err := u.Start(context.Background(), sandboxID) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(15 * time.Second): + t.Fatal("callback was not invoked within timeout") + } + + mu.Lock() + defer mu.Unlock() + assert.Equal(t, sandboxID, callbackSandbox) + assert.NotNil(t, callbackErr) +} + +// TestLateCallbackRegistrationAfterFailure verifies the key correctness guarantee +// from the code review: a callback registered AFTER a failure has already occurred +// is still invoked immediately with the stored failure details. +func TestLateCallbackRegistrationAfterFailure(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + t.Parallel() + + memfile := testutils.NewMockMemfile(t) + socketPath := filepath.Join(t.TempDir(), "test.sock") + + u := New(memfile, socketPath) + + sandboxID := "test-sandbox-late" + + // Start without a callback - UFFD will fail after socket timeout (10 seconds) + err := u.Start(context.Background(), sandboxID) + require.NoError(t, err) + + // Wait for the internal goroutine to finish (socket timeout = 10 seconds) + <-u.readyCh + + // Register callback AFTER failure has already occurred. + // It must be invoked immediately in a goroutine with the stored failure details. + var ( + wg sync.WaitGroup + receivedSandbox string + receivedErr error + mu sync.Mutex + ) + + wg.Add(1) + u.SetOnFailure(func(ctx context.Context, sbxID string, cbErr error) { + defer wg.Done() + mu.Lock() + defer mu.Unlock() + receivedSandbox = sbxID + receivedErr = cbErr + }) + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("late-registered callback was not invoked") + } + + mu.Lock() + defer mu.Unlock() + assert.Equal(t, sandboxID, receivedSandbox) + assert.NotNil(t, receivedErr) +}