Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions packages/orchestrator/pkg/sandbox/uffd/userfaultfd/deferred.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,40 @@ import "sync"
// deferredFaults collects pagefaults that returned EAGAIN so they get
// retried on the next poll iteration. Safe for concurrent push.
type deferredFaults struct {
mu sync.Mutex
pf []*UffdPagefault
mu sync.Mutex
pf []*UffdPagefault
byAddr map[uint64]*UffdPagefault
}

// push queues a deferred fault, skipping addresses already queued so a page
// faulted by several threads is retried once instead of once per fault.
// Fault addresses are already page-aligned by the kernel (UFFDIO_COPY rejects
// unaligned dst), so the raw address keys per page. If the same page is faulted
// as both read and write, the retained fault is upgraded to write so the retry
// installs it dirty instead of leaving a later WP fault to catch it.
func (d *deferredFaults) push(pf *UffdPagefault) {
d.mu.Lock()
defer d.mu.Unlock()
if d.byAddr == nil {
d.byAddr = make(map[uint64]*UffdPagefault)
}
addr := uint64(pf.address)
if existing, ok := d.byAddr[addr]; ok {
if pf.flags&UFFD_PAGEFAULT_FLAG_WRITE != 0 {
existing.flags |= UFFD_PAGEFAULT_FLAG_WRITE
}

return
}
d.byAddr[addr] = pf
d.pf = append(d.pf, pf)
d.mu.Unlock()
}

func (d *deferredFaults) drain() []*UffdPagefault {
d.mu.Lock()
out := d.pf
d.pf = nil
d.byAddr = nil
d.mu.Unlock()

return out
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
//go:build linux

package userfaultfd

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestDeferredFaultsDedupesByAddress(t *testing.T) {
t.Parallel()

var d deferredFaults
d.push(&UffdPagefault{address: 42})
d.push(&UffdPagefault{address: 42})
d.push(&UffdPagefault{address: 43})

require.Len(t, d.drain(), 2)
require.Empty(t, d.drain())

d.push(&UffdPagefault{address: 42})
require.Len(t, d.drain(), 1)
}

func TestDeferredFaultsUpgradesReadToWrite(t *testing.T) {
t.Parallel()

var d deferredFaults
d.push(&UffdPagefault{address: 42})
d.push(&UffdPagefault{address: 42, flags: UFFD_PAGEFAULT_FLAG_WRITE})

out := d.drain()
require.Len(t, out, 1)
require.NotZero(t, out[0].flags&UFFD_PAGEFAULT_FLAG_WRITE)
}
Loading