diff --git a/core/services/ocr2/plugins/vault/plugin.go b/core/services/ocr2/plugins/vault/plugin.go index aac7d494f31..9b00c97a324 100644 --- a/core/services/ocr2/plugins/vault/plugin.go +++ b/core/services/ocr2/plugins/vault/plugin.go @@ -470,6 +470,7 @@ func (r *ReportingPlugin) Observation(ctx context.Context, seqNr uint64, aq type } blobPayloads := make([][]byte, 0, len(localQueueItems)) + blobPayloadIDs := make([]string, 0, len(localQueueItems)) maxObservedLocalQueueItems := 0 for _, item := range localQueueItems { // The item is already in the pending queue. We'll be processing it @@ -502,6 +503,7 @@ func (r *ReportingPlugin) Observation(ctx context.Context, seqNr uint64, aq type } blobPayloads = append(blobPayloads, itemb) + blobPayloadIDs = append(blobPayloadIDs, item.Id) if len(blobPayloads) >= maxObservedLocalQueueItems { r.lggr.Warnw("Observed local queue exceeds batch size limit, truncating", @@ -512,35 +514,11 @@ func (r *ReportingPlugin) Observation(ctx context.Context, seqNr uint64, aq type } } - observedLocalQueue := make([][]byte, len(blobPayloads)) - // Broadcast pending-queue blobs in parallel to reduce Observation() latency. - // Shortening this phase helps the OCR round finish within DeltaProgress. - blobBroadcastStart := time.Now() - defer func() { - r.lggr.Debugw("observation blob broadcast finished", "seqNr", seqNr, "blobCount", len(blobPayloads), "elapsed", time.Since(blobBroadcastStart)) - }() - g, broadcastCtx := errgroup.WithContext(ctx) - for i, payload := range blobPayloads { - g.Go(func() error { - blobHandle, ierr2 := blobBroadcastFetcher.BroadcastBlob(broadcastCtx, payload, ocr3_1types.BlobExpirationHintSequenceNumber{SeqNr: seqNr + 2}) - if ierr2 != nil { - return fmt.Errorf("could not broadcast pending queue item as blob: %w", ierr2) - } - - blobHandleBytes, ierr2 := r.marshalBlob(blobHandle) - if ierr2 != nil { - return fmt.Errorf("could not marshal blob handle to bytes: %w", ierr2) - } - - observedLocalQueue[i] = blobHandleBytes - return nil - }) - } - if err = g.Wait(); err != nil { + pendingQueueItems, err := r.broadcastBlobPayloads(ctx, blobBroadcastFetcher, seqNr, blobPayloads, blobPayloadIDs) + if err != nil { return nil, err } - - obspb.PendingQueueItems = observedLocalQueue + obspb.PendingQueueItems = pendingQueueItems // Second, generate a random nonce that we'll use to sort the observations. // Each node generates a nonce idepedently, to be concatenated later on. @@ -563,6 +541,67 @@ func (r *ReportingPlugin) Observation(ctx context.Context, seqNr uint64, aq type return types.Observation(obsb), nil } +// broadcastBlobPayloads broadcasts each payload as a blob in parallel to reduce +// Observation() latency (shortening this phase helps the OCR round finish within +// DeltaProgress). Individual broadcast failures are logged and skipped rather than +// aborting the entire observation, so that one problematic payload does not prevent +// the remaining items from being observed. Context cancellation/deadline errors are +// propagated immediately so that expired rounds fail fast. +func (r *ReportingPlugin) broadcastBlobPayloads( + ctx context.Context, + fetcher ocr3_1types.BlobBroadcastFetcher, + seqNr uint64, + payloads [][]byte, + requestIDs []string, +) ([][]byte, error) { + results := make([][]byte, len(payloads)) + + start := time.Now() + defer func() { + r.lggr.Debugw("observation blob broadcast finished", "seqNr", seqNr, "blobCount", len(payloads), "elapsed", time.Since(start)) + }() + + var g errgroup.Group + for i, payload := range payloads { + g.Go(func() error { + blobHandle, err := fetcher.BroadcastBlob(ctx, payload, ocr3_1types.BlobExpirationHintSequenceNumber{SeqNr: seqNr + 2}) + if err != nil { + if ctx.Err() != nil { + return ctx.Err() + } + r.lggr.Warnw("failed to broadcast pending queue item as blob, skipping", + "seqNr", seqNr, + "requestID", requestIDs[i], + "err", err) + return nil + } + + blobHandleBytes, err := r.marshalBlob(blobHandle) + if err != nil { + r.lggr.Warnw("failed to marshal blob handle, skipping", + "seqNr", seqNr, + "requestID", requestIDs[i], + "err", err) + return nil + } + + results[i] = blobHandleBytes + return nil + }) + } + if err := g.Wait(); err != nil { + return nil, err + } + + filtered := make([][]byte, 0, len(results)) + for _, item := range results { + if item != nil { + filtered = append(filtered, item) + } + } + return filtered, nil +} + func (r *ReportingPlugin) observeGetSecrets(ctx context.Context, reader ReadKVStore, req proto.Message, o *vaultcommon.Observation) { tp := req.(*vaultcommon.GetSecretsRequest) o.RequestType = vaultcommon.RequestType_GET_SECRETS diff --git a/core/services/ocr2/plugins/vault/plugin_test.go b/core/services/ocr2/plugins/vault/plugin_test.go index 3ecfe0ff5c5..332762a73b5 100644 --- a/core/services/ocr2/plugins/vault/plugin_test.go +++ b/core/services/ocr2/plugins/vault/plugin_test.go @@ -760,7 +760,7 @@ func TestPlugin_Observation_PendingQueueEnabled_BroadcastsPendingQueueBlobsInPar } func TestPlugin_Observation_PendingQueueEnabled_BroadcastBlobError(t *testing.T) { - lggr := logger.TestLogger(t) + lggr, observed := logger.TestLoggerObserved(t, zapcore.WarnLevel) store := requests.NewStore[*vaulttypes.Request]() r := &ReportingPlugin{ lggr: lggr, @@ -803,8 +803,15 @@ func TestPlugin_Observation_PendingQueueEnabled_BroadcastBlobError(t *testing.T) require.NoError(t, store.Add(&vaulttypes.Request{Payload: p, IDVal: "request-1"})) rdr := &kv{m: make(map[string]response)} - _, err = r.Observation(t.Context(), 1, types.AttributedQuery{}, rdr, &errorBlobBroadcastFetcher{err: errors.New("boom")}) - require.ErrorContains(t, err, "could not broadcast pending queue item as blob: boom") + obs, err := r.Observation(t.Context(), 1, types.AttributedQuery{}, rdr, &errorBlobBroadcastFetcher{err: errors.New("boom")}) + require.NoError(t, err) + require.NotNil(t, obs) + + warnLogs := observed.FilterMessage("failed to broadcast pending queue item as blob, skipping") + assert.Equal(t, 1, warnLogs.Len()) + fields := warnLogs.All()[0].ContextMap() + assert.Equal(t, "request-1", fields["requestID"]) + assert.Contains(t, fmt.Sprint(fields["err"]), "boom") } func TestPlugin_Observation_GetSecretsRequest_SecretIdentifierInvalid(t *testing.T) { @@ -5166,6 +5173,21 @@ func mockMarshalBlob(ocr3_1types.BlobHandle) ([]byte, error) { return []byte{}, nil } +type callbackBlobFetcher struct { + fn func(payload []byte) error +} + +func (f *callbackBlobFetcher) BroadcastBlob(_ context.Context, payload []byte, _ ocr3_1types.BlobExpirationHint) (ocr3_1types.BlobHandle, error) { + if err := f.fn(payload); err != nil { + return ocr3_1types.BlobHandle{}, err + } + return ocr3_1types.BlobHandle{}, nil +} + +func (f *callbackBlobFetcher) FetchBlob(context.Context, ocr3_1types.BlobHandle) ([]byte, error) { + panic("FetchBlob should not be called in broadcastBlobPayloads tests") +} + func TestPlugin_StateTransition_StoresPendingQueue(t *testing.T) { lggr := logger.TestLogger(t) store := requests.NewStore[*vaulttypes.Request]() @@ -7108,3 +7130,206 @@ func TestLogUserErrorAware(t *testing.T) { assert.Contains(t, fmt.Sprint(fields["error"]), "internal error") }) } + +func TestPlugin_broadcastBlobPayloads(t *testing.T) { + t.Run("empty payloads returns empty slice", func(t *testing.T) { + lggr := logger.TestLogger(t) + r := &ReportingPlugin{ + lggr: lggr, + metrics: newTestMetrics(t), + marshalBlob: func(ocr3_1types.BlobHandle) ([]byte, error) { + return []byte("handle"), nil + }, + } + + fetcher := &callbackBlobFetcher{fn: func([]byte) error { return nil }} + result, err := r.broadcastBlobPayloads(t.Context(), fetcher, 1, nil, nil) + require.NoError(t, err) + assert.Empty(t, result) + }) + + t.Run("all payloads broadcast successfully", func(t *testing.T) { + lggr := logger.TestLogger(t) + r := &ReportingPlugin{ + lggr: lggr, + metrics: newTestMetrics(t), + marshalBlob: func(ocr3_1types.BlobHandle) ([]byte, error) { + return []byte("handle"), nil + }, + } + + fetcher := &callbackBlobFetcher{fn: func([]byte) error { return nil }} + payloads := [][]byte{[]byte("p1"), []byte("p2"), []byte("p3")} + ids := []string{"req-1", "req-2", "req-3"} + + result, err := r.broadcastBlobPayloads(t.Context(), fetcher, 1, payloads, ids) + require.NoError(t, err) + assert.Len(t, result, 3) + for _, item := range result { + assert.Equal(t, []byte("handle"), item) + } + }) + + t.Run("failed broadcast is skipped and logged", func(t *testing.T) { + lggr, observed := logger.TestLoggerObserved(t, zapcore.WarnLevel) + r := &ReportingPlugin{ + lggr: lggr, + metrics: newTestMetrics(t), + marshalBlob: func(ocr3_1types.BlobHandle) ([]byte, error) { + return []byte("handle"), nil + }, + } + + fetcher := &callbackBlobFetcher{fn: func(payload []byte) error { + if string(payload) == "p2" { + return errors.New("broadcast error") + } + return nil + }} + + payloads := [][]byte{[]byte("p1"), []byte("p2"), []byte("p3")} + ids := []string{"req-1", "req-2", "req-3"} + + result, err := r.broadcastBlobPayloads(t.Context(), fetcher, 5, payloads, ids) + require.NoError(t, err) + assert.Len(t, result, 2) + + warnLogs := observed.FilterMessage("failed to broadcast pending queue item as blob, skipping") + assert.Equal(t, 1, warnLogs.Len()) + fields := warnLogs.All()[0].ContextMap() + assert.Equal(t, "req-2", fields["requestID"]) + assert.Equal(t, uint64(5), fields["seqNr"]) + assert.Contains(t, fmt.Sprint(fields["err"]), "broadcast error") + }) + + t.Run("all broadcasts fail returns empty slice", func(t *testing.T) { + lggr, observed := logger.TestLoggerObserved(t, zapcore.WarnLevel) + r := &ReportingPlugin{ + lggr: lggr, + metrics: newTestMetrics(t), + marshalBlob: func(ocr3_1types.BlobHandle) ([]byte, error) { + return []byte("handle"), nil + }, + } + + fetcher := &errorBlobBroadcastFetcher{err: errors.New("network down")} + payloads := [][]byte{[]byte("p1"), []byte("p2")} + ids := []string{"req-1", "req-2"} + + result, err := r.broadcastBlobPayloads(t.Context(), fetcher, 1, payloads, ids) + require.NoError(t, err) + assert.Empty(t, result) + + warnLogs := observed.FilterMessage("failed to broadcast pending queue item as blob, skipping") + assert.Equal(t, 2, warnLogs.Len()) + }) + + t.Run("marshal blob failure skips item and logs warning", func(t *testing.T) { + lggr, observed := logger.TestLoggerObserved(t, zapcore.WarnLevel) + r := &ReportingPlugin{ + lggr: lggr, + metrics: newTestMetrics(t), + marshalBlob: func(ocr3_1types.BlobHandle) ([]byte, error) { + return nil, errors.New("marshal error") + }, + } + + fetcher := &callbackBlobFetcher{fn: func([]byte) error { return nil }} + payloads := [][]byte{[]byte("p1"), []byte("p2")} + ids := []string{"req-1", "req-2"} + + result, err := r.broadcastBlobPayloads(t.Context(), fetcher, 1, payloads, ids) + require.NoError(t, err) + assert.Empty(t, result) + + warnLogs := observed.FilterMessage("failed to marshal blob handle, skipping") + assert.Equal(t, 2, warnLogs.Len()) + }) + + t.Run("mix of broadcast and marshal failures", func(t *testing.T) { + lggr, observed := logger.TestLoggerObserved(t, zapcore.WarnLevel) + + marshalCallCount := atomic.Int32{} + r := &ReportingPlugin{ + lggr: lggr, + metrics: newTestMetrics(t), + marshalBlob: func(ocr3_1types.BlobHandle) ([]byte, error) { + n := marshalCallCount.Add(1) + if n == 1 { + return nil, errors.New("marshal error") + } + return []byte("handle"), nil + }, + } + + fetcher := &callbackBlobFetcher{fn: func(payload []byte) error { + if string(payload) == "p1" { + return errors.New("broadcast error") + } + return nil + }} + + payloads := [][]byte{[]byte("p1"), []byte("p2"), []byte("p3")} + ids := []string{"req-1", "req-2", "req-3"} + + result, err := r.broadcastBlobPayloads(t.Context(), fetcher, 1, payloads, ids) + require.NoError(t, err) + + broadcastWarns := observed.FilterMessage("failed to broadcast pending queue item as blob, skipping") + marshalWarns := observed.FilterMessage("failed to marshal blob handle, skipping") + assert.Equal(t, 1, broadcastWarns.Len()) + assert.Equal(t, 1, marshalWarns.Len()) + assert.Len(t, result, 1) + }) + + t.Run("context cancellation propagates error", func(t *testing.T) { + lggr := logger.TestLogger(t) + r := &ReportingPlugin{ + lggr: lggr, + metrics: newTestMetrics(t), + marshalBlob: func(ocr3_1types.BlobHandle) ([]byte, error) { + return []byte("handle"), nil + }, + } + + ctx, cancel := context.WithCancel(t.Context()) + cancel() + + fetcher := &callbackBlobFetcher{fn: func([]byte) error { + return ctx.Err() + }} + + payloads := [][]byte{[]byte("p1"), []byte("p2")} + ids := []string{"req-1", "req-2"} + + result, err := r.broadcastBlobPayloads(ctx, fetcher, 1, payloads, ids) + assert.Nil(t, result) + assert.ErrorIs(t, err, context.Canceled) + }) + + t.Run("context deadline exceeded propagates error", func(t *testing.T) { + lggr := logger.TestLogger(t) + r := &ReportingPlugin{ + lggr: lggr, + metrics: newTestMetrics(t), + marshalBlob: func(ocr3_1types.BlobHandle) ([]byte, error) { + return []byte("handle"), nil + }, + } + + ctx, cancel := context.WithTimeout(t.Context(), 0) + defer cancel() + <-ctx.Done() + + fetcher := &callbackBlobFetcher{fn: func([]byte) error { + return ctx.Err() + }} + + payloads := [][]byte{[]byte("p1")} + ids := []string{"req-1"} + + result, err := r.broadcastBlobPayloads(ctx, fetcher, 1, payloads, ids) + assert.Nil(t, result) + assert.ErrorIs(t, err, context.DeadlineExceeded) + }) +}