Skip to content

Commit

Permalink
support for websocket client + notifications
Browse files Browse the repository at this point in the history
  • Loading branch information
jcodybaker committed Dec 22, 2024
1 parent d74fc2c commit 47f186e
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 41 deletions.
9 changes: 8 additions & 1 deletion cmd/helper_discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ import (
"strings"
"sync"

"math/rand"

mqtt "github.com/eclipse/paho.mqtt.golang"
"github.com/jcodybaker/shellyctl/pkg/discovery"
"github.com/rs/zerolog/log"
"github.com/spf13/pflag"
"github.com/spf13/viper"
"golang.org/x/exp/rand"
"golang.org/x/term"
)

Expand Down Expand Up @@ -47,6 +48,12 @@ func discoveryFlags(f *pflag.FlagSet, opts discoveryFlagsOptions) {
"mDNS names must be within the zone specified by the `--mdns-zone` flag (default `local`).\n"+
"URL formatted auth is supported (ex. `http://admin:[email protected]/`)")

f.String(
"local-id",
"shellyctl-${PID}",
"local src id to use. ${PID} will be replaced with the current process id, ${RANDOM} with a random int",
)

f.Bool(
"mdns-search",
false,
Expand Down
41 changes: 41 additions & 0 deletions cmd/notifications.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ import (
"os"
"os/signal"

"github.com/jcodybaker/go-shelly"
"github.com/jcodybaker/shellyctl/pkg/discovery"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)

func init() {
Expand Down Expand Up @@ -49,6 +51,45 @@ var notificationsCmd = &cobra.Command{
if err := discoveryAddDevices(ctx, disc); err != nil {
l.Fatal().Err(err).Msg("adding devices")
}

for _, d := range disc.AllDevices() {
d := d
m, err := d.Open(ctx)
if err != nil {
if viper.GetBool("skip-failed-hosts") {
l.Warn().Err(err).
Str("instance", d.Instance()).
Str("name", d.BestName()).
Msg("failed to open device; skipping")
continue
} else {
l.Fatal().Err(err).
Str("instance", d.Instance()).
Str("name", d.BestName()).
Msg("failed to open device")
}
}
// We need to make an initial request for websockets
r := &shelly.ShellyGetConfigRequest{}
_, respF, err := r.Do(ctx, m, d.AuthCallback(ctx))
if err != nil {
if viper.GetBool("skip-failed-hosts") {
l.Warn().Err(err).
Str("instance", d.Instance()).
Str("name", d.BestName()).
Msg("failed to open device; skipping")
continue
} else {
l.Fatal().Err(err).
Str("instance", d.Instance()).
Str("name", d.BestName()).
Msg("failed to open device")
}
}
log.Info().Int64("id", respF.ID).Msg("connection established")
defer m.Disconnect(ctx)
}
log.Info().Msg("beginning notification announcements")
for {
select {
case <-ctx.Done():
Expand Down
14 changes: 8 additions & 6 deletions pkg/discovery/ble.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,11 @@ func (d *Discoverer) searchBLE(ctx context.Context, stop chan struct{}) ([]*Devi
}
macStr := strings.ToUpper(scanResult.Address.String())
dev := &Device{
Name: scanResult.LocalName(),
MACAddr: macStr,
uri: (&url.URL{Scheme: "ble", Host: macStr}).String(),
authCallback: d.authCallback,
Name: scanResult.LocalName(),
MACAddr: macStr,
uri: (&url.URL{Scheme: "ble", Host: macStr}).String(),
authCallback: d.authCallback,
notifications: &d.notifications,
}
ll := dev.LogCtx(ctx)

Expand Down Expand Up @@ -213,7 +214,8 @@ func (d *Discoverer) AddBLE(ctx context.Context, mac string) (*Device, error) {
ble: &BLEDevice{
options: d.options,
},
authCallback: d.authCallback,
authCallback: d.authCallback,
notifications: &d.notifications,
})
return dev, nil
}
Expand Down Expand Up @@ -351,7 +353,7 @@ func (b *BLEDevice) Call(
Str("subcomponent", "ble").
Str("method", cmd.Cmd).Logger()
cmd.ID = atomic.AddInt64(&bleMGRPCID, 1)
reqFrame := frame.NewRequestFrame("shellyctl", "", "", cmd, false)
reqFrame := frame.NewRequestFrame(localID(), "", "", cmd, false)
reqFrameBytes, err := json.Marshal(reqFrame)
if err != nil {
return nil, fmt.Errorf("encoding command: %w", err)
Expand Down
32 changes: 26 additions & 6 deletions pkg/discovery/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package discovery
import (
"context"
"fmt"
"math/rand"
"os"
"strconv"
"strings"
"time"

Expand All @@ -12,6 +15,7 @@ import (
"github.com/mongoose-os/mos/common/mgrpc/codec"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/spf13/viper"
)

type AuthCallback func(ctx context.Context, desc string) (pw string, err error)
Expand All @@ -29,6 +33,8 @@ type Device struct {

mqttPrefix string
mqttClient mqtt.Client

notifications *notifications
}

// Open creates an mongoose rpc channel to the device.
Expand All @@ -39,27 +45,33 @@ func (d *Device) Open(ctx context.Context) (mgrpc.MgRPC, error) {
if err := d.ble.open(ctx, d.MACAddr); err != nil {
return nil, err
}
d.notifications.register(d.ble)
return d.ble, nil
}
if d.mqttClient != nil && d.mqttPrefix != "" {
c, err := newMQTTCodec(ctx, d.mqttPrefix, d.mqttClient)
if err != nil {
return nil, fmt.Errorf("establishing mqtt rpc channel: %w", err)
}
return mgrpc.Serve(ctx, c), nil
m := mgrpc.Serve(ctx, c)
d.notifications.register(m)
return m, nil
}
if strings.HasPrefix(d.uri, "ws://") || strings.HasPrefix(d.uri, "wss://") {
c, err := mgrpc.New(ctx, d.uri,
m, err := mgrpc.New(ctx, d.uri,
mgrpc.UseWebSocket(),
mgrpc.LocalID(localID()),
)
if err != nil {
return nil, fmt.Errorf("establishing rpc channel: %w", err)
}
ll.Info().Str("channel_protocol", "http").Msg("connected to device")
return c, nil
ll.Info().Str("channel_protocol", "ws").Msg("connected to device")
d.notifications.register(m)
return m, nil
}
c, err := mgrpc.New(ctx, d.uri,
m, err := mgrpc.New(ctx, d.uri,
mgrpc.UseHTTPPost(),
mgrpc.LocalID(localID()),
mgrpc.CodecOptions(
codec.Options{
HTTPOut: codec.OutboundHTTPCodecOptions{
Expand All @@ -70,8 +82,9 @@ func (d *Device) Open(ctx context.Context) (mgrpc.MgRPC, error) {
if err != nil {
return nil, fmt.Errorf("establishing rpc channel: %w", err)
}
d.notifications.register(m)
ll.Info().Str("channel_protocol", "http").Msg("connected to device")
return c, nil
return m, nil
}

func (d *Device) resolveSpecs(ctx context.Context) error {
Expand Down Expand Up @@ -136,3 +149,10 @@ func WithDeviceName(name string) DeviceOption {
d.Name = name
}
}

func localID() string {
l := viper.GetString("local-id")
l = strings.Replace(l, "${PID}", strconv.Itoa(os.Getpid()), -1)
l = strings.Replace(l, "${RANDOM}", strconv.Itoa(rand.Int()), -1)
return l
}
20 changes: 10 additions & 10 deletions pkg/discovery/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ type Discoverer struct {
// so mDNS/BLE can opperate simultaneously.
ioLock sync.Mutex

statusChan chan StatusNotification
fullStatusChan chan StatusNotification
eventChan chan EventNotification
notifications
}

// AddDeviceByAddress attempts to parse a user-provided URI and add the device.
Expand Down Expand Up @@ -124,9 +122,10 @@ func (d *Discoverer) AddDeviceByAddress(ctx context.Context, addr string, opts .
}

dev := &Device{
uri: u.String(),
source: sourceManual,
authCallback: authCallback,
uri: u.String(),
source: sourceManual,
authCallback: authCallback,
notifications: &d.notifications,
}

if err = dev.resolveSpecs(ctx); err != nil {
Expand All @@ -139,10 +138,11 @@ func (d *Discoverer) AddDeviceByAddress(ctx context.Context, addr string, opts .

func (d *Discoverer) AddMQTTDevice(ctx context.Context, topicPrefix string, opts ...DeviceOption) (*Device, error) {
dev := &Device{
source: sourceManual,
authCallback: d.authCallback,
mqttPrefix: topicPrefix,
mqttClient: d.mqttClient,
source: sourceManual,
authCallback: d.authCallback,
mqttPrefix: topicPrefix,
mqttClient: d.mqttClient,
notifications: &d.notifications,
}

if err := dev.resolveSpecs(ctx); err != nil {
Expand Down
4 changes: 1 addition & 3 deletions pkg/discovery/mqtt.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ func (d *Discoverer) MQTTConnect(ctx context.Context) error {
return fmt.Errorf("subscribing to MQTT topic %q: %w", t, err)
}
s := mgrpc.Serve(ctx, c)
s.AddHandler("NotifyStatus", d.statusNotificationHandler)
s.AddHandler("NotifyFullStatus", d.fullStatusNotificationHandler)
s.AddHandler("NotifyEvent", d.eventNotificationHandler)
d.notifications.register(s)
}
return nil
}
Expand Down
46 changes: 31 additions & 15 deletions pkg/discovery/notifications.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,29 @@ package discovery

import (
"encoding/json"
"sync"

"github.com/jcodybaker/go-shelly"
"github.com/mongoose-os/mos/common/mgrpc"
"github.com/mongoose-os/mos/common/mgrpc/frame"
"github.com/rs/zerolog/log"
)

type notifications struct {
statusChan chan StatusNotification
fullStatusChan chan StatusNotification
eventChan chan EventNotification

lock sync.Mutex
}

func (n *notifications) register(s mgrpc.MgRPC) {
log.Debug().Msg("registering notification handlers")
s.AddHandler("NotifyStatus", n.statusNotificationHandler)
s.AddHandler("NotifyFullStatus", n.fullStatusNotificationHandler)
s.AddHandler("NotifyEvent", n.eventNotificationHandler)
}

// StatusNotification carries a status notification and metadata.
type StatusNotification struct {
Status *shelly.NotifyStatus
Expand Down Expand Up @@ -56,10 +72,10 @@ func (d *Discoverer) GetEventNotifications(buffer int) <-chan EventNotification
return d.eventChan
}

func (d *Discoverer) statusNotificationHandler(mr mgrpc.MgRPC, f *frame.Frame) *frame.Frame {
d.lock.Lock()
defer d.lock.Unlock()
if d.statusChan == nil {
func (n *notifications) statusNotificationHandler(mr mgrpc.MgRPC, f *frame.Frame) *frame.Frame {
n.lock.Lock()
defer n.lock.Unlock()
if n.statusChan == nil {
return nil
}
s := &shelly.NotifyStatus{}
Expand All @@ -72,17 +88,17 @@ func (d *Discoverer) statusNotificationHandler(mr mgrpc.MgRPC, f *frame.Frame) *
Str("payload", string(f.Params)).
Msg("unmarshalling NotifyStatus frame")
}
d.statusChan <- StatusNotification{
n.statusChan <- StatusNotification{
Status: s,
Frame: f,
}
return nil
}

func (d *Discoverer) fullStatusNotificationHandler(mr mgrpc.MgRPC, f *frame.Frame) *frame.Frame {
d.lock.Lock()
defer d.lock.Unlock()
if d.fullStatusChan == nil {
func (n *notifications) fullStatusNotificationHandler(mr mgrpc.MgRPC, f *frame.Frame) *frame.Frame {
n.lock.Lock()
defer n.lock.Unlock()
if n.fullStatusChan == nil {
return nil
}
s := &shelly.NotifyStatus{}
Expand All @@ -95,17 +111,17 @@ func (d *Discoverer) fullStatusNotificationHandler(mr mgrpc.MgRPC, f *frame.Fram
Str("payload", string(f.Params)).
Msg("unmarshalling NotifyFullStatus frame")
}
d.fullStatusChan <- StatusNotification{
n.fullStatusChan <- StatusNotification{
Status: s,
Frame: f,
}
return nil
}

func (d *Discoverer) eventNotificationHandler(mr mgrpc.MgRPC, f *frame.Frame) *frame.Frame {
d.lock.Lock()
defer d.lock.Unlock()
if d.eventChan == nil {
func (n *notifications) eventNotificationHandler(mr mgrpc.MgRPC, f *frame.Frame) *frame.Frame {
n.lock.Lock()
defer n.lock.Unlock()
if n.eventChan == nil {
return nil
}
e := &shelly.NotifyEvent{}
Expand All @@ -118,7 +134,7 @@ func (d *Discoverer) eventNotificationHandler(mr mgrpc.MgRPC, f *frame.Frame) *f
Str("payload", string(f.Params)).
Msg("unmarshalling NotifyFullStatus frame")
}
d.eventChan <- EventNotification{
n.eventChan <- EventNotification{
Event: e,
Frame: f,
}
Expand Down

0 comments on commit 47f186e

Please sign in to comment.