Skip to content

Commit

Permalink
use viper for config file / env var support
Browse files Browse the repository at this point in the history
  • Loading branch information
jcodybaker committed Feb 18, 2024
1 parent 5e4580c commit cdccfab
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 81 deletions.
90 changes: 36 additions & 54 deletions cmd/helper_discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,90 +9,63 @@ import (
"net"
"os"
"sync"
"time"

"github.com/jcodybaker/shellyctl/pkg/discovery"
"github.com/rs/zerolog/log"
"github.com/spf13/pflag"
"golang.org/x/crypto/ssh/terminal"
"github.com/spf13/viper"
"golang.org/x/term"
)

var (
auth string
hosts []string
mdnsSearch bool
bleSearch bool
bleDevices []string
mdnsInterface string
mdnsZone string
mdnsService string
discoveryDeviceTTL time.Duration
searchTimeout time.Duration
discoveryConcurrency int
skipFailedHosts bool

preferIPVersion string
)

func discoveryFlags(f *pflag.FlagSet, withTTL, interactive bool) {
f.StringVar(
&auth,
f.String(
"auth",
"",
"password to use for authenticating with devices.",
)

f.StringArrayVar(
&hosts,
f.StringArray(
"host",
[]string{},
"host address of a single device. IP, DNS, or mDNS/BonJour addresses are accepted. If a URL scheme is provided, only `http` and `https` are supported. mDNS names must be within the zone specified by the `--mdns-zone` flag (default `local`).")

f.BoolVar(
&mdnsSearch,
f.Bool(
"mdns-search",
false,
"if true, devices will be discovered via mDNS")

f.BoolVar(
&bleSearch,
f.Bool(
"ble-search",
false,
"if true, devices will be discovered via Bluetooth Low-Energy")

f.StringArrayVar(
&bleDevices,
f.StringArray(
"ble-device",
[]string{},
"MAC address of a single bluetooth low-energy device. May be specified multiple times to work with multiple devices.")

f.StringVar(
&mdnsInterface,
f.String(
"mdns-interface",
"",
"if specified, search only the specified network interface for devices.")

f.StringVar(
&mdnsZone,
f.String(
"mdns-zone",
discovery.DefaultMDNSZone,
"mDNS zone to search")

f.StringVar(
&mdnsService,
f.String(
"mdns-service",
discovery.DefaultMDNSService,
"mDNS service to search")

f.DurationVar(
&searchTimeout,
f.Duration(
"search-timeout",
discovery.DefaultMDNSSearchTimeout,
"timeout for devices to respond to the mDNS discovery query.",
)

// search-interactive and interactive cannot use the BoolVar() pattern as the default
// search-interactive and interactive cannot use the Bool() pattern as the default
// varies by command and the global be set to whatever the last value was.
f.Bool(
"search-interactive",
Expand All @@ -106,29 +79,25 @@ func discoveryFlags(f *pflag.FlagSet, withTTL, interactive bool) {
"if true prompt for confirmation or passwords.",
)

f.IntVar(
&discoveryConcurrency,
f.Int(
"discovery-concurrency",
discovery.DefaultConcurrency,
"number of concurrent ",
)

f.StringVar(
&preferIPVersion,
f.String(
"prefer-ip-version",
"",
"prefer ip version (`4` or `6`)")

f.BoolVar(
&skipFailedHosts,
f.Bool(
"skip-failed-hosts",
false,
"continue with other hosts in the face errors.",
)

if withTTL {
f.DurationVar(
&discoveryDeviceTTL,
f.Duration(
"device-ttl",
discovery.DefaultDeviceTTL,
"time-to-live for discovered devices in long-lived commands like the prometheus server.",
Expand All @@ -137,6 +106,14 @@ func discoveryFlags(f *pflag.FlagSet, withTTL, interactive bool) {
}

func discoveryOptionsFromFlags(flags *pflag.FlagSet) (opts []discovery.DiscovererOption, err error) {
viper.BindPFlags(flags)
hosts := viper.GetStringSlice("host")
bleDevices := viper.GetStringSlice("ble-device")
mdnsSearch := viper.GetBool("mdns-search")
bleSearch := viper.GetBool("ble-search")
mdnsInterface := viper.GetString("mdns-interface")
preferIPVersion := viper.GetString("prefer-ip-version")

if len(hosts) == 0 && len(bleDevices) == 0 && !mdnsSearch && !bleSearch {
return nil, errors.New("no hosts and or discovery (mDNS)")
}
Expand Down Expand Up @@ -192,11 +169,11 @@ func discoveryOptionsFromFlags(flags *pflag.FlagSet) (opts []discovery.Discovere
opts = append(opts, discovery.WithSearchConfirm(searchConfirm))
}
opts = append(opts,
discovery.WithMDNSZone(mdnsZone),
discovery.WithMDNSService(mdnsService),
discovery.WithSearchTimeout(searchTimeout),
discovery.WithConcurrency(discoveryConcurrency),
discovery.WithDeviceTTL(discoveryDeviceTTL),
discovery.WithMDNSZone(viper.GetString("mdns-zone")),
discovery.WithMDNSService(viper.GetString("mdns-service")),
discovery.WithSearchTimeout(viper.GetDuration("search-timeout")),
discovery.WithConcurrency(viper.GetInt("discovery-concurrency")),
discovery.WithDeviceTTL(viper.GetDuration("device-ttl")),
discovery.WithMDNSSearchEnabled(mdnsSearch),
discovery.WithBLESearchEnabled(bleSearch),
)
Expand All @@ -206,9 +183,12 @@ func discoveryOptionsFromFlags(flags *pflag.FlagSet) (opts []discovery.Discovere
func discoveryAddDevices(ctx context.Context, d *discovery.Discoverer) error {
l := log.Ctx(ctx)
var wg sync.WaitGroup
concurrencyLimit := make(chan struct{}, discoveryConcurrency)
concurrencyLimit := make(chan struct{}, viper.GetInt("discovery-concurrency"))
defer close(concurrencyLimit)
defer wg.Wait()
hosts := viper.GetStringSlice("host")
bleDevices := viper.GetStringSlice("ble-device")
skipFailedHosts := viper.GetBool("skip-failed-hosts")
if len(bleDevices) > 0 {
select {
case concurrencyLimit <- struct{}{}:
Expand Down Expand Up @@ -251,6 +231,8 @@ func discoveryAddDevices(ctx context.Context, d *discovery.Discoverer) error {

func discoveryAddBLEDevices(ctx context.Context, d *discovery.Discoverer) error {
l := log.Ctx(ctx)
skipFailedHosts := viper.GetBool("skip-failed-hosts")
bleDevices := viper.GetStringSlice("ble-device")
for _, mac := range bleDevices {
if err := ctx.Err(); err != nil {
return err
Expand Down Expand Up @@ -311,12 +293,12 @@ func passwordPrompt(ctx context.Context, desc string) (w string, err error) {
fmt.Printf("\nDevice %s requires authentication. Please enter a password:\n", desc)
log.Ctx(ctx)

oldState, err := terminal.MakeRaw(int(os.Stdin.Fd()))
oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
if err != nil {
log.Ctx(ctx).Warn().Err(err).Msg("failed to convert terminal to raw mode for password entry")
} else {
defer func() {
if err := terminal.Restore(int(os.Stdin.Fd()), oldState); err != nil {
if err := term.Restore(int(os.Stdin.Fd()), oldState); err != nil {
log.Ctx(ctx).Warn().Err(err).Msg("failed to convert terminal to raw mode for password entry")
}
fmt.Println()
Expand Down
34 changes: 13 additions & 21 deletions cmd/prometheus.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,16 @@ import (
"github.com/jcodybaker/shellyctl/pkg/promserver"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
)

var (
bindAddr net.IP
bindPort uint16
promNamespace string
promSubsystem string
promConcurrency int
promDeviceTimeout time.Duration
"github.com/spf13/viper"
)

func init() {
prometheusCmd.Flags().IPVar(&bindAddr, "bind-addr", net.IPv6zero, "local ip address to bind the metrics server to")
prometheusCmd.Flags().Uint16Var(&bindPort, "bind-port", 8080, "port to bind the metrics server")
prometheusCmd.Flags().StringVar(&promNamespace, "prometheus-namespace", promserver.DefaultNamespace, "set the namespace string to use for prometheus metric names.")
prometheusCmd.Flags().StringVar(&promSubsystem, "prometheus-subsystem", promserver.DefaultSubsystem, "set the subsystem section of the prometheus metric names.")
prometheusCmd.Flags().IntVar(&promConcurrency, "probe-concurrency", promserver.DefaultConcurrency, "set the number of concurrent probes which will be made to service a metrics request.")
prometheusCmd.Flags().DurationVar(&promDeviceTimeout, "device-timeout", promserver.DefaultDeviceTimeout, "set the maximum time allowed for a device to respond to it probe.")
prometheusCmd.Flags().IP("bind-addr", net.IPv6zero, "local ip address to bind the metrics server to")
prometheusCmd.Flags().Uint16("bind-port", 8080, "port to bind the metrics server")
prometheusCmd.Flags().String("prometheus-namespace", promserver.DefaultNamespace, "set the namespace string to use for prometheus metric names.")
prometheusCmd.Flags().String("prometheus-subsystem", promserver.DefaultSubsystem, "set the subsystem section of the prometheus metric names.")
prometheusCmd.Flags().Int("probe-concurrency", promserver.DefaultConcurrency, "set the number of concurrent probes which will be made to service a metrics request.")
prometheusCmd.Flags().Duration("device-timeout", promserver.DefaultDeviceTimeout, "set the maximum time allowed for a device to respond to it probe.")
discoveryFlags(prometheusCmd.Flags(), true, false)
rootCmd.AddCommand(prometheusCmd)
rootCmd.AddGroup(&cobra.Group{
Expand Down Expand Up @@ -63,15 +55,15 @@ var prometheusCmd = &cobra.Command{
ps := promserver.NewServer(
ctx,
disc,
promserver.WithPrometheusNamespace(promNamespace),
promserver.WithPrometheusSubsystem(promSubsystem),
promserver.WithConcurrency(promConcurrency),
promserver.WithDeviceTimeout(promDeviceTimeout),
promserver.WithPrometheusNamespace(viper.GetString("prometheus-namespace")),
promserver.WithPrometheusSubsystem(viper.GetString("prometheus-subsystem")),
promserver.WithConcurrency(viper.GetInt("probe-concurrency")),
promserver.WithDeviceTimeout(viper.GetDuration("device-timeout")),
)

hs := http.Server{
Handler: ps,
Addr: net.JoinHostPort(bindAddr.String(), strconv.Itoa(int(bindPort))),
Addr: net.JoinHostPort(viper.GetString("bind-addr"), strconv.Itoa(int(viper.GetUint16("bind-port")))),
}
go func() {
<-ctx.Done()
Expand All @@ -81,7 +73,7 @@ var prometheusCmd = &cobra.Command{
l.Err(err).Msg("shutting down http server")
}
}()
l.Info().Msg("starting metrics server")
l.Info().Str("bind_address", hs.Addr).Msg("starting metrics server")
if err := hs.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
l.Err(err).Msg("starting http server")
}
Expand Down
21 changes: 15 additions & 6 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ import (

var (
ctx context.Context
logLevel string
outputFormat string
activeOutputter outputter.Outputter = outputter.JSON
)

Expand All @@ -33,15 +31,26 @@ func init() {
rootCmd.Run = func(cmd *cobra.Command, args []string) {
rootCmd.Help()
}
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "warn", "threshold for outputing logs: trace, debug, info, warn, error, fatal, panic")
rootCmd.PersistentFlags().StringVarP(&outputFormat, "output-format", "o", "text", "desired output format: json, min-json, yaml, text, log")
rootCmd.PersistentFlags().String("log-level", "warn", "threshold for outputing logs: trace, debug, info, warn, error, fatal, panic")
rootCmd.PersistentFlags().StringP("output-format", "o", "text", "desired output format: json, min-json, yaml, text, log")
rootCmd.PersistentFlags().String("config", "", "path to config file. format will be determined by extension (.yaml, .json, .toml, .ini valid)")

rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error {
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
viper.SetEnvPrefix("SHELLYCTL")
viper.AutomaticEnv()
viper.BindPFlags(rootCmd.PersistentFlags())
viper.BindPFlags(cmd.Flags())

if configFile := viper.GetString("config"); configFile != "" {
viper.SetConfigFile(configFile)
if err := viper.ReadInConfig(); err != nil {
return err
}
}

log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
switch strings.ToLower(logLevel) {
switch strings.ToLower(viper.GetString("log-level")) {
case "trace":
log.Logger = log.Level(zerolog.TraceLevel)
case "debug":
Expand All @@ -64,7 +73,7 @@ func init() {
ctx = log.Logger.WithContext(ctx)

var err error
activeOutputter, err = outputter.ByName(outputFormat)
activeOutputter, err = outputter.ByName(viper.GetString("output-format"))
if err != nil {
return err
}
Expand Down

0 comments on commit cdccfab

Please sign in to comment.