diff --git a/CHANGELOG.md b/CHANGELOG.md index 823f4b2..9028789 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -111,7 +111,7 @@ ### v2.16.1 -- fix ZINTERSTORE with wets (thanks @lingjl2010 and @okhowang) +- fix ZINTERSTORE with sets (thanks @lingjl2010 and @okhowang) - fix exclusive ranges in XRANGE (thanks @joseotoro) diff --git a/README.md b/README.md index 14a1345..46d8bbd 100644 --- a/README.md +++ b/README.md @@ -157,6 +157,7 @@ Implemented commands: - ZCARD - ZCOUNT - ZINCRBY + - ZINTER - ZINTERSTORE - ZLEXCOUNT - ZPOPMIN diff --git a/cmd_sorted_set.go b/cmd_sorted_set.go index 1cc20ec..ab943ad 100644 --- a/cmd_sorted_set.go +++ b/cmd_sorted_set.go @@ -19,7 +19,8 @@ func commandsSortedSet(m *Miniredis) { m.srv.Register("ZCARD", m.cmdZcard) m.srv.Register("ZCOUNT", m.cmdZcount) m.srv.Register("ZINCRBY", m.cmdZincrby) - m.srv.Register("ZINTERSTORE", m.cmdZinterstore) + m.srv.Register("ZINTER", m.makeCmdZinter(false)) + m.srv.Register("ZINTERSTORE", m.makeCmdZinter(true)) m.srv.Register("ZLEXCOUNT", m.cmdZlexcount) m.srv.Register("ZRANGE", m.cmdZrange) m.srv.Register("ZRANGEBYLEX", m.makeCmdZrangebylex(false)) @@ -324,145 +325,192 @@ func (m *Miniredis) cmdZincrby(c *server.Peer, cmd string, args []string) { }) } -// ZINTERSTORE -func (m *Miniredis) cmdZinterstore(c *server.Peer, cmd string, args []string) { - if len(args) < 3 { - setDirty(c) - c.WriteError(errWrongNumber(cmd)) - return - } - if !m.handleAuth(c) { - return - } - if m.checkPubsub(c, cmd) { - return - } +// ZINTERSTORE and ZINTER +func (m *Miniredis) makeCmdZinter(store bool) func(c *server.Peer, cmd string, args []string) { + return func(c *server.Peer, cmd string, args []string) { + minArgs := 2 + if store { + minArgs++ + } + if len(args) < minArgs { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + if m.checkPubsub(c, cmd) { + return + } - destination := args[0] - numKeys, err := strconv.Atoi(args[1]) - if err != nil { - setDirty(c) - c.WriteError(msgInvalidInt) - return - } - args = args[2:] - if len(args) < numKeys { - setDirty(c) - c.WriteError(msgSyntaxError) - return - } - if numKeys <= 0 { - setDirty(c) - c.WriteError("ERR at least 1 input key is needed for ZUNIONSTORE/ZINTERSTORE") - return - } - keys := args[:numKeys] - args = args[numKeys:] + var opts = struct { + Store bool // if true this is ZINTERSTORE + Destination string // only relevant if $store is true + Keys []string + Aggregate string + WithWeights bool + Weights []float64 + WithScores bool // only for ZINTER + }{ + Store: store, + Aggregate: "sum", + } - withWeights := false - weights := []float64{} - aggregate := "sum" - for len(args) > 0 { - switch strings.ToLower(args[0]) { - case "weights": - if len(args) < numKeys+1 { - setDirty(c) - c.WriteError(msgSyntaxError) - return - } - for i := 0; i < numKeys; i++ { - f, err := strconv.ParseFloat(args[i+1], 64) - if err != nil { + if store { + opts.Destination = args[0] + args = args[1:] + } + numKeys, err := strconv.Atoi(args[0]) + if err != nil { + setDirty(c) + c.WriteError(msgInvalidInt) + return + } + args = args[1:] + if len(args) < numKeys { + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + if numKeys <= 0 { + setDirty(c) + c.WriteError("ERR at least 1 input key is needed for ZUNIONSTORE/ZINTERSTORE") + return + } + opts.Keys = args[:numKeys] + args = args[numKeys:] + + for len(args) > 0 { + switch strings.ToLower(args[0]) { + case "weights": + if len(args) < numKeys+1 { setDirty(c) - c.WriteError("ERR weight value is not a float") + c.WriteError(msgSyntaxError) return } - weights = append(weights, f) - } - withWeights = true - args = args[numKeys+1:] - case "aggregate": - if len(args) < 2 { - setDirty(c) - c.WriteError(msgSyntaxError) - return - } - aggregate = strings.ToLower(args[1]) - switch aggregate { - case "sum", "min", "max": + for i := 0; i < numKeys; i++ { + f, err := strconv.ParseFloat(args[i+1], 64) + if err != nil { + setDirty(c) + c.WriteError("ERR weight value is not a float") + return + } + opts.Weights = append(opts.Weights, f) + } + opts.WithWeights = true + args = args[numKeys+1:] + case "aggregate": + if len(args) < 2 { + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + aggregate := strings.ToLower(args[1]) + switch aggregate { + case "sum", "min", "max": + opts.Aggregate = aggregate + default: + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + args = args[2:] + case "withscores": + if store { + setDirty(c) + c.WriteError(msgSyntaxError) + return + } + opts.WithScores = true + args = args[1:] default: setDirty(c) c.WriteError(msgSyntaxError) return } - args = args[2:] - default: - setDirty(c) - c.WriteError(msgSyntaxError) - return } - } - withTx(m, c, func(c *server.Peer, ctx *connCtx) { - db := m.db(ctx.selectedDB) - db.del(destination, true) - - // We collect everything and remove all keys which turned out not to be - // present in every set. - sset := map[string]float64{} - counts := map[string]int{} - for i, key := range keys { - if !db.exists(key) { - continue + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + if opts.Store { + db.del(opts.Destination, true) } - var set map[string]float64 - switch db.t(key) { - case "set": - set = map[string]float64{} - for elem := range db.setKeys[key] { - set[elem] = 1.0 - } - case "zset": - set = db.sortedSet(key) - default: - c.WriteError(msgWrongType) - return - } - for member, score := range set { - if withWeights { - score *= weights[i] - } - counts[member]++ - old, ok := sset[member] - if !ok { - sset[member] = score + // We collect everything and remove all keys which turned out not to be + // present in every set. + sset := map[string]float64{} + counts := map[string]int{} + for i, key := range opts.Keys { + if !db.exists(key) { continue } - switch aggregate { + + var set map[string]float64 + switch db.t(key) { + case "set": + set = map[string]float64{} + for elem := range db.setKeys[key] { + set[elem] = 1.0 + } + case "zset": + set = db.sortedSet(key) default: - panic("Invalid aggregate") - case "sum": - sset[member] += score - case "min": - if score < old { - sset[member] = score + c.WriteError(msgWrongType) + return + } + for member, score := range set { + if opts.WithWeights { + score *= opts.Weights[i] } - case "max": - if score > old { + counts[member]++ + old, ok := sset[member] + if !ok { sset[member] = score + continue + } + switch opts.Aggregate { + default: + panic("Invalid aggregate") + case "sum": + sset[member] += score + case "min": + if score < old { + sset[member] = score + } + case "max": + if score > old { + sset[member] = score + } } } } - } - for key, count := range counts { - if count != numKeys { - delete(sset, key) + for key, count := range counts { + if count != numKeys { + delete(sset, key) + } } - } - db.ssetSet(destination, sset) - c.WriteInt(len(sset)) - }) + + if opts.Store { + // ZINTERSTORE mode + db.ssetSet(opts.Destination, sset) + c.WriteInt(len(sset)) + return + } + // ZINTER mode + size := len(sset) + if opts.WithScores { + size *= 2 + } + c.WriteLen(size) + for _, l := range sortedKeys(sset) { + c.WriteBulk(l) + if opts.WithScores { + c.WriteFloat(sset[l]) + } + } + }) + } } // ZLEXCOUNT @@ -1947,3 +1995,12 @@ func parseLexrange(s string) (string, bool, error) { return "", false, errors.New(msgInvalidRangeItem) } } + +func sortedKeys(m map[string]float64) []string { + var keys []string + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + return keys +} diff --git a/cmd_sorted_set_test.go b/cmd_sorted_set_test.go index 9acec4b..db3d8b6 100644 --- a/cmd_sorted_set_test.go +++ b/cmd_sorted_set_test.go @@ -1865,6 +1865,36 @@ func TestZunion(t *testing.T) { }) } +func TestZinter(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := proto.Dial(s.Addr()) + ok(t, err) + defer c.Close() + + s.ZAdd("h1", 1.0, "field1") + s.ZAdd("h1", 2.0, "field2") + s.ZAdd("h1", 3.0, "field3") + s.ZAdd("h2", 1.0, "field1") + s.ZAdd("h2", 2.0, "field2") + s.ZAdd("h2", 4.0, "field4") + s.SAdd("s2", "field1") + + // Simple case + { + mustDo(t, c, + "ZINTER", "2", "h1", "h2", + proto.Strings("field1", "field2"), + ) + mustDo(t, c, + "ZINTER", "2", "h1", "h2", "WITHSCORES", + proto.Strings("field1", "2", "field2", "4"), + ) + } + // it's the same code as ZINTERSTORE, so see TestZinterstore() +} + func TestZinterstore(t *testing.T) { s, err := Run() ok(t, err) diff --git a/integration/sorted_set_test.go b/integration/sorted_set_test.go index 933317f..a9ce211 100644 --- a/integration/sorted_set_test.go +++ b/integration/sorted_set_test.go @@ -783,8 +783,48 @@ func TestZunionstore(t *testing.T) { }) } -func TestZinterstore(t *testing.T) { +func TestZinter(t *testing.T) { skip(t) + // ZINTER + testRaw(t, func(c *client) { + c.Do("ZADD", "h1", "1.0", "key1") + c.Do("ZADD", "h1", "2.0", "key2") + c.Do("ZADD", "h1", "3.0", "key3") + c.Do("ZADD", "h2", "1.0", "key1") + c.Do("ZADD", "h2", "4.0", "key2") + c.Do("ZADD", "h3", "4.0", "key4") + c.DoSorted("ZINTER", "2", "h1", "h2") + + c.DoSorted("ZINTER", "2", "h1", "h2", "WEIGHTS", "2.0", "12") + c.DoSorted("ZINTER", "2", "h1", "h2", "WEIGHTS", "2", "-12") + + c.DoSorted("ZINTER", "2", "h1", "h2", "AGGREGATE", "min") + c.DoSorted("ZINTER", "2", "h1", "h2", "AGGREGATE", "max") + c.DoSorted("ZINTER", "2", "h1", "h2", "AGGREGATE", "sum") + + // normal set + c.Do("ZADD", "q1", "2", "f1") + c.Do("SADD", "q2", "f1") + c.Do("ZINTER", "2", "q1", "q2") + c.DoSorted("ZINTER", "2", "q1", "q2", "WITHSCORES") + + // Error cases + c.Error("wrong number", "ZINTER") + c.Error("wrong number", "ZINTER", "noint") + c.Error("at least 1", "ZINTER", "0", "f") + c.Error("syntax error", "ZINTER", "2", "f") + c.Error("at least 1", "ZINTER", "-1", "f") + c.Error("syntax error", "ZINTER", "2", "f1", "f2", "f3") + c.Error("syntax error", "ZINTER", "2", "f1", "f2", "WEIGHTS") + c.Error("syntax error", "ZINTER", "2", "f1", "f2", "WEIGHTS", "1") + c.Error("syntax error", "ZINTER", "2", "f1", "f2", "WEIGHTS", "1", "2", "3") + c.Error("not a float", "ZINTER", "2", "f1", "f2", "WEIGHTS", "f", "2") + c.Error("syntax error", "ZINTER", "2", "f1", "f2", "AGGREGATE", "foo") + c.Do("SET", "str", "1") + c.Error("wrong kind", "ZINTER", "1", "str") + }) + + // ZINTERSTORE testRaw(t, func(c *client) { c.Do("ZADD", "h1", "1.0", "key1") c.Do("ZADD", "h1", "2.0", "key2")