diff --git a/client.go b/client.go index 890134a3..210e79c6 100644 --- a/client.go +++ b/client.go @@ -28,6 +28,8 @@ import ( "time" "github.com/eclipse/paho.mqtt.golang/packets" + + "github.com/desertbit/timer" ) const ( @@ -96,10 +98,9 @@ type Client interface { // client implements the Client interface type client struct { - lastSent atomic.Value - lastReceived atomic.Value - pingOutstanding int32 - status uint32 + keepaliveTimer *timer.Timer + pingTimeoutTimer *timer.Timer + status uint32 sync.RWMutex messageIds conn net.Conn @@ -333,9 +334,6 @@ func (c *client) Connect() Token { c.options.protocolVersionExplicit = true if c.options.KeepAlive != 0 { - atomic.StoreInt32(&c.pingOutstanding, 0) - c.lastReceived.Store(time.Now()) - c.lastSent.Store(time.Now()) c.workers.Add(1) go keepalive(c) } @@ -454,9 +452,6 @@ func (c *client) reconnect() { c.stop = make(chan struct{}) if c.options.KeepAlive != 0 { - atomic.StoreInt32(&c.pingOutstanding, 0) - c.lastReceived.Store(time.Now()) - c.lastSent.Store(time.Now()) c.workers.Add(1) go keepalive(c) } diff --git a/net.go b/net.go index 804cb3fc..bb74f444 100644 --- a/net.go +++ b/net.go @@ -22,7 +22,6 @@ import ( "net/url" "os" "reflect" - "sync/atomic" "time" "github.com/eclipse/paho.mqtt.golang/packets" @@ -118,7 +117,7 @@ func incoming(c *client) { case c.ibound <- cp: // Notify keepalive logic that we recently received a packet if c.options.KeepAlive != 0 { - c.lastReceived.Store(time.Now()) + resetKeepaliveTimer(c) } case <-c.stop: // This avoids a deadlock should a message arrive while shutting down. @@ -196,7 +195,7 @@ func outgoing(c *client) { } // Reset ping timer after sending control packet. if c.options.KeepAlive != 0 { - c.lastSent.Store(time.Now()) + resetKeepaliveTimer(c) } } } @@ -219,7 +218,7 @@ func alllogic(c *client) { switch m := msg.(type) { case *packets.PingrespPacket: DEBUG.Println(NET, "received pingresp") - atomic.StoreInt32(&c.pingOutstanding, 0) + stopPingTimeOutTimer(c) case *packets.SubackPacket: DEBUG.Println(NET, "received suback, id:", m.MessageID) token := c.getToken(m.MessageID) diff --git a/ping.go b/ping.go index dbc1ff45..1f6c9719 100644 --- a/ping.go +++ b/ping.go @@ -16,54 +16,49 @@ package mqtt import ( "errors" - "sync/atomic" "time" "github.com/eclipse/paho.mqtt.golang/packets" + + "github.com/desertbit/timer" ) func keepalive(c *client) { defer c.workers.Done() DEBUG.Println(PNG, "keepalive starting") - var checkInterval int64 - var pingSent time.Time - if c.options.KeepAlive > 10 { - checkInterval = 5 - } else { - checkInterval = c.options.KeepAlive / 2 - } + c.keepaliveTimer = timer.NewTimer(time.Duration(c.options.KeepAlive * int64(time.Second))) + defer c.keepaliveTimer.Stop() - intervalTicker := time.NewTicker(time.Duration(checkInterval * int64(time.Second))) - defer intervalTicker.Stop() + c.pingTimeoutTimer = timer.NewTimer(c.options.PingTimeout) + c.pingTimeoutTimer.Stop() + defer c.pingTimeoutTimer.Stop() for { select { case <-c.stop: DEBUG.Println(PNG, "keepalive stopped") return - case <-intervalTicker.C: - lastSent := c.lastSent.Load().(time.Time) - lastReceived := c.lastReceived.Load().(time.Time) - - DEBUG.Println(PNG, "ping check", time.Since(lastSent).Seconds()) - if time.Since(lastSent) >= time.Duration(c.options.KeepAlive*int64(time.Second)) || time.Since(lastReceived) >= time.Duration(c.options.KeepAlive*int64(time.Second)) { - if atomic.LoadInt32(&c.pingOutstanding) == 0 { - DEBUG.Println(PNG, "keepalive sending ping") - ping := packets.NewControlPacket(packets.Pingreq).(*packets.PingreqPacket) - //We don't want to wait behind large messages being sent, the Write call - //will block until it it able to send the packet. - atomic.StoreInt32(&c.pingOutstanding, 1) - ping.Write(c.conn) - c.lastSent.Store(time.Now()) - pingSent = time.Now() - } - } - if atomic.LoadInt32(&c.pingOutstanding) > 0 && time.Since(pingSent) >= c.options.PingTimeout { - CRITICAL.Println(PNG, "pingresp not received, disconnecting") - c.errors <- errors.New("pingresp not received, disconnecting") - return - } + case <-c.keepaliveTimer.C: + resetKeepaliveTimer(c) + DEBUG.Println(PNG, "keepalive sending ping") + ping := packets.NewControlPacket(packets.Pingreq).(*packets.PingreqPacket) + //We don't want to wait behind large messages being sent, the Write call + //will block until it it able to send the packet. + ping.Write(c.conn) + c.pingTimeoutTimer.Reset(c.options.PingTimeout) + case <-c.pingTimeoutTimer.C: + CRITICAL.Println(PNG, "pingresp not received, disconnecting") + c.errors <- errors.New("pingresp not received, disconnecting") + return } } } + +func resetKeepaliveTimer(c *client) { + c.keepaliveTimer.Reset(time.Duration(c.options.KeepAlive * int64(time.Second))) +} + +func stopPingTimeOutTimer(c *client) { + c.pingTimeoutTimer.Stop() +}