diff --git a/cache/redis/client_interface.go b/cache/redis/client_interface.go index da1e881..b758b22 100644 --- a/cache/redis/client_interface.go +++ b/cache/redis/client_interface.go @@ -15,6 +15,7 @@ type redisClient interface { Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.StatusCmd Del(ctx context.Context, keys ...string) *redis.IntCmd Pipeline() redis.Pipeliner + Watch(ctx context.Context, fn func(tx *redis.Tx) error, keys ...string) error Ping(ctx context.Context) *redis.StatusCmd Scan(ctx context.Context, cursor uint64, match string, count int64) *redis.ScanCmd diff --git a/cache/redis/redis.go b/cache/redis/redis.go index 861c1f8..c5cb5d3 100755 --- a/cache/redis/redis.go +++ b/cache/redis/redis.go @@ -234,6 +234,10 @@ func (r *Redis) Delete(ctx context.Context, key string) error { return nil } +func (r *Redis) Watch(ctx context.Context, fn func(tx *redis.Tx) error, keys ...string) error { + return r.client.Watch(ctx, fn, keys...) +} + func (r *Redis) IsAvailable(ctx context.Context) bool { return r.client.Ping(ctx).Err() == nil } diff --git a/cache/redis/redis_test.go b/cache/redis/redis_test.go index de76dff..cd45a7e 100755 --- a/cache/redis/redis_test.go +++ b/cache/redis/redis_test.go @@ -18,6 +18,7 @@ import ( "time" "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -214,6 +215,40 @@ func TestRedis_Delete(t *testing.T) { } } +func TestRedis_Watch(t *testing.T) { + redisInitFns := []redisInitFn{redisInit, redisClusterInit} + for _, redisInit := range redisInitFns { + t.Run("", func(t *testing.T) { + r, err := redisInit(t) + assert.Nil(t, err) + + ctx := context.TODO() + key := "test" + err = r.Watch(ctx, func(tx *redis.Tx) error { + n, err := tx.Get(ctx, key).Int() + if err != nil && err != redis.Nil { + return err + } + + // Actual operation (local in optimistic lock). + n++ + + // Operation is commited only if the watched keys remain unchanged. + _, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error { + pipe.Set(ctx, key, fmt.Sprintf("%d", n), 0) + return nil + }) + return err + }, key) + assert.Nil(t, err) + + newValue, err := r.GetBytes(context.TODO(), key) + assert.Nil(t, err) + assert.Equal(t, "1", string(newValue)) + }) + } +} + func TestRedis_IsAvailable(t *testing.T) { redisInitFns := []redisInitFn{redisInit, redisClusterInit} for _, redisInit := range redisInitFns {