From 271ff3e68bc6458b4b3bd813e061eb65f47c24af Mon Sep 17 00:00:00 2001 From: AdaAibaby Date: Fri, 29 May 2026 16:12:32 +0800 Subject: [PATCH 1/2] implement UFFD failure handling and unit tests --- packages/orchestrator/pkg/sandbox/sandbox.go | 9 + .../sandbox/uffd/testutils/mock_memfile.go | 61 +++++ .../orchestrator/pkg/sandbox/uffd/uffd.go | 21 +- .../pkg/sandbox/uffd/uffd_test.go | 234 ++++++++++++++++++ 4 files changed, 324 insertions(+), 1 deletion(-) create mode 100644 packages/orchestrator/pkg/sandbox/uffd/testutils/mock_memfile.go create mode 100644 packages/orchestrator/pkg/sandbox/uffd/uffd_test.go diff --git a/packages/orchestrator/pkg/sandbox/sandbox.go b/packages/orchestrator/pkg/sandbox/sandbox.go index ae8f4ef719..9b2f936ca0 100644 --- a/packages/orchestrator/pkg/sandbox/sandbox.go +++ b/packages/orchestrator/pkg/sandbox/sandbox.go @@ -836,6 +836,15 @@ func (f *Factory) ResumeSandbox( // This is to prevent race condition of reporting unhealthy sandbox sbx.Checks = NewChecks(sbx, useClickhouseMetrics) + // Set UFFD failure callback to stop sandbox on memory handler failure + fcUffd.SetOnFailure(func(ctx context.Context, sandboxID string, err error) { + if stopErr := sbx.Stop(ctx); stopErr != nil { + logger.L().Error(ctx, "failed to stop sandbox after uffd failure", + logger.WithSandboxID(sandboxID), + zap.Error(stopErr)) + } + }) + cleanup.AddPriority(ctx, func(ctx context.Context) error { // Stop the sandbox first if it is still running, otherwise do nothing return sbx.Stop(ctx) diff --git a/packages/orchestrator/pkg/sandbox/uffd/testutils/mock_memfile.go b/packages/orchestrator/pkg/sandbox/uffd/testutils/mock_memfile.go new file mode 100644 index 0000000000..ddf79b4f43 --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/uffd/testutils/mock_memfile.go @@ -0,0 +1,61 @@ +//go:build linux + +package testutils + +import ( + "context" + "testing" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +// MockMemfile is a mock implementation of block.ReadonlyDevice for testing. +type MockMemfile struct { + t *testing.T +} + +// NewMockMemfile creates a new mock memfile for testing. +func NewMockMemfile(t *testing.T) *MockMemfile { + return &MockMemfile{t: t} +} + +// ReadAt implements block.ReadonlyDevice. +func (m *MockMemfile) ReadAt(ctx context.Context, p []byte, off int64) (int, error) { + // Return zeros for testing + for i := range p { + p[i] = 0 + } + return len(p), nil +} + +// Size implements block.ReadonlyDevice. +func (m *MockMemfile) Size(ctx context.Context) (int64, error) { + return 1024 * 1024 * 1024, nil // 1GB +} + +// Close implements io.Closer. +func (m *MockMemfile) Close() error { + return nil +} + +// Slice implements block.Slicer. +func (m *MockMemfile) Slice(ctx context.Context, off, length int64) ([]byte, error) { + data := make([]byte, length) + _, err := m.ReadAt(ctx, data, off) + return data, err +} + +// BlockSize implements block.ReadonlyDevice. +func (m *MockMemfile) BlockSize() int64 { + return 4096 +} + +// Header implements block.ReadonlyDevice. +func (m *MockMemfile) Header() *header.Header { + return &header.Header{} +} + +// SwapHeader implements block.ReadonlyDevice. +func (m *MockMemfile) SwapHeader(h *header.Header) { + // No-op for mock +} diff --git a/packages/orchestrator/pkg/sandbox/uffd/uffd.go b/packages/orchestrator/pkg/sandbox/uffd/uffd.go index 4dd2fc1760..6fca05e8f2 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/uffd.go +++ b/packages/orchestrator/pkg/sandbox/uffd/uffd.go @@ -43,6 +43,7 @@ type Uffd struct { memfile block.ReadonlyDevice handler utils.SetOnce[*userfaultfd.Userfaultfd] fdExit utils.SetOnce[*fdexit.FdExit] + onFailure func(context.Context, string, error) // callback when handle fails } var _ MemoryBackend = (*Uffd)(nil) @@ -55,9 +56,21 @@ func New(memfile block.ReadonlyDevice, socketPath string) *Uffd { memfile: memfile, handler: *utils.NewSetOnce[*userfaultfd.Userfaultfd](), fdExit: *utils.NewSetOnce[*fdexit.FdExit](), + onFailure: nil, } } +// SetOnFailure sets a callback to be invoked when handle fails. +// The callback receives the context, sandbox ID, and error. +func (u *Uffd) SetOnFailure(fn func(context.Context, string, error)) { + u.onFailure = fn +} + +// GetOnFailure returns the currently set failure callback (for testing). +func (u *Uffd) GetOnFailure() func(context.Context, string, error) { + return u.onFailure +} + func (u *Uffd) Prefault(ctx context.Context, offset int64, data []byte) error { handler, err := u.handler.WaitWithContext(ctx) if err != nil { @@ -95,13 +108,19 @@ func (u *Uffd) Start(ctx context.Context, sandboxId string) error { ctx, span := tracer.Start(ctx, "serve uffd") defer span.End() - // TODO: If the handle function fails, we should kill the sandbox handleErr := u.handle(ctx, sandboxId, fdExit) // If handle failed before setting the handler value, set an error to unblock // any waiters (e.g., prefetcher goroutines waiting on Prefault). if handleErr != nil { u.handler.SetError(handleErr) + logger.L().Error(ctx, "uffd handle failed", + logger.WithSandboxID(sandboxId), + zap.Error(handleErr)) + // Invoke failure callback to stop the sandbox + if u.onFailure != nil { + u.onFailure(ctx, sandboxId, handleErr) + } } closeErr := u.lis.Close() diff --git a/packages/orchestrator/pkg/sandbox/uffd/uffd_test.go b/packages/orchestrator/pkg/sandbox/uffd/uffd_test.go new file mode 100644 index 0000000000..143632dd5e --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/uffd/uffd_test.go @@ -0,0 +1,234 @@ +//go:build linux + +package uffd + +import ( + "context" + "path/filepath" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/testutils" +) + +// TestSetOnFailure verifies that the failure callback is properly set and can be invoked. +func TestSetOnFailure(t *testing.T) { + t.Parallel() + + memfile := testutils.NewMockMemfile(t) + socketPath := filepath.Join(t.TempDir(), "test.sock") + + u := New(memfile, socketPath) + + // Verify callback is initially nil + assert.Nil(t, u.GetOnFailure()) + + // Set a callback + u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) { + // Callback set + }) + + // Verify callback is set + assert.NotNil(t, u.GetOnFailure()) +} + +// TestOnFailureInvokedOnHandleError verifies that the failure callback is invoked when handle fails. +// This test verifies the callback mechanism by checking that it's called when handle returns an error. +func TestOnFailureInvokedOnHandleError(t *testing.T) { + t.Parallel() + + memfile := testutils.NewMockMemfile(t) + socketPath := filepath.Join(t.TempDir(), "test.sock") + + u := New(memfile, socketPath) + + // Track callback invocation + var ( + callbackInvoked atomic.Bool + callbackSandbox string + callbackErr error + mu sync.Mutex + wg sync.WaitGroup + ) + + wg.Add(1) + u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) { + defer wg.Done() + callbackInvoked.Store(true) + mu.Lock() + defer mu.Unlock() + callbackSandbox = sandboxID + callbackErr = err + }) + + // Start UFFD - it will fail because no Firecracker connects + // The socket deadline is 10 seconds, so we need to wait for that + ctx := context.Background() + sandboxID := "test-sandbox-123" + err := u.Start(ctx, sandboxID) + require.NoError(t, err, "Start should not error immediately") + + // Wait for the goroutine to process the failure (socket timeout is 10 seconds) + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // Callback was invoked + case <-time.After(15 * time.Second): + t.Fatal("callback was not invoked within timeout") + } + + // Verify callback was invoked + assert.True(t, callbackInvoked.Load(), "callback should have been invoked") + + mu.Lock() + defer mu.Unlock() + assert.Equal(t, sandboxID, callbackSandbox, "callback should receive correct sandbox ID") + assert.NotNil(t, callbackErr, "callback should receive error") +} + +// TestOnFailureNotInvokedWhenCallbackNil verifies that nil callback doesn't crash. +func TestOnFailureNotInvokedWhenCallbackNil(t *testing.T) { + t.Parallel() + + memfile := testutils.NewMockMemfile(t) + socketPath := filepath.Join(t.TempDir(), "test.sock") + + u := New(memfile, socketPath) + + // Don't set callback - should not crash + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + err := u.Start(ctx, "test-sandbox") + require.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + // Should be able to stop without issues + stopErr := u.Stop() + assert.NoError(t, stopErr, "Stop should not error") +} + +// TestMultipleCallbackSets verifies that SetOnFailure can be called multiple times. +func TestMultipleCallbackSets(t *testing.T) { + t.Parallel() + + memfile := testutils.NewMockMemfile(t) + socketPath := filepath.Join(t.TempDir(), "test.sock") + + u := New(memfile, socketPath) + + // Set first callback + firstInvoked := atomic.Bool{} + u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) { + firstInvoked.Store(true) + }) + + // Verify first callback is set + assert.NotNil(t, u.GetOnFailure()) + + // Set second callback (should replace first) + secondInvoked := atomic.Bool{} + u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) { + secondInvoked.Store(true) + }) + + // Verify second callback is set + assert.NotNil(t, u.GetOnFailure()) + + // Note: We don't actually trigger the failure here because it would take 10 seconds + // The important thing is that SetOnFailure can be called multiple times +} + +// TestCallbackReceivesCorrectParameters verifies callback receives correct context, sandbox ID, and error. +// This is a long-running test that waits for the socket timeout (10 seconds). +func TestCallbackReceivesCorrectParameters(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + + t.Parallel() + + memfile := testutils.NewMockMemfile(t) + socketPath := filepath.Join(t.TempDir(), "test.sock") + + u := New(memfile, socketPath) + + var ( + receivedCtx context.Context + receivedSandbox string + receivedErr error + mu sync.Mutex + wg sync.WaitGroup + ) + + wg.Add(1) + u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) { + defer wg.Done() + mu.Lock() + defer mu.Unlock() + receivedCtx = ctx + receivedSandbox = sandboxID + receivedErr = err + }) + + sandboxID := "test-sandbox-xyz" + err := u.Start(context.Background(), sandboxID) + require.NoError(t, err) + + // Wait for callback (socket timeout is 10 seconds) + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(15 * time.Second): + t.Fatal("callback was not invoked within timeout") + } + + mu.Lock() + defer mu.Unlock() + + assert.NotNil(t, receivedCtx, "callback should receive context") + assert.Equal(t, sandboxID, receivedSandbox, "callback should receive correct sandbox ID") + assert.NotNil(t, receivedErr, "callback should receive error") +} + +// TestUffdStopAfterFailure verifies that UFFD can be stopped after a failure. +func TestUffdStopAfterFailure(t *testing.T) { + t.Parallel() + + memfile := testutils.NewMockMemfile(t) + socketPath := filepath.Join(t.TempDir(), "test.sock") + + u := New(memfile, socketPath) + + u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) { + // Callback does nothing + }) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + err := u.Start(ctx, "test-sandbox") + require.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + // Should be able to stop UFFD after failure + stopErr := u.Stop() + assert.NoError(t, stopErr, "Stop should not error after failure") +} From 36fc7d71e4da06c927298a89a8cb864123daa58d Mon Sep 17 00:00:00 2001 From: AdaAibaby Date: Fri, 29 May 2026 16:25:48 +0800 Subject: [PATCH 2/2] fix(uffd): protect failure callback with mutex and persist failure state Address data race and late-registration issues identified in code review: - Add sync.Mutex to guard onFailure field across goroutines - Persist failure state (failedErr/failedSbx/failedCtx) so callbacks registered after a failure still receive the event immediately - Invoke callback in a goroutine to avoid holding the lock during potentially slow sandbox teardown - Add TestLateCallbackRegistrationAfterFailure to verify the guarantee - Add TestCallbackNotInvokedOnCleanStop to verify no false positives --- .../orchestrator/pkg/sandbox/uffd/uffd.go | 23 +- .../pkg/sandbox/uffd/uffd_test.go | 206 +++++++++--------- 2 files changed, 125 insertions(+), 104 deletions(-) diff --git a/packages/orchestrator/pkg/sandbox/uffd/uffd.go b/packages/orchestrator/pkg/sandbox/uffd/uffd.go index 6fca05e8f2..3a264d92f8 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/uffd.go +++ b/packages/orchestrator/pkg/sandbox/uffd/uffd.go @@ -43,7 +43,11 @@ type Uffd struct { memfile block.ReadonlyDevice handler utils.SetOnce[*userfaultfd.Userfaultfd] fdExit utils.SetOnce[*fdexit.FdExit] + mu sync.Mutex onFailure func(context.Context, string, error) // callback when handle fails + failedErr error + failedSbx string + failedCtx context.Context } var _ MemoryBackend = (*Uffd)(nil) @@ -62,12 +66,21 @@ func New(memfile block.ReadonlyDevice, socketPath string) *Uffd { // SetOnFailure sets a callback to be invoked when handle fails. // The callback receives the context, sandbox ID, and error. +// If a failure has already occurred before this is called, the callback is +// invoked immediately in a new goroutine with the stored failure details. func (u *Uffd) SetOnFailure(fn func(context.Context, string, error)) { + u.mu.Lock() u.onFailure = fn + if u.failedErr != nil && fn != nil { + go fn(u.failedCtx, u.failedSbx, u.failedErr) + } + u.mu.Unlock() } // GetOnFailure returns the currently set failure callback (for testing). func (u *Uffd) GetOnFailure() func(context.Context, string, error) { + u.mu.Lock() + defer u.mu.Unlock() return u.onFailure } @@ -117,10 +130,16 @@ func (u *Uffd) Start(ctx context.Context, sandboxId string) error { logger.L().Error(ctx, "uffd handle failed", logger.WithSandboxID(sandboxId), zap.Error(handleErr)) - // Invoke failure callback to stop the sandbox + // Persist failure state and invoke callback under lock. + // Storing the state ensures late-registered callbacks still receive the event. + u.mu.Lock() + u.failedErr = handleErr + u.failedSbx = sandboxId + u.failedCtx = ctx if u.onFailure != nil { - u.onFailure(ctx, sandboxId, handleErr) + go u.onFailure(ctx, sandboxId, handleErr) } + u.mu.Unlock() } closeErr := u.lis.Close() diff --git a/packages/orchestrator/pkg/sandbox/uffd/uffd_test.go b/packages/orchestrator/pkg/sandbox/uffd/uffd_test.go index 143632dd5e..918a01d9bd 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/uffd_test.go +++ b/packages/orchestrator/pkg/sandbox/uffd/uffd_test.go @@ -16,7 +16,7 @@ import ( "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/uffd/testutils" ) -// TestSetOnFailure verifies that the failure callback is properly set and can be invoked. +// TestSetOnFailure verifies that the failure callback is properly set and can be retrieved. func TestSetOnFailure(t *testing.T) { t.Parallel() @@ -25,21 +25,15 @@ func TestSetOnFailure(t *testing.T) { u := New(memfile, socketPath) - // Verify callback is initially nil assert.Nil(t, u.GetOnFailure()) - // Set a callback - u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) { - // Callback set - }) + u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) {}) - // Verify callback is set assert.NotNil(t, u.GetOnFailure()) } -// TestOnFailureInvokedOnHandleError verifies that the failure callback is invoked when handle fails. -// This test verifies the callback mechanism by checking that it's called when handle returns an error. -func TestOnFailureInvokedOnHandleError(t *testing.T) { +// TestOnFailureNotInvokedWhenCallbackNil verifies that a nil callback doesn't crash. +func TestOnFailureNotInvokedWhenCallbackNil(t *testing.T) { t.Parallel() memfile := testutils.NewMockMemfile(t) @@ -47,65 +41,48 @@ func TestOnFailureInvokedOnHandleError(t *testing.T) { u := New(memfile, socketPath) - // Track callback invocation - var ( - callbackInvoked atomic.Bool - callbackSandbox string - callbackErr error - mu sync.Mutex - wg sync.WaitGroup - ) + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() - wg.Add(1) - u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) { - defer wg.Done() - callbackInvoked.Store(true) - mu.Lock() - defer mu.Unlock() - callbackSandbox = sandboxID - callbackErr = err - }) + err := u.Start(ctx, "test-sandbox") + require.NoError(t, err) - // Start UFFD - it will fail because no Firecracker connects - // The socket deadline is 10 seconds, so we need to wait for that - ctx := context.Background() - sandboxID := "test-sandbox-123" - err := u.Start(ctx, sandboxID) - require.NoError(t, err, "Start should not error immediately") + time.Sleep(100 * time.Millisecond) - // Wait for the goroutine to process the failure (socket timeout is 10 seconds) - done := make(chan struct{}) - go func() { - wg.Wait() - close(done) - }() + stopErr := u.Stop() + assert.NoError(t, stopErr) +} - select { - case <-done: - // Callback was invoked - case <-time.After(15 * time.Second): - t.Fatal("callback was not invoked within timeout") - } +// TestMultipleCallbackSets verifies that SetOnFailure can be called multiple times, +// with each call replacing the previous callback. +func TestMultipleCallbackSets(t *testing.T) { + t.Parallel() - // Verify callback was invoked - assert.True(t, callbackInvoked.Load(), "callback should have been invoked") + memfile := testutils.NewMockMemfile(t) + socketPath := filepath.Join(t.TempDir(), "test.sock") - mu.Lock() - defer mu.Unlock() - assert.Equal(t, sandboxID, callbackSandbox, "callback should receive correct sandbox ID") - assert.NotNil(t, callbackErr, "callback should receive error") + u := New(memfile, socketPath) + + u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) {}) + assert.NotNil(t, u.GetOnFailure()) + + u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) {}) + assert.NotNil(t, u.GetOnFailure()) + + u.SetOnFailure(nil) + assert.Nil(t, u.GetOnFailure()) } -// TestOnFailureNotInvokedWhenCallbackNil verifies that nil callback doesn't crash. -func TestOnFailureNotInvokedWhenCallbackNil(t *testing.T) { +// TestUffdStopAfterFailure verifies that UFFD can be stopped cleanly after a failure. +func TestUffdStopAfterFailure(t *testing.T) { t.Parallel() memfile := testutils.NewMockMemfile(t) socketPath := filepath.Join(t.TempDir(), "test.sock") u := New(memfile, socketPath) + u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) {}) - // Don't set callback - should not crash ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) defer cancel() @@ -114,13 +91,13 @@ func TestOnFailureNotInvokedWhenCallbackNil(t *testing.T) { time.Sleep(100 * time.Millisecond) - // Should be able to stop without issues stopErr := u.Stop() - assert.NoError(t, stopErr, "Stop should not error") + assert.NoError(t, stopErr) } -// TestMultipleCallbackSets verifies that SetOnFailure can be called multiple times. -func TestMultipleCallbackSets(t *testing.T) { +// TestCallbackNotInvokedOnCleanStop verifies that the callback is NOT invoked +// when UFFD is stopped cleanly (no failure). +func TestCallbackNotInvokedOnCleanStop(t *testing.T) { t.Parallel() memfile := testutils.NewMockMemfile(t) @@ -128,31 +105,27 @@ func TestMultipleCallbackSets(t *testing.T) { u := New(memfile, socketPath) - // Set first callback - firstInvoked := atomic.Bool{} + invoked := atomic.Bool{} u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) { - firstInvoked.Store(true) + invoked.Store(true) }) - // Verify first callback is set - assert.NotNil(t, u.GetOnFailure()) + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() - // Set second callback (should replace first) - secondInvoked := atomic.Bool{} - u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) { - secondInvoked.Store(true) - }) + err := u.Start(ctx, "test-sandbox") + require.NoError(t, err) - // Verify second callback is set - assert.NotNil(t, u.GetOnFailure()) + _ = u.Stop() - // Note: We don't actually trigger the failure here because it would take 10 seconds - // The important thing is that SetOnFailure can be called multiple times + time.Sleep(50 * time.Millisecond) + + assert.False(t, invoked.Load(), "callback should not be invoked on clean stop") } -// TestCallbackReceivesCorrectParameters verifies callback receives correct context, sandbox ID, and error. -// This is a long-running test that waits for the socket timeout (10 seconds). -func TestCallbackReceivesCorrectParameters(t *testing.T) { +// TestOnFailureInvokedOnHandleError verifies that the failure callback is invoked +// when handle fails (socket timeout). Long-running: waits ~10s for socket deadline. +func TestOnFailureInvokedOnHandleError(t *testing.T) { if testing.Short() { t.Skip("skipping long-running test in short mode") } @@ -165,28 +138,25 @@ func TestCallbackReceivesCorrectParameters(t *testing.T) { u := New(memfile, socketPath) var ( - receivedCtx context.Context - receivedSandbox string - receivedErr error - mu sync.Mutex wg sync.WaitGroup + callbackSandbox string + callbackErr error + mu sync.Mutex ) wg.Add(1) - u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) { + sandboxID := "test-sandbox-123" + u.SetOnFailure(func(ctx context.Context, sbxID string, err error) { defer wg.Done() mu.Lock() defer mu.Unlock() - receivedCtx = ctx - receivedSandbox = sandboxID - receivedErr = err + callbackSandbox = sbxID + callbackErr = err }) - sandboxID := "test-sandbox-xyz" err := u.Start(context.Background(), sandboxID) require.NoError(t, err) - // Wait for callback (socket timeout is 10 seconds) done := make(chan struct{}) go func() { wg.Wait() @@ -201,14 +171,18 @@ func TestCallbackReceivesCorrectParameters(t *testing.T) { mu.Lock() defer mu.Unlock() - - assert.NotNil(t, receivedCtx, "callback should receive context") - assert.Equal(t, sandboxID, receivedSandbox, "callback should receive correct sandbox ID") - assert.NotNil(t, receivedErr, "callback should receive error") + assert.Equal(t, sandboxID, callbackSandbox) + assert.NotNil(t, callbackErr) } -// TestUffdStopAfterFailure verifies that UFFD can be stopped after a failure. -func TestUffdStopAfterFailure(t *testing.T) { +// TestLateCallbackRegistrationAfterFailure verifies the key correctness guarantee +// from the code review: a callback registered AFTER a failure has already occurred +// is still invoked immediately with the stored failure details. +func TestLateCallbackRegistrationAfterFailure(t *testing.T) { + if testing.Short() { + t.Skip("skipping long-running test in short mode") + } + t.Parallel() memfile := testutils.NewMockMemfile(t) @@ -216,19 +190,47 @@ func TestUffdStopAfterFailure(t *testing.T) { u := New(memfile, socketPath) - u.SetOnFailure(func(ctx context.Context, sandboxID string, err error) { - // Callback does nothing - }) + sandboxID := "test-sandbox-late" - ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) - defer cancel() - - err := u.Start(ctx, "test-sandbox") + // Start without a callback - UFFD will fail after socket timeout (10 seconds) + err := u.Start(context.Background(), sandboxID) require.NoError(t, err) - time.Sleep(100 * time.Millisecond) + // Wait for the internal goroutine to finish (socket timeout = 10 seconds) + <-u.readyCh - // Should be able to stop UFFD after failure - stopErr := u.Stop() - assert.NoError(t, stopErr, "Stop should not error after failure") + // Register callback AFTER failure has already occurred. + // It must be invoked immediately in a goroutine with the stored failure details. + var ( + wg sync.WaitGroup + receivedSandbox string + receivedErr error + mu sync.Mutex + ) + + wg.Add(1) + u.SetOnFailure(func(ctx context.Context, sbxID string, cbErr error) { + defer wg.Done() + mu.Lock() + defer mu.Unlock() + receivedSandbox = sbxID + receivedErr = cbErr + }) + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("late-registered callback was not invoked") + } + + mu.Lock() + defer mu.Unlock() + assert.Equal(t, sandboxID, receivedSandbox) + assert.NotNil(t, receivedErr) }