Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 58 additions & 13 deletions core/capabilities/vault/request_authorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,14 @@ type requestAuthorizer struct {
workflowRegistrySyncer workflowsyncerv2.WorkflowRegistrySyncer
replayGuard *DigestReplayGuard
lggr logger.Logger
sleep func(time.Duration)
}

const (
allowlistReadRetryCount = 3
allowlistReadRetryInterval = 3 * time.Second
)

// AuthorizeRequest authorizes a request based on the request digest and the allowlisted requests.
// It does NOT check if the request method is allowed.
func (r *requestAuthorizer) AuthorizeRequest(ctx context.Context, req jsonrpc.Request[json.RawMessage]) (isAuthorized bool, owner string, err error) {
Expand All @@ -43,20 +49,8 @@ func (r *requestAuthorizer) AuthorizeRequest(ctx context.Context, req jsonrpc.Re
r.lggr.Errorw("AuthorizeRequest workflowRegistrySyncer is nil", "method", req.Method, "requestID", req.ID)
return false, "", errors.New("internal error: workflowRegistrySyncer is nil")
}
allowedRequests := r.workflowRegistrySyncer.GetAllowlistedRequests(ctx)
allowedRequestsStrs := make([]string, 0, len(allowedRequests))
for _, rr := range allowedRequests {
allowedReqStr := fmt.Sprintf("Owner: %s, RequestDigest: %s, ExpiryTimestamp: %d", rr.Owner.Hex(), hex.EncodeToString(rr.RequestDigest[:]), rr.ExpiryTimestamp)
allowedRequestsStrs = append(allowedRequestsStrs, allowedReqStr)
}
r.lggr.Infow("AuthorizeRequest GetAllowlistedRequests", "method", req.Method, "requestID", req.ID, "allowedRequests", allowedRequestsStrs)
allowlistedRequest := r.fetchAllowlistedItem(allowedRequests, requestDigestBytes32)
allowlistedRequest, _ := r.fetchAllowlistedItemWithRetry(ctx, req.Method, req.ID, requestDigest, requestDigestBytes32)
if allowlistedRequest == nil {
r.lggr.Infow("AuthorizeRequest fetchAllowlistedItem request not allowlisted",
"method", req.Method,
"requestID", req.ID,
"digestHexStr", requestDigest,
"allowedRequestsStrs", allowedRequestsStrs)
return false, "", errors.New("request not allowlisted")
}

Expand All @@ -76,6 +70,56 @@ func (r *requestAuthorizer) AuthorizeRequest(ctx context.Context, req jsonrpc.Re
return true, allowlistedRequest.Owner.Hex(), nil
}

func (r *requestAuthorizer) fetchAllowlistedItemWithRetry(ctx context.Context, method string, requestID interface{}, requestDigest string, digest [32]byte) (*workflow_registry_wrapper_v2.WorkflowRegistryOwnerAllowlistedRequest, []string) {
var allowedRequestsStrs []string
for attempt := 0; attempt <= allowlistReadRetryCount; attempt++ {
allowedRequests := r.workflowRegistrySyncer.GetAllowlistedRequests(ctx)
allowedRequestsStrs = make([]string, 0, len(allowedRequests))
for _, rr := range allowedRequests {
allowedReqStr := fmt.Sprintf("Owner: %s, RequestDigest: %s, ExpiryTimestamp: %d", rr.Owner.Hex(), hex.EncodeToString(rr.RequestDigest[:]), rr.ExpiryTimestamp)
allowedRequestsStrs = append(allowedRequestsStrs, allowedReqStr)
}
r.lggr.Infow("AuthorizeRequest GetAllowlistedRequests", "method", method, "requestID", requestID, "attempt", attempt+1, "allowedRequests", allowedRequestsStrs)

allowlistedRequest := r.fetchAllowlistedItem(allowedRequests, digest)
if allowlistedRequest != nil {
return allowlistedRequest, allowedRequestsStrs
}

if attempt == allowlistReadRetryCount {
break
}

r.lggr.Warnw("AuthorizeRequest request not found in allowlist, retrying",
"method", method,
"requestID", requestID,
"digestHexStr", requestDigest,
"attempt", attempt+1,
"retryInterval", allowlistReadRetryInterval,
"allowedRequestsStrs", allowedRequestsStrs)

select {
case <-ctx.Done():
r.lggr.Warnw("AuthorizeRequest allowlist retry canceled",
"method", method,
"requestID", requestID,
"digestHexStr", requestDigest,
"attempt", attempt+1)
return nil, allowedRequestsStrs
default:
}

r.sleep(allowlistReadRetryInterval)
}

r.lggr.Infow("AuthorizeRequest fetchAllowlistedItem request not allowlisted",
"method", method,
"requestID", requestID,
"digestHexStr", requestDigest,
"allowedRequestsStrs", allowedRequestsStrs)
return nil, allowedRequestsStrs
}

func (r *requestAuthorizer) fetchAllowlistedItem(allowListedRequests []workflow_registry_wrapper_v2.WorkflowRegistryOwnerAllowlistedRequest, digest [32]byte) *workflow_registry_wrapper_v2.WorkflowRegistryOwnerAllowlistedRequest {
for _, item := range allowListedRequests {
if item.RequestDigest == digest {
Expand All @@ -90,5 +134,6 @@ func NewRequestAuthorizer(lggr logger.Logger, workflowRegistrySyncer workflowsyn
workflowRegistrySyncer: workflowRegistrySyncer,
lggr: logger.Named(lggr, "VaultRequestAuthorizer"),
replayGuard: NewDigestReplayGuard(),
sleep: time.Sleep,
}
}
95 changes: 95 additions & 0 deletions core/capabilities/vault/request_authorizer_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package vault

import (
"context"
"encoding/hex"
"encoding/json"
"testing"
Expand Down Expand Up @@ -202,3 +203,97 @@ func testAuthForRequests(t *testing.T, allowlistedRequest, notAllowlistedRequest
require.False(t, isAuthorized)
require.ErrorContains(t, err, "not allowlisted")
}

func TestRequestAuthorizer_RetriesAllowlistReadsUntilDigestAppears(t *testing.T) {
lggr := logger.TestLogger(t)
owner := common.Address{1, 2, 3}
req := makeListSecretsRequest(t, "123", "b")

digest, err := req.Digest()
require.NoError(t, err)
digestBytes, err := hex.DecodeString(digest)
require.NoError(t, err)

allowlisted := []workflow_registry_wrapper_v2.WorkflowRegistryOwnerAllowlistedRequest{
{
RequestDigest: [32]byte(digestBytes),
Owner: owner,
ExpiryTimestamp: uint32(time.Now().UTC().Unix() + 100), //nolint:gosec // test fixture expiry is bounded and safe here
},
}

mockSyncer := syncerv2mocks.NewWorkflowRegistrySyncer(t)
mockSyncer.On("GetAllowlistedRequests", mock.Anything).Return([]workflow_registry_wrapper_v2.WorkflowRegistryOwnerAllowlistedRequest{}).Once()
mockSyncer.On("GetAllowlistedRequests", mock.Anything).Return([]workflow_registry_wrapper_v2.WorkflowRegistryOwnerAllowlistedRequest{}).Once()
mockSyncer.On("GetAllowlistedRequests", mock.Anything).Return(allowlisted).Once()

auth := NewRequestAuthorizer(lggr, mockSyncer)
sleepCalls := 0
auth.sleep = func(d time.Duration) {
require.Equal(t, allowlistReadRetryInterval, d)
sleepCalls++
}

isAuthorized, gotOwner, err := auth.AuthorizeRequest(t.Context(), req)
require.True(t, isAuthorized, err)
require.NoError(t, err)
require.Equal(t, owner.Hex(), gotOwner)
require.Equal(t, 2, sleepCalls)
}

func TestRequestAuthorizer_FailsAfterAllowlistReadRetries(t *testing.T) {
lggr := logger.TestLogger(t)
req := makeListSecretsRequest(t, "123", "b")

mockSyncer := syncerv2mocks.NewWorkflowRegistrySyncer(t)
mockSyncer.On("GetAllowlistedRequests", mock.Anything).Return([]workflow_registry_wrapper_v2.WorkflowRegistryOwnerAllowlistedRequest{}).Times(allowlistReadRetryCount + 1)

auth := NewRequestAuthorizer(lggr, mockSyncer)
sleepCalls := 0
auth.sleep = func(d time.Duration) {
require.Equal(t, allowlistReadRetryInterval, d)
sleepCalls++
}

isAuthorized, _, err := auth.AuthorizeRequest(t.Context(), req)
require.False(t, isAuthorized)
require.ErrorContains(t, err, "not allowlisted")
require.Equal(t, allowlistReadRetryCount, sleepCalls)
}

func TestRequestAuthorizer_StopsRetriesWhenContextCanceled(t *testing.T) {
lggr := logger.TestLogger(t)
req := makeListSecretsRequest(t, "123", "b")

ctx, cancel := context.WithCancel(t.Context())
cancel()

mockSyncer := syncerv2mocks.NewWorkflowRegistrySyncer(t)
mockSyncer.On("GetAllowlistedRequests", mock.Anything).Return([]workflow_registry_wrapper_v2.WorkflowRegistryOwnerAllowlistedRequest{}).Once()

auth := NewRequestAuthorizer(lggr, mockSyncer)
sleepCalls := 0
auth.sleep = func(time.Duration) {
sleepCalls++
}

isAuthorized, _, err := auth.AuthorizeRequest(ctx, req)
require.False(t, isAuthorized)
require.ErrorContains(t, err, "not allowlisted")
require.Zero(t, sleepCalls)
}

func makeListSecretsRequest(t *testing.T, id, namespace string) jsonrpc.Request[json.RawMessage] {
t.Helper()

params, err := json.Marshal(vaultcommon.ListSecretIdentifiersRequest{
Namespace: namespace,
})
require.NoError(t, err)

return jsonrpc.Request[json.RawMessage]{
ID: id,
Method: vaulttypes.MethodSecretsList,
Params: (*json.RawMessage)(&params),
}
}
13 changes: 10 additions & 3 deletions core/services/workflows/syncerlimiter/limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,18 @@ type Config struct {
}

type keyedOwnerSettings struct {
key string
vals map[string]string
getter settings.Getter
key string
vals map[string]string
}

func (k keyedOwnerSettings) GetScoped(ctx context.Context, scope settings.Scope, key string) (value string, err error) {
if k.getter != nil {
value, err = k.getter.GetScoped(ctx, scope, key)
}
if value != "" {
return
}
if k.key != key || scope != settings.ScopeOwner {
return "", nil
}
Expand All @@ -50,7 +57,7 @@ func NewWorkflowLimits(lggr logger.Logger, cfg Config, lf limits.Factory) (limit
for k, v := range cfg.PerOwnerOverrides {
perOwner[k] = strconv.Itoa(int(v))
}
lf.Settings = keyedOwnerSettings{key: ownerLimit.Key, vals: perOwner}
lf.Settings = keyedOwnerSettings{getter: lf.Settings, key: ownerLimit.Key, vals: perOwner}
owner, err := limits.MakeResourcePoolLimiter(lf, ownerLimit)
if err != nil {
return nil, fmt.Errorf("failed to create owner resource limiter: %w", err)
Expand Down
13 changes: 10 additions & 3 deletions core/services/workflows/v2/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,17 @@ func TestEngine_Init(t *testing.T) {

func TestEngine_Start_RateLimited(t *testing.T) {
t.Parallel()
getter, err := settings.NewTOMLGetter([]byte(`
[global]
WorkflowExecutionConcurrencyLimit = "2"
[global.PerOwner]
WorkflowExecutionConcurrencyLimit = "1"
`))
require.NoError(t, err)
sLimiter, err := syncerlimiter.NewWorkflowLimits(logger.Test(t), syncerlimiter.Config{
Global: 2,
PerOwner: 1,
}, limits.Factory{})
Global: 0,
PerOwner: 0,
}, limits.Factory{Settings: getter})
require.NoError(t, err)

module := modulemocks.NewModuleV2(t)
Expand Down
Loading
Loading