diff --git a/README.md b/README.md index 0ed50c3..1c104ab 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ Common flags: --id string MQTT client ID (default "mqtt-test-bssJjZUs1vhTvf6KpTpTLw") -q, --quiet Quiet mode, only print results -s, --server stringArray MQTT endpoint as username:password@host:port (default [tcp://localhost:1883]) + --timeout duration Timeout for the test (default 10s) --version version for mqtt-test -v, --very-verbose Very verbose, print everything we can ``` @@ -40,7 +41,7 @@ Flags: --retain Mark each published message as retained --size int Approximate size of each message (pub adds a timestamp) --timestamp Prepend a timestamp to each message ---topic string Base topic (prefix) to publish into (/{n} will be added if --topics > 0) (default "mqtt-test/fIqfOq5Lg5wk636V4sLXoc") +--topic string Base topic (prefix) to publish into (/{n} will be added if --topics > 0) --topics int Cycle through NTopics appending "/{n}" ``` @@ -56,7 +57,8 @@ Flags: --repeat int Subscribe, receive retained messages, and unsubscribe N times (default 1) --retained int Expect to receive this many retained messages --subscribers int Number of subscribers to run concurrently (default 1) ---topic string Base topic for the test, will subscribe to {topic}/+ +--timestamp Expect a timestamp in the payload and use it to calculate receive time +--topic string Topic to subscribe to ``` ##### pubsub @@ -64,11 +66,14 @@ Flags: Publishes N messages, and waits for all of them to be received by subscribers. Measures end-end delivery time on the messages. Used with `--num-subscribers` can run several concurrent subscriber connections. ``` ---messages int Number of messages to publish and receive (default 1) ---qos int MQTT QOS ---size int Approximate size of each message (pub adds a timestamp) ---subscribers int Number of subscribers to run concurrently (default 1) ---topic string Topic to publish and subscribe to (default "mqtt-test/JPrbNU6U3IbVQLIyazkP4y") +--messages int Number of messages to publish and receive (default 1) +--mps int Publish mps messages per second; 0 means no delay (default 1000) +--pub-server string Server to publish to. Defaults to the first server in --servers +--qos int MQTT QOS +--size int Message extra payload size (in addition to the JSON timestamp) +--subscribers int Number of subscribers to run concurrently (default 1) +--topic string Topic (or base topic if --topics > 1) +--topics int Number of topics to use, If more than one will add /1, /2, ... to --topic when publishing, and subscribe to topic/+ (default 1) ``` ##### subret @@ -78,10 +83,12 @@ topics N times. Measures time to SUBACK and to all retained messages received. Used with `--subscribers` can run several concurrent subscriber connections. ``` ---qos int MQTT QOS ---repeat int Subscribe, receive retained messages, and unsubscribe N times (default 1) ---size int Approximate size of each message (pub adds a timestamp) ---subscribers int Number of subscribers to run concurrently (default 1) ---topic string Base topic (prefix) for the test (default "mqtt-test/yNkmAFnFHETSGnQJNjwGdN") ---topics int Number of sub-topics to publish retained messages to (default 1) +--mps int Publish mps messages per second; 0 means no delay (default 1000) +--pub-server stringArray Server(s) to publish to. Defaults to --servers +--qos int MQTT QOS for subscriptions. Messages are published as QOS1. +--repeat int Subscribe, receive retained messages, and unsubscribe N times (default 1) +--retained int Number of retained messages to publish and receive (default 1) +--size int Message payload size +--subscribers int Number of subscribers to run concurrently (default 1) +--topic string base topic (if --retaned > 1 will be published to topic/1, topic/2, ...) ``` diff --git a/command-pub.go b/command-pub.go index c42bb80..9437c71 100644 --- a/command-pub.go +++ b/command-pub.go @@ -14,19 +14,14 @@ package main import ( - "encoding/json" - "log" - "os" "strconv" - "time" "github.com/spf13/cobra" ) type pubCommand struct { - publisher + opts publisher publishers int - timestamp bool } func newPubCommand() *cobra.Command { @@ -39,59 +34,30 @@ func newPubCommand() *cobra.Command { Args: cobra.NoArgs, } - // Message options - cmd.Flags().StringVar(&c.topic, "topic", defaultTopic(), "Base topic (prefix) to publish into (/{n} will be added if --topics > 0)") - cmd.Flags().IntVar(&c.qos, "qos", DefaultQOS, "MQTT QOS") - cmd.Flags().IntVar(&c.size, "size", 0, "Approximate size of each message (pub adds a timestamp)") - cmd.Flags().BoolVar(&c.retain, "retain", false, "Mark each published message as retained") - cmd.Flags().BoolVar(&c.timestamp, "timestamp", false, "Prepend a timestamp to each message") - - // Test options - cmd.Flags().IntVar(&c.mps, "mps", 1000, `Publish mps messages per second; 0 means no delay`) - cmd.Flags().IntVar(&c.messages, "messages", 1, "Number of transactions to run, see the specific command") + cmd.Flags().IntVar(&c.opts.messages, "messages", 1, "Number of transactions to run, see the specific command") + cmd.Flags().IntVar(&c.opts.mps, "mps", 1000, `Publish mps messages per second; 0 means no delay`) + cmd.Flags().IntVar(&c.opts.qos, "qos", DefaultQOS, "MQTT QOS") + cmd.Flags().BoolVar(&c.opts.retain, "retain", false, "Mark each published message as retained") + cmd.Flags().IntVar(&c.opts.size, "size", 0, "Approximate size of each message (pub adds a timestamp)") + cmd.Flags().BoolVar(&c.opts.timestamp, "timestamp", false, "Prepend a timestamp to each message") + cmd.Flags().StringVar(&c.opts.topic, "topic", defaultTopic(), "Base topic (prefix) to publish into (/{n} will be added if --topics > 0)") + cmd.Flags().IntVar(&c.opts.topics, "topics", 0, `Cycle through NTopics appending "/{n}"`) cmd.Flags().IntVar(&c.publishers, "publishers", 1, `Number of publishers to run concurrently, at --mps each`) - cmd.Flags().IntVar(&c.topics, "topics", 0, `Cycle through NTopics appending "/{n}"`) return cmd } func (c *pubCommand) run(_ *cobra.Command, _ []string) { - msgChan := make(chan *Stat) - errChan := make(chan error) - + doneCh := make(chan struct{}) for i := 0; i < c.publishers; i++ { - p := c.publisher // copy - p.clientID = ClientID + "-" + strconv.Itoa(i) - go p.publish(msgChan, errChan, c.timestamp) - } - - pubOps := 0 - pubNS := time.Duration(0) - pubBytes := int64(0) - timeout := time.NewTimer(Timeout) - defer timeout.Stop() - - // get back 1 report per publisher - for n := 0; n < c.publishers; { - select { - case stat := <-msgChan: - pubOps += stat.Ops - pubNS += stat.NS["pub"] - pubBytes += stat.Bytes - n++ - - case err := <-errChan: - log.Fatalf("Error: %v", err) - - case <-timeout.C: - log.Fatalf("Error: timeout waiting for publishers") + p := c.opts // copy + p.dials = dials(Servers) + p.clientID = ClientID + if c.publishers > 1 { + p.clientID = p.clientID + "-" + strconv.Itoa(i) } + go p.publish(doneCh) } - bb, _ := json.Marshal(Stat{ - Ops: pubOps, - NS: map[string]time.Duration{"pub": pubNS}, - Bytes: pubBytes, - }) - os.Stdout.Write(bb) + waitN(doneCh, c.publishers, "publisher to finish") } diff --git a/command-pubsub.go b/command-pubsub.go index 27bbaeb..f2f8a7c 100644 --- a/command-pubsub.go +++ b/command-pubsub.go @@ -14,25 +14,20 @@ package main import ( - "encoding/json" - "log" - "os" "strconv" - "time" "github.com/spf13/cobra" ) type pubsubCommand struct { - messageOpts - - messages int + pubOpts publisher + subOpts receiver subscribers int + pubServer string } func newPubSubCommand() *cobra.Command { c := &pubsubCommand{} - cmd := &cobra.Command{ Use: "pubsub [--flags...]", Short: "Subscribe and receive N published messages", @@ -40,81 +35,70 @@ func newPubSubCommand() *cobra.Command { Args: cobra.NoArgs, } - // Message options - cmd.Flags().IntVar(&c.messages, "messages", 1, "Number of messages to publish and receive") - cmd.Flags().StringVar(&c.topic, "topic", defaultTopic(), "Topic to publish and subscribe to") - cmd.Flags().IntVar(&c.qos, "qos", DefaultQOS, "MQTT QOS") - cmd.Flags().IntVar(&c.size, "size", 0, "Approximate size of each message (pub adds a timestamp)") - - // Test options + cmd.Flags().IntVar(&c.pubOpts.messages, "messages", 1, "Number of messages to publish and receive") + cmd.Flags().IntVar(&c.pubOpts.mps, "mps", 1000, `Publish mps messages per second; 0 means no delay`) + cmd.Flags().IntVar(&c.pubOpts.qos, "qos", DefaultQOS, "MQTT QOS") + cmd.Flags().IntVar(&c.pubOpts.size, "size", 0, "Message extra payload size (in addition to the JSON timestamp)") + cmd.Flags().StringVar(&c.pubOpts.topic, "topic", defaultTopic(), "Topic (or base topic if --topics > 1)") + cmd.Flags().IntVar(&c.pubOpts.topics, "topics", 1, "Number of topics to use, If more than one will add /1, /2, ... to --topic when publishing, and subscribe to topic/+") + cmd.Flags().StringVar(&c.pubServer, "pub-server", "", "Server to publish to. Defaults to the first server in --servers") cmd.Flags().IntVar(&c.subscribers, "subscribers", 1, `Number of subscribers to run concurrently`) + cmd.PreRun = func(_ *cobra.Command, _ []string) { + c.pubOpts.clientID = ClientID + "-pub" + c.pubOpts.timestamp = true + s := c.pubServer + if s == "" { + s = Servers[0] + } + c.pubOpts.dials = []dial{dial(s)} + + c.subOpts.clientID = ClientID + "-sub" + c.subOpts.expectPublished = c.pubOpts.messages + c.subOpts.expectTimestamp = true + c.subOpts.filterPrefix = c.pubOpts.topic + c.subOpts.qos = c.pubOpts.qos + c.subOpts.repeat = 1 + c.subOpts.topic = c.pubOpts.topic + if c.pubOpts.topics > 1 { + c.subOpts.topic = c.pubOpts.topic + "/+" + } + } + return cmd } func (c *pubsubCommand) run(_ *cobra.Command, _ []string) { - clientID := ClientID + "-sub" readyCh := make(chan struct{}) - errCh := make(chan error) - statsCh := make(chan *Stat) + doneCh := make(chan struct{}) - // Connect all subscribers (and subscribe) - for i := 0; i < c.subscribers; i++ { - r := &receiver{ - clientID: clientID + "-" + strconv.Itoa(i), - topic: c.topic, - qos: c.qos, - expectPublished: c.messages, - repeat: 1, - } - go r.receive(readyCh, statsCh, errCh) + counter := 0 + if len(Servers) > 1 || c.subscribers > 1 { + counter = 1 } - - // Wait for all subscriptions to signal ready - cSub := 0 - timeout := time.NewTimer(Timeout) - defer timeout.Stop() - for cSub < c.subscribers { - select { - case <-readyCh: - cSub++ - case err := <-errCh: - log.Fatal(err) - case <-timeout.C: - log.Fatalf("timeout waiting for subscribers to be ready") - } - } - - // ready to receive, start publishing. The publisher will exit when done, no need to wait for it. - p := &publisher{ - clientID: ClientID + "-pub", - messageOpts: c.messageOpts, - messages: c.messages, - mps: 1000, - } - go p.publish(nil, errCh, true) - - // wait for the stats - total := Stat{ - NS: make(map[string]time.Duration), - } - timeout = time.NewTimer(Timeout) - defer timeout.Stop() - for i := 0; i < c.subscribers; i++ { - select { - case stat := <-statsCh: - total.Ops += stat.Ops - total.Bytes += stat.Bytes - for k, v := range stat.NS { - total.NS[k] += v + N := c.subscribers * len(Servers) + + // Connect all subscribers and subscribe. Wait for all subscriptions to + // signal ready before publishing. + for _, d := range dials(Servers) { + for i := 0; i < c.subscribers; i++ { + r := c.subOpts // copy + if r.clientID == "" { + r.clientID = ClientID } - case err := <-errCh: - log.Fatalf("Error: %v", err) - case <-timeout.C: - log.Fatalf("Error: timeout waiting for messages") + if counter != 0 { + r.clientID = r.clientID + "-" + strconv.Itoa(counter) + counter++ + } + r.dial = d + go r.receive(readyCh, doneCh) } } + waitN(readyCh, N, "subscribers to be ready") + + // ready to receive, start publishing. Give the publisher the same done + // channel, will wait for one more. + go c.pubOpts.publish(doneCh) - bb, _ := json.Marshal(total) - os.Stdout.Write(bb) + waitN(doneCh, N+1, "publisher and all subscribers to finish") } diff --git a/command-sub.go b/command-sub.go index 73bf462..e21bcf9 100644 --- a/command-sub.go +++ b/command-sub.go @@ -14,142 +14,66 @@ package main import ( - "encoding/json" - "log" - "os" "strconv" - "time" "github.com/spf13/cobra" ) type subCommand struct { - // message options - messageOpts - - // test options - repeat int - subscribers int - expectRetained int - expectPublished int + opts receiver + subscribers int } func newSubCommand() *cobra.Command { c := &subCommand{} - cmd := &cobra.Command{ Use: "sub [--flags...]", Short: "Subscribe, receive all messages, unsubscribe, {repeat} times.", Run: c.run, Args: cobra.NoArgs, } - - cmd.Flags().StringVar(&c.topic, "topic", "", "Base topic for the test, will subscribe to {topic}/+") - cmd.Flags().IntVar(&c.qos, "qos", DefaultQOS, "MQTT QOS") - cmd.Flags().IntVar(&c.repeat, "repeat", 1, "Subscribe, receive retained messages, and unsubscribe N times") + cmd.Flags().IntVar(&c.opts.expectPublished, "messages", 0, `Expect to receive this many published messages`) + cmd.Flags().IntVar(&c.opts.expectRetained, "retained", 0, `Expect to receive this many retained messages`) + cmd.Flags().BoolVar(&c.opts.expectTimestamp, "timestamp", false, "Expect a timestamp in the payload and use it to calculate receive time") + cmd.Flags().IntVar(&c.opts.qos, "qos", DefaultQOS, "MQTT QOS") + cmd.Flags().IntVar(&c.opts.repeat, "repeat", 1, "Subscribe, receive (retained) messages, and unsubscribe this many times") + cmd.Flags().StringVar(&c.opts.topic, "topic", defaultTopic(), "Topic to subscribe to") cmd.Flags().IntVar(&c.subscribers, "subscribers", 1, `Number of subscribers to run concurrently`) - cmd.Flags().IntVar(&c.expectRetained, "retained", 0, `Expect to receive this many retained messages`) - cmd.Flags().IntVar(&c.expectPublished, "messages", 0, `Expect to receive this many published messages`) + + cmd.PreRun = func(_ *cobra.Command, _ []string) { + prefix := c.opts.topic + i := len(prefix) - 1 + for ; i >= 0; i-- { + if prefix[i] != '/' && prefix[i] != '#' && prefix[i] != '+' { + break + } + } + c.opts.filterPrefix = prefix[:i+1] + } return cmd } func (c *subCommand) run(_ *cobra.Command, _ []string) { - total := runSubPrepublishRetained(c.subscribers, c.repeat, c.expectRetained, c.expectPublished, c.messageOpts, false) - bb, _ := json.Marshal(total) - os.Stdout.Write(bb) -} - -func runSubPrepublishRetained( - nSubscribers int, - repeat int, - expectRetained, - expectPublished int, - messageOpts messageOpts, - prepublishRetained bool, -) *Stat { - errCh := make(chan error) - receiverReadyCh := make(chan struct{}) - statsCh := make(chan *Stat) - - if prepublishRetained { - if expectPublished != 0 { - log.Fatalf("Error: --messages is not supported with --retained") - } - - // We need to wait for all prepublished retained messages to be processed. - // To ensure, subscribe once before we pre-publish and receive all published - // messages. - r := &receiver{ - clientID: ClientID + "-sub-init", - filterPrefix: messageOpts.topic, - topic: messageOpts.topic + "/+", - qos: messageOpts.qos, - expectRetained: 0, - expectPublished: expectRetained, - repeat: 1, - } - go r.receive(receiverReadyCh, statsCh, errCh) - <-receiverReadyCh - - // Pre-publish retained messages. - p := &publisher{ - clientID: ClientID + "-pub", - messages: expectRetained, - topics: expectRetained, - messageOpts: messageOpts, - } - p.messageOpts.retain = true - go p.publish(nil, errCh, true) - - // wait for the initial subscription to have received all messages - timeout := time.NewTimer(Timeout) - defer timeout.Stop() - select { - case err := <-errCh: - log.Fatalf("Error: %v", err) - case <-timeout.C: - log.Fatalf("Error: timeout waiting for messages in initial subscription") - case <-statsCh: - // all received - } + doneCh := make(chan struct{}) + counter := 0 + if len(Servers) > 1 || c.subscribers > 1 { + counter = 1 } - // Connect all subscribers (and subscribe to a wildcard topic that includes - // all published retained messages). - for i := 0; i < nSubscribers; i++ { - r := &receiver{ - clientID: ClientID + "-sub-" + strconv.Itoa(i), - filterPrefix: messageOpts.topic, - topic: messageOpts.topic + "/+", - qos: messageOpts.qos, - expectRetained: expectRetained, - expectPublished: expectPublished, - repeat: repeat, - } - go r.receive(nil, statsCh, errCh) - } - - // wait for the stats - total := &Stat{ - NS: make(map[string]time.Duration), - } - timeout := time.NewTimer(Timeout) - defer timeout.Stop() - for i := 0; i < nSubscribers*repeat; i++ { - select { - case stat := <-statsCh: - total.Ops += stat.Ops - total.Bytes += stat.Bytes - for k, v := range stat.NS { - total.NS[k] += v + for _, d := range dials(Servers) { + for i := 0; i < c.subscribers; i++ { + r := c.opts // copy + r.clientID = ClientID + if counter != 0 { + r.clientID = r.clientID + "-" + strconv.Itoa(counter) + counter++ } - case err := <-errCh: - log.Fatalf("Error: %v", err) - case <-timeout.C: - log.Fatalf("Error: timeout waiting for messages") + r.dial = d + go r.receive(nil, doneCh) } } - return total + + waitN(doneCh, c.subscribers*len(Servers), "all subscribers to finish") } diff --git a/command-subret.go b/command-subret.go index 9370bf5..3d33795 100644 --- a/command-subret.go +++ b/command-subret.go @@ -14,25 +14,20 @@ package main import ( - "encoding/json" - "os" + "strconv" "github.com/spf13/cobra" ) type subretCommand struct { - // message options - messageOpts - - // test options - repeat int + pubOpts publisher + subOpts receiver + pubServers []string subscribers int - messages int } func newSubRetCommand() *cobra.Command { c := &subretCommand{} - cmd := &cobra.Command{ Use: "subret [--flags...]", Short: "Publish {topics} retained messages, subscribe {repeat} times, and receive all retained messages.", @@ -40,18 +35,76 @@ func newSubRetCommand() *cobra.Command { Args: cobra.NoArgs, } - cmd.Flags().StringVar(&c.topic, "topic", defaultTopic(), "Base topic (prefix) for the test") - cmd.Flags().IntVar(&c.qos, "qos", DefaultQOS, "MQTT QOS") - cmd.Flags().IntVar(&c.size, "size", 0, "Approximate size of each message (pub adds a timestamp)") - cmd.Flags().IntVar(&c.repeat, "repeat", 1, "Subscribe, receive retained messages, and unsubscribe N times") + cmd.Flags().IntVar(&c.pubOpts.messages, "retained", 1, "Number of retained messages to publish and receive") + cmd.Flags().IntVar(&c.pubOpts.mps, "mps", 1000, `Publish mps messages per second; 0 means no delay`) + cmd.Flags().IntVar(&c.pubOpts.size, "size", 0, "Message payload size") + cmd.Flags().StringVar(&c.pubOpts.topic, "topic", defaultTopic(), "base topic (if --retained > 1 will be published to topic/1, topic/2, ...)") + + cmd.Flags().IntVar(&c.subOpts.qos, "qos", DefaultQOS, "MQTT QOS for subscriptions. Messages are published as QOS1.") + cmd.Flags().IntVar(&c.subOpts.repeat, "repeat", 1, "Subscribe, receive retained messages, and unsubscribe N times") + + cmd.Flags().StringArrayVar(&c.pubServers, "pub-server", nil, "Server(s) to publish to. Defaults to --servers") cmd.Flags().IntVar(&c.subscribers, "subscribers", 1, `Number of subscribers to run concurrently`) - cmd.Flags().IntVar(&c.messages, "topics", 1, `Number of sub-topics to publish retained messages to`) + + cmd.PreRun = func(_ *cobra.Command, _ []string) { + c.pubOpts.clientID = ClientID + "-pub" + if len(c.pubServers) > 0 { + c.pubOpts.dials = dials(c.pubServers) + } else { + c.pubOpts.dials = dials(Servers) + } + c.pubOpts.retain = true + c.pubOpts.timestamp = false + c.pubOpts.topics = c.pubOpts.messages + c.pubOpts.qos = c.subOpts.qos + // Always use at least QoS1 to ensure retained is processed. + if c.pubOpts.qos < 1 { + c.pubOpts.qos = 1 + } + // Always send at least 1 byte so the messages are not treated as "retained delete". + if c.pubOpts.size < 1 { + c.pubOpts.size = 1 + } + + c.subOpts.clientID = ClientID + "-sub" + c.subOpts.expectRetained = c.pubOpts.messages + c.subOpts.expectTimestamp = false + c.subOpts.filterPrefix = c.pubOpts.topic + c.subOpts.topic = c.pubOpts.topic + if c.pubOpts.topics > 1 { + c.subOpts.topic = c.pubOpts.topic + "/+" + } + } return cmd } func (c *subretCommand) run(_ *cobra.Command, _ []string) { - total := runSubPrepublishRetained(c.subscribers, c.repeat, c.messages, 0, c.messageOpts, true) - bb, _ := json.Marshal(total) - os.Stdout.Write(bb) + // Pre-publish retained messages, and wait for all to be received. + c.pubOpts.publish(nil) + + counter := 0 + if len(Servers) > 1 || c.subscribers > 1 { + counter = 1 + } + N := c.subscribers * len(Servers) + doneCh := make(chan struct{}) + + // Connect all subscribers and subscribe. Wait for all subscriptions to + // signal ready before publishing. + for _, d := range dials(Servers) { + for i := 0; i < c.subscribers; i++ { + r := c.subOpts // copy + if r.clientID == "" { + r.clientID = ClientID + } + if counter != 0 { + r.clientID = r.clientID + "-" + strconv.Itoa(counter) + counter++ + } + r.dial = d + go r.receive(nil, doneCh) + } + } + waitN(doneCh, N, "subscribers to finish") } diff --git a/connect.go b/connect.go index 0eb0173..9a826b2 100644 --- a/connect.go +++ b/connect.go @@ -2,8 +2,6 @@ package main import ( "log" - "strings" - "sync/atomic" "time" paho "github.com/eclipse/paho.mqtt.golang" @@ -28,44 +26,14 @@ const ( PersistentSession = false ) -var nextConnectServerIndex = atomic.Uint64{} - -func connect(clientID string, cleanSession bool) (paho.Client, *Stat, func(), error) { +func connect(d dial, clientID string, cleanSession bool) (paho.Client, func(), error) { if clientID == "" { clientID = ClientID } if clientID == "" { clientID = Name + "-" + nuid.Next() } - - parseDial := func(in string) (u, p, s, c string) { - if in == "" { - return "", "", DefaultServer, "" - } - - if i := strings.LastIndex(in, "#"); i != -1 { - c = in[i+1:] - in = in[:i] - } - - if i := strings.LastIndex(in, "@"); i != -1 { - up := in[:i] - in = in[i+1:] - u = up - if i := strings.Index(up, ":"); i != -1 { - u = up[:i] - p = up[i+1:] - } - } - - s = in - return u, p, s, c - } - - // round-robin the servers. since we start at 0 and add first, subtract 1 to - // compensate and start at 0! - next := int((nextConnectServerIndex.Add(1) - 1) % uint64(len(Servers))) - u, p, s, c := parseDial(Servers[next]) + u, p, s, _ := d.parse() cl := paho.NewClient(paho.NewClientOptions(). SetClientID(clientID). @@ -84,22 +52,16 @@ func connect(clientID string, cleanSession bool) (paho.Client, *Stat, func(), er start := time.Now() if t := cl.Connect(); t.Wait() && t.Error() != nil { disconnectedWG.Done() - return nil, nil, nil, t.Error() + return nil, nil, t.Error() } - if c != "" { - logOp(clientID, "CONN", time.Since(start), "Connected to %q (%s)\n", s, c) - } else { - logOp(clientID, "CONN", time.Since(start), "Connected to %q\n", s) + recordOp(clientID, d, "conn", 1, time.Since(start), 0, "Connected to %s\n", d.String()) + + cleanup := func() { + start := time.Now() + cl.Disconnect(DisconnectCleanupTimeout) + recordOp(clientID, d, "disc", 1, time.Since(start), 0, "Disconnected from %s\n", d.String()) + disconnectedWG.Done() } - return cl, - &Stat{ - Ops: 1, - NS: map[string]time.Duration{"conn": time.Since(start)}, - }, - func() { - cl.Disconnect(DisconnectCleanupTimeout) - disconnectedWG.Done() - }, - nil + return cl, cleanup, nil } diff --git a/main.go b/main.go index 5c39c37..25a8a7d 100644 --- a/main.go +++ b/main.go @@ -14,12 +14,9 @@ package main import ( - "fmt" "io" "log" - "math/rand" "os" - "sync" "time" paho "github.com/eclipse/paho.mqtt.golang" @@ -29,11 +26,10 @@ import ( const ( Name = "mqtt-test" - Version = "v0.1.0" + Version = "v0.2.0" DefaultServer = "tcp://localhost:1883" DefaultQOS = 0 - Timeout = 10 * time.Second - DisconnectCleanupTimeout = 500 // milliseconds + DisconnectCleanupTimeout = 2000 // milliseconds ) var ( @@ -43,28 +39,36 @@ var ( Servers []string Username string Verbose bool + Timeout time.Duration ) -var disconnectedWG = sync.WaitGroup{} +type Stat struct { + Ops int `json:"ops"` + NS time.Duration `json:"ns"` + Bytes int64 `json:"bytes"` +} func main() { _ = mainCmd.Execute() disconnectedWG.Wait() + + printTotals() } var mainCmd = &cobra.Command{ - Use: Name + " [pub|sub|subret|...] [--flags...]", + Use: Name + " [pub|sub|test...] [--flags...]", Short: "MQTT Test and Benchmark Utility", Version: Version, } func init() { mainCmd.PersistentFlags().StringVar(&ClientID, "id", Name+"-"+nuid.Next(), "MQTT client ID") + mainCmd.PersistentFlags().DurationVar(&Timeout, "timeout", 10*time.Second, "Timeout for the test") mainCmd.PersistentFlags().StringArrayVarP(&Servers, "server", "s", []string{DefaultServer}, "MQTT endpoint as username:password@host:port") mainCmd.PersistentFlags().BoolVarP(&Quiet, "quiet", "q", false, "Quiet mode, only print results") mainCmd.PersistentFlags().BoolVarP(&Verbose, "very-verbose", "v", false, "Very verbose, print everything we can") - mainCmd.PersistentFlags().StringArrayVar(&Servers, "servers", []string{DefaultServer}, "MQTT endpoint as username:password@host:port") + oldServers := mainCmd.PersistentFlags().StringArray("servers", nil, "MQTT endpoint as username:password@host:port") mainCmd.PersistentFlags().MarkDeprecated("servers", "please use server instead.") mainCmd.PersistentPreRun = func(cmd *cobra.Command, args []string) { @@ -72,14 +76,17 @@ func init() { if Quiet { Verbose = false log.SetOutput(io.Discard) - } - if !Quiet { + } else { paho.ERROR = log.New(os.Stderr, "[MQTT ERROR] ", 0) } if Verbose { paho.WARN = log.New(os.Stderr, "[MQTT WARN] ", 0) paho.DEBUG = log.New(os.Stderr, "[MQTT DEBUG] ", 0) } + + if len(*oldServers) > 0 { + Servers = *oldServers + } } mainCmd.AddCommand(newPubCommand()) @@ -88,49 +95,4 @@ func init() { mainCmd.AddCommand(newSubRetCommand()) } -type PubValue struct { - Seq int `json:"seq"` - Timestamp int64 `json:"timestamp"` -} - -type Stat struct { - Ops int `json:"ops"` - NS map[string]time.Duration `json:"ns"` - Bytes int64 `json:"bytes"` -} - -func randomPayload(sz int) []byte { - const ch = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@$#%^&*()" - b := make([]byte, sz) - for i := range b { - b[i] = ch[rand.Intn(len(ch))] - } - return b -} - -func mqttVarIntLen(value int) int { - c := 0 - for ; value > 0; value >>= 7 { - c++ - } - return c -} - -func mqttPublishLen(topic string, qos byte, retained bool, msg []byte) int { - // Compute len (will have to add packet id if message is sent as QoS>=1) - pkLen := 2 + len(topic) + len(msg) - if qos > 0 { - pkLen += 2 - } - return 1 + mqttVarIntLen(pkLen) + pkLen -} - func defaultTopic() string { return Name + "/" + nuid.Next() } - -func logOp(clientID, op string, dur time.Duration, f string, args ...interface{}) { - log.Printf("%8s %-6s %30s\t"+f, append([]any{ - fmt.Sprintf("%.3fms", float64(dur)/float64(time.Millisecond)), - op, - clientID + ":"}, - args...)...) -} diff --git a/publish.go b/publish.go index 4528b75..8dc43f1 100644 --- a/publish.go +++ b/publish.go @@ -18,39 +18,55 @@ import ( "log" "strconv" "time" + + paho "github.com/eclipse/paho.mqtt.golang" ) -// Message options -type messageOpts struct { - qos int - retain bool - size int - topic string +type publisher struct { + clientID string + dials []dial + messages int // send this many messages + mps int // send at this rate (messages per second) + qos int // MQTT QOS + retain bool // Mark each published message as retained + size int // Approximate size of each message (may add a timestamp) + timestamp bool // Prepend a timestamp to each message + topic string // Base topic (prefix) to publish into (/{n} will be added if --topics > 0) + topics int // Cycle through this many topics appending "/{n}" } -type publisher struct { - messageOpts +func (p *publisher) publish(doneCh chan struct{}) { + defer func() { + if doneCh != nil { + doneCh <- struct{}{} + } + }() - mps int - messages int - topics int - clientID string -} + // Connect to all servers. + clients := make([]paho.Client, len(p.dials)) + for i, d := range p.dials { + id := p.clientID + if len(p.dials) > 1 { + id = p.clientID + "-" + strconv.Itoa(i) + } -func (p *publisher) publish(msgCh chan *Stat, errorCh chan error, timestamp bool) { - cl, _, cleanup, err := connect(p.clientID, CleanSession) - if err != nil { - log.Fatal(err) + cl, cleanup, err := connect(dial(d), id, CleanSession) + if err != nil { + log.Fatal(err) + } + defer cleanup() + clients[i] = cl } - defer cleanup() - opts := cl.OptionsReader() start := time.Now() - var elapsed time.Duration - bc := 0 iTopic := 0 - + iServer := 0 for n := 0; n < p.messages; n++ { + // Get a round-robin client. + i := iServer % len(p.dials) + iServer++ + d := dial(p.dials[i]) + cl := clients[i] now := time.Now() if n > 0 && p.mps > 0 { next := start.Add(time.Duration(n) * time.Second / time.Duration(p.mps)) @@ -61,7 +77,7 @@ func (p *publisher) publish(msgCh chan *Stat, errorCh chan error, timestamp bool // is always terminated with a '-', which can not be part of the random // fill. payload is then filled to the requested size with random data. payload := randomPayload(p.size) - if timestamp { + if p.timestamp { structuredPayload, _ := json.Marshal(PubValue{ Seq: n, Timestamp: time.Now().UnixNano(), @@ -75,27 +91,22 @@ func (p *publisher) publish(msgCh chan *Stat, errorCh chan error, timestamp bool } currTopic := p.topic - if p.topics > 0 { + if p.topics > 1 { currTopic = p.topic + "/" + strconv.Itoa(iTopic) iTopic = (iTopic + 1) % p.topics } startPublish := time.Now() if token := cl.Publish(currTopic, byte(p.qos), p.retain, payload); token.Wait() && token.Error() != nil { - errorCh <- token.Error() + log.Fatal(token.Error()) return } + elapsedPublish := time.Since(startPublish) - elapsed += elapsedPublish - logOp(opts.ClientID(), "PUB <-", elapsedPublish, "Published: %d bytes to %q, qos:%v, retain:%v", len(payload), currTopic, p.qos, p.retain) - bc += mqttPublishLen(currTopic, byte(p.qos), p.retain, payload) - } + pubBytes := mqttPublishLen(currTopic, byte(p.qos), p.retain, payload) + opts := cl.OptionsReader() - if msgCh != nil { - msgCh <- &Stat{ - Ops: p.messages, - NS: map[string]time.Duration{"pub": elapsed}, - Bytes: int64(bc), - } + recordOp(opts.ClientID(), d, "pub", 1, elapsedPublish, pubBytes, "<- Published: %d bytes to %q, qos:%v, retain:%v", pubBytes, currTopic, p.qos, p.retain) } + } diff --git a/receive.go b/receive.go index ae7ecb1..e9298e0 100644 --- a/receive.go +++ b/receive.go @@ -16,7 +16,6 @@ package main import ( "bytes" "encoding/json" - "fmt" "log" "strings" "sync/atomic" @@ -26,66 +25,74 @@ import ( ) type receiver struct { + dial dial // MQTT server to connect to. clientID string // MQTT client ID. - topic string // Subscription topic. + expectPublished int // expect to receive this many published messages. + expectRetained int // expect to receive this many retained messages. + expectTimestamp bool // Expect a timestamp in the payload. filterPrefix string // Only count messages if their topic starts with the prefix. qos int // MQTT QOS for the subscription. - expectRetained int // expect to receive this many retained messages. - expectPublished int // expect to receive this many published messages. repeat int // Number of times to repeat subscribe/receive/unsubscribe. + topic string // Subscription topic. - cRetained atomic.Int32 // Count of retained messages received. - cPublished atomic.Int32 // Count of published messages received. - durPublished atomic.Int64 // Total duration of published messages received (measured from the sent timestamp in the message). - bc atomic.Int64 // Byte count of all messages received. - - start time.Time - errCh chan error - statCh chan *Stat + // state + cRetained *atomic.Int32 // Count of retained messages received. + cPublished *atomic.Int32 // Count of published messages received. + durPublished *atomic.Int64 // Total duration of published messages received (measured from the sent timestamp in the message). + bc *atomic.Int64 // Byte count of all messages received. + start time.Time + allReceivedCh chan struct{} // Signal that all expected messages have been received. } -func (r *receiver) receive(readyCh chan struct{}, statCh chan *Stat, errCh chan error) { - r.errCh = errCh - r.statCh = make(chan *Stat) +func (r *receiver) receive(readyCh chan struct{}, doneCh chan struct{}) { + defer func() { + if doneCh != nil { + doneCh <- struct{}{} + } + }() + if r.filterPrefix == "" { r.filterPrefix = r.topic } - cl, _, cleanup, err := connect(r.clientID, CleanSession) + cl, cleanup, err := connect(r.dial, r.clientID, CleanSession) if err != nil { - errCh <- err - return + log.Fatal(err) } + defer cleanup() for i := 0; i < r.repeat; i++ { - // Reset the stats for each iteration. + // Reset the state for each iteration. + r.cRetained = new(atomic.Int32) + r.cPublished = new(atomic.Int32) + r.durPublished = new(atomic.Int64) + r.bc = new(atomic.Int64) r.start = time.Now() - r.cRetained.Store(0) - r.cPublished.Store(0) - r.durPublished.Store(0) - r.bc.Store(0) + r.allReceivedCh = make(chan struct{}) token := cl.Subscribe(r.topic, byte(r.qos), r.msgHandler) if token.Wait() && token.Error() != nil { - errCh <- token.Error() - return + log.Fatal(token.Error()) } - logOp(r.clientID, "SUB", time.Since(r.start), "Subscribed to %q", r.topic) + elapsed := time.Since(r.start) + r.start = time.Now() + recordOp(r.clientID, r.dial, "sub", 1, elapsed, 0, "Subscribed to %q", r.topic) + + // signal that the sub is ready to receive (pulished) messages. if readyCh != nil { readyCh <- struct{}{} } - // wait for the stat value, then clean up and forward it to the caller. Errors are handled by the caller. - stat := <-r.statCh - statCh <- stat + // wait for all messages to be received, then clean up and signal to the caller. + <-r.allReceivedCh + start := time.Now() token = cl.Unsubscribe(r.topic) if token.Wait() && token.Error() != nil { - errCh <- token.Error() - return + log.Fatal(token.Error()) } + recordOp(r.clientID, r.dial, "unsub", 1, time.Since(start), 0, "Unsubscribed from %q", r.topic) } - cleanup() } func (r *receiver) msgHandler(client paho.Client, msg paho.Message) { @@ -97,57 +104,49 @@ func (r *receiver) msgHandler(client paho.Client, msg paho.Message) { return case msg.Duplicate(): - r.errCh <- fmt.Errorf("received unexpected duplicate message") + log.Fatalf("received unexpected duplicate message") return case msg.Retained(): newC := r.cRetained.Add(1) if newC > int32(r.expectRetained) { - r.errCh <- fmt.Errorf("received unexpected retained message") + log.Fatalf("received unexpected retained message") return } - logOp(clientID, "RRET ->", time.Since(r.start), "Received %d bytes on %q, qos:%v", len(msg.Payload()), msg.Topic(), msg.Qos()) - r.bc.Add(int64(len(msg.Payload()))) - + bc := r.bc.Add(int64(len(msg.Payload()))) if newC < int32(r.expectRetained) { return } elapsed := time.Since(r.start) - r.statCh <- &Stat{ - Ops: 1, - NS: map[string]time.Duration{fmt.Sprintf("rec%vret", r.expectRetained): elapsed}, - Bytes: r.bc.Load(), - } + recordOp(r.clientID, r.dial, "rec-ret", r.expectRetained, elapsed, bc, "Received %d retained messages", r.expectRetained) + close(r.allReceivedCh) return default: newC := r.cPublished.Add(1) if newC > int32(r.expectPublished) { - r.errCh <- fmt.Errorf("received unexpected published message: dup:%v, topic: %s, qos:%v, retained:%v, payload: %q", + log.Fatalf("received unexpected published message: dup:%v, topic: %s, qos:%v, retained:%v, payload: %q", msg.Duplicate(), msg.Topic(), msg.Qos(), msg.Retained(), msg.Payload()) return } - v := PubValue{} - body := msg.Payload() - if i := bytes.IndexByte(body, '\n'); i != -1 { - body = body[:i] - } - if err := json.Unmarshal(body, &v); err != nil { - log.Fatalf("Error parsing message JSON: %v", err) + elapsed := time.Since(r.start) + if r.expectTimestamp { + v := PubValue{} + body := msg.Payload() + if i := bytes.IndexByte(body, '\n'); i != -1 { + body = body[:i] + } + if err := json.Unmarshal(body, &v); err != nil { + log.Fatalf("Error parsing message JSON: %v", err) + } + elapsed = time.Since(time.Unix(0, v.Timestamp)) } - elapsed := time.Since(time.Unix(0, v.Timestamp)) - logOp(clientID, "RPUB ->", elapsed, "Received %d bytes on %q, qos:%v", len(msg.Payload()), msg.Topic(), msg.Qos()) + recordOp(clientID, r.dial, "rec", 1, elapsed, int64(len(msg.Payload())), "Received published message on %q", msg.Topic()) - dur := r.durPublished.Add(int64(elapsed)) - bb := r.bc.Add(int64(len(msg.Payload()))) if newC < int32(r.expectPublished) { return } - r.statCh <- &Stat{ - Ops: r.expectPublished, - Bytes: bb, - NS: map[string]time.Duration{"receive": time.Duration(dur)}, - } + close(r.allReceivedCh) } } diff --git a/util.go b/util.go new file mode 100644 index 0000000..71f0a08 --- /dev/null +++ b/util.go @@ -0,0 +1,152 @@ +// Copyright 2024 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "encoding/json" + "fmt" + "log" + "math/rand" + "os" + "strings" + "sync" + "time" +) + +type PubValue struct { + Seq int `json:"seq"` + Timestamp int64 `json:"timestamp"` +} + +func randomPayload(sz int) []byte { + const ch = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@$#%^&*()" + b := make([]byte, sz) + for i := range b { + b[i] = ch[rand.Intn(len(ch))] + } + return b +} + +func mqttVarIntLen(value int) int { + c := 0 + for ; value > 0; value >>= 7 { + c++ + } + return c +} + +func mqttPublishLen(topic string, qos byte, retained bool, msg []byte) int64 { + // Compute len (will have to add packet id if message is sent as QoS>=1) + pkLen := 2 + len(topic) + len(msg) + if qos > 0 { + pkLen += 2 + } + return int64(1 + mqttVarIntLen(pkLen) + pkLen) +} + +type dial string + +func (d dial) parse() (u, p, s, c string) { + in := string(d) + if in == "" { + return "", "", DefaultServer, "" + } + + if i := strings.LastIndex(in, "#"); i != -1 { + c = in[i+1:] + in = in[:i] + } + + if i := strings.LastIndex(in, "@"); i != -1 { + up := in[:i] + in = in[i+1:] + u = up + if i := strings.Index(up, ":"); i != -1 { + u = up[:i] + p = up[i+1:] + } + } + + s = in + return u, p, s, c +} + +func (d dial) String() string { + u, _, s, c := d.parse() + if c != "" { + return c + } + if u == "" { + return s + } else { + return u + ":****@" + s + } +} + +var disconnectedWG = sync.WaitGroup{} + +var Stats = make(map[string]*Stat) +var statsMu = new(sync.Mutex) + +func recordOp(clientID string, d dial, name string, n int, dur time.Duration, bytes int64, f string, args ...any) { + statsMu.Lock() + stat, ok := Stats[name] + if !ok { + stat = new(Stat) + Stats[name] = stat + } + stat.NS += dur + stat.Ops += n + stat.Bytes += bytes + statsMu.Unlock() + + log.Printf("%8s %-6s %30s\t"+f, append([]any{ + fmt.Sprintf("%.3fms", float64(dur)/float64(time.Millisecond)), + strings.ToUpper(name), + clientID + "/" + d.String() + ":"}, + args...)...) +} + +func printTotals() { + statsMu.Lock() + defer statsMu.Unlock() + + if len(Stats) == 0 { + return + } + b, _ := json.MarshalIndent(Stats, "", " ") + os.Stdout.Write(b) +} + +func dials(ss []string) []dial { + d := make([]dial, 0, len(ss)) + for _, s := range ss { + d = append(d, dial(s)) + } + return d +} + +func waitN(doneCh chan struct{}, N int, comment string) { + if N == 0 { + N = 1 + } + for n := 0; n < N; n++ { + select { + case <-time.After(Timeout): + log.Fatal("Error: timeout waiting for ", comment) + case <-doneCh: + // one is done + } + } +}