diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..f5925e3 --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,37 @@ +name: Test with nats-server +on: [push, pull_request] + +jobs: + test: + env: + GOPATH: /home/runner/work/mqtt-test/go + GO111MODULE: "on" + runs-on: ubuntu-latest + steps: + - name: Setup Go + uses: actions/setup-go@v4 + with: + go-version: 1.21 + + - name: Checkout code + uses: actions/checkout@v4 + with: + path: go/src/github.com/ConnectEverything/mqtt-test + + - name: Checkout nats-server + uses: actions/checkout@v4 + with: + repository: nats-io/nats-server + path: go/src/github.com/nats-io/nats-server + + - name: Build and install + shell: bash --noprofile --norc -x -eo pipefail {0} + run: | + cd go/src/github.com/ConnectEverything/mqtt-test + go install -v . + + - name: Run 'MQTTEx from nats-server' + shell: bash --noprofile --norc -x -eo pipefail {0} + run: | + cd go/src/github.com/nats-io/nats-server + go test -v --run='-' --bench 'MQTTEx' --benchtime=100x ./server | tee /tmp/current-bench-result.txt diff --git a/README.md b/README.md new file mode 100644 index 0000000..bb022e3 --- /dev/null +++ b/README.md @@ -0,0 +1,70 @@ +MQTT Test is a CLI command used to test and benchmark the MQTT support in [NATS Server](https://github.com/nats-io/nats-server) + +Outputs JSON results that can be reported in a `go test --bench` wrapper. + +#### Usage + +##### Subcommands and common flags + +```sh +mqtt-test [pub|pubsub|subret] [flags...] +``` + +Available Commands: +- [pub](#pub) - Publish N messages +- [pubsub](#pubsub) - Subscribe and receive N published messages +- [subret](#subret) - Subscribe N times, and receive NTopics retained messages +Common Flags: +``` +-h, --help help for mqtt-test + --id string MQTT client ID +-n, --n int Number of transactions to run, see the specific command (default 1) +-p, --password string MQTT client password (empty if auth disabled) +-q, --quiet Quiet mode, only print results +-s, --servers stringArray MQTT servers endpoint as host:port (default [tcp://localhost:1883]) +-u, --username string MQTT client username (empty if auth disabled) + --version version for mqtt-test +-v, --very-verbose Very verbose, print everything we can +``` + +##### pub + +Publishes N messages using the flags and reports the results. Used with `--num-publishers` can run several concurrent publish connections. + +Flags: + +``` +--mps int Publish mps messages per second; 0 means no delay (default 1000) +--num-publishers int Number of publishers to run concurrently, at --mps each (default 1) +--num-topics int Cycle through NTopics appending "-{n}" where n starts with --num-topics-start; 0 means use --topic +--num-topics-start int Start topic suffixes with this number (default 0) +--qos int MQTT QOS +--retain Mark each published message as retained +--size int Approximate size of each message (pub adds a timestamp) +--topic string MQTT topic +``` + +##### pubsub + +Publishes N messages, and waits for all of them to be received by subscribers. Measures end-end delivery time on the messages. Used with `--num-subscribers` can run several concurrent subscriber connections. + +``` +--mps int Publish mps messages per second; 0 means all ASAP (default 1000) +--num-subscribers int Number of subscribers to run concurrently (default 1) +--qos int MQTT QOS +--size int Approximate size of each message (pub adds a timestamp) +--topic string MQTT topic +``` + +##### subret + +Publishes retained messages into NTopics, then subscribes to a wildcard with all +topics N times. Measures time to SUBACK and to all retained messages received. +Used with `--num-subscribers` can run several concurrent subscriber connections. + +``` +--num-subscribers int Number of subscribers to run concurrently (default 1) +--num-topics int Use this many topics with retained messages +--qos int MQTT QOS +--size int Approximate size of each message (pub adds a timestamp) +``` diff --git a/common.go b/common.go new file mode 100644 index 0000000..1a969c5 --- /dev/null +++ b/common.go @@ -0,0 +1,92 @@ +package main + +import ( + "fmt" + "log" + "math/rand" + "time" + + "github.com/spf13/cobra" +) + +const ( + Name = "mqtt-test" + Version = "v0.0.1" + + DefaultServer = "tcp://localhost:1883" + DefaultQOS = 0 + + Timeout = 10 * time.Second + DisconnectCleanupTimeout = 500 // milliseconds + +) + +var ( + ClientID string + MPS int + N int + NPublishers int + NSubscribers int + NTopics int + NTopicsStart int + Password string + QOS int + Retain bool + Servers []string + Size int + Username string + MatchTopicPrefix string + Quiet bool + Verbose bool +) + +func initPubSub(cmd *cobra.Command) *cobra.Command { + cmd.Flags().IntVar(&QOS, "qos", DefaultQOS, "MQTT QOS") + cmd.Flags().IntVar(&Size, "size", 0, "Approximate size of each message (pub adds a timestamp)") + return cmd +} + +type PubValue struct { + Seq int `json:"seq"` + Timestamp int64 `json:"timestamp"` +} + +type MQTTBenchmarkResult struct { + Ops int `json:"ops"` + NS map[string]time.Duration `json:"ns"` + Bytes int64 `json:"bytes"` +} + +func randomPayload(sz int) []byte { + const ch = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@$#%^&*()" + b := make([]byte, sz) + for i := range b { + b[i] = ch[rand.Intn(len(ch))] + } + return b +} + +func mqttVarIntLen(value int) int { + c := 0 + for ; value > 0; value >>= 7 { + c++ + } + return c +} + +func mqttPublishLen(topic string, qos byte, retained bool, msg []byte) int { + // Compute len (will have to add packet id if message is sent as QoS>=1) + pkLen := 2 + len(topic) + len(msg) + if qos > 0 { + pkLen += 2 + } + return 1 + mqttVarIntLen(pkLen) + pkLen +} + +func logOp(clientID, op string, dur time.Duration, f string, args ...interface{}) { + log.Printf("%8s %-6s %30s\t"+f, append([]any{ + fmt.Sprintf("%.3fms", float64(dur)/float64(time.Millisecond)), + op, + clientID + ":"}, + args...)...) +} diff --git a/connect.go b/connect.go new file mode 100644 index 0000000..983f96b --- /dev/null +++ b/connect.go @@ -0,0 +1,59 @@ +package main + +import ( + "log" + "time" + + paho "github.com/eclipse/paho.mqtt.golang" + "github.com/nats-io/nuid" +) + +const ( + CleanSession = true + PersistentSession = false +) + +func connect(clientID string, cleanSession bool, setoptsF func(*paho.ClientOptions)) (paho.Client, func(), error) { + if clientID == "" { + clientID = ClientID + } + if clientID == "" { + clientID = Name + "-" + nuid.Next() + } + + clientOpts := paho.NewClientOptions(). + SetClientID(clientID). + SetCleanSession(cleanSession). + SetProtocolVersion(4). + SetUsername(Username). + SetPassword(Password). + SetStore(paho.NewMemoryStore()). + SetAutoReconnect(false). + SetDefaultPublishHandler(func(client paho.Client, msg paho.Message) { + log.Fatalf("received an unexpected message on %q (default handler)", msg.Topic()) + }) + + for _, s := range Servers { + clientOpts.AddBroker(s) + } + if setoptsF != nil { + setoptsF(clientOpts) + } + + cl := paho.NewClient(clientOpts) + + disconnectedWG.Add(1) + start := time.Now() + if t := cl.Connect(); t.Wait() && t.Error() != nil { + disconnectedWG.Done() + return nil, func() {}, t.Error() + } + + logOp(clientOpts.ClientID, "CONN", time.Since(start), "Connected to %q\n", Servers) + return cl, + func() { + cl.Disconnect(DisconnectCleanupTimeout) + disconnectedWG.Done() + }, + nil +} diff --git a/dependencies.md b/dependencies.md new file mode 100644 index 0000000..e3b1bd9 --- /dev/null +++ b/dependencies.md @@ -0,0 +1,10 @@ +# External Dependencies + +This file lists the dependencies used in this repository. + +| Dependency | License | +|--------------------------------------|------------------------------------------| +| Go | BSD 3-Clause | +| github.com/eclipse/paho.mqtt.golang | Eclipse Public License - v 2.0 (EPL-2.0) | +| github.com/spf13/cobra v1.8.0 | Apache-2.0 | +| github.com/nats-io/nuid | Apache-2.0 | diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..714aef3 --- /dev/null +++ b/go.mod @@ -0,0 +1,17 @@ +module github.com/ConnectEverything/mqtt-test + +go 1.21 + +require ( + github.com/eclipse/paho.mqtt.golang v1.4.3 + github.com/nats-io/nuid v1.0.1 + github.com/spf13/cobra v1.8.0 +) + +require ( + github.com/gorilla/websocket v1.5.0 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + golang.org/x/net v0.10.0 // indirect + golang.org/x/sync v0.1.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..f62af80 --- /dev/null +++ b/go.sum @@ -0,0 +1,20 @@ +github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/eclipse/paho.mqtt.golang v1.4.3 h1:2kwcUGn8seMUfWndX0hGbvH8r7crgcJguQNCyp70xik= +github.com/eclipse/paho.mqtt.golang v1.4.3/go.mod h1:CSYvoAlsMkhYOXh/oKyxa8EcBci6dVkLCbo5tTC1RIE= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= +github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= +github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go new file mode 100644 index 0000000..0b02bcb --- /dev/null +++ b/main.go @@ -0,0 +1,49 @@ +package main + +import ( + "io" + "log" + "os" + "sync" + + paho "github.com/eclipse/paho.mqtt.golang" + "github.com/spf13/cobra" +) + +var mainCmd = &cobra.Command{ + Use: Name + " [conn|pub|sub|...|] [--flags...]", + Short: "MQTT Test/Benchmark Utility", + Version: Version, +} + +func init() { + mainCmd.PersistentFlags().StringVar(&ClientID, "id", "", "MQTT client ID") + mainCmd.PersistentFlags().StringArrayVarP(&Servers, "servers", "s", []string{DefaultServer}, "MQTT servers endpoint as host:port") + mainCmd.PersistentFlags().StringVarP(&Username, "username", "u", "", "MQTT client username (empty if auth disabled)") + mainCmd.PersistentFlags().StringVarP(&Password, "password", "p", "", "MQTT client password (empty if auth disabled)") + mainCmd.PersistentFlags().IntVarP(&N, "n", "n", 1, "Number of transactions to run, see the specific command") + mainCmd.PersistentFlags().BoolVarP(&Quiet, "quiet", "q", false, "Quiet mode, only print results") + mainCmd.PersistentFlags().BoolVarP(&Verbose, "very-verbose", "v", false, "Very verbose, print everything we can") + + mainCmd.PersistentPreRun = func(cmd *cobra.Command, args []string) { + paho.CRITICAL = log.New(os.Stderr, "[MQTT CRIT] ", 0) + if Quiet { + Verbose = false + log.SetOutput(io.Discard) + } + if !Quiet { + paho.ERROR = log.New(os.Stderr, "[MQTT ERROR] ", 0) + } + if Verbose { + paho.WARN = log.New(os.Stderr, "[MQTT WARN] ", 0) + paho.DEBUG = log.New(os.Stderr, "[MQTT DEBUG] ", 0) + } + } +} + +var disconnectedWG = sync.WaitGroup{} + +func main() { + _ = mainCmd.Execute() + disconnectedWG.Wait() +} diff --git a/pub.go b/pub.go new file mode 100644 index 0000000..aadc1d0 --- /dev/null +++ b/pub.go @@ -0,0 +1,144 @@ +package main + +import ( + "encoding/json" + "log" + "os" + "strconv" + "time" + + paho "github.com/eclipse/paho.mqtt.golang" + "github.com/nats-io/nuid" + "github.com/spf13/cobra" +) + +func init() { + cmd := initPubSub(&cobra.Command{ + Use: "pub [--flags...]", + Short: "Publish N messages", + Run: runPub, + Args: cobra.NoArgs, + }) + + cmd.Flags().IntVar(&NPublishers, "num-publishers", 1, `Number of publishers to run concurrently, at --mps each`) + cmd.Flags().IntVar(&NTopics, "num-topics", 0, `Cycle through NTopics appending "-{n}" where n starts with --num-topics-start; 0 means use --topic`) + cmd.Flags().IntVar(&NTopicsStart, "num-topics-start", 0, `Start topic suffixes with this number (default 0)`) + cmd.Flags().IntVar(&MPS, "mps", 1000, `Publish mps messages per second; 0 means no delay`) + cmd.Flags().BoolVar(&Retain, "retain", false, "Mark each published message as retained") + + mainCmd.AddCommand(cmd) +} + +func runPub(_ *cobra.Command, _ []string) { + clientID := ClientID + if clientID == "" { + clientID = Name + "-pub-" + nuid.Next() + } + topic := "/" + Name + "/" + nuid.Next() + msgChan := make(chan *MQTTBenchmarkResult) + errChan := make(chan error) + + for i := 0; i < NPublishers; i++ { + id := clientID + "-" + strconv.Itoa(i) + go func() { + cl, cleanup, err := connect(id, CleanSession, nil) + if err != nil { + log.Fatal(err) + } + defer cleanup() + + r, err := publish(cl, topic) + if err == nil { + msgChan <- r + } else { + errChan <- err + } + }() + } + + // wait for messages to arrive + pubOps := 0 + pubNS := time.Duration(0) + pubBytes := int64(0) + timeout := time.NewTimer(Timeout) + defer timeout.Stop() + + // get back 1 report per publisher + for n := 0; n < NPublishers; { + select { + case r := <-msgChan: + pubOps += r.Ops + pubNS += r.NS["pub"] + pubBytes += r.Bytes + n++ + + case err := <-errChan: + log.Fatalf("Error: %v", err) + + case <-timeout.C: + log.Fatalf("Error: timeout waiting for publishers") + } + } + + bb, _ := json.Marshal(MQTTBenchmarkResult{ + Ops: pubOps, + NS: map[string]time.Duration{"pub": pubNS}, + Bytes: pubBytes, + }) + os.Stdout.Write(bb) +} + +func publish(cl paho.Client, topic string) (*MQTTBenchmarkResult, error) { + opts := cl.OptionsReader() + start := time.Now() + var elapsed time.Duration + bc := 0 + iTopic := 0 + + for n := 0; n < N; n++ { + now := time.Now() + if n > 0 && MPS > 0 { + next := start.Add(time.Duration(n) * time.Second / time.Duration(MPS)) + time.Sleep(next.Sub(now)) + } + + // payload always starts with JSON containing timestamp, etc. The JSON + // is always terminated with a '-', which can not be part of the random + // fill. payload is then filled to the requested size with random data. + payload := randomPayload(Size) + structuredPayload, _ := json.Marshal(PubValue{ + Seq: n, + Timestamp: time.Now().UnixNano(), + }) + structuredPayload = append(structuredPayload, '\n') + if len(structuredPayload) > len(payload) { + payload = structuredPayload + } else { + copy(payload, structuredPayload) + } + + currTopic := topic + if NTopics > 0 { + currTopic = topic + "-" + strconv.Itoa(iTopic+NTopicsStart) + iTopic++ + if iTopic >= NTopics { + iTopic = 0 + } + } + + startPublish := time.Now() + if token := cl.Publish(currTopic, byte(QOS), Retain, payload); token.Wait() && token.Error() != nil { + return nil, token.Error() + } + elapsedPublish := time.Since(startPublish) + elapsed += elapsedPublish + logOp(opts.ClientID(), "PUB <-", elapsedPublish, "Published: %d bytes to %q, qos:%v, retain:%v", len(payload), currTopic, QOS, Retain) + bc += mqttPublishLen(currTopic, byte(QOS), Retain, payload) + } + + return &MQTTBenchmarkResult{ + Ops: N, + NS: map[string]time.Duration{"pub": elapsed}, + Bytes: int64(bc), + }, nil +} diff --git a/pubsub.go b/pubsub.go new file mode 100644 index 0000000..91422cc --- /dev/null +++ b/pubsub.go @@ -0,0 +1,169 @@ +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "log" + "os" + "strconv" + "time" + + paho "github.com/eclipse/paho.mqtt.golang" + "github.com/nats-io/nuid" + "github.com/spf13/cobra" +) + +func init() { + cmd := initPubSub(&cobra.Command{ + Use: "pubsub [--flags...]", + Short: "Subscribe and receive N published messages", + Run: runPubSub, + Args: cobra.NoArgs, + }) + + cmd.Flags().IntVar(&MPS, "mps", 1000, `Publish mps messages per second; 0 means all ASAP`) + cmd.Flags().IntVar(&NSubscribers, "num-subscribers", 1, `Number of subscribers to run concurrently`) + + mainCmd.AddCommand(cmd) +} + +func pubsubMsgHandler(topic string, errChan chan error, msgChan chan MQTTBenchmarkResult, +) func(paho.Client, paho.Message) { + return func(client paho.Client, msg paho.Message) { + opts := client.OptionsReader() + clientID := opts.ClientID() + switch { + case msg.Topic() != topic: + log.Printf("Received a QOS %d message on unexpected topic: %s\n", msg.Qos(), msg.Topic()) + // ignore + + case msg.Duplicate(): + errChan <- fmt.Errorf("received unexpected duplicate message") + return + + case msg.Retained(): + errChan <- fmt.Errorf("received unexpected retained message") + return + } + + v := PubValue{} + body := msg.Payload() + if i := bytes.IndexByte(body, '\n'); i != -1 { + body = body[:i] + } + if err := json.Unmarshal(body, &v); err != nil { + log.Fatalf("Error parsing message JSON: %v", err) + } + elapsed := time.Since(time.Unix(0, v.Timestamp)) + msgChan <- MQTTBenchmarkResult{ + Ops: 1, + Bytes: int64(len(msg.Payload())), + NS: map[string]time.Duration{"receive": elapsed}, + } + logOp(clientID, "REC ->", elapsed, "Received %d bytes on %q, qos:%v", len(msg.Payload()), msg.Topic(), QOS) + } +} + +func runPubSub(_ *cobra.Command, _ []string) { + clientID := ClientID + if clientID == "" { + clientID = Name + "-sub-" + nuid.Next() + } + topic := "/" + Name + "/" + nuid.Next() + + clientCleanupChan := make(chan func()) + subsChan := make(chan struct{}) + msgChan := make(chan MQTTBenchmarkResult) + errChan := make(chan error) + + // Connect all subscribers (and subscribe) + for i := 0; i < NSubscribers; i++ { + id := clientID + "-" + strconv.Itoa(i) + go func() { + _, cleanup, err := connect(id, CleanSession, func(opts *paho.ClientOptions) { + opts. + SetOnConnectHandler(func(cl paho.Client) { + start := time.Now() + token := cl.Subscribe(topic, byte(QOS), pubsubMsgHandler(topic, errChan, msgChan)) + if token.Wait() && token.Error() != nil { + errChan <- token.Error() + return + } + logOp(id, "SUB", time.Since(start), "Subscribed to %q", topic) + subsChan <- struct{}{} + }). + SetDefaultPublishHandler(func(client paho.Client, msg paho.Message) { + errChan <- fmt.Errorf("received an unexpected message on %q", msg.Topic()) + }) + }) + if err != nil { + errChan <- err + } else { + clientCleanupChan <- cleanup + } + }() + } + + cConn, cSub := 0, 0 + timeout := time.NewTimer(Timeout) + defer timeout.Stop() + for (cConn < NSubscribers) || (cSub < NSubscribers) { + select { + case cleanup := <-clientCleanupChan: + defer cleanup() + cConn++ + case <-subsChan: + cSub++ + case err := <-errChan: + log.Fatal(err) + case <-timeout.C: + log.Fatalf("timeout waiting for connections") + } + } + + // ready to receive, start publishing. The publisher will exit when done, no need to wait for it. + clientID = ClientID + if clientID == "" { + clientID = Name + "-pub-" + nuid.Next() + } else { + clientID = clientID + "-pub" + } + go func() { + cl, cleanup, err := connect(clientID, CleanSession, nil) + defer cleanup() + if err == nil { + _, err = publish(cl, topic) + } + if err != nil { + errChan <- err + } + }() + + // wait for messages to arrive + elapsed := time.Duration(0) + bc := int64(0) + timeout = time.NewTimer(Timeout) + defer timeout.Stop() + for n := 0; n < N*NSubscribers; { + select { + case r := <-msgChan: + elapsed += r.NS["receive"] + bc += r.Bytes + n++ + + case err := <-errChan: + log.Fatalf("Error: %v", err) + + case <-timeout.C: + log.Fatalf("Error: timeout waiting for messages") + } + } + + bb, _ := json.Marshal(MQTTBenchmarkResult{ + Ops: N * NSubscribers, + NS: map[string]time.Duration{"receive": elapsed}, + Bytes: bc, + }) + os.Stdout.Write(bb) +} diff --git a/subret.go b/subret.go new file mode 100644 index 0000000..d72e8ce --- /dev/null +++ b/subret.go @@ -0,0 +1,188 @@ +package main + +import ( + "encoding/json" + "log" + "os" + "strconv" + "strings" + "sync/atomic" + "time" + + paho "github.com/eclipse/paho.mqtt.golang" + "github.com/nats-io/nuid" + "github.com/spf13/cobra" +) + +func init() { + cmd := initPubSub(&cobra.Command{ + Use: "subret [--flags...]", + Short: "Subscribe N times, and receive NTopics retained messages", + Run: runSubRet, + Args: cobra.NoArgs, + }) + + cmd.Flags().IntVar(&NSubscribers, "num-subscribers", 1, `Number of subscribers to run concurrently`) + cmd.Flags().IntVar(&NTopics, "num-topics", 0, `Use this many topics with retained messages`) + + mainCmd.AddCommand(cmd) +} + +func subretMsgHandler( + topicPrefix string, + expectNRetained int, + start time.Time, + doneChan chan MQTTBenchmarkResult, +) func(paho.Client, paho.Message) { + var cRetained atomic.Int32 + var bc atomic.Int64 + return func(client paho.Client, msg paho.Message) { + opts := client.OptionsReader() + clientID := opts.ClientID() + switch { + case !strings.HasPrefix(msg.Topic(), topicPrefix): + log.Printf("Received a QOS %d message on unexpected topic: %s\n", msg.Qos(), msg.Topic()) + // ignore + + case msg.Duplicate(): + log.Fatal("received unexpected duplicate message") + return + + case msg.Retained(): + newC := cRetained.Add(1) + bc.Add(int64(len(msg.Payload()))) + switch { + case newC < int32(expectNRetained): + logOp(clientID, "REC ->", time.Since(start), "Received %d bytes on %q, qos:%v", len(msg.Payload()), msg.Topic(), QOS) + // skip it + return + case newC > int32(expectNRetained): + log.Fatal("received unexpected retained message") + default: // last expected retained message + elapsed := time.Since(start) + r := MQTTBenchmarkResult{ + Ops: 1, + NS: map[string]time.Duration{"receive": elapsed}, + Bytes: bc.Load(), + } + doneChan <- r + } + } + } +} + +func runSubRet(_ *cobra.Command, _ []string) { + topic := "/" + Name + "/" + nuid.Next() + + clientID := ClientID + if clientID == "" { + clientID = Name + "-pub-" + nuid.Next() + } else { + clientID = clientID + "-pub" + } + nTopics := NTopics + if nTopics == 0 { + nTopics = 1 + } + + // Publish NTopics retained messages, 1 per topic; Use at least QoS1 to + // ensure the retained messages are fully processed by the time the + // publisher exits. + cl, disconnect, err := connect(clientID, CleanSession, nil) + if err != nil { + log.Fatal("Error connecting: ", err) + } + for i := 0; i < nTopics; i++ { + t := topic + "/" + strconv.Itoa(i) + payload := randomPayload(Size) + start := time.Now() + publishQOS := 1 + if publishQOS < QOS { + publishQOS = QOS + } + if token := cl.Publish(t, byte(publishQOS), true, payload); token.Wait() && token.Error() != nil { + log.Fatal("Error publishing: ", token.Error()) + } + logOp(clientID, "PUB <-", time.Since(start), "Published: %d bytes to %q, qos:%v, retain:%v", len(payload), t, QOS, true) + } + disconnect() + + // Now subscribe and verify that all subs receive all messages + clientID = ClientID + if clientID == "" { + clientID = Name + "-sub-" + nuid.Next() + } + + // Connect all subscribers (and subscribe to a wildcard topic that includes + // all published retained messages). + doneChan := make(chan MQTTBenchmarkResult) + for i := 0; i < NSubscribers; i++ { + id := clientID + "-" + strconv.Itoa(i) + prefix := topic + t := topic + "/+" + go subscribeAndReceiveRetained(id, t, prefix, N, nTopics, doneChan) + } + + var cDone int + total := MQTTBenchmarkResult{ + NS: map[string]time.Duration{}, + } + timeout := time.NewTimer(Timeout) + defer timeout.Stop() + + for cDone < NSubscribers { + select { + case r := <-doneChan: + total.Ops += r.Ops + total.NS["sub"] += r.NS["sub"] + total.NS["receive"] += r.NS["receive"] + total.Bytes += r.Bytes + cDone++ + case <-timeout.C: + log.Fatalf("timeout waiting for connections") + } + } + + bb, _ := json.Marshal(total) + os.Stdout.Write(bb) +} + +func subscribeAndReceiveRetained(id string, subTopic string, pubTopicPrefix string, n, expected int, doneChan chan MQTTBenchmarkResult) { + subNS := time.Duration(0) + receiveNS := time.Duration(0) + receiveBytes := int64(0) + cl, cleanup, err := connect(id, CleanSession, nil) + if err != nil { + log.Fatal(err) + } + defer cleanup() + + for i := 0; i < n; i++ { + start := time.Now() + doneChan := make(chan MQTTBenchmarkResult) + token := cl.Subscribe(subTopic, byte(QOS), + subretMsgHandler(pubTopicPrefix, expected, start, doneChan)) + if token.Wait() && token.Error() != nil { + log.Fatal(token.Error()) + } + subElapsed := time.Since(start) + subNS += subElapsed + logOp(id, "SUB", subElapsed, "Subscribed to %q", subTopic) + + r := <-doneChan + receiveNS += r.NS["receive"] + receiveBytes += r.Bytes + logOp(id, "SUBRET", r.NS["receive"], "Received %d messages (%d bytes) on %q", expected, r.Bytes, subTopic) + + // Unsubscribe + if token = cl.Unsubscribe(subTopic); token.Wait() && token.Error() != nil { + log.Fatal(token.Error()) + } + } + + doneChan <- MQTTBenchmarkResult{ + Ops: N, + NS: map[string]time.Duration{"sub": subNS, "receive": receiveNS}, + Bytes: receiveBytes, + } +}