From 09574a0f978a932d7e1ed1ec2ea11f5c6a89b7d8 Mon Sep 17 00:00:00 2001 From: dcbickfo Date: Thu, 19 Feb 2026 09:31:35 -0500 Subject: [PATCH 1/3] Refactored set --- cacheaside.go | 82 ++-- cacheaside_test.go | 168 +++++++ errors.go | 48 ++ errors_test.go | 53 +++ internal/cmdx/slot.go | 15 +- internal/cmdx/slot_test.go | 37 ++ internal/lockpool/lockpool.go | 37 ++ internal/lockpool/lockpool_test.go | 65 +++ internal/mapsx/maps.go | 3 + internal/syncx/map.go | 2 + internal/syncx/wait.go | 56 +-- internal/syncx/wait_test.go | 15 +- lua_scripts.go | 113 +++++ primeable_cacheaside.go | 234 ++++++++++ primeable_cacheaside_test.go | 682 +++++++++++++++++++++++++++++ primeable_setmulti_helpers.go | 358 +++++++++++++++ 16 files changed, 1902 insertions(+), 66 deletions(-) create mode 100644 errors.go create mode 100644 errors_test.go create mode 100644 internal/lockpool/lockpool.go create mode 100644 internal/lockpool/lockpool_test.go create mode 100644 lua_scripts.go create mode 100644 primeable_cacheaside.go create mode 100644 primeable_cacheaside_test.go create mode 100644 primeable_setmulti_helpers.go diff --git a/cacheaside.go b/cacheaside.go index 536a2ad..6e90ac5 100644 --- a/cacheaside.go +++ b/cacheaside.go @@ -62,9 +62,7 @@ import ( "context" "errors" "fmt" - "iter" "log/slog" - "maps" "strconv" "strings" "sync" @@ -95,6 +93,8 @@ type Logger interface { Debug(msg string, args ...any) } +// CacheAside provides a cache-aside pattern backed by Redis with distributed locking +// and client-side caching via rueidis invalidation messages. type CacheAside struct { client rueidis.Client locks syncx.Map[string, *lockEntry] @@ -103,10 +103,13 @@ type CacheAside struct { lockPrefix string } +// CacheAsideOption configures a CacheAside instance. type CacheAsideOption struct { // LockTTL is the maximum time a lock can be held, and also the timeout for waiting // on locks when handling lost Redis invalidation messages. Defaults to 10 seconds. - LockTTL time.Duration + LockTTL time.Duration + // ClientBuilder optionally overrides how the rueidis.Client is created. + // When nil, rueidis.NewClient is used. ClientBuilder func(option rueidis.ClientOption) (rueidis.Client, error) // Logger for logging errors and debug information. Defaults to slog.Default(). // The logger should handle log levels internally (e.g., only log Debug if level is enabled). @@ -116,6 +119,7 @@ type CacheAsideOption struct { LockPrefix string } +// NewRedCacheAside creates a CacheAside with the given Redis client and cache-aside options. func NewRedCacheAside(clientOption rueidis.ClientOption, caOption CacheAsideOption) (*CacheAside, error) { // Validate client options if len(clientOption.InitAddress) == 0 { @@ -165,6 +169,15 @@ func (rca *CacheAside) Client() rueidis.Client { return rca.client } +// Close cancels all pending lock entries. It does NOT close the underlying +// Redis client — that is the caller's responsibility. +func (rca *CacheAside) Close() { + rca.locks.Range(func(_ string, entry *lockEntry) bool { + entry.cancel() + return true + }) +} + func (rca *CacheAside) onInvalidate(messages []rueidis.RedisMessage) { for _, m := range messages { key, err := m.ToString() @@ -179,11 +192,6 @@ func (rca *CacheAside) onInvalidate(messages []rueidis.RedisMessage) { } } -var ( - delKeyLua = rueidis.NewLuaScript(`if redis.call("GET",KEYS[1]) == ARGV[1] then return redis.call("DEL",KEYS[1]) else return 0 end`) - setKeyLua = rueidis.NewLuaScript(`if redis.call("GET",KEYS[1]) == ARGV[1] then return redis.call("SET",KEYS[1],ARGV[2],"PX",ARGV[3]) else return 0 end`) -) - func (rca *CacheAside) register(key string) <-chan struct{} { retry: // Create new entry with context that auto-cancels after lockTTL @@ -232,6 +240,9 @@ retry: } } +// Get returns the cached value for key, populating the cache by calling fn on a miss. +// Only one goroutine across all instances executes fn for a given key at a time; +// other callers wait for the result via Redis invalidation messages. func (rca *CacheAside) Get( ctx context.Context, ttl time.Duration, @@ -254,17 +265,18 @@ retry: val, err = rca.trySetKeyFunc(ctx, ttl, key, fn) } - if err != nil && !errors.Is(err, errLockFailed) { + if err != nil && !errors.Is(err, errLockFailed) && !errors.Is(err, ErrLockLost) { return "", err } - if val == "" { - // Wait for lock release (channel auto-closes after lockTTL or on invalidation) + if val == "" || errors.Is(err, ErrLockLost) { + // Wait for lock release (channel auto-closes after lockTTL or on invalidation). + // ErrLockLost means another operation (e.g., ForceSet) stole our lock — retry. select { case <-wait: goto retry case <-ctx.Done(): - // Parent context cancelled + // Parent context cancelled. return "", ctx.Err() } } @@ -272,10 +284,12 @@ retry: return val, err } +// Del removes a key from Redis, triggering invalidation on all clients. func (rca *CacheAside) Del(ctx context.Context, key string) error { return rca.client.Do(ctx, rca.client.B().Del().Key(key).Build()).Error() } +// DelMulti removes multiple keys from Redis, triggering invalidation on all clients. func (rca *CacheAside) DelMulti(ctx context.Context, keys ...string) error { cmds := make(rueidis.Commands, 0, len(keys)) for _, key := range keys { @@ -295,9 +309,7 @@ var ( errLockFailed = errors.New("lock failed") ) -// ErrLockLost indicates the distributed lock was lost or expired before the value could be set. -// This can occur if the lock TTL expires during callback execution or if Redis invalidates the lock. -var ErrLockLost = errors.New("lock was lost or expired before value could be set") +// ErrLockLost is defined in errors.go. func (rca *CacheAside) tryGet(ctx context.Context, ttl time.Duration, key string) (string, error) { resp := rca.client.DoCache(ctx, rca.client.B().Get().Key(key).Cache(), ttl) @@ -348,27 +360,33 @@ func (rca *CacheAside) trySetKeyFunc(ctx context.Context, ttl time.Duration, key func (rca *CacheAside) tryLock(ctx context.Context, key string) (string, error) { uuidv7, err := uuid.NewV7() if err != nil { - return "", fmt.Errorf("failed to generate lock UUID for key %q: %w", key, err) + return "", fmt.Errorf("lock UUID for key %q: %w", key, err) } lockVal := rca.lockPrefix + uuidv7.String() err = rca.client.Do(ctx, rca.client.B().Set().Key(key).Value(lockVal).Nx().Get().Px(rca.lockTTL).Build()).Error() if !rueidis.IsRedisNil(err) { rca.logger.Debug("lock contention - failed to acquire lock", "key", key) - return "", fmt.Errorf("failed to acquire lock for key %q: %w", key, errLockFailed) + return "", fmt.Errorf("lock key %q: %w", key, errLockFailed) } rca.logger.Debug("lock acquired", "key", key, "lockVal", lockVal) return lockVal, nil } func (rca *CacheAside) setWithLock(ctx context.Context, ttl time.Duration, key string, valLock valAndLock) (string, error) { - err := setKeyLua.Exec(ctx, rca.client, []string{key}, []string{valLock.lockVal, valLock.val, strconv.FormatInt(ttl.Milliseconds(), 10)}).Error() - if err != nil { + resp := setKeyLua.Exec(ctx, rca.client, []string{key}, []string{valLock.lockVal, valLock.val, strconv.FormatInt(ttl.Milliseconds(), 10)}) + if err := resp.Error(); err != nil { if !rueidis.IsRedisNil(err) { - return "", fmt.Errorf("failed to set value for key %q: %w", key, err) + return "", fmt.Errorf("set key %q: %w", key, err) } rca.logger.Debug("lock lost during set operation", "key", key) return "", fmt.Errorf("lock lost for key %q: %w", key, ErrLockLost) } + // The Lua script returns 0 when the lock was lost (CAS mismatch). + // .Error() returns nil for integer responses, so we must check the value. + if val, _ := resp.AsInt64(); val == 0 { + rca.logger.Debug("lock lost during set operation", "key", key) + return "", fmt.Errorf("lock lost for key %q: %w", key, ErrLockLost) + } rca.logger.Debug("value set successfully", "key", key) return valLock.val, nil } @@ -377,6 +395,8 @@ func (rca *CacheAside) unlock(ctx context.Context, key string, lock string) erro return delKeyLua.Exec(ctx, rca.client, []string{key}, []string{lock}).Error() } +// GetMulti returns cached values for the given keys, populating any misses by calling fn. +// SET operations are grouped by Redis cluster slot for efficient batching. func (rca *CacheAside) GetMulti( ctx context.Context, ttl time.Duration, @@ -391,7 +411,7 @@ func (rca *CacheAside) GetMulti( } retry: - waitLock = rca.registerAll(maps.Keys(waitLock), len(waitLock)) + waitLock = rca.registerAll(mapsx.Keys(waitLock)) vals, err := rca.tryGetMulti(ctx, ttl, mapsx.Keys(waitLock)) if err != nil && !rueidis.IsRedisNil(err) { @@ -416,7 +436,7 @@ retry: if len(waitLock) > 0 { // Wait for lock releases (channels auto-close after lockTTL or on invalidation) - err = syncx.WaitForAll(ctx, maps.Values(waitLock), len(waitLock)) + err = syncx.WaitForAll(ctx, mapsx.Values(waitLock)) if err != nil { // Parent context cancelled or deadline exceeded return nil, ctx.Err() @@ -426,9 +446,9 @@ retry: return res, err } -func (rca *CacheAside) registerAll(keys iter.Seq[string], length int) map[string]<-chan struct{} { - res := make(map[string]<-chan struct{}, length) - for key := range keys { +func (rca *CacheAside) registerAll(keys []string) map[string]<-chan struct{} { + res := make(map[string]<-chan struct{}, len(keys)) + for _, key := range keys { res[key] = rca.register(key) } return res @@ -448,14 +468,14 @@ func (rca *CacheAside) tryGetMulti(ctx context.Context, ttl time.Duration, keys res := make(map[string]string) for i, resp := range resps { val, err := resp.ToString() - if err != nil && rueidis.IsRedisNil(err) { + if rueidis.IsRedisNil(err) { continue - } else if err != nil { - return nil, fmt.Errorf("failed to get key %q: %w", keys[i], err) + } + if err != nil { + return nil, fmt.Errorf("key %q: %w", keys[i], err) } if !strings.HasPrefix(val, rca.lockPrefix) { res[keys[i]] = val - continue } } return res, nil @@ -599,7 +619,7 @@ func (rca *CacheAside) executeSetStatements(ctx context.Context, stmts map[uint1 return nil, err } - out := make([]string, 0) + var out []string for _, keys := range keyByStmt { out = append(out, keys...) } @@ -623,7 +643,7 @@ func (rca *CacheAside) unlockMulti(ctx context.Context, lockVals map[string]stri Args: []string{lockVal}, }) } - wg := sync.WaitGroup{} + var wg sync.WaitGroup for _, stmts := range delStmts { wg.Add(1) go func() { diff --git a/cacheaside_test.go b/cacheaside_test.go index 54d3bec..e396953 100644 --- a/cacheaside_test.go +++ b/cacheaside_test.go @@ -839,3 +839,171 @@ func TestConcurrentInvalidation(t *testing.T) { defer mu.Unlock() assert.Greater(t, callCount, initialCount, "callbacks should be invoked after invalidation") } + +func TestCacheAside_Get_CallbackError(t *testing.T) { + client := makeClient(t, addr) + defer client.Client().Close() + ctx := context.Background() + + key := "key:" + uuid.New().String() + cbErr := fmt.Errorf("callback failed") + + // First Get: callback returns error. + _, err := client.Get(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + return "", cbErr + }) + require.ErrorIs(t, err, cbErr) + + // Second Get: lock should have been cleaned up, so a fresh callback succeeds. + val := "good-val:" + uuid.New().String() + res, err := client.Get(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + return val, nil + }) + require.NoError(t, err) + assert.Equal(t, val, res) +} + +func TestCacheAside_GetMulti_CallbackError(t *testing.T) { + client := makeClient(t, addr) + defer client.Client().Close() + ctx := context.Background() + + keys := []string{ + "key:0:" + uuid.New().String(), + "key:1:" + uuid.New().String(), + } + cbErr := fmt.Errorf("multi callback failed") + + // Callback returns error — locks should be cleaned up. + _, err := client.GetMulti(ctx, time.Second*10, keys, func(ctx context.Context, ks []string) (map[string]string, error) { + return nil, cbErr + }) + require.ErrorIs(t, err, cbErr) + + // Retry should succeed — locks were released. + vals := map[string]string{ + keys[0]: "val:0:" + uuid.New().String(), + keys[1]: "val:1:" + uuid.New().String(), + } + res, err := client.GetMulti(ctx, time.Second*10, keys, func(ctx context.Context, ks []string) (map[string]string, error) { + out := make(map[string]string, len(ks)) + for _, k := range ks { + out[k] = vals[k] + } + return out, nil + }) + require.NoError(t, err) + if diff := cmp.Diff(vals, res); diff != "" { + t.Errorf("GetMulti() mismatch (-want +got):\n%s", diff) + } +} + +func TestCacheAside_Close(t *testing.T) { + client := makeClient(t, addr) + defer client.Client().Close() + ctx := context.Background() + + key := "key:" + uuid.New().String() + + // Place a long-lived lock. + innerClient := client.Client() + lockVal := "__redcache:lock:" + uuid.New().String() + err := innerClient.Do(ctx, innerClient.B().Set().Key(key).Value(lockVal).Nx().Get().Px(time.Second*30).Build()).Error() + require.True(t, rueidis.IsRedisNil(err)) + + getCtx, getCancel := context.WithTimeout(ctx, 2*time.Second) + defer getCancel() + + errCh := make(chan error, 1) + go func() { + _, err := client.Get(getCtx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + return "val", nil + }) + errCh <- err + }() + + time.Sleep(100 * time.Millisecond) + client.Close() + + select { + case <-time.After(5 * time.Second): + t.Fatal("Get did not return after Close") + case err := <-errCh: + // Close wakes up the waiter; since the lock persists it will eventually timeout. + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) + } +} + +func TestNewRedCacheAside_Validation(t *testing.T) { + t.Run("empty InitAddress", func(t *testing.T) { + _, err := redcache.NewRedCacheAside( + rueidis.ClientOption{}, + redcache.CacheAsideOption{}, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "InitAddress") + }) + + t.Run("negative LockTTL", func(t *testing.T) { + _, err := redcache.NewRedCacheAside( + rueidis.ClientOption{InitAddress: addr}, + redcache.CacheAsideOption{LockTTL: -1 * time.Second}, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "negative") + }) + + t.Run("too small LockTTL", func(t *testing.T) { + _, err := redcache.NewRedCacheAside( + rueidis.ClientOption{InitAddress: addr}, + redcache.CacheAsideOption{LockTTL: 10 * time.Millisecond}, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "100ms") + }) +} + +// TestCacheAside_Get_ErrLockLostRetry verifies that when a ForceSet steals the lock +// during a Get callback, Get retries and eventually sees the ForceSet value. +func TestCacheAside_Get_ErrLockLostRetry(t *testing.T) { + client, err := redcache.NewPrimeableCacheAside( + rueidis.ClientOption{InitAddress: addr}, + redcache.CacheAsideOption{LockTTL: time.Second * 2}, + ) + require.NoError(t, err) + defer client.Client().Close() + ctx := context.Background() + + key := "key:" + uuid.New().String() + forcedVal := "forced:" + uuid.New().String() + + getStarted := make(chan struct{}) + + go func() { + // Get acquires lock, then we steal it with ForceSet during callback. + _, _ = client.Get(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + close(getStarted) + time.Sleep(300 * time.Millisecond) + return "get-val", nil + }) + }() + + <-getStarted + time.Sleep(50 * time.Millisecond) + + // ForceSet steals the lock — Get's setWithLock will see CAS mismatch (ErrLockLost). + err = client.ForceSet(ctx, time.Second*10, key, forcedVal) + require.NoError(t, err) + + // Wait for Get to complete its retry. + time.Sleep(500 * time.Millisecond) + + // Key should have a value (either the forced value or a re-populated value). + res, err := client.Get(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + t.Fatal("callback should not be called — value should exist") + return "", nil + }) + require.NoError(t, err) + assert.NotEmpty(t, res) +} diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..8b9aa68 --- /dev/null +++ b/errors.go @@ -0,0 +1,48 @@ +package redcache + +import ( + "errors" + "fmt" + "strings" +) + +// ErrLockLost indicates the distributed lock was lost or expired before the value could be set. +// This can occur if the lock TTL expires during callback execution or if Redis invalidates the lock. +var ErrLockLost = errors.New("lock was lost or expired before value could be set") + +// BatchError represents partial failures in a multi-key operation. +// Some keys may have succeeded while others failed. +type BatchError struct { + // Failed maps each failed key to its error. + Failed map[string]error + // Succeeded lists the keys that were set successfully. + Succeeded []string +} + +// Error returns a human-readable summary of the batch failure. +func (e *BatchError) Error() string { + var b strings.Builder + fmt.Fprintf(&b, "batch operation partially failed: %d succeeded, %d failed", len(e.Succeeded), len(e.Failed)) + for key, err := range e.Failed { + fmt.Fprintf(&b, "; key %q: %s", key, err) + } + return b.String() +} + +// HasFailures returns true if any keys failed. +func (e *BatchError) HasFailures() bool { + return len(e.Failed) > 0 +} + +// NewBatchError creates a BatchError from the given failures and successes. +// Returns nil (untyped) if there are no failures, so it is safe to return +// directly as an error interface value. +func NewBatchError(failed map[string]error, succeeded []string) error { + if len(failed) == 0 { + return nil + } + return &BatchError{ + Failed: failed, + Succeeded: succeeded, + } +} diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 0000000..f4a3d23 --- /dev/null +++ b/errors_test.go @@ -0,0 +1,53 @@ +package redcache_test + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/dcbickfo/redcache" +) + +func TestBatchError_Error(t *testing.T) { + be := &redcache.BatchError{ + Failed: map[string]error{"key1": errors.New("timeout"), "key2": errors.New("lock lost")}, + Succeeded: []string{"key3"}, + } + msg := be.Error() + assert.Contains(t, msg, "1 succeeded") + assert.Contains(t, msg, "2 failed") + assert.Contains(t, msg, "key1") + assert.Contains(t, msg, "key2") +} + +func TestBatchError_HasFailures(t *testing.T) { + be := &redcache.BatchError{ + Failed: map[string]error{"key1": errors.New("err")}, + Succeeded: []string{"key2"}, + } + assert.True(t, be.HasFailures()) + + beNoFail := &redcache.BatchError{ + Failed: map[string]error{}, + Succeeded: []string{"key1"}, + } + assert.False(t, beNoFail.HasFailures()) +} + +func TestNewBatchError_NilWhenNoFailures(t *testing.T) { + be := redcache.NewBatchError(map[string]error{}, []string{"key1"}) + assert.Nil(t, be) +} + +func TestNewBatchError_ReturnsErrorWhenFailures(t *testing.T) { + failed := map[string]error{"key1": errors.New("oops")} + succeeded := []string{"key2"} + err := redcache.NewBatchError(failed, succeeded) + require.NotNil(t, err) + var be *redcache.BatchError + require.ErrorAs(t, err, &be) + assert.Equal(t, failed, be.Failed) + assert.Equal(t, succeeded, be.Succeeded) +} diff --git a/internal/cmdx/slot.go b/internal/cmdx/slot.go index 2c5ee85..850680c 100644 --- a/internal/cmdx/slot.go +++ b/internal/cmdx/slot.go @@ -1,12 +1,23 @@ +// Package cmdx provides Redis cluster slot calculation utilities. package cmdx -// https://redis.io/topics/cluster-spec - const ( // RedisClusterSlots is the maximum slot number in a Redis cluster (16384 total slots, numbered 0-16383). RedisClusterSlots = 16383 ) +// GroupBySlot groups items by their Redis cluster slot, using keyFn to extract +// the key for slot computation. +func GroupBySlot[V any](items []V, keyFn func(V) string) map[uint16][]V { + groups := make(map[uint16][]V) + for _, item := range items { + slot := Slot(keyFn(item)) + groups[slot] = append(groups[slot], item) + } + return groups +} + +// Slot returns the Redis cluster slot for the given key, following the CRC16 hash tag spec. func Slot(key string) uint16 { var s, e int for ; s < len(key); s++ { diff --git a/internal/cmdx/slot_test.go b/internal/cmdx/slot_test.go index 54ca2b0..56f8766 100644 --- a/internal/cmdx/slot_test.go +++ b/internal/cmdx/slot_test.go @@ -164,6 +164,43 @@ func TestSlot_BoundaryValues(t *testing.T) { } } +func TestGroupBySlot(t *testing.T) { + type item struct { + key string + value int + } + + items := []item{ + {key: "{user:1}:a", value: 1}, + {key: "{user:1}:b", value: 2}, + {key: "{user:2}:a", value: 3}, + {key: "standalone", value: 4}, + } + + groups := cmdx.GroupBySlot(items, func(i item) string { return i.key }) + + // {user:1}:a and {user:1}:b should be in the same slot (same hash tag). + slot1 := cmdx.Slot("{user:1}:a") + assert.Len(t, groups[slot1], 2) + assert.Equal(t, 1, groups[slot1][0].value) + assert.Equal(t, 2, groups[slot1][1].value) + + // {user:2}:a should be in its own slot. + slot2 := cmdx.Slot("{user:2}:a") + assert.Len(t, groups[slot2], 1) + assert.Equal(t, 3, groups[slot2][0].value) + + // standalone should be in its own slot. + slot3 := cmdx.Slot("standalone") + assert.Len(t, groups[slot3], 1) + assert.Equal(t, 4, groups[slot3][0].value) +} + +func TestGroupBySlot_Empty(t *testing.T) { + groups := cmdx.GroupBySlot([]string{}, func(s string) string { return s }) + assert.Empty(t, groups) +} + func BenchmarkSlot(b *testing.B) { keys := []string{ "simple", diff --git a/internal/lockpool/lockpool.go b/internal/lockpool/lockpool.go new file mode 100644 index 0000000..1f93d3a --- /dev/null +++ b/internal/lockpool/lockpool.go @@ -0,0 +1,37 @@ +// Package lockpool provides fast lock value generation using an atomic counter +// and a per-instance UUID prefix. +package lockpool + +import ( + "strconv" + "sync/atomic" + + "github.com/google/uuid" +) + +// Pool generates unique lock values by combining a fixed instance UUID with an +// atomic counter. This avoids calling uuid.NewV7() per lock, which is expensive +// under high concurrency. +type Pool struct { + prefix string + instanceID string + counter atomic.Uint64 +} + +// New creates a Pool with the given lock prefix (e.g., "__redcache:lock:"). +func New(prefix string) (*Pool, error) { + id, err := uuid.NewV7() + if err != nil { + return nil, err + } + return &Pool{ + prefix: prefix, + instanceID: id.String(), + }, nil +} + +// Generate returns a unique lock value: prefix + instanceID + ":" + counter. +func (p *Pool) Generate() string { + n := p.counter.Add(1) + return p.prefix + p.instanceID + ":" + strconv.FormatUint(n, 10) +} diff --git a/internal/lockpool/lockpool_test.go b/internal/lockpool/lockpool_test.go new file mode 100644 index 0000000..613dbbd --- /dev/null +++ b/internal/lockpool/lockpool_test.go @@ -0,0 +1,65 @@ +package lockpool_test + +import ( + "strings" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/dcbickfo/redcache/internal/lockpool" +) + +func TestPool_Generate_Prefix(t *testing.T) { + prefix := "__redcache:lock:" + pool, err := lockpool.New(prefix) + require.NoError(t, err) + + val := pool.Generate() + assert.True(t, strings.HasPrefix(val, prefix), "expected prefix %q, got %q", prefix, val) +} + +func TestPool_Generate_Uniqueness(t *testing.T) { + pool, err := lockpool.New("lock:") + require.NoError(t, err) + + seen := make(map[string]struct{}) + for range 1000 { + val := pool.Generate() + _, exists := seen[val] + assert.False(t, exists, "duplicate lock value: %s", val) + seen[val] = struct{}{} + } +} + +func TestPool_Generate_ConcurrentSafety(t *testing.T) { + pool, err := lockpool.New("lock:") + require.NoError(t, err) + + const goroutines = 100 + const perGoroutine = 100 + + results := make(chan string, goroutines*perGoroutine) + var wg sync.WaitGroup + wg.Add(goroutines) + + for range goroutines { + go func() { + defer wg.Done() + for range perGoroutine { + results <- pool.Generate() + } + }() + } + wg.Wait() + close(results) + + seen := make(map[string]struct{}) + for val := range results { + _, exists := seen[val] + assert.False(t, exists, "duplicate lock value under concurrency: %s", val) + seen[val] = struct{}{} + } + assert.Len(t, seen, goroutines*perGoroutine) +} diff --git a/internal/mapsx/maps.go b/internal/mapsx/maps.go index c8e8c4b..653281e 100644 --- a/internal/mapsx/maps.go +++ b/internal/mapsx/maps.go @@ -1,5 +1,7 @@ +// Package mapsx provides generic helpers for map operations. package mapsx +// Keys returns the keys of the map m in unspecified order. func Keys[M ~map[K]V, K comparable, V any](m M) []K { keys := make([]K, 0, len(m)) for k := range m { @@ -8,6 +10,7 @@ func Keys[M ~map[K]V, K comparable, V any](m M) []K { return keys } +// Values returns the values of the map m in unspecified order. func Values[M ~map[K]V, K comparable, V any](m M) []V { values := make([]V, 0, len(m)) for _, v := range m { diff --git a/internal/syncx/map.go b/internal/syncx/map.go index 11872d4..84f2cda 100644 --- a/internal/syncx/map.go +++ b/internal/syncx/map.go @@ -1,7 +1,9 @@ +// Package syncx provides generic typed wrappers around standard library sync primitives. package syncx import "sync" +// Map is a generic typed wrapper around sync.Map that avoids interface{} casts at call sites. type Map[K comparable, V any] struct { m sync.Map } diff --git a/internal/syncx/wait.go b/internal/syncx/wait.go index 3f167a2..a306f5b 100644 --- a/internal/syncx/wait.go +++ b/internal/syncx/wait.go @@ -2,35 +2,41 @@ package syncx import ( "context" - "iter" - "reflect" + "sync" ) -func WaitForAll[C ~<-chan V, V any](ctx context.Context, waitLock iter.Seq[C], length int) error { - cases := setupCases(ctx, waitLock, length) - for range length { - chosen, _, _ := reflect.Select(cases) - if ctx.Err() != nil { - return ctx.Err() - } - cases = append(cases[:chosen], cases[chosen+1:]...) +// WaitForAll blocks until all channels are closed (or receive a value) or the context is cancelled. +// Returns nil if all channels signaled, or the context error if cancelled first. +func WaitForAll[C ~<-chan V, V any](ctx context.Context, channels []C) error { + if len(channels) == 0 { + return nil } - return nil -} -func setupCases[C ~<-chan V, V any](ctx context.Context, waitLock iter.Seq[C], length int) []reflect.SelectCase { - cases := make([]reflect.SelectCase, length+1) - i := 0 - for ch := range waitLock { - cases[i] = reflect.SelectCase{ - Dir: reflect.SelectRecv, - Chan: reflect.ValueOf(ch), - } - i++ + // Merged channel: each goroutine sends one signal when its channel fires. + done := make(chan struct{}, len(channels)) + var wg sync.WaitGroup + wg.Add(len(channels)) + + for _, ch := range channels { + go func() { + defer wg.Done() + select { + case <-ch: + done <- struct{}{} + case <-ctx.Done(): + } + }() } - cases[i] = reflect.SelectCase{ - Dir: reflect.SelectRecv, - Chan: reflect.ValueOf(ctx.Done()), + + // Wait for all signals or context cancellation. + for range len(channels) { + select { + case <-done: + case <-ctx.Done(): + // Ensure all goroutines finish before returning. + wg.Wait() + return ctx.Err() + } } - return cases + return nil } diff --git a/internal/syncx/wait_test.go b/internal/syncx/wait_test.go index 89da67a..ea1c4a1 100644 --- a/internal/syncx/wait_test.go +++ b/internal/syncx/wait_test.go @@ -2,7 +2,6 @@ package syncx_test import ( "context" - "slices" "testing" "time" @@ -36,7 +35,7 @@ func TestWaitForAll_Success(t *testing.T) { waitLock := []<-chan struct{}{ch1, ch2} - err := syncx.WaitForAll(ctx, slices.Values(waitLock), len(waitLock)) + err := syncx.WaitForAll(ctx, waitLock) assert.NoErrorf(t, err, "expected no error, got %v", err) } @@ -50,7 +49,7 @@ func TestWaitForAll_SuccessClosed(t *testing.T) { waitLock := []<-chan struct{}{ch1, ch2} - err := syncx.WaitForAll(ctx, slices.Values(waitLock), len(waitLock)) + err := syncx.WaitForAll(ctx, waitLock) assert.NoErrorf(t, err, "expected no error, got %v", err) } @@ -66,7 +65,7 @@ func TestWaitForAll_ContextCancelled(t *testing.T) { waitLock := []<-chan int{ch1, ch2} - err := syncx.WaitForAll(ctx, slices.Values(waitLock), len(waitLock)) + err := syncx.WaitForAll(ctx, waitLock) assert.ErrorIsf(t, err, context.DeadlineExceeded, "expected context.DeadlineExceeded, got %v", err) } @@ -82,7 +81,7 @@ func TestWaitForAll_PartialCompleteContextCancelled(t *testing.T) { waitLock := []<-chan int{ch1, ch2} - err := syncx.WaitForAll(ctx, slices.Values(waitLock), len(waitLock)) + err := syncx.WaitForAll(ctx, waitLock) assert.ErrorIsf(t, err, context.DeadlineExceeded, "expected context.DeadlineExceeded, got %v", err) } @@ -90,7 +89,7 @@ func TestWaitForAll_NoChannels(t *testing.T) { ctx := context.Background() var waitLock []<-chan int - err := syncx.WaitForAll(ctx, slices.Values(waitLock), len(waitLock)) + err := syncx.WaitForAll(ctx, waitLock) assert.NoErrorf(t, err, "expected no error, got %v", err) } @@ -103,7 +102,7 @@ func TestWaitForAll_ImmediateContextCancel(t *testing.T) { waitLock := []<-chan int{ch1, ch2} - err := syncx.WaitForAll(ctx, slices.Values(waitLock), len(waitLock)) + err := syncx.WaitForAll(ctx, waitLock) assert.ErrorIsf(t, err, context.Canceled, "expected context.Canceled, got %v", err) } @@ -117,6 +116,6 @@ func TestWaitForAll_ChannelAlreadyClosed(t *testing.T) { waitLock := []<-chan int{ch1, ch2} - err := syncx.WaitForAll(ctx, slices.Values(waitLock), len(waitLock)) + err := syncx.WaitForAll(ctx, waitLock) assert.NoErrorf(t, err, "expected no error, got %v", err) } diff --git a/lua_scripts.go b/lua_scripts.go new file mode 100644 index 0000000..698fd61 --- /dev/null +++ b/lua_scripts.go @@ -0,0 +1,113 @@ +package redcache + +import "github.com/redis/rueidis" + +// Lua scripts for CacheAside lock operations. +var ( + // delKeyLua atomically deletes a key only if the current value matches the lock. + delKeyLua = rueidis.NewLuaScript(`if redis.call("GET",KEYS[1]) == ARGV[1] then return redis.call("DEL",KEYS[1]) else return 0 end`) + // setKeyLua atomically sets a value only if the current value matches the lock (CAS). + setKeyLua = rueidis.NewLuaScript(`if redis.call("GET",KEYS[1]) == ARGV[1] then return redis.call("SET",KEYS[1],ARGV[2],"PX",ARGV[3]) else return 0 end`) +) + +// Lua scripts for PrimeableCacheAside write-lock operations. +var ( + // acquireWriteLockScript atomically acquires a write lock. + // Unlike SET NX (used by Get), this allows overwriting real values but + // refuses to overwrite an existing lock, preventing Set from stomping + // on an active Get operation's lock. + // Returns 1 on success, 0 if an existing lock is present. + acquireWriteLockScript = rueidis.NewLuaScript(` + local key = KEYS[1] + local lock_value = ARGV[1] + local ttl = ARGV[2] + local lock_prefix = ARGV[3] + + local current = redis.call("GET", key) + + if current == false then + redis.call("SET", key, lock_value, "PX", ttl) + return 1 + end + + if string.sub(current, 1, string.len(lock_prefix)) == lock_prefix then + return 0 + end + + redis.call("SET", key, lock_value, "PX", ttl) + return 1 + `) + + // acquireWriteLockWithBackupScript acquires a lock and returns the previous value + // for rollback in SetMulti. + // Returns: [success (0 or 1), previous_value or false]. + acquireWriteLockWithBackupScript = rueidis.NewLuaScript(` + local key = KEYS[1] + local lock_value = ARGV[1] + local ttl = ARGV[2] + local lock_prefix = ARGV[3] + + local current = redis.call("GET", key) + + if current == false then + redis.call("SET", key, lock_value, "PX", ttl) + return {1, false} + end + + if string.sub(current, 1, string.len(lock_prefix)) == lock_prefix then + return {0, current} + end + + redis.call("SET", key, lock_value, "PX", ttl) + return {1, current} + `) + + // restoreValueOrDeleteScript CAS-restores a saved value or deletes the key. + // Used during SetMulti rollback. Only acts if we still hold our lock. + restoreValueOrDeleteScript = rueidis.NewLuaScript(` + local key = KEYS[1] + local expected_lock = ARGV[1] + local restore_value = ARGV[2] + + if redis.call("GET", key) == expected_lock then + if restore_value and restore_value ~= "" then + redis.call("SET", key, restore_value) + else + redis.call("DEL", key) + end + return 1 + else + return 0 + end + `) + + // setWithWriteLockScript is a strict CAS: SET value only if we hold the exact lock. + // Returns 1 on success, 0 if lock was lost. + setWithWriteLockScript = rueidis.NewLuaScript(` + local key = KEYS[1] + local value = ARGV[1] + local ttl = ARGV[2] + local expected_lock = ARGV[3] + + if redis.call("GET", key) == expected_lock then + redis.call("SET", key, value, "PX", ttl) + return 1 + else + return 0 + end + `) + + // refreshLockScript CAS-refreshes a lock's TTL using PEXPIRE. + // Unlike re-SET, PEXPIRE does not trigger Redis invalidation messages + // and cannot overwrite a value written by another operation between + // lock acquisition and TTL refresh. + // Returns 1 on success, 0 if lock was lost. + refreshLockScript = rueidis.NewLuaScript(` + if redis.call("GET", KEYS[1]) == ARGV[1] then + redis.call("PEXPIRE", KEYS[1], ARGV[2]) + return 1 + else + return 0 + end + `) +) diff --git a/primeable_cacheaside.go b/primeable_cacheaside.go new file mode 100644 index 0000000..b7b7358 --- /dev/null +++ b/primeable_cacheaside.go @@ -0,0 +1,234 @@ +package redcache + +import ( + "context" + "fmt" + "strconv" + "strings" + "time" + + "github.com/redis/rueidis" + + "github.com/dcbickfo/redcache/internal/lockpool" + "github.com/dcbickfo/redcache/internal/mapsx" + "github.com/dcbickfo/redcache/internal/syncx" +) + +// PrimeableCacheAside extends CacheAside with explicit Set operations for cache +// priming and coordinated cache updates. +// +// It inherits all Get/GetMulti/Del/DelMulti capabilities and adds: +// - Set/SetMulti for coordinated cache updates with write locking +// - ForceSet/ForceSetMulti for unconditional writes bypassing locks +type PrimeableCacheAside struct { + *CacheAside + lockPool *lockpool.Pool +} + +// NewPrimeableCacheAside creates a PrimeableCacheAside that wraps a CacheAside +// with additional Set operations. +func NewPrimeableCacheAside(clientOption rueidis.ClientOption, caOption CacheAsideOption) (*PrimeableCacheAside, error) { + rca, err := NewRedCacheAside(clientOption, caOption) + if err != nil { + return nil, err + } + lp, err := lockpool.New(rca.lockPrefix) + if err != nil { + return nil, fmt.Errorf("lock pool: %w", err) + } + return &PrimeableCacheAside{ + CacheAside: rca, + lockPool: lp, + }, nil +} + +// Close cancels all pending lock entries. It does NOT close the underlying Redis client. +func (pca *PrimeableCacheAside) Close() { + pca.CacheAside.Close() +} + +// Set acquires a write lock on the key, calls fn to produce the value, and atomically +// sets it in Redis. If another operation holds a lock, Set waits for it to complete. +// +// The callback fn receives the key and should return the value to cache. +// Set respects context cancellation for timeouts. +func (pca *PrimeableCacheAside) Set( + ctx context.Context, + ttl time.Duration, + key string, + fn func(ctx context.Context, key string) (string, error), +) error { + lockVal := pca.lockPool.Generate() + lockTTLMs := strconv.FormatInt(pca.lockTTL.Milliseconds(), 10) + +retry: + waitChan := pca.register(key) + + // Subscribe + read current value. + resp := pca.client.DoCache(ctx, pca.client.B().Get().Key(key).Cache(), pca.lockTTL) + val, err := resp.ToString() + if err != nil && !rueidis.IsRedisNil(err) { + return fmt.Errorf("read key %q: %w", key, err) + } + + // If current value is a lock, wait for it to be released. + if !rueidis.IsRedisNil(err) && strings.HasPrefix(val, pca.lockPrefix) { + select { + case <-waitChan: + goto retry + case <-ctx.Done(): + return ctx.Err() + } + } + + // Try to acquire write lock. + result, err := acquireWriteLockScript.Exec(ctx, pca.client, []string{key}, []string{lockVal, lockTTLMs, pca.lockPrefix}).AsInt64() + if err != nil { + return fmt.Errorf("write lock for key %q: %w", key, err) + } + if result == 0 { + // Another lock appeared between DoCache and Exec. + select { + case <-waitChan: + goto retry + case <-ctx.Done(): + return ctx.Err() + } + } + + // Lock acquired — execute callback. + newVal, err := fn(ctx, key) + if err != nil { + pca.bestEffortUnlock(ctx, key, lockVal) + return err + } + + // CAS set the value. + casResult, err := setWithWriteLockScript.Exec(ctx, pca.client, []string{key}, []string{newVal, strconv.FormatInt(ttl.Milliseconds(), 10), lockVal}).AsInt64() + if err != nil { + return fmt.Errorf("set key %q: %w", key, err) + } + if casResult == 0 { + return fmt.Errorf("key %q: %w", key, ErrLockLost) + } + return nil +} + +// SetMulti acquires write locks on all keys, calls fn once with all keys, +// and atomically sets the returned values. Locks are acquired in sorted order +// to prevent deadlocks. +// +// On partial CAS failure, returns a *BatchError listing succeeded and failed keys. +// On full success, returns nil. +func (pca *PrimeableCacheAside) SetMulti( + ctx context.Context, + ttl time.Duration, + keys []string, + fn func(ctx context.Context, keys []string) (map[string]string, error), +) error { + if len(keys) == 0 { + return nil + } + + // Wait for any existing read locks on these keys. + if err := pca.waitForReadLocks(ctx, keys); err != nil { + return err + } + + // Acquire write locks in sorted order. + lockValues, savedValues, err := pca.acquireMultiWriteLocks(ctx, keys) + if err != nil { + return err + } + + // Execute the callback with all locked keys. + vals, err := fn(ctx, mapsx.Keys(lockValues)) + if err != nil { + pca.restoreMultiValues(ctx, lockValues, savedValues) + return err + } + + // CAS batch set. + succeeded, failed := pca.setMultiValuesWithCAS(ctx, ttl, vals, lockValues) + + // Unlock any keys that weren't successfully written. + toUnlock := make(map[string]string) + for key, lockVal := range lockValues { + found := false + for _, s := range succeeded { + if s == key { + found = true + break + } + } + if !found { + toUnlock[key] = lockVal + } + } + if len(toUnlock) > 0 { + pca.unlockMultiKeys(ctx, toUnlock) + } + + return NewBatchError(failed, succeeded) +} + +// ForceSet unconditionally writes a value to Redis, bypassing all locks. +// Any in-progress Get or Set on this key will see ErrLockLost and retry. +func (pca *PrimeableCacheAside) ForceSet(ctx context.Context, ttl time.Duration, key, value string) error { + return pca.client.Do(ctx, pca.client.B().Set().Key(key).Value(value).Px(ttl).Build()).Error() +} + +// ForceSetMulti unconditionally writes multiple values to Redis, bypassing all locks. +func (pca *PrimeableCacheAside) ForceSetMulti(ctx context.Context, ttl time.Duration, values map[string]string) error { + if len(values) == 0 { + return nil + } + cmds := make(rueidis.Commands, 0, len(values)) + for key, val := range values { + cmds = append(cmds, pca.client.B().Set().Key(key).Value(val).Px(ttl).Build()) + } + resps := pca.client.DoMulti(ctx, cmds...) + for _, resp := range resps { + if err := resp.Error(); err != nil { + return err + } + } + return nil +} + +// waitForReadLocks registers all keys, batch-reads them, and waits for any that +// currently hold a lock value. Uses correct ordering: register first, then DoCache. +func (pca *PrimeableCacheAside) waitForReadLocks(ctx context.Context, keys []string) error { + // 1. Register ALL keys first so onInvalidate can find the lockEntries. + waitChans := make(map[string]<-chan struct{}, len(keys)) + for _, key := range keys { + waitChans[key] = pca.register(key) + } + + // 2. DoMultiCache to subscribe and read values. + multi := make([]rueidis.CacheableTTL, len(keys)) + for i, key := range keys { + multi[i] = rueidis.CacheableTTL{ + Cmd: pca.client.B().Get().Key(key).Cache(), + TTL: pca.lockTTL, + } + } + resps := pca.client.DoMultiCache(ctx, multi...) + + // 3. Collect channels for keys that have locks. + var lockedChans []<-chan struct{} + for i, resp := range resps { + val, err := resp.ToString() + if err != nil { + continue // Redis nil or error — no lock. + } + if strings.HasPrefix(val, pca.lockPrefix) { + lockedChans = append(lockedChans, waitChans[keys[i]]) + } + } + + if len(lockedChans) == 0 { + return nil + } + return syncx.WaitForAll(ctx, lockedChans) +} diff --git a/primeable_cacheaside_test.go b/primeable_cacheaside_test.go new file mode 100644 index 0000000..bf20287 --- /dev/null +++ b/primeable_cacheaside_test.go @@ -0,0 +1,682 @@ +package redcache_test + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/uuid" + "github.com/redis/rueidis" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/dcbickfo/redcache" +) + +func makePrimeableClient(t *testing.T, addr []string) *redcache.PrimeableCacheAside { + t.Helper() + client, err := redcache.NewPrimeableCacheAside( + rueidis.ClientOption{ + InitAddress: addr, + }, + redcache.CacheAsideOption{ + LockTTL: time.Second * 1, + }, + ) + require.NoError(t, err) + return client +} + +func TestPrimeableCacheAside_Set_Basic(t *testing.T) { + client := makePrimeableClient(t, addr) + defer client.Client().Close() + ctx := context.Background() + + key := "key:" + uuid.New().String() + val := "val:" + uuid.New().String() + + err := client.Set(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + assert.Equal(t, key, k) + return val, nil + }) + require.NoError(t, err) + + // Subsequent Get should return cached value without callback. + called := false + res, err := client.Get(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + called = true + return "", nil + }) + require.NoError(t, err) + assert.Equal(t, val, res) + assert.False(t, called, "Get callback should not be invoked after Set") +} + +func TestPrimeableCacheAside_Set_Overwrites(t *testing.T) { + client := makePrimeableClient(t, addr) + defer client.Client().Close() + ctx := context.Background() + + key := "key:" + uuid.New().String() + val1 := "val1:" + uuid.New().String() + val2 := "val2:" + uuid.New().String() + + // Set initial value via Get. + res, err := client.Get(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + return val1, nil + }) + require.NoError(t, err) + assert.Equal(t, val1, res) + + // Overwrite with Set. + err = client.Set(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + return val2, nil + }) + require.NoError(t, err) + + // Verify new value. + res, err = client.Get(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + t.Fatal("callback should not be called") + return "", nil + }) + require.NoError(t, err) + assert.Equal(t, val2, res) +} + +func TestPrimeableCacheAside_Set_WaitsForExistingReadLock(t *testing.T) { + client := makePrimeableClient(t, addr) + defer client.Client().Close() + ctx := context.Background() + + key := "key:" + uuid.New().String() + getVal := "get-val:" + uuid.New().String() + setVal := "set-val:" + uuid.New().String() + + getStarted := make(chan struct{}) + getComplete := make(chan struct{}) + + // Start a Get that holds a lock for a while. + go func() { + _, err := client.Get(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + close(getStarted) + time.Sleep(200 * time.Millisecond) + return getVal, nil + }) + assert.NoError(t, err) + close(getComplete) + }() + + // Wait for Get to acquire its lock. + <-getStarted + time.Sleep(50 * time.Millisecond) + + // Set should wait for the Get lock to be released, then proceed. + err := client.Set(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + return setVal, nil + }) + require.NoError(t, err) + + <-getComplete + + // The Set value should be the final value. + res, err := client.Get(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + t.Fatal("callback should not be called") + return "", nil + }) + require.NoError(t, err) + assert.Equal(t, setVal, res) +} + +func TestPrimeableCacheAside_Set_Concurrent(t *testing.T) { + client := makePrimeableClient(t, addr) + defer client.Client().Close() + ctx := context.Background() + + key := "key:" + uuid.New().String() + var callCount atomic.Int32 + + var wg sync.WaitGroup + for i := range 10 { + wg.Add(1) + go func() { + defer wg.Done() + err := client.Set(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + callCount.Add(1) + return "val-from-" + uuid.New().String(), nil + }) + // Either success or ErrLockLost is acceptable. + if err != nil { + assert.ErrorIs(t, err, redcache.ErrLockLost, "iteration %d", i) + } + }() + } + wg.Wait() + + // At least one should have succeeded. + assert.GreaterOrEqual(t, callCount.Load(), int32(1)) +} + +func TestPrimeableCacheAside_SetMulti_Basic(t *testing.T) { + client := makePrimeableClient(t, addr) + defer client.Client().Close() + ctx := context.Background() + + keyAndVals := map[string]string{ + "key:0:" + uuid.New().String(): "val:0:" + uuid.New().String(), + "key:1:" + uuid.New().String(): "val:1:" + uuid.New().String(), + "key:2:" + uuid.New().String(): "val:2:" + uuid.New().String(), + } + keys := make([]string, 0, len(keyAndVals)) + for k := range keyAndVals { + keys = append(keys, k) + } + + err := client.SetMulti(ctx, time.Second*10, keys, func(ctx context.Context, ks []string) (map[string]string, error) { + res := make(map[string]string, len(ks)) + for _, k := range ks { + res[k] = keyAndVals[k] + } + return res, nil + }) + require.NoError(t, err) + + // Verify all keys cached. + res, err := client.GetMulti(ctx, time.Second*10, keys, func(ctx context.Context, ks []string) (map[string]string, error) { + t.Fatal("GetMulti callback should not be called after SetMulti") + return nil, nil + }) + require.NoError(t, err) + if diff := cmp.Diff(keyAndVals, res); diff != "" { + t.Errorf("GetMulti() mismatch (-want +got):\n%s", diff) + } +} + +func TestPrimeableCacheAside_SetMulti_NoDeadlock(t *testing.T) { + client := makePrimeableClient(t, addr) + defer client.Client().Close() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Two overlapping key sets — sorted order prevents deadlock. + keys1 := []string{ + "key:a:" + uuid.New().String(), + "key:b:" + uuid.New().String(), + "key:c:" + uuid.New().String(), + } + keys2 := []string{ + keys1[1], // overlap on key:b + "key:d:" + uuid.New().String(), + "key:e:" + uuid.New().String(), + } + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + for range 5 { + _ = client.SetMulti(ctx, time.Second*10, keys1, func(ctx context.Context, ks []string) (map[string]string, error) { + res := make(map[string]string, len(ks)) + for _, k := range ks { + res[k] = "val-1:" + uuid.New().String() + } + return res, nil + }) + } + }() + + go func() { + defer wg.Done() + for range 5 { + _ = client.SetMulti(ctx, time.Second*10, keys2, func(ctx context.Context, ks []string) (map[string]string, error) { + res := make(map[string]string, len(ks)) + for _, k := range ks { + res[k] = "val-2:" + uuid.New().String() + } + return res, nil + }) + } + }() + + // If there's a deadlock, the test will timeout. + wg.Wait() +} + +func TestPrimeableCacheAside_ForceSet_Basic(t *testing.T) { + client := makePrimeableClient(t, addr) + defer client.Client().Close() + ctx := context.Background() + + key := "key:" + uuid.New().String() + val := "forced-val:" + uuid.New().String() + + err := client.ForceSet(ctx, time.Second*10, key, val) + require.NoError(t, err) + + // Get should return the force-set value. + res, err := client.Get(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + t.Fatal("callback should not be called") + return "", nil + }) + require.NoError(t, err) + assert.Equal(t, val, res) +} + +func TestPrimeableCacheAside_ForceSet_StealsLock(t *testing.T) { + client := makePrimeableClient(t, addr) + defer client.Client().Close() + ctx := context.Background() + + key := "key:" + uuid.New().String() + forcedVal := "forced:" + uuid.New().String() + + getStarted := make(chan struct{}) + + // Start a slow Get that holds a lock. + go func() { + _, _ = client.Get(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + close(getStarted) + time.Sleep(300 * time.Millisecond) + return "get-val", nil + }) + }() + + <-getStarted + time.Sleep(50 * time.Millisecond) + + // ForceSet overwrites the lock. + err := client.ForceSet(ctx, time.Second*10, key, forcedVal) + require.NoError(t, err) + + // Wait for Get to complete (it will see ErrLockLost and retry). + time.Sleep(500 * time.Millisecond) + + // The forced value should be present (or Get retried with its own value). + res, err := client.Get(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + t.Fatal("callback should not be called — value should exist") + return "", nil + }) + require.NoError(t, err) + assert.NotEmpty(t, res, "expected a value to be cached") +} + +func TestPrimeableCacheAside_ForceSetMulti_Basic(t *testing.T) { + client := makePrimeableClient(t, addr) + defer client.Client().Close() + ctx := context.Background() + + values := map[string]string{ + "key:0:" + uuid.New().String(): "val:0:" + uuid.New().String(), + "key:1:" + uuid.New().String(): "val:1:" + uuid.New().String(), + } + + err := client.ForceSetMulti(ctx, time.Second*10, values) + require.NoError(t, err) + + // Verify via direct reads. + for key, expected := range values { + res, err := client.Get(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + t.Fatal("callback should not be called") + return "", nil + }) + require.NoError(t, err) + assert.Equal(t, expected, res, "key %s", key) + } +} + +func TestPrimeableCacheAside_Set_ContextCancellation(t *testing.T) { + client := makePrimeableClient(t, addr) + defer client.Client().Close() + + key := "key:" + uuid.New().String() + + // Place a lock so Set will wait. + innerClient := client.Client() + lockVal := "__redcache:lock:" + uuid.New().String() + err := innerClient.Do(context.Background(), innerClient.B().Set().Key(key).Value(lockVal).Nx().Get().Px(time.Second*30).Build()).Error() + require.True(t, rueidis.IsRedisNil(err)) + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + err = client.Set(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + return "val", nil + }) + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) +} + +func TestPrimeableCacheAside_Close_CancelsPendingLocks(t *testing.T) { + client := makePrimeableClient(t, addr) + defer client.Client().Close() + ctx := context.Background() + + key := "key:" + uuid.New().String() + + // Place a lock so operations will wait. + innerClient := client.Client() + lockVal := "__redcache:lock:" + uuid.New().String() + err := innerClient.Do(ctx, innerClient.B().Set().Key(key).Value(lockVal).Nx().Get().Px(time.Second*30).Build()).Error() + require.True(t, rueidis.IsRedisNil(err)) + + // Use a context with timeout so Set doesn't loop forever after Close. + setCtx, setCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer setCancel() + + errCh := make(chan error, 1) + go func() { + errCh <- client.Set(setCtx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + return "val", nil + }) + }() + + // Give Set time to start waiting. + time.Sleep(100 * time.Millisecond) + + // Close should cancel pending lock entries, causing Set to wake up and retry. + client.Close() + + select { + case <-time.After(5 * time.Second): + t.Fatal("Set did not return after Close") + case err := <-errCh: + // Set should eventually fail with context deadline exceeded because + // the external lock persists, but Close woke it up at least once. + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) + } +} + +func TestPrimeableCacheAside_SetMulti_ContextCancellation(t *testing.T) { + client := makePrimeableClient(t, addr) + defer client.Client().Close() + + keys := []string{ + "key:0:" + uuid.New().String(), + "key:1:" + uuid.New().String(), + } + + // Place locks on all keys. + innerClient := client.Client() + for _, key := range keys { + lockVal := "__redcache:lock:" + uuid.New().String() + err := innerClient.Do(context.Background(), innerClient.B().Set().Key(key).Value(lockVal).Nx().Get().Px(time.Second*30).Build()).Error() + require.True(t, rueidis.IsRedisNil(err)) + } + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + err := client.SetMulti(ctx, time.Second*10, keys, func(ctx context.Context, ks []string) (map[string]string, error) { + t.Fatal("callback should not be called when waiting for locks") + return nil, nil + }) + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) +} + +func TestPrimeableCacheAside_Set_CallbackError(t *testing.T) { + client := makePrimeableClient(t, addr) + defer client.Client().Close() + ctx := context.Background() + + key := "key:" + uuid.New().String() + cbErr := fmt.Errorf("set callback failed") + + // Set with failing callback. + err := client.Set(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + return "", cbErr + }) + require.ErrorIs(t, err, cbErr) + + // Lock should have been cleaned up — a subsequent Set should succeed. + val := "good-val:" + uuid.New().String() + err = client.Set(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + return val, nil + }) + require.NoError(t, err) + + // Verify the value is there. + res, err := client.Get(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + t.Fatal("callback should not be called") + return "", nil + }) + require.NoError(t, err) + assert.Equal(t, val, res) +} + +func TestPrimeableCacheAside_SetMulti_CallbackError_RestoresValues(t *testing.T) { + client := makePrimeableClient(t, addr) + defer client.Client().Close() + ctx := context.Background() + + keys := []string{ + "key:0:" + uuid.New().String(), + "key:1:" + uuid.New().String(), + } + + // Pre-populate with known values via Get. + originalVals := map[string]string{ + keys[0]: "original:0:" + uuid.New().String(), + keys[1]: "original:1:" + uuid.New().String(), + } + res, err := client.GetMulti(ctx, time.Second*10, keys, func(ctx context.Context, ks []string) (map[string]string, error) { + out := make(map[string]string, len(ks)) + for _, k := range ks { + out[k] = originalVals[k] + } + return out, nil + }) + require.NoError(t, err) + if diff := cmp.Diff(originalVals, res); diff != "" { + t.Fatalf("setup mismatch: %s", diff) + } + + // SetMulti with failing callback — should restore original values. + cbErr := fmt.Errorf("setmulti callback failed") + err = client.SetMulti(ctx, time.Second*10, keys, func(ctx context.Context, ks []string) (map[string]string, error) { + return nil, cbErr + }) + require.ErrorIs(t, err, cbErr) + + // Give invalidation a moment to propagate, then verify original values were restored. + time.Sleep(100 * time.Millisecond) + res, err = client.GetMulti(ctx, time.Second*10, keys, func(ctx context.Context, ks []string) (map[string]string, error) { + // If the callback fires, it means the originals were NOT restored — the keys were + // left as lock values (deleted). This is acceptable but we need to return values. + out := make(map[string]string, len(ks)) + for _, k := range ks { + out[k] = originalVals[k] + } + return out, nil + }) + require.NoError(t, err) + // Either the originals are restored or the callback re-populated them. + if diff := cmp.Diff(originalVals, res); diff != "" { + t.Errorf("values after rollback mismatch (-want +got):\n%s", diff) + } +} + +func TestPrimeableCacheAside_SetMulti_PartialCASFailure_BatchError(t *testing.T) { + client := makePrimeableClient(t, addr) + defer client.Client().Close() + ctx := context.Background() + + key1 := "key:0:" + uuid.New().String() + key2 := "key:1:" + uuid.New().String() + keys := []string{key1, key2} + + // Use SetMulti with a callback that calls ForceSet on key2 to steal its lock + // between lock acquisition and CAS write. + forcedVal := "forced:" + uuid.New().String() + err := client.SetMulti(ctx, time.Second*10, keys, func(ctx context.Context, ks []string) (map[string]string, error) { + // Steal key2's lock while we hold it. + forceErr := client.ForceSet(ctx, time.Second*10, key2, forcedVal) + if forceErr != nil { + return nil, forceErr + } + // Return values for both keys — but CAS on key2 should fail. + return map[string]string{ + key1: "val1:" + uuid.New().String(), + key2: "val2:" + uuid.New().String(), + }, nil + }) + + if err != nil { + // Should be a BatchError with key2 failed. + var batchErr *redcache.BatchError + if assert.ErrorAs(t, err, &batchErr) { + assert.True(t, batchErr.HasFailures()) + assert.Contains(t, batchErr.Failed, key2, "key2 should have failed CAS") + assert.ErrorIs(t, batchErr.Failed[key2], redcache.ErrLockLost) + } + } + // Either way, key2 should have the forced value. + res, getErr := client.Get(ctx, time.Second*10, key2, func(ctx context.Context, k string) (string, error) { + t.Fatal("callback should not be called — forced value should exist") + return "", nil + }) + require.NoError(t, getErr) + assert.Equal(t, forcedVal, res) +} + +func TestPrimeableCacheAside_ForceSet_OverwritesExistingValue(t *testing.T) { + client := makePrimeableClient(t, addr) + defer client.Client().Close() + ctx := context.Background() + + key := "key:" + uuid.New().String() + originalVal := "original:" + uuid.New().String() + forcedVal := "forced:" + uuid.New().String() + + // Populate via Get. + res, err := client.Get(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + return originalVal, nil + }) + require.NoError(t, err) + assert.Equal(t, originalVal, res) + + // ForceSet overwrites the real value. + err = client.ForceSet(ctx, time.Second*10, key, forcedVal) + require.NoError(t, err) + + // Allow invalidation message to propagate to the client-side cache. + time.Sleep(100 * time.Millisecond) + + // Verify forced value is returned. + res, err = client.Get(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + t.Fatal("callback should not be called") + return "", nil + }) + require.NoError(t, err) + assert.Equal(t, forcedVal, res) +} + +func TestNewPrimeableCacheAside_Validation(t *testing.T) { + t.Run("empty InitAddress", func(t *testing.T) { + _, err := redcache.NewPrimeableCacheAside( + rueidis.ClientOption{}, + redcache.CacheAsideOption{}, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "InitAddress") + }) + + t.Run("negative LockTTL", func(t *testing.T) { + _, err := redcache.NewPrimeableCacheAside( + rueidis.ClientOption{InitAddress: addr}, + redcache.CacheAsideOption{LockTTL: -1 * time.Second}, + ) + require.Error(t, err) + }) +} + +func TestPrimeableCacheAside_MultiClient_SetGet(t *testing.T) { + client1 := makePrimeableClient(t, addr) + defer client1.Client().Close() + client2 := makePrimeableClient(t, addr) + defer client2.Client().Close() + ctx := context.Background() + + key := "key:" + uuid.New().String() + setVal := "set-val:" + uuid.New().String() + + // client1 does Set. + err := client1.Set(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + return setVal, nil + }) + require.NoError(t, err) + + // client2 does Get — should see the Set value without invoking callback. + called := false + res, err := client2.Get(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + called = true + return "other-val", nil + }) + require.NoError(t, err) + assert.Equal(t, setVal, res) + assert.False(t, called, "client2 Get callback should not be called") +} + +func TestPrimeableCacheAside_ConcurrentSetAndGet(t *testing.T) { + client := makePrimeableClient(t, addr) + defer client.Client().Close() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + key := "key:" + uuid.New().String() + + var wg sync.WaitGroup + for range 50 { + wg.Add(2) + go func() { + defer wg.Done() + _ = client.Set(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + return "set:" + uuid.New().String(), nil + }) + }() + go func() { + defer wg.Done() + _, _ = client.Get(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + return "get:" + uuid.New().String(), nil + }) + }() + } + wg.Wait() + + // Key should have a value — no deadlock, no panic. + res, err := client.Get(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + t.Fatal("callback should not be called — value should exist") + return "", nil + }) + require.NoError(t, err) + assert.NotEmpty(t, res) +} + +func TestPrimeableCacheAside_SetMulti_EmptyKeys(t *testing.T) { + client := makePrimeableClient(t, addr) + defer client.Client().Close() + ctx := context.Background() + + err := client.SetMulti(ctx, time.Second*10, nil, func(ctx context.Context, ks []string) (map[string]string, error) { + t.Fatal("callback should not be called for empty keys") + return nil, nil + }) + require.NoError(t, err) +} + +func TestPrimeableCacheAside_ForceSetMulti_EmptyMap(t *testing.T) { + client := makePrimeableClient(t, addr) + defer client.Client().Close() + ctx := context.Background() + + err := client.ForceSetMulti(ctx, time.Second*10, nil) + require.NoError(t, err) + + err = client.ForceSetMulti(ctx, time.Second*10, map[string]string{}) + require.NoError(t, err) +} diff --git a/primeable_setmulti_helpers.go b/primeable_setmulti_helpers.go new file mode 100644 index 0000000..95f4238 --- /dev/null +++ b/primeable_setmulti_helpers.go @@ -0,0 +1,358 @@ +package redcache + +import ( + "context" + "fmt" + "sort" + "strconv" + "time" + + "github.com/redis/rueidis" + "golang.org/x/sync/errgroup" + + "github.com/dcbickfo/redcache/internal/cmdx" +) + +// acquireMultiWriteLocks acquires write locks on all keys in sorted order with rollback. +// Returns lockValues map and savedValues map (for rollback of previous real values). +func (pca *PrimeableCacheAside) acquireMultiWriteLocks( + ctx context.Context, + keys []string, +) (lockValues map[string]string, savedValues map[string]string, err error) { + sorted := make([]string, len(keys)) + copy(sorted, keys) + sort.Strings(sorted) + + lockValues = make(map[string]string, len(sorted)) + savedValues = make(map[string]string, len(sorted)) + remaining := sorted + + for len(remaining) > 0 { + firstFailed, err := pca.tryAcquireRemaining(ctx, remaining, lockValues, savedValues) + if err != nil { + pca.restoreMultiValues(ctx, lockValues, savedValues) + return nil, nil, err + } + + if firstFailed == "" { + break + } + + if err := pca.waitForFailedKey(ctx, firstFailed, lockValues, savedValues); err != nil { + return nil, nil, err + } + + remaining = pca.computeRemaining(sorted, lockValues) + } + + return lockValues, savedValues, nil +} + +// tryAcquireRemaining generates lock values and batch-acquires locks for remaining keys. +// On partial failure, rolls back locks after the first failed key and returns firstFailed. +func (pca *PrimeableCacheAside) tryAcquireRemaining( + ctx context.Context, + remaining []string, + lockValues map[string]string, + savedValues map[string]string, +) (firstFailed string, err error) { + lockTTLMs := strconv.FormatInt(pca.lockTTL.Milliseconds(), 10) + batchLocks := make(map[string]string, len(remaining)) + for _, key := range remaining { + if _, ok := lockValues[key]; !ok { + batchLocks[key] = pca.lockPool.Generate() + } + } + + acquired, backups, firstFailed, err := pca.batchAcquireWithBackup(ctx, remaining, batchLocks, lockTTLMs) + if err != nil { + return "", err + } + + for key, lockVal := range acquired { + lockValues[key] = lockVal + if backup, ok := backups[key]; ok { + savedValues[key] = backup + } + } + + if firstFailed != "" { + pca.rollbackAfterFirstFailure(ctx, remaining, firstFailed, lockValues, savedValues, acquired) + pca.touchMultiLocks(ctx, lockValues) + } + + return firstFailed, nil +} + +// waitForFailedKey registers, subscribes, and waits for a failed key's lock to release. +func (pca *PrimeableCacheAside) waitForFailedKey( + ctx context.Context, + firstFailed string, + lockValues map[string]string, + savedValues map[string]string, +) error { + waitChan := pca.register(firstFailed) + pca.client.DoCache(ctx, pca.client.B().Get().Key(firstFailed).Cache(), pca.lockTTL) + + select { + case <-waitChan: + return nil + case <-ctx.Done(): + pca.restoreMultiValues(ctx, lockValues, savedValues) + return ctx.Err() + } +} + +// computeRemaining returns the sorted keys that haven't been locked yet. +func (pca *PrimeableCacheAside) computeRemaining(sorted []string, lockValues map[string]string) []string { + remaining := make([]string, 0, len(sorted)) + for _, key := range sorted { + if _, ok := lockValues[key]; !ok { + remaining = append(remaining, key) + } + } + return remaining +} + +type lockAcquireEntry struct { + key string + lockVal string +} + +// batchAcquireWithBackup attempts to acquire write locks on keys, grouped by slot. +// Returns acquired locks, saved backups, the first failed key (in input order), and any error. +func (pca *PrimeableCacheAside) batchAcquireWithBackup( + ctx context.Context, + keys []string, + batchLocks map[string]string, + lockTTLMs string, +) (acquired map[string]string, backups map[string]string, firstFailed string, err error) { + acquired = make(map[string]string, len(keys)) + backups = make(map[string]string) + + entries := make([]lockAcquireEntry, 0, len(keys)) + for _, key := range keys { + entries = append(entries, lockAcquireEntry{key: key, lockVal: batchLocks[key]}) + } + slotGroups := cmdx.GroupBySlot(entries, func(e lockAcquireEntry) string { return e.key }) + + for _, group := range slotGroups { + if err := pca.execSlotAcquire(ctx, group, lockTTLMs, acquired, backups); err != nil { + return nil, nil, "", err + } + } + + // Find first failed key in input order. + for _, key := range keys { + if _, ok := acquired[key]; !ok { + return acquired, backups, key, nil + } + } + + return acquired, backups, "", nil +} + +// execSlotAcquire executes lock acquisitions for a single slot group and populates acquired/backups. +func (pca *PrimeableCacheAside) execSlotAcquire( + ctx context.Context, + group []lockAcquireEntry, + lockTTLMs string, + acquired map[string]string, + backups map[string]string, +) error { + stmts := make([]rueidis.LuaExec, len(group)) + for i, entry := range group { + stmts[i] = rueidis.LuaExec{ + Keys: []string{entry.key}, + Args: []string{entry.lockVal, lockTTLMs, pca.lockPrefix}, + } + } + resps := acquireWriteLockWithBackupScript.ExecMulti(ctx, pca.client, stmts...) + for i, resp := range resps { + arr, err := resp.ToArray() + if err != nil { + // Release any locks we've already acquired in previous slots. + for k, v := range acquired { + pca.bestEffortUnlock(ctx, k, v) + } + return fmt.Errorf("lock key %q: %w", group[i].key, err) + } + success, _ := arr[0].AsInt64() + if success != 1 { + continue + } + acquired[group[i].key] = group[i].lockVal + if backupVal, bErr := arr[1].ToString(); bErr == nil && backupVal != "" { + backups[group[i].key] = backupVal + } + } + return nil +} + +// rollbackAfterFirstFailure releases locks acquired AFTER the first failed key +// (in sorted order), keeping locks before it. +func (pca *PrimeableCacheAside) rollbackAfterFirstFailure( + ctx context.Context, + sorted []string, + firstFailed string, + lockValues map[string]string, + savedValues map[string]string, + justAcquired map[string]string, +) { + pastFailure := false + for _, key := range sorted { + if key == firstFailed { + pastFailure = true + continue + } + if !pastFailure { + continue + } + // Release locks acquired in this batch that are after the first failure. + if lockVal, ok := justAcquired[key]; ok { + if saved, hasSaved := savedValues[key]; hasSaved { + pca.restoreValue(ctx, key, lockVal, saved) + delete(savedValues, key) + } else { + pca.bestEffortUnlock(ctx, key, lockVal) + } + delete(lockValues, key) + } + } +} + +// touchMultiLocks refreshes TTL on held locks using CAS PEXPIRE. +// Removes any locks that were lost (stolen by another operation). +func (pca *PrimeableCacheAside) touchMultiLocks(ctx context.Context, lockValues map[string]string) { + lockTTLMs := strconv.FormatInt(pca.lockTTL.Milliseconds(), 10) + for key, lockVal := range lockValues { + result, err := refreshLockScript.Exec(ctx, pca.client, []string{key}, []string{lockVal, lockTTLMs}).AsInt64() + if err != nil || result == 0 { + // Lock was lost — remove from our set. + pca.logger.Debug("lock refresh failed, removing from held locks", "key", key) + delete(lockValues, key) + } + } +} + +type casSetEntry struct { + key string + setStmt rueidis.LuaExec +} + +// setMultiValuesWithCAS batch-sets values using CAS, grouped by Redis cluster slot. +// Returns succeeded keys and a map of failed keys to their errors. +func (pca *PrimeableCacheAside) setMultiValuesWithCAS( + ctx context.Context, + ttl time.Duration, + values map[string]string, + lockValues map[string]string, +) (succeeded []string, failed map[string]error) { + failed = make(map[string]error) + + ttlMs := strconv.FormatInt(ttl.Milliseconds(), 10) + entries := make([]casSetEntry, 0, len(values)) + for key, val := range values { + lockVal, ok := lockValues[key] + if !ok { + continue + } + entries = append(entries, casSetEntry{ + key: key, + setStmt: rueidis.LuaExec{ + Keys: []string{key}, + Args: []string{val, ttlMs, lockVal}, + }, + }) + } + slotGroups := cmdx.GroupBySlot(entries, func(e casSetEntry) string { return e.key }) + + type slotResult struct { + entries []casSetEntry + resps []rueidis.RedisResult + } + + eg, egCtx := errgroup.WithContext(ctx) + resultsCh := make(chan slotResult, len(slotGroups)) + + for _, group := range slotGroups { + eg.Go(func() error { + stmts := make([]rueidis.LuaExec, len(group)) + for i, e := range group { + stmts[i] = e.setStmt + } + resps := setWithWriteLockScript.ExecMulti(egCtx, pca.client, stmts...) + resultsCh <- slotResult{entries: group, resps: resps} + return nil + }) + } + + _ = eg.Wait() + close(resultsCh) + + for sr := range resultsCh { + pca.collectCASResults(sr.entries, sr.resps, &succeeded, failed) + } + + return succeeded, failed +} + +// collectCASResults processes Lua CAS responses, populating succeeded/failed. +func (pca *PrimeableCacheAside) collectCASResults( + entries []casSetEntry, + resps []rueidis.RedisResult, + succeeded *[]string, + failed map[string]error, +) { + for i, resp := range resps { + key := entries[i].key + if err := resp.Error(); err != nil { + failed[key] = fmt.Errorf("CAS set key %q: %w", key, err) + continue + } + val, _ := resp.AsInt64() + if val == 0 { + failed[key] = ErrLockLost + continue + } + *succeeded = append(*succeeded, key) + } +} + +// restoreMultiValues restores saved values or deletes keys for all held locks. +func (pca *PrimeableCacheAside) restoreMultiValues(ctx context.Context, lockValues, savedValues map[string]string) { + cleanupCtx := context.WithoutCancel(ctx) + toCtx, cancel := context.WithTimeout(cleanupCtx, pca.lockTTL) + defer cancel() + + for key, lockVal := range lockValues { + saved := savedValues[key] + pca.restoreValue(toCtx, key, lockVal, saved) + } +} + +// restoreValue restores a single key's previous value or deletes it. +func (pca *PrimeableCacheAside) restoreValue(ctx context.Context, key, lockVal, savedValue string) { + err := restoreValueOrDeleteScript.Exec(ctx, pca.client, []string{key}, []string{lockVal, savedValue}).Error() + if err != nil { + pca.logger.Error("failed to restore value", "key", key, "error", err) + } +} + +// bestEffortUnlock releases a lock using delKeyLua. +func (pca *PrimeableCacheAside) bestEffortUnlock(ctx context.Context, key, lockVal string) { + cleanupCtx := context.WithoutCancel(ctx) + toCtx, cancel := context.WithTimeout(cleanupCtx, pca.lockTTL) + defer cancel() + if err := pca.unlock(toCtx, key, lockVal); err != nil { + pca.logger.Error("failed to unlock key", "key", key, "error", err) + } +} + +// unlockMultiKeys releases multiple locks. +func (pca *PrimeableCacheAside) unlockMultiKeys(ctx context.Context, lockVals map[string]string) { + cleanupCtx := context.WithoutCancel(ctx) + toCtx, cancel := context.WithTimeout(cleanupCtx, pca.lockTTL) + defer cancel() + pca.unlockMulti(toCtx, lockVals) +} From e5036af66b2a50ee52bf3c50c0634e418eba99d4 Mon Sep 17 00:00:00 2001 From: dcbickfo Date: Sun, 15 Mar 2026 13:09:03 -0400 Subject: [PATCH 2/3] Fix Set to backup and restore previous value on callback error --- lua_scripts.go | 32 ++++---------------------------- primeable_cacheaside.go | 11 ++++++----- primeable_cacheaside_test.go | 34 ++++++++++++++++++++++++++++++++++ primeable_setmulti_helpers.go | 19 +++++++++++++++++++ 4 files changed, 63 insertions(+), 33 deletions(-) diff --git a/lua_scripts.go b/lua_scripts.go index 698fd61..4e65095 100644 --- a/lua_scripts.go +++ b/lua_scripts.go @@ -12,34 +12,10 @@ var ( // Lua scripts for PrimeableCacheAside write-lock operations. var ( - // acquireWriteLockScript atomically acquires a write lock. - // Unlike SET NX (used by Get), this allows overwriting real values but - // refuses to overwrite an existing lock, preventing Set from stomping - // on an active Get operation's lock. - // Returns 1 on success, 0 if an existing lock is present. - acquireWriteLockScript = rueidis.NewLuaScript(` - local key = KEYS[1] - local lock_value = ARGV[1] - local ttl = ARGV[2] - local lock_prefix = ARGV[3] - - local current = redis.call("GET", key) - - if current == false then - redis.call("SET", key, lock_value, "PX", ttl) - return 1 - end - - if string.sub(current, 1, string.len(lock_prefix)) == lock_prefix then - return 0 - end - - redis.call("SET", key, lock_value, "PX", ttl) - return 1 - `) - - // acquireWriteLockWithBackupScript acquires a lock and returns the previous value - // for rollback in SetMulti. + // acquireWriteLockWithBackupScript atomically acquires a write lock and + // returns the previous value for rollback. Unlike SET NX (used by Get), + // this allows overwriting real values but refuses to overwrite an existing + // lock, preventing Set from stomping on an active Get operation's lock. // Returns: [success (0 or 1), previous_value or false]. acquireWriteLockWithBackupScript = rueidis.NewLuaScript(` local key = KEYS[1] diff --git a/primeable_cacheaside.go b/primeable_cacheaside.go index b7b7358..5a00d22 100644 --- a/primeable_cacheaside.go +++ b/primeable_cacheaside.go @@ -49,6 +49,7 @@ func (pca *PrimeableCacheAside) Close() { // Set acquires a write lock on the key, calls fn to produce the value, and atomically // sets it in Redis. If another operation holds a lock, Set waits for it to complete. +// If the callback returns an error, the previous value is restored (matching SetMulti behavior). // // The callback fn receives the key and should return the value to cache. // Set respects context cancellation for timeouts. @@ -81,12 +82,12 @@ retry: } } - // Try to acquire write lock. - result, err := acquireWriteLockScript.Exec(ctx, pca.client, []string{key}, []string{lockVal, lockTTLMs, pca.lockPrefix}).AsInt64() + // Try to acquire write lock, capturing the previous value for rollback. + acquired, savedValue, err := pca.tryAcquireWriteLock(ctx, key, lockVal, lockTTLMs) if err != nil { - return fmt.Errorf("write lock for key %q: %w", key, err) + return err } - if result == 0 { + if !acquired { // Another lock appeared between DoCache and Exec. select { case <-waitChan: @@ -99,7 +100,7 @@ retry: // Lock acquired — execute callback. newVal, err := fn(ctx, key) if err != nil { - pca.bestEffortUnlock(ctx, key, lockVal) + pca.restoreValue(ctx, key, lockVal, savedValue) return err } diff --git a/primeable_cacheaside_test.go b/primeable_cacheaside_test.go index bf20287..5ffdcb9 100644 --- a/primeable_cacheaside_test.go +++ b/primeable_cacheaside_test.go @@ -449,6 +449,40 @@ func TestPrimeableCacheAside_Set_CallbackError(t *testing.T) { assert.Equal(t, val, res) } +func TestPrimeableCacheAside_Set_CallbackError_RestoresValue(t *testing.T) { + client := makePrimeableClient(t, addr) + defer client.Client().Close() + ctx := context.Background() + + key := "key:" + uuid.New().String() + originalVal := "original:" + uuid.New().String() + + // Pre-populate with a known value via Get. + res, err := client.Get(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + return originalVal, nil + }) + require.NoError(t, err) + assert.Equal(t, originalVal, res) + + // Set with failing callback — should restore the original value. + cbErr := fmt.Errorf("set callback failed") + err = client.Set(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + return "", cbErr + }) + require.ErrorIs(t, err, cbErr) + + // Give invalidation a moment to propagate. + time.Sleep(100 * time.Millisecond) + + // Original value should be restored. + res, err = client.Get(ctx, time.Second*10, key, func(ctx context.Context, k string) (string, error) { + // If callback fires, the original was NOT restored — it was deleted. + return originalVal, nil + }) + require.NoError(t, err) + assert.Equal(t, originalVal, res) +} + func TestPrimeableCacheAside_SetMulti_CallbackError_RestoresValues(t *testing.T) { client := makePrimeableClient(t, addr) defer client.Client().Close() diff --git a/primeable_setmulti_helpers.go b/primeable_setmulti_helpers.go index 95f4238..53d607a 100644 --- a/primeable_setmulti_helpers.go +++ b/primeable_setmulti_helpers.go @@ -13,6 +13,25 @@ import ( "github.com/dcbickfo/redcache/internal/cmdx" ) +// tryAcquireWriteLock attempts to acquire a write lock on a single key, returning the +// previous value for rollback. Returns (acquired, savedValue, error). +func (pca *PrimeableCacheAside) tryAcquireWriteLock(ctx context.Context, key, lockVal, lockTTLMs string) (bool, string, error) { + resp := acquireWriteLockWithBackupScript.Exec(ctx, pca.client, []string{key}, []string{lockVal, lockTTLMs, pca.lockPrefix}) + arr, err := resp.ToArray() + if err != nil { + return false, "", fmt.Errorf("write lock for key %q: %w", key, err) + } + success, _ := arr[0].AsInt64() + if success == 0 { + return false, "", nil + } + var savedValue string + if backupVal, bErr := arr[1].ToString(); bErr == nil { + savedValue = backupVal + } + return true, savedValue, nil +} + // acquireMultiWriteLocks acquires write locks on all keys in sorted order with rollback. // Returns lockValues map and savedValues map (for rollback of previous real values). func (pca *PrimeableCacheAside) acquireMultiWriteLocks( From 84a751e12774dc39f1eaeeccd950d1c14b829a6e Mon Sep 17 00:00:00 2001 From: dcbickfo Date: Sat, 14 Mar 2026 08:15:46 -0400 Subject: [PATCH 3/3] add refresh ahead --- cacheaside.go | 320 +++++++++++++++++++++++++++++++++---- cacheaside_test.go | 386 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 679 insertions(+), 27 deletions(-) diff --git a/cacheaside.go b/cacheaside.go index 6e90ac5..f66649b 100644 --- a/cacheaside.go +++ b/cacheaside.go @@ -96,11 +96,16 @@ type Logger interface { // CacheAside provides a cache-aside pattern backed by Redis with distributed locking // and client-side caching via rueidis invalidation messages. type CacheAside struct { - client rueidis.Client - locks syncx.Map[string, *lockEntry] - lockTTL time.Duration - logger Logger - lockPrefix string + client rueidis.Client + locks syncx.Map[string, *lockEntry] + lockTTL time.Duration + logger Logger + lockPrefix string + refreshAfter float64 // 0 = disabled + refreshing syncx.Map[string, struct{}] // dedup in-flight refreshes (local) + refreshPrefix string // prefix for distributed refresh lock keys + refreshQueue chan func() // worker pool job queue (nil when disabled) + refreshWg sync.WaitGroup // tracks active refresh workers } // CacheAsideOption configures a CacheAside instance. @@ -117,21 +122,37 @@ type CacheAsideOption struct { // LockPrefix for distributed locks. Defaults to "__redcache:lock:". // Choose a prefix unlikely to conflict with your data keys. LockPrefix string + // RefreshAfterFraction enables refresh-ahead caching. When a cached value + // is returned and more than this fraction of its TTL has elapsed, a + // background worker refreshes the value while the stale one is returned + // immediately. For example, 0.8 means "refresh after 80% of TTL has passed" + // (i.e., when 20% remains). Set to 0 (default) to disable. Must be in [0, 1). + // + // The refresh threshold is based on the client-side cache TTL (CachePTTL), + // which tracks the remaining lifetime of the locally cached entry. This closely + // approximates the server-side TTL when the same ttl parameter is used + // consistently for a given key across Get calls. + RefreshAfterFraction float64 + // RefreshWorkers is the number of background workers that process refresh-ahead + // jobs. Defaults to 4 when RefreshAfterFraction > 0. Must be > 0 when refresh + // is enabled. + RefreshWorkers int + // RefreshQueueSize is the maximum number of pending refresh jobs. When the queue + // is full, new refresh requests are silently dropped — the stale value continues + // to be served until the next access. Defaults to 64 when RefreshAfterFraction > 0. + // Must be > 0 when refresh is enabled. + RefreshQueueSize int } -// NewRedCacheAside creates a CacheAside with the given Redis client and cache-aside options. -func NewRedCacheAside(clientOption rueidis.ClientOption, caOption CacheAsideOption) (*CacheAside, error) { - // Validate client options +func validateAndApplyDefaults(clientOption rueidis.ClientOption, caOption *CacheAsideOption) error { if len(clientOption.InitAddress) == 0 { - return nil, errors.New("at least one Redis address must be provided in InitAddress") + return errors.New("at least one Redis address must be provided in InitAddress") } - - // Validate and set defaults for cache aside options if caOption.LockTTL < 0 { - return nil, errors.New("LockTTL must not be negative") + return errors.New("LockTTL must not be negative") } if caOption.LockTTL > 0 && caOption.LockTTL < 100*time.Millisecond { - return nil, errors.New("LockTTL should be at least 100ms to avoid excessive lock churn") + return errors.New("LockTTL should be at least 100ms to avoid excessive lock churn") } if caOption.LockTTL == 0 { caOption.LockTTL = 10 * time.Second @@ -142,11 +163,43 @@ func NewRedCacheAside(clientOption rueidis.ClientOption, caOption CacheAsideOpti if caOption.LockPrefix == "" { caOption.LockPrefix = "__redcache:lock:" } + return validateRefreshDefaults(caOption) +} + +func validateRefreshDefaults(caOption *CacheAsideOption) error { + if caOption.RefreshAfterFraction < 0 || caOption.RefreshAfterFraction >= 1 { + return errors.New("RefreshAfterFraction must be in range [0, 1)") + } + if caOption.RefreshAfterFraction == 0 { + return nil + } + if caOption.RefreshWorkers < 0 { + return errors.New("RefreshWorkers must not be negative") + } + if caOption.RefreshQueueSize < 0 { + return errors.New("RefreshQueueSize must not be negative") + } + if caOption.RefreshWorkers == 0 { + caOption.RefreshWorkers = 4 + } + if caOption.RefreshQueueSize == 0 { + caOption.RefreshQueueSize = 64 + } + return nil +} + +// NewRedCacheAside creates a CacheAside with the given Redis client and cache-aside options. +func NewRedCacheAside(clientOption rueidis.ClientOption, caOption CacheAsideOption) (*CacheAside, error) { + if err := validateAndApplyDefaults(clientOption, &caOption); err != nil { + return nil, err + } rca := &CacheAside{ - lockTTL: caOption.LockTTL, - logger: caOption.Logger, - lockPrefix: caOption.LockPrefix, + lockTTL: caOption.LockTTL, + logger: caOption.Logger, + lockPrefix: caOption.LockPrefix, + refreshAfter: caOption.RefreshAfterFraction, + refreshPrefix: "__redcache:refresh:", } clientOption.OnInvalidations = rca.onInvalidate @@ -159,6 +212,12 @@ func NewRedCacheAside(clientOption rueidis.ClientOption, caOption CacheAsideOpti if err != nil { return nil, err } + + if rca.refreshAfter > 0 { + rca.refreshQueue = make(chan func(), caOption.RefreshQueueSize) + rca.startRefreshWorkers(caOption.RefreshWorkers) + } + return rca, nil } @@ -169,13 +228,31 @@ func (rca *CacheAside) Client() rueidis.Client { return rca.client } -// Close cancels all pending lock entries. It does NOT close the underlying -// Redis client — that is the caller's responsibility. +// Close cancels all pending lock entries and shuts down refresh workers. +// It does NOT close the underlying Redis client — that is the caller's responsibility. +// If refresh-ahead is enabled, Close waits for in-flight refresh jobs to complete +// (bounded by LockTTL). func (rca *CacheAside) Close() { rca.locks.Range(func(_ string, entry *lockEntry) bool { entry.cancel() return true }) + if rca.refreshQueue != nil { + close(rca.refreshQueue) + rca.refreshWg.Wait() + } +} + +func (rca *CacheAside) startRefreshWorkers(n int) { + for range n { + rca.refreshWg.Add(1) + go func() { + defer rca.refreshWg.Done() + for job := range rca.refreshQueue { + job() + } + }() + } } func (rca *CacheAside) onInvalidate(messages []rueidis.RedisMessage) { @@ -251,13 +328,16 @@ func (rca *CacheAside) Get( ) (string, error) { retry: wait := rca.register(key) - val, err := rca.tryGet(ctx, ttl, key) + val, pttl, err := rca.tryGet(ctx, ttl, key) if err != nil && !errors.Is(err, errNotFound) { return "", err } if err == nil && val != "" { + if rca.shouldRefresh(pttl, ttl) { + rca.triggerRefresh(ctx, ttl, key, fn) + } return val, nil } @@ -311,7 +391,7 @@ var ( // ErrLockLost is defined in errors.go. -func (rca *CacheAside) tryGet(ctx context.Context, ttl time.Duration, key string) (string, error) { +func (rca *CacheAside) tryGet(ctx context.Context, ttl time.Duration, key string) (string, int64, error) { resp := rca.client.DoCache(ctx, rca.client.B().Get().Key(key).Cache(), ttl) val, err := resp.ToString() if rueidis.IsRedisNil(err) || strings.HasPrefix(val, rca.lockPrefix) { // no response or is a lock value @@ -320,13 +400,13 @@ func (rca *CacheAside) tryGet(ctx context.Context, ttl time.Duration, key string } else { rca.logger.Debug("cache miss - lock value found", "key", key) } - return "", errNotFound + return "", 0, errNotFound } if err != nil { - return "", err + return "", 0, err } rca.logger.Debug("cache hit", "key", key) - return val, nil + return val, resp.CachePTTL(), nil } func (rca *CacheAside) trySetKeyFunc(ctx context.Context, ttl time.Duration, key string, fn func(ctx context.Context, key string) (string, error)) (val string, err error) { @@ -413,7 +493,7 @@ func (rca *CacheAside) GetMulti( retry: waitLock = rca.registerAll(mapsx.Keys(waitLock)) - vals, err := rca.tryGetMulti(ctx, ttl, mapsx.Keys(waitLock)) + vals, needRefresh, err := rca.tryGetMulti(ctx, ttl, mapsx.Keys(waitLock)) if err != nil && !rueidis.IsRedisNil(err) { return nil, err } @@ -423,6 +503,10 @@ retry: delete(waitLock, k) } + if len(needRefresh) > 0 { + rca.triggerMultiRefresh(ctx, ttl, needRefresh, fn) + } + if len(waitLock) > 0 { vals, err := rca.trySetMultiKeyFn(ctx, ttl, mapsx.Keys(waitLock), fn) if err != nil { @@ -454,7 +538,7 @@ func (rca *CacheAside) registerAll(keys []string) map[string]<-chan struct{} { return res } -func (rca *CacheAside) tryGetMulti(ctx context.Context, ttl time.Duration, keys []string) (map[string]string, error) { +func (rca *CacheAside) tryGetMulti(ctx context.Context, ttl time.Duration, keys []string) (map[string]string, []string, error) { multi := make([]rueidis.CacheableTTL, len(keys)) for i, key := range keys { cmd := rca.client.B().Get().Key(key).Cache() @@ -466,19 +550,23 @@ func (rca *CacheAside) tryGetMulti(ctx context.Context, ttl time.Duration, keys resps := rca.client.DoMultiCache(ctx, multi...) res := make(map[string]string) + var needRefresh []string for i, resp := range resps { val, err := resp.ToString() if rueidis.IsRedisNil(err) { continue } if err != nil { - return nil, fmt.Errorf("key %q: %w", keys[i], err) + return nil, nil, fmt.Errorf("key %q: %w", keys[i], err) } if !strings.HasPrefix(val, rca.lockPrefix) { res[keys[i]] = val + if rca.shouldRefresh(resp.CachePTTL(), ttl) { + needRefresh = append(needRefresh, keys[i]) + } } } - return res, nil + return res, needRefresh, nil } func (rca *CacheAside) trySetMultiKeyFn( @@ -631,6 +719,184 @@ func (rca *CacheAside) setMultiWithLock(ctx context.Context, ttl time.Duration, return rca.executeSetStatements(ctx, stmts) } +func (rca *CacheAside) shouldRefresh(cachePTTL int64, ttl time.Duration) bool { + if rca.refreshAfter == 0 || cachePTTL <= 0 { + return false + } + threshold := time.Duration(float64(ttl) * (1 - rca.refreshAfter)) + return time.Duration(cachePTTL)*time.Millisecond < threshold +} + +// triggerRefresh enqueues a single-key refresh job to the worker pool. +// Two-level dedup: local syncx.Map + distributed SET NX on a separate refresh key. +// If the queue is full, the refresh is silently dropped (stale value is still served). +func (rca *CacheAside) triggerRefresh( + ctx context.Context, + ttl time.Duration, + key string, + fn func(ctx context.Context, key string) (string, error), +) { + // Local dedup: skip if this process is already refreshing this key. + if _, loaded := rca.refreshing.LoadOrStore(key, struct{}{}); loaded { + return + } + + job := func() { + defer rca.refreshing.Delete(key) + rca.doSingleRefresh(ctx, ttl, key, fn) + } + + select { + case rca.refreshQueue <- job: + default: + // Queue full — drop refresh, stale value is fine. + rca.refreshing.Delete(key) + } +} + +// doSingleRefresh acquires a distributed refresh lock, calls fn, and writes the result. +func (rca *CacheAside) doSingleRefresh( + ctx context.Context, + ttl time.Duration, + key string, + fn func(ctx context.Context, key string) (string, error), +) { + refreshCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), rca.lockTTL) + defer cancel() + + // Distributed dedup: SET NX on a separate refresh lock key. + refreshKey := rca.refreshPrefix + key + err := rca.client.Do(refreshCtx, rca.client.B().Set().Key(refreshKey).Value("1").Nx().Px(rca.lockTTL).Build()).Error() + if err != nil { + return // NX failed or Redis error — another process is refreshing. + } + defer func() { + cleanupCtx, cleanupCancel := context.WithTimeout(context.WithoutCancel(ctx), rca.lockTTL) + defer cleanupCancel() + rca.client.Do(cleanupCtx, rca.client.B().Del().Key(refreshKey).Build()) + }() + + val, err := fn(refreshCtx, key) + if err != nil { + rca.logger.Error("refresh-ahead callback failed", "key", key, "error", err) + return + } + + if err := rca.client.Do(refreshCtx, rca.client.B().Set().Key(key).Value(val).Px(ttl).Build()).Error(); err != nil { + rca.logger.Error("refresh-ahead set failed", "key", key, "error", err) + } +} + +// triggerMultiRefresh enqueues a multi-key refresh job to the worker pool. +// Two-level dedup: local syncx.Map + distributed SET NX on separate refresh keys. +// If the queue is full, the refresh is silently dropped (stale values are still served). +func (rca *CacheAside) triggerMultiRefresh( + ctx context.Context, + ttl time.Duration, + keys []string, + fn func(ctx context.Context, keys []string) (map[string]string, error), +) { + // Local dedup: filter to keys not already being refreshed. + var toRefresh []string + for _, key := range keys { + if _, loaded := rca.refreshing.LoadOrStore(key, struct{}{}); !loaded { + toRefresh = append(toRefresh, key) + } + } + if len(toRefresh) == 0 { + return + } + + job := func() { + defer func() { + for _, key := range toRefresh { + rca.refreshing.Delete(key) + } + }() + rca.doMultiRefresh(ctx, ttl, toRefresh, fn) + } + + select { + case rca.refreshQueue <- job: + default: + // Queue full — drop refresh, stale values are fine. + for _, key := range toRefresh { + rca.refreshing.Delete(key) + } + } +} + +// doMultiRefresh acquires distributed refresh locks, calls fn, and writes results. +func (rca *CacheAside) doMultiRefresh( + ctx context.Context, + ttl time.Duration, + keys []string, + fn func(ctx context.Context, keys []string) (map[string]string, error), +) { + refreshCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), rca.lockTTL) + defer cancel() + + lockedKeys := rca.acquireRefreshLocks(refreshCtx, keys) + if len(lockedKeys) == 0 { + return + } + defer rca.deleteRefreshLocks(ctx, lockedKeys) + + vals, err := fn(refreshCtx, lockedKeys) + if err != nil { + rca.logger.Error("refresh-ahead multi callback failed", "error", err) + return + } + + rca.setRefreshedValues(refreshCtx, ttl, vals) +} + +// acquireRefreshLocks batch-acquires distributed SET NX locks for refresh keys. +func (rca *CacheAside) acquireRefreshLocks(ctx context.Context, keys []string) []string { + cmds := make(rueidis.Commands, len(keys)) + for i, key := range keys { + cmds[i] = rca.client.B().Set().Key(rca.refreshPrefix + key).Value("1").Nx().Px(rca.lockTTL).Build() + } + resps := rca.client.DoMulti(ctx, cmds...) + + var locked []string + for i, resp := range resps { + if err := resp.Error(); err != nil { + continue + } + locked = append(locked, keys[i]) + } + return locked +} + +// deleteRefreshLocks removes distributed refresh lock keys (best effort). +func (rca *CacheAside) deleteRefreshLocks(ctx context.Context, keys []string) { + cleanupCtx, cleanupCancel := context.WithTimeout(context.WithoutCancel(ctx), rca.lockTTL) + defer cleanupCancel() + delCmds := make(rueidis.Commands, len(keys)) + for i, key := range keys { + delCmds[i] = rca.client.B().Del().Key(rca.refreshPrefix + key).Build() + } + rca.client.DoMulti(cleanupCtx, delCmds...) +} + +// setRefreshedValues writes refreshed values to Redis via direct SET. +func (rca *CacheAside) setRefreshedValues(ctx context.Context, ttl time.Duration, vals map[string]string) { + setCmds := make(rueidis.Commands, 0, len(vals)) + for key, val := range vals { + setCmds = append(setCmds, rca.client.B().Set().Key(key).Value(val).Px(ttl).Build()) + } + if len(setCmds) == 0 { + return + } + resps := rca.client.DoMulti(ctx, setCmds...) + for _, resp := range resps { + if err := resp.Error(); err != nil { + rca.logger.Error("refresh-ahead multi set failed", "error", err) + } + } +} + func (rca *CacheAside) unlockMulti(ctx context.Context, lockVals map[string]string) { if len(lockVals) == 0 { return diff --git a/cacheaside_test.go b/cacheaside_test.go index e396953..2fccaba 100644 --- a/cacheaside_test.go +++ b/cacheaside_test.go @@ -6,6 +6,7 @@ import ( "math/rand/v2" "slices" "sync" + "sync/atomic" "testing" "time" @@ -935,6 +936,391 @@ func TestCacheAside_Close(t *testing.T) { } } +func makeRefreshClient(t *testing.T, addr []string, fraction float64) *redcache.CacheAside { + t.Helper() + client, err := redcache.NewRedCacheAside( + rueidis.ClientOption{ + InitAddress: addr, + }, + redcache.CacheAsideOption{ + LockTTL: time.Second * 2, + RefreshAfterFraction: fraction, + }, + ) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + client.Close() + client.Client().Close() + }) + return client +} + +func TestRefreshAhead_TriggersBackgroundRefresh(t *testing.T) { + // fraction=0.5 means refresh when >50% of TTL elapsed (i.e. <50% remaining). + client := makeRefreshClient(t, addr, 0.5) + ctx := context.Background() + + key := "key:" + uuid.New().String() + callCount := 0 + var mu sync.Mutex + + cb := func(_ context.Context, _ string) (string, error) { + mu.Lock() + callCount++ + c := callCount + mu.Unlock() + return fmt.Sprintf("val-%d", c), nil + } + + ttl := 2 * time.Second + + // First call: populates cache — fn called once. + res, err := client.Get(ctx, ttl, key, cb) + require.NoError(t, err) + assert.Equal(t, "val-1", res) + + // Wait until >50% of TTL has elapsed so remaining < threshold. + time.Sleep(1200 * time.Millisecond) + + // This Get should return the stale value and trigger a background refresh. + res, err = client.Get(ctx, ttl, key, cb) + require.NoError(t, err) + assert.Equal(t, "val-1", res) // stale value returned immediately + + // Give the background goroutine time to complete. + time.Sleep(500 * time.Millisecond) + + // Next Get should see the refreshed value. + res, err = client.Get(ctx, ttl, key, cb) + require.NoError(t, err) + assert.Equal(t, "val-2", res) + + mu.Lock() + assert.Equal(t, 2, callCount, "fn should have been called exactly twice") + mu.Unlock() +} + +func TestRefreshAhead_Dedup(t *testing.T) { + client := makeRefreshClient(t, addr, 0.5) + ctx := context.Background() + + key := "key:" + uuid.New().String() + var refreshCount int64 + var mu sync.Mutex + firstCall := true + + cb := func(_ context.Context, _ string) (string, error) { + mu.Lock() + defer mu.Unlock() + if firstCall { + firstCall = false + return "initial", nil + } + refreshCount++ + time.Sleep(200 * time.Millisecond) // slow enough to overlap concurrent Gets + return "refreshed", nil + } + + ttl := 2 * time.Second + + // Populate cache. + _, err := client.Get(ctx, ttl, key, cb) + require.NoError(t, err) + + // Wait until threshold is crossed. + time.Sleep(1200 * time.Millisecond) + + // Fire many concurrent Gets — all should return stale and trigger at most one refresh. + var wg sync.WaitGroup + for i := 0; i < 20; i++ { + wg.Add(1) + go func() { + defer wg.Done() + res, err := client.Get(ctx, ttl, key, cb) + assert.NoError(t, err) + assert.Equal(t, "initial", res) // stale returned + }() + } + wg.Wait() + + // Wait for refresh goroutine to finish. + time.Sleep(500 * time.Millisecond) + + mu.Lock() + assert.Equal(t, int64(1), refreshCount, "refresh callback should be called exactly once") + mu.Unlock() +} + +func TestRefreshAhead_Disabled(t *testing.T) { + // Default (fraction=0) — no refresh-ahead. + client := makeClient(t, addr) + defer client.Client().Close() + ctx := context.Background() + + key := "key:" + uuid.New().String() + callCount := 0 + var mu sync.Mutex + + cb := func(_ context.Context, _ string) (string, error) { + mu.Lock() + callCount++ + mu.Unlock() + return "val", nil + } + + _, err := client.Get(ctx, 2*time.Second, key, cb) + require.NoError(t, err) + + // Wait until TTL is nearly expired. + time.Sleep(1500 * time.Millisecond) + + _, err = client.Get(ctx, 2*time.Second, key, cb) + require.NoError(t, err) + + time.Sleep(300 * time.Millisecond) + + mu.Lock() + assert.Equal(t, 1, callCount, "fn should only be called once with refresh-ahead disabled") + mu.Unlock() +} + +func TestRefreshAhead_ErrorLogged(t *testing.T) { + client := makeRefreshClient(t, addr, 0.5) + ctx := context.Background() + + key := "key:" + uuid.New().String() + firstCall := true + var mu sync.Mutex + + cb := func(_ context.Context, _ string) (string, error) { + mu.Lock() + defer mu.Unlock() + if firstCall { + firstCall = false + return "initial", nil + } + return "", fmt.Errorf("refresh failed") + } + + ttl := 2 * time.Second + + // Populate cache. + res, err := client.Get(ctx, ttl, key, cb) + require.NoError(t, err) + assert.Equal(t, "initial", res) + + // Wait until threshold is crossed. + time.Sleep(1200 * time.Millisecond) + + // Get triggers background refresh which will fail — stale value returned. + res, err = client.Get(ctx, ttl, key, cb) + require.NoError(t, err) + assert.Equal(t, "initial", res) + + // Wait for refresh goroutine to complete (error is logged, not returned). + time.Sleep(500 * time.Millisecond) + + // Stale value should still be present — no panic. + res, err = client.Get(ctx, ttl, key, cb) + require.NoError(t, err) + assert.Equal(t, "initial", res) +} + +func TestRefreshAhead_GetMulti(t *testing.T) { + client := makeRefreshClient(t, addr, 0.5) + ctx := context.Background() + + keys := []string{ + "key:0:" + uuid.New().String(), + "key:1:" + uuid.New().String(), + } + callCount := 0 + var mu sync.Mutex + + cb := func(_ context.Context, ks []string) (map[string]string, error) { + mu.Lock() + callCount++ + c := callCount + mu.Unlock() + res := make(map[string]string, len(ks)) + for _, k := range ks { + res[k] = fmt.Sprintf("val-%d", c) + } + return res, nil + } + + ttl := 2 * time.Second + + // Populate cache. + res, err := client.GetMulti(ctx, ttl, keys, cb) + require.NoError(t, err) + for _, k := range keys { + assert.Equal(t, "val-1", res[k]) + } + + // Wait until threshold is crossed. + time.Sleep(1200 * time.Millisecond) + + // GetMulti returns stale values and triggers background refresh. + res, err = client.GetMulti(ctx, ttl, keys, cb) + require.NoError(t, err) + for _, k := range keys { + assert.Equal(t, "val-1", res[k]) // stale + } + + // Wait for refresh. + time.Sleep(500 * time.Millisecond) + + // Next GetMulti should see refreshed values. + res, err = client.GetMulti(ctx, ttl, keys, cb) + require.NoError(t, err) + for _, k := range keys { + assert.Equal(t, "val-2", res[k]) + } + + mu.Lock() + assert.Equal(t, 2, callCount, "fn should have been called exactly twice") + mu.Unlock() +} + +func TestRefreshAhead_Backpressure(t *testing.T) { + // Tiny pool: 1 worker, queue size 1. + // The worker sleeps during refresh, so the queue fills fast. + client, err := redcache.NewRedCacheAside( + rueidis.ClientOption{InitAddress: addr}, + redcache.CacheAsideOption{ + LockTTL: time.Second * 3, + RefreshAfterFraction: 0.5, + RefreshWorkers: 1, + RefreshQueueSize: 1, + }, + ) + require.NoError(t, err) + t.Cleanup(func() { + client.Close() + client.Client().Close() + }) + ctx := context.Background() + + // Create many distinct keys so each triggers a separate refresh. + const numKeys = 20 + keys := make([]string, numKeys) + for i := range numKeys { + keys[i] = fmt.Sprintf("key:%d:%s", i, uuid.New().String()) + } + + ttl := 3 * time.Second + var refreshCount atomic.Int64 + + populateCb := func(_ context.Context, _ string) (string, error) { + return "initial", nil + } + + // Populate all keys. + for _, key := range keys { + _, err := client.Get(ctx, ttl, key, populateCb) + require.NoError(t, err) + } + + // Wait until >50% of TTL has elapsed so refresh triggers. + time.Sleep(1700 * time.Millisecond) + + refreshCb := func(_ context.Context, _ string) (string, error) { + refreshCount.Add(1) + time.Sleep(500 * time.Millisecond) // slow — keeps the single worker busy + return "refreshed", nil + } + + // Fire concurrent Gets on all 20 keys. With 1 worker and queue size 1, + // at most ~2 refresh jobs can be accepted (1 executing + 1 queued). + // The rest are silently dropped. + var wg sync.WaitGroup + for _, key := range keys { + k := key + wg.Add(1) + go func() { + defer wg.Done() + res, err := client.Get(ctx, ttl, k, refreshCb) + assert.NoError(t, err) + assert.Equal(t, "initial", res) // stale value always returned + }() + } + wg.Wait() + + // Wait for all enqueued refreshes to finish. + time.Sleep(1500 * time.Millisecond) + + // With 1 worker processing a 500ms job and queue size 1, far fewer than + // 20 refreshes should have executed. + count := refreshCount.Load() + assert.Less(t, count, int64(numKeys), + "expected fewer than %d refreshes, got %d — backpressure should drop excess jobs", numKeys, count) + assert.Greater(t, count, int64(0), "at least one refresh should have executed") +} + +func TestRefreshAhead_FractionValidation(t *testing.T) { + t.Run("negative fraction", func(t *testing.T) { + _, err := redcache.NewRedCacheAside( + rueidis.ClientOption{InitAddress: addr}, + redcache.CacheAsideOption{RefreshAfterFraction: -0.1}, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "RefreshAfterFraction") + }) + t.Run("fraction equals 1", func(t *testing.T) { + _, err := redcache.NewRedCacheAside( + rueidis.ClientOption{InitAddress: addr}, + redcache.CacheAsideOption{RefreshAfterFraction: 1.0}, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "RefreshAfterFraction") + }) + t.Run("fraction greater than 1", func(t *testing.T) { + _, err := redcache.NewRedCacheAside( + rueidis.ClientOption{InitAddress: addr}, + redcache.CacheAsideOption{RefreshAfterFraction: 1.5}, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "RefreshAfterFraction") + }) + t.Run("valid fraction", func(t *testing.T) { + client, err := redcache.NewRedCacheAside( + rueidis.ClientOption{InitAddress: addr}, + redcache.CacheAsideOption{RefreshAfterFraction: 0.8}, + ) + require.NoError(t, err) + client.Close() + client.Client().Close() + }) + t.Run("negative RefreshWorkers", func(t *testing.T) { + _, err := redcache.NewRedCacheAside( + rueidis.ClientOption{InitAddress: addr}, + redcache.CacheAsideOption{RefreshAfterFraction: 0.8, RefreshWorkers: -1}, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "RefreshWorkers") + }) + t.Run("negative RefreshQueueSize", func(t *testing.T) { + _, err := redcache.NewRedCacheAside( + rueidis.ClientOption{InitAddress: addr}, + redcache.CacheAsideOption{RefreshAfterFraction: 0.8, RefreshQueueSize: -1}, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "RefreshQueueSize") + }) + t.Run("custom workers and queue", func(t *testing.T) { + client, err := redcache.NewRedCacheAside( + rueidis.ClientOption{InitAddress: addr}, + redcache.CacheAsideOption{RefreshAfterFraction: 0.8, RefreshWorkers: 2, RefreshQueueSize: 16}, + ) + require.NoError(t, err) + client.Close() + client.Client().Close() + }) +} + func TestNewRedCacheAside_Validation(t *testing.T) { t.Run("empty InitAddress", func(t *testing.T) { _, err := redcache.NewRedCacheAside(