diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/deferred.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/deferred.go index 70d85c03c5..cc11682bb9 100644 --- a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/deferred.go +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/deferred.go @@ -7,20 +7,40 @@ import "sync" // deferredFaults collects pagefaults that returned EAGAIN so they get // retried on the next poll iteration. Safe for concurrent push. type deferredFaults struct { - mu sync.Mutex - pf []*UffdPagefault + mu sync.Mutex + pf []*UffdPagefault + byAddr map[uint64]*UffdPagefault } +// push queues a deferred fault, skipping addresses already queued so a page +// faulted by several threads is retried once instead of once per fault. +// Fault addresses are already page-aligned by the kernel (UFFDIO_COPY rejects +// unaligned dst), so the raw address keys per page. If the same page is faulted +// as both read and write, the retained fault is upgraded to write so the retry +// installs it dirty instead of leaving a later WP fault to catch it. func (d *deferredFaults) push(pf *UffdPagefault) { d.mu.Lock() + defer d.mu.Unlock() + if d.byAddr == nil { + d.byAddr = make(map[uint64]*UffdPagefault) + } + addr := uint64(pf.address) + if existing, ok := d.byAddr[addr]; ok { + if pf.flags&UFFD_PAGEFAULT_FLAG_WRITE != 0 { + existing.flags |= UFFD_PAGEFAULT_FLAG_WRITE + } + + return + } + d.byAddr[addr] = pf d.pf = append(d.pf, pf) - d.mu.Unlock() } func (d *deferredFaults) drain() []*UffdPagefault { d.mu.Lock() out := d.pf d.pf = nil + d.byAddr = nil d.mu.Unlock() return out diff --git a/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/deferred_test.go b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/deferred_test.go new file mode 100644 index 0000000000..d28ad24bfe --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/uffd/userfaultfd/deferred_test.go @@ -0,0 +1,36 @@ +//go:build linux + +package userfaultfd + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDeferredFaultsDedupesByAddress(t *testing.T) { + t.Parallel() + + var d deferredFaults + d.push(&UffdPagefault{address: 42}) + d.push(&UffdPagefault{address: 42}) + d.push(&UffdPagefault{address: 43}) + + require.Len(t, d.drain(), 2) + require.Empty(t, d.drain()) + + d.push(&UffdPagefault{address: 42}) + require.Len(t, d.drain(), 1) +} + +func TestDeferredFaultsUpgradesReadToWrite(t *testing.T) { + t.Parallel() + + var d deferredFaults + d.push(&UffdPagefault{address: 42}) + d.push(&UffdPagefault{address: 42, flags: UFFD_PAGEFAULT_FLAG_WRITE}) + + out := d.drain() + require.Len(t, out, 1) + require.NotZero(t, out[0].flags&UFFD_PAGEFAULT_FLAG_WRITE) +}