Skip to content
81 changes: 69 additions & 12 deletions helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ import (
intl "github.com/redis/rueidis/internal/cmds"
)

func slot(key string) uint16 {
return intl.Slot(key)
}

// MGetCache is a helper that consults the client-side caches with multiple keys by grouping keys within the same slot into multiple GETs
func MGetCache(client Client, ctx context.Context, ttl time.Duration, keys []string) (ret map[string]RedisMessage, err error) {
if len(keys) == 0 {
Expand Down Expand Up @@ -50,12 +54,7 @@ func MGet(client Client, ctx context.Context, keys []string) (ret map[string]Red
return clientMGet(client, ctx, client.B().Mget().Key(keys...).Build(), keys)
}

cmds := mgetcmdsp.Get(len(keys), len(keys))
defer mgetcmdsp.Put(cmds)
for i := range cmds.s {
cmds.s[i] = client.B().Get().Key(keys[i]).Build()
}
return doMultiGet(client, ctx, cmds.s, keys)
return clusterMGet(client, ctx, keys)
}

// MSet is a helper that consults the redis directly with multiple keys by grouping keys within the same slot into MSETs or multiple SETs
Expand Down Expand Up @@ -139,12 +138,7 @@ func JsonMGet(client Client, ctx context.Context, keys []string, path string) (r
return clientMGet(client, ctx, client.B().JsonMget().Key(keys...).Path(path).Build(), keys)
}

cmds := mgetcmdsp.Get(len(keys), len(keys))
defer mgetcmdsp.Put(cmds)
for i := range cmds.s {
cmds.s[i] = client.B().JsonGet().Key(keys[i]).Path(path).Build()
}
return doMultiGet(client, ctx, cmds.s, keys)
return clusterJsonMGet(client, ctx, keys, path)
}

// JsonMSet is a helper that consults redis directly with multiple keys by grouping keys within the same slot into JSON.MSETs or multiple JSON.SETs
Expand Down Expand Up @@ -277,6 +271,69 @@ func arrayToKV(m map[string]RedisMessage, arr []RedisMessage, keys []string) map
return m
}

func clusterMGet(client Client, ctx context.Context, keys []string) (ret map[string]RedisMessage, err error) {
ret = make(map[string]RedisMessage, len(keys))
slotGroups := make(map[uint16][]string)
for _, key := range keys {
ks := slot(key)
slotGroups[ks] = append(slotGroups[ks], key)
}
cmds := mgetcmdsp.Get(0, len(slotGroups))
defer mgetcmdsp.Put(cmds)
var cmdKeys [][]string
for _, group := range slotGroups {
cmd := client.B().Mget().Key(group...).Build().Pin()
cmds.s = append(cmds.s, cmd)
cmdKeys = append(cmdKeys, group)
}
resps := client.DoMulti(ctx, cmds.s...)
defer resultsp.Put(&redisresults{s: resps})
for i, resp := range resps {
arr, err := resp.ToArray()
if err != nil {
return nil, err
}
ret = arrayToKV(ret, arr, cmdKeys[i])
}
for i := range cmds.s {
intl.PutCompletedForce(cmds.s[i])
}
return ret, nil
}

func clusterJsonMGet(client Client, ctx context.Context, keys []string, path string) (ret map[string]RedisMessage, err error) {
ret = make(map[string]RedisMessage, len(keys))
slotGroups := make(map[uint16][]string)
for _, key := range keys {
ks := slot(key)
slotGroups[ks] = append(slotGroups[ks], key)
}
if len(slotGroups) == 0 {
return ret, nil
}
cmds := mgetcmdsp.Get(0, len(slotGroups))
defer mgetcmdsp.Put(cmds)
var cmdKeys [][]string
for _, group := range slotGroups {
cmd := client.B().JsonMget().Key(group...).Path(path).Build().Pin()
cmds.s = append(cmds.s, cmd)
cmdKeys = append(cmdKeys, group)
}
resps := client.DoMulti(ctx, cmds.s...)
defer resultsp.Put(&redisresults{s: resps})
for i, resp := range resps {
arr, err := resp.ToArray()
if err != nil {
return nil, err
}
ret = arrayToKV(ret, arr, cmdKeys[i])
}
for i := range cmds.s {
intl.PutCompletedForce(cmds.s[i])
}
return ret, nil
}

// ErrMSetNXNotSet is used in the MSetNX helper when the underlying MSETNX response is 0.
// Ref: https://redis.io/commands/msetnx/
var ErrMSetNXNotSet = errors.New("MSETNX: no key was set")
Expand Down
70 changes: 43 additions & 27 deletions helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,18 +179,22 @@ func TestMGetCache(t *testing.T) {
t.Fatalf("unexpected err %v", err)
}
t.Run("Delegate DisabledCache DoCache", func(t *testing.T) {
keys := make([]string, 100)
for i := range keys {
keys[i] = strconv.Itoa(i)
}
keys := []string{"{slot1}a", "{slot1}b", "{slot2}a", "{slot2}b"}
m.DoMultiFn = func(cmd ...Completed) *redisresults {
result := make([]RedisResult, len(cmd))
for i, key := range keys {
if !reflect.DeepEqual(cmd[i].Commands(), []string{"GET", key}) {
t.Fatalf("unexpected command %v", cmd)
for i, c := range cmd {
// Each command should be MGET with keys from the same slot
commands := c.Commands()
if commands[0] != "MGET" {
t.Fatalf("expected MGET command, got %v", commands)
return nil
}
result[i] = newResult(strmsg('+', key), nil)
// Build response array with values matching the keys
values := make([]RedisMessage, len(commands)-1)
for j := 1; j < len(commands); j++ {
values[j-1] = strmsg('+', commands[j])
}
result[i] = newResult(slicemsg('*', values), nil)
}
return &redisresults{s: result}
}
Expand All @@ -200,7 +204,7 @@ func TestMGetCache(t *testing.T) {
}
for _, key := range keys {
if vKey, ok := v[key]; !ok || vKey.string() != key {
t.Fatalf("unexpected response %v", v)
t.Fatalf("unexpected response for key %s: %v", key, v)
}
}
})
Expand Down Expand Up @@ -358,18 +362,22 @@ func TestMGet(t *testing.T) {
t.Fatalf("unexpected err %v", err)
}
t.Run("Delegate Do", func(t *testing.T) {
keys := make([]string, 100)
for i := range keys {
keys[i] = strconv.Itoa(i)
}
keys := []string{"{slot1}a", "{slot1}b", "{slot2}a", "{slot2}b"}
m.DoMultiFn = func(cmd ...Completed) *redisresults {
result := make([]RedisResult, len(cmd))
for i, key := range keys {
if !reflect.DeepEqual(cmd[i].Commands(), []string{"GET", key}) {
t.Fatalf("unexpected command %v", cmd)
for i, c := range cmd {
// Each command should be MGET with keys from the same slot
commands := c.Commands()
if commands[0] != "MGET" {
t.Fatalf("expected MGET command, got %v", commands)
return nil
}
result[i] = newResult(strmsg('+', key), nil)
// Build response array with values matching the keys
values := make([]RedisMessage, len(commands)-1)
for j := 1; j < len(commands); j++ {
values[j-1] = strmsg('+', commands[j])
}
result[i] = newResult(slicemsg('*', values), nil)
}
return &redisresults{s: result}
}
Expand All @@ -379,7 +387,7 @@ func TestMGet(t *testing.T) {
}
for _, key := range keys {
if vKey, ok := v[key]; !ok || vKey.string() != key {
t.Fatalf("unexpected response %v", v)
t.Fatalf("unexpected response for key %s: %v", key, v)
}
}
})
Expand Down Expand Up @@ -1162,18 +1170,26 @@ func TestJsonMGet(t *testing.T) {
t.Fatalf("unexpected err %v", err)
}
t.Run("Delegate Do", func(t *testing.T) {
keys := make([]string, 100)
for i := range keys {
keys[i] = strconv.Itoa(i)
}
keys := []string{"{slot1}a", "{slot1}b", "{slot2}a", "{slot2}b"}
m.DoMultiFn = func(cmd ...Completed) *redisresults {
result := make([]RedisResult, len(cmd))
for i, key := range keys {
if !reflect.DeepEqual(cmd[i].Commands(), []string{"JSON.GET", key, "$"}) {
t.Fatalf("unexpected command %v", cmd)
for i, c := range cmd {
// Each command should be JSON.MGET with keys from the same slot and path at the end
commands := c.Commands()
if commands[0] != "JSON.MGET" {
t.Fatalf("expected JSON.MGET command, got %v", commands)
return nil
}
result[i] = newResult(strmsg('+', key), nil)
if commands[len(commands)-1] != "$" {
t.Fatalf("expected $ as last parameter, got %v", commands)
return nil
}
// Build response array with values matching the keys (exclude the path)
values := make([]RedisMessage, len(commands)-2)
for j := 1; j < len(commands)-1; j++ {
values[j-1] = strmsg('+', commands[j])
}
result[i] = newResult(slicemsg('*', values), nil)
}
return &redisresults{s: result}
}
Expand All @@ -1183,7 +1199,7 @@ func TestJsonMGet(t *testing.T) {
}
for _, key := range keys {
if vKey, ok := v[key]; !ok || vKey.string() != key {
t.Fatalf("unexpected response %v", v)
t.Fatalf("unexpected response for key %s: %v", key, v)
}
}
})
Expand Down
Loading