Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 66 additions & 27 deletions core/services/ocr2/plugins/vault/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ func (r *ReportingPlugin) Observation(ctx context.Context, seqNr uint64, aq type
}

blobPayloads := make([][]byte, 0, len(localQueueItems))
blobPayloadIDs := make([]string, 0, len(localQueueItems))
maxObservedLocalQueueItems := 0
for _, item := range localQueueItems {
// The item is already in the pending queue. We'll be processing it
Expand Down Expand Up @@ -502,6 +503,7 @@ func (r *ReportingPlugin) Observation(ctx context.Context, seqNr uint64, aq type
}

blobPayloads = append(blobPayloads, itemb)
blobPayloadIDs = append(blobPayloadIDs, item.Id)

if len(blobPayloads) >= maxObservedLocalQueueItems {
r.lggr.Warnw("Observed local queue exceeds batch size limit, truncating",
Expand All @@ -512,35 +514,11 @@ func (r *ReportingPlugin) Observation(ctx context.Context, seqNr uint64, aq type
}
}

observedLocalQueue := make([][]byte, len(blobPayloads))
// Broadcast pending-queue blobs in parallel to reduce Observation() latency.
// Shortening this phase helps the OCR round finish within DeltaProgress.
blobBroadcastStart := time.Now()
defer func() {
r.lggr.Debugw("observation blob broadcast finished", "seqNr", seqNr, "blobCount", len(blobPayloads), "elapsed", time.Since(blobBroadcastStart))
}()
g, broadcastCtx := errgroup.WithContext(ctx)
for i, payload := range blobPayloads {
g.Go(func() error {
blobHandle, ierr2 := blobBroadcastFetcher.BroadcastBlob(broadcastCtx, payload, ocr3_1types.BlobExpirationHintSequenceNumber{SeqNr: seqNr + 2})
if ierr2 != nil {
return fmt.Errorf("could not broadcast pending queue item as blob: %w", ierr2)
}

blobHandleBytes, ierr2 := r.marshalBlob(blobHandle)
if ierr2 != nil {
return fmt.Errorf("could not marshal blob handle to bytes: %w", ierr2)
}

observedLocalQueue[i] = blobHandleBytes
return nil
})
}
if err = g.Wait(); err != nil {
pendingQueueItems, err := r.broadcastBlobPayloads(ctx, blobBroadcastFetcher, seqNr, blobPayloads, blobPayloadIDs)
if err != nil {
return nil, err
}

obspb.PendingQueueItems = observedLocalQueue
obspb.PendingQueueItems = pendingQueueItems

// Second, generate a random nonce that we'll use to sort the observations.
// Each node generates a nonce idepedently, to be concatenated later on.
Expand All @@ -563,6 +541,67 @@ func (r *ReportingPlugin) Observation(ctx context.Context, seqNr uint64, aq type
return types.Observation(obsb), nil
}

// broadcastBlobPayloads broadcasts each payload as a blob in parallel to reduce
// Observation() latency (shortening this phase helps the OCR round finish within
// DeltaProgress). Individual broadcast failures are logged and skipped rather than
// aborting the entire observation, so that one problematic payload does not prevent
// the remaining items from being observed. Context cancellation/deadline errors are
// propagated immediately so that expired rounds fail fast.
func (r *ReportingPlugin) broadcastBlobPayloads(
ctx context.Context,
fetcher ocr3_1types.BlobBroadcastFetcher,
seqNr uint64,
payloads [][]byte,
requestIDs []string,
) ([][]byte, error) {
results := make([][]byte, len(payloads))

start := time.Now()
defer func() {
r.lggr.Debugw("observation blob broadcast finished", "seqNr", seqNr, "blobCount", len(payloads), "elapsed", time.Since(start))
}()

var g errgroup.Group
for i, payload := range payloads {
g.Go(func() error {
blobHandle, err := fetcher.BroadcastBlob(ctx, payload, ocr3_1types.BlobExpirationHintSequenceNumber{SeqNr: seqNr + 2})
if err != nil {
if ctx.Err() != nil {
return ctx.Err()
}
r.lggr.Warnw("failed to broadcast pending queue item as blob, skipping",
"seqNr", seqNr,
"requestID", requestIDs[i],
"err", err)
return nil
}

blobHandleBytes, err := r.marshalBlob(blobHandle)
if err != nil {
r.lggr.Warnw("failed to marshal blob handle, skipping",
"seqNr", seqNr,
"requestID", requestIDs[i],
"err", err)
return nil
}

results[i] = blobHandleBytes
return nil
})
}
if err := g.Wait(); err != nil {
return nil, err
}

filtered := make([][]byte, 0, len(results))
for _, item := range results {
if item != nil {
filtered = append(filtered, item)
}
}
return filtered, nil
}

func (r *ReportingPlugin) observeGetSecrets(ctx context.Context, reader ReadKVStore, req proto.Message, o *vaultcommon.Observation) {
tp := req.(*vaultcommon.GetSecretsRequest)
o.RequestType = vaultcommon.RequestType_GET_SECRETS
Expand Down
231 changes: 228 additions & 3 deletions core/services/ocr2/plugins/vault/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ func TestPlugin_Observation_PendingQueueEnabled_BroadcastsPendingQueueBlobsInPar
}

func TestPlugin_Observation_PendingQueueEnabled_BroadcastBlobError(t *testing.T) {
lggr := logger.TestLogger(t)
lggr, observed := logger.TestLoggerObserved(t, zapcore.WarnLevel)
store := requests.NewStore[*vaulttypes.Request]()
r := &ReportingPlugin{
lggr: lggr,
Expand Down Expand Up @@ -803,8 +803,15 @@ func TestPlugin_Observation_PendingQueueEnabled_BroadcastBlobError(t *testing.T)
require.NoError(t, store.Add(&vaulttypes.Request{Payload: p, IDVal: "request-1"}))
rdr := &kv{m: make(map[string]response)}

_, err = r.Observation(t.Context(), 1, types.AttributedQuery{}, rdr, &errorBlobBroadcastFetcher{err: errors.New("boom")})
require.ErrorContains(t, err, "could not broadcast pending queue item as blob: boom")
obs, err := r.Observation(t.Context(), 1, types.AttributedQuery{}, rdr, &errorBlobBroadcastFetcher{err: errors.New("boom")})
require.NoError(t, err)
require.NotNil(t, obs)

warnLogs := observed.FilterMessage("failed to broadcast pending queue item as blob, skipping")
assert.Equal(t, 1, warnLogs.Len())
fields := warnLogs.All()[0].ContextMap()
assert.Equal(t, "request-1", fields["requestID"])
assert.Contains(t, fmt.Sprint(fields["err"]), "boom")
}

func TestPlugin_Observation_GetSecretsRequest_SecretIdentifierInvalid(t *testing.T) {
Expand Down Expand Up @@ -5166,6 +5173,21 @@ func mockMarshalBlob(ocr3_1types.BlobHandle) ([]byte, error) {
return []byte{}, nil
}

type callbackBlobFetcher struct {
fn func(payload []byte) error
}

func (f *callbackBlobFetcher) BroadcastBlob(_ context.Context, payload []byte, _ ocr3_1types.BlobExpirationHint) (ocr3_1types.BlobHandle, error) {
if err := f.fn(payload); err != nil {
return ocr3_1types.BlobHandle{}, err
}
return ocr3_1types.BlobHandle{}, nil
}

func (f *callbackBlobFetcher) FetchBlob(context.Context, ocr3_1types.BlobHandle) ([]byte, error) {
panic("FetchBlob should not be called in broadcastBlobPayloads tests")
}

func TestPlugin_StateTransition_StoresPendingQueue(t *testing.T) {
lggr := logger.TestLogger(t)
store := requests.NewStore[*vaulttypes.Request]()
Expand Down Expand Up @@ -7108,3 +7130,206 @@ func TestLogUserErrorAware(t *testing.T) {
assert.Contains(t, fmt.Sprint(fields["error"]), "internal error")
})
}

func TestPlugin_broadcastBlobPayloads(t *testing.T) {
t.Run("empty payloads returns empty slice", func(t *testing.T) {
lggr := logger.TestLogger(t)
r := &ReportingPlugin{
lggr: lggr,
metrics: newTestMetrics(t),
marshalBlob: func(ocr3_1types.BlobHandle) ([]byte, error) {
return []byte("handle"), nil
},
}

fetcher := &callbackBlobFetcher{fn: func([]byte) error { return nil }}
result, err := r.broadcastBlobPayloads(t.Context(), fetcher, 1, nil, nil)
require.NoError(t, err)
assert.Empty(t, result)
})

t.Run("all payloads broadcast successfully", func(t *testing.T) {
lggr := logger.TestLogger(t)
r := &ReportingPlugin{
lggr: lggr,
metrics: newTestMetrics(t),
marshalBlob: func(ocr3_1types.BlobHandle) ([]byte, error) {
return []byte("handle"), nil
},
}

fetcher := &callbackBlobFetcher{fn: func([]byte) error { return nil }}
payloads := [][]byte{[]byte("p1"), []byte("p2"), []byte("p3")}
ids := []string{"req-1", "req-2", "req-3"}

result, err := r.broadcastBlobPayloads(t.Context(), fetcher, 1, payloads, ids)
require.NoError(t, err)
assert.Len(t, result, 3)
for _, item := range result {
assert.Equal(t, []byte("handle"), item)
}
})

t.Run("failed broadcast is skipped and logged", func(t *testing.T) {
lggr, observed := logger.TestLoggerObserved(t, zapcore.WarnLevel)
r := &ReportingPlugin{
lggr: lggr,
metrics: newTestMetrics(t),
marshalBlob: func(ocr3_1types.BlobHandle) ([]byte, error) {
return []byte("handle"), nil
},
}

fetcher := &callbackBlobFetcher{fn: func(payload []byte) error {
if string(payload) == "p2" {
return errors.New("broadcast error")
}
return nil
}}

payloads := [][]byte{[]byte("p1"), []byte("p2"), []byte("p3")}
ids := []string{"req-1", "req-2", "req-3"}

result, err := r.broadcastBlobPayloads(t.Context(), fetcher, 5, payloads, ids)
require.NoError(t, err)
assert.Len(t, result, 2)

warnLogs := observed.FilterMessage("failed to broadcast pending queue item as blob, skipping")
assert.Equal(t, 1, warnLogs.Len())
fields := warnLogs.All()[0].ContextMap()
assert.Equal(t, "req-2", fields["requestID"])
assert.Equal(t, uint64(5), fields["seqNr"])
assert.Contains(t, fmt.Sprint(fields["err"]), "broadcast error")
})

t.Run("all broadcasts fail returns empty slice", func(t *testing.T) {
lggr, observed := logger.TestLoggerObserved(t, zapcore.WarnLevel)
r := &ReportingPlugin{
lggr: lggr,
metrics: newTestMetrics(t),
marshalBlob: func(ocr3_1types.BlobHandle) ([]byte, error) {
return []byte("handle"), nil
},
}

fetcher := &errorBlobBroadcastFetcher{err: errors.New("network down")}
payloads := [][]byte{[]byte("p1"), []byte("p2")}
ids := []string{"req-1", "req-2"}

result, err := r.broadcastBlobPayloads(t.Context(), fetcher, 1, payloads, ids)
require.NoError(t, err)
assert.Empty(t, result)

warnLogs := observed.FilterMessage("failed to broadcast pending queue item as blob, skipping")
assert.Equal(t, 2, warnLogs.Len())
})

t.Run("marshal blob failure skips item and logs warning", func(t *testing.T) {
lggr, observed := logger.TestLoggerObserved(t, zapcore.WarnLevel)
r := &ReportingPlugin{
lggr: lggr,
metrics: newTestMetrics(t),
marshalBlob: func(ocr3_1types.BlobHandle) ([]byte, error) {
return nil, errors.New("marshal error")
},
}

fetcher := &callbackBlobFetcher{fn: func([]byte) error { return nil }}
payloads := [][]byte{[]byte("p1"), []byte("p2")}
ids := []string{"req-1", "req-2"}

result, err := r.broadcastBlobPayloads(t.Context(), fetcher, 1, payloads, ids)
require.NoError(t, err)
assert.Empty(t, result)

warnLogs := observed.FilterMessage("failed to marshal blob handle, skipping")
assert.Equal(t, 2, warnLogs.Len())
})

t.Run("mix of broadcast and marshal failures", func(t *testing.T) {
lggr, observed := logger.TestLoggerObserved(t, zapcore.WarnLevel)

marshalCallCount := atomic.Int32{}
r := &ReportingPlugin{
lggr: lggr,
metrics: newTestMetrics(t),
marshalBlob: func(ocr3_1types.BlobHandle) ([]byte, error) {
n := marshalCallCount.Add(1)
if n == 1 {
return nil, errors.New("marshal error")
}
return []byte("handle"), nil
},
}

fetcher := &callbackBlobFetcher{fn: func(payload []byte) error {
if string(payload) == "p1" {
return errors.New("broadcast error")
}
return nil
}}

payloads := [][]byte{[]byte("p1"), []byte("p2"), []byte("p3")}
ids := []string{"req-1", "req-2", "req-3"}

result, err := r.broadcastBlobPayloads(t.Context(), fetcher, 1, payloads, ids)
require.NoError(t, err)

broadcastWarns := observed.FilterMessage("failed to broadcast pending queue item as blob, skipping")
marshalWarns := observed.FilterMessage("failed to marshal blob handle, skipping")
assert.Equal(t, 1, broadcastWarns.Len())
assert.Equal(t, 1, marshalWarns.Len())
assert.Len(t, result, 1)
})

t.Run("context cancellation propagates error", func(t *testing.T) {
lggr := logger.TestLogger(t)
r := &ReportingPlugin{
lggr: lggr,
metrics: newTestMetrics(t),
marshalBlob: func(ocr3_1types.BlobHandle) ([]byte, error) {
return []byte("handle"), nil
},
}

ctx, cancel := context.WithCancel(t.Context())
cancel()

fetcher := &callbackBlobFetcher{fn: func([]byte) error {
return ctx.Err()
}}

payloads := [][]byte{[]byte("p1"), []byte("p2")}
ids := []string{"req-1", "req-2"}

result, err := r.broadcastBlobPayloads(ctx, fetcher, 1, payloads, ids)
assert.Nil(t, result)
assert.ErrorIs(t, err, context.Canceled)
})

t.Run("context deadline exceeded propagates error", func(t *testing.T) {
lggr := logger.TestLogger(t)
r := &ReportingPlugin{
lggr: lggr,
metrics: newTestMetrics(t),
marshalBlob: func(ocr3_1types.BlobHandle) ([]byte, error) {
return []byte("handle"), nil
},
}

ctx, cancel := context.WithTimeout(t.Context(), 0)
defer cancel()
<-ctx.Done()

fetcher := &callbackBlobFetcher{fn: func([]byte) error {
return ctx.Err()
}}

payloads := [][]byte{[]byte("p1")}
ids := []string{"req-1"}

result, err := r.broadcastBlobPayloads(ctx, fetcher, 1, payloads, ids)
assert.Nil(t, result)
assert.ErrorIs(t, err, context.DeadlineExceeded)
})
}
Loading