From 69d4322d97ef00e4215e40ea44939d24ebd7687e Mon Sep 17 00:00:00 2001 From: Cody Baker Date: Sun, 18 Feb 2024 23:58:36 +0000 Subject: [PATCH] ensure discovery completes on schedule --- cmd/gen.go | 2 +- cmd/helper_discovery.go | 37 ++++++++++++++++++++--------------- cmd/prometheus.go | 6 +++++- cmd/shelly.go | 2 +- pkg/discovery/discovery.go | 2 +- pkg/discovery/mdns.go | 7 ++++++- pkg/discovery/mdns_test.go | 2 +- pkg/discovery/options.go | 18 +++++++++++++---- pkg/discovery/test_harness.go | 3 ++- pkg/promserver/server.go | 2 ++ 10 files changed, 54 insertions(+), 27 deletions(-) diff --git a/cmd/gen.go b/cmd/gen.go index 1448118..4827f46 100644 --- a/cmd/gen.go +++ b/cmd/gen.go @@ -202,7 +202,7 @@ func init() { } for _, childCmd := range c.Parent.Commands() { childRun := childCmd.RunE - discoveryFlags(childCmd.Flags(), false, true) + discoveryFlags(childCmd.Flags(), discoveryFlagsOptions{interactive: true}) childCmd.RunE = func(cmd *cobra.Command, args []string) error { if err := rootCmd.PersistentPreRunE(cmd, args); err != nil { return err diff --git a/cmd/helper_discovery.go b/cmd/helper_discovery.go index 9354ebe..668a33e 100644 --- a/cmd/helper_discovery.go +++ b/cmd/helper_discovery.go @@ -19,7 +19,13 @@ import ( var addAll bool -func discoveryFlags(f *pflag.FlagSet, withTTL, interactive bool) { +type discoveryFlagsOptions struct { + withTTL bool + interactive bool + searchStrictTimeoutDefault bool +} + +func discoveryFlags(f *pflag.FlagSet, opts discoveryFlagsOptions) { f.String( "auth", "", @@ -67,17 +73,23 @@ func discoveryFlags(f *pflag.FlagSet, withTTL, interactive bool) { "timeout for devices to respond to the mDNS discovery query.", ) + f.Bool( + "search-strict-timeout", + opts.searchStrictTimeoutDefault, + "ignore devices which have been found but completed their initial query within the search-timeout", + ) + // search-interactive and interactive cannot use the Bool() pattern as the default // varies by command and the global be set to whatever the last value was. f.Bool( "search-interactive", - interactive, + opts.interactive, "if true confirm devices discovered in search before proceeding with commands. Defers to --interactive if not explicitly set.", ) f.Bool( "interactive", - interactive, + opts.interactive, "if true prompt for confirmation or passwords.", ) @@ -98,7 +110,7 @@ func discoveryFlags(f *pflag.FlagSet, withTTL, interactive bool) { "continue with other hosts in the face errors.", ) - if withTTL { + if opts.withTTL { f.Duration( "device-ttl", discovery.DefaultDeviceTTL, @@ -134,22 +146,14 @@ func discoveryOptionsFromFlags(flags *pflag.FlagSet) (opts []discovery.Discovere default: return nil, errors.New("invalid value for --prefer-ip-version; must be `4` or `6`") } - searchInteractive, err := flags.GetBool("search-interactive") - if err != nil { - return nil, err - } + searchInteractive := viper.GetBool("search-interactive") explictSearchInteractive := flags.Lookup("search-interactive").Changed - interactive, err := flags.GetBool("interactive") - if err != nil { - return nil, err - } + interactive := viper.GetBool("interactive") + if !explictSearchInteractive { searchInteractive = interactive } - auth, err := flags.GetString("auth") - if err != nil { - return nil, err - } + auth := viper.GetString("auth") if auth != "" { opts = append(opts, discovery.WithAuthCallback(func(_ context.Context, _ string) (passwd string, err error) { return auth, nil @@ -174,6 +178,7 @@ func discoveryOptionsFromFlags(flags *pflag.FlagSet) (opts []discovery.Discovere discovery.WithMDNSZone(viper.GetString("mdns-zone")), discovery.WithMDNSService(viper.GetString("mdns-service")), discovery.WithSearchTimeout(viper.GetDuration("search-timeout")), + discovery.WithSearchStrictTimeout(viper.GetBool("search-strict-timeout")), discovery.WithConcurrency(viper.GetInt("discovery-concurrency")), discovery.WithDeviceTTL(viper.GetDuration("device-ttl")), discovery.WithMDNSSearchEnabled(mdnsSearch), diff --git a/cmd/prometheus.go b/cmd/prometheus.go index 8b05e02..6308ba5 100644 --- a/cmd/prometheus.go +++ b/cmd/prometheus.go @@ -25,7 +25,11 @@ func init() { prometheusCmd.Flags().Int("probe-concurrency", promserver.DefaultConcurrency, "set the number of concurrent probes which will be made to service a metrics request.") prometheusCmd.Flags().Duration("device-timeout", promserver.DefaultDeviceTimeout, "set the maximum time allowed for a device to respond to it probe.") prometheusCmd.Flags().Duration("scrape-duration-warning", promserver.DefaultScrapeDurationWarning, "sets the value for scrape duration warning. Scrapes which exceed this duration will log a warning generate. Default value 8s is 80% of the 10s default prometheus scrape_timeout.") - discoveryFlags(prometheusCmd.Flags(), true, false) + discoveryFlags(prometheusCmd.Flags(), discoveryFlagsOptions{ + withTTL: true, + interactive: false, + searchStrictTimeoutDefault: true, + }) rootCmd.AddCommand(prometheusCmd) rootCmd.AddGroup(&cobra.Group{ ID: "servers", diff --git a/cmd/shelly.go b/cmd/shelly.go index 183796b..5c02743 100644 --- a/cmd/shelly.go +++ b/cmd/shelly.go @@ -20,7 +20,7 @@ func init() { "password", "", "password to use for auth. If empty, the password will be cleared.", ) shellyComponent.Parent.AddCommand(shellyAuthCmd) - discoveryFlags(shellyAuthCmd.Flags(), false, true) + discoveryFlags(shellyAuthCmd.Flags(), discoveryFlagsOptions{interactive: true}) shellyAuthCmd.RunE = func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() ll := log.Ctx(ctx).With().Str("request", (&shelly.ShellySetAuthRequest{}).Method()).Logger() diff --git a/pkg/discovery/discovery.go b/pkg/discovery/discovery.go index a40c006..a83db3b 100644 --- a/pkg/discovery/discovery.go +++ b/pkg/discovery/discovery.go @@ -37,7 +37,7 @@ func NewDiscoverer(opts ...DiscovererOption) *Discoverer { mdnsService: DefaultMDNSService, searchTimeout: DefaultMDNSSearchTimeout, concurrency: DefaultConcurrency, - mdnsQueryFunc: mdns.Query, + mdnsQueryFunc: mdns.QueryContext, }, } for _, o := range opts { diff --git a/pkg/discovery/mdns.go b/pkg/discovery/mdns.go index 6784277..c261ac1 100644 --- a/pkg/discovery/mdns.go +++ b/pkg/discovery/mdns.go @@ -22,6 +22,11 @@ func (d *Discoverer) searchMDNS(ctx context.Context, stop chan struct{}) ([]*Dev if !d.mdnsSearchEnabled { return nil, nil } + if d.searchStrictTimeout { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, d.searchTimeout) + defer cancel() + } c := make(chan *mdns.ServiceEntry, mdnsSearchBuffer) params := &mdns.QueryParam{ Service: d.mdnsService, @@ -93,7 +98,7 @@ func (d *Discoverer) searchMDNS(ctx context.Context, stop chan struct{}) ([]*Dev } }() - if err := d.mdnsQueryFunc(params); err != nil { + if err := d.mdnsQueryFunc(ctx, params); err != nil { close(c) return nil, fmt.Errorf("querying mdns for devices: %w", err) } diff --git a/pkg/discovery/mdns_test.go b/pkg/discovery/mdns_test.go index 22070fb..49b609c 100644 --- a/pkg/discovery/mdns_test.go +++ b/pkg/discovery/mdns_test.go @@ -60,7 +60,7 @@ func TestDiscovererMDNSSearch(t *testing.T) { serviceEntryTemplate.Port, err = strconv.Atoi(port) require.NoError(t, err) - queryFunc := func(params *mdns.QueryParam) error { + queryFunc := func(ctx context.Context, params *mdns.QueryParam) error { se := serviceEntryTemplate params.Entries <- &se return nil diff --git a/pkg/discovery/options.go b/pkg/discovery/options.go index 03799e5..4c4a313 100644 --- a/pkg/discovery/options.go +++ b/pkg/discovery/options.go @@ -1,6 +1,7 @@ package discovery import ( + "context" "net" "sync" "time" @@ -24,9 +25,10 @@ type options struct { mdnsService string mdnsSearchEnabled bool - searchTimeout time.Duration - searchConfirm SearchConfirm - concurrency int + searchStrictTimeout bool + searchTimeout time.Duration + searchConfirm SearchConfirm + concurrency int // deviceTTL is relevant for long-lived commands (like prometheus metrics server) when // mixed with mDNS or other ephemeral discovery. @@ -34,7 +36,7 @@ type options struct { preferIPVersion string - mdnsQueryFunc func(*mdns.QueryParam) error + mdnsQueryFunc func(context.Context, *mdns.QueryParam) error } // DiscovererOption provides optional parameters for the Discoverer. @@ -126,4 +128,12 @@ func WithAuthCallback(authCallback AuthCallback) DiscovererOption { } } +// WithSearchStrictTimeout will force devices which have been discovered, but not resolved and added +// to finish within the search timeout or be cancelled. +func WithSearchStrictTimeout(strictTimeoutMode bool) DiscovererOption { + return func(d *Discoverer) { + d.searchStrictTimeout = strictTimeoutMode + } +} + type DeviceOption func(*Device) diff --git a/pkg/discovery/test_harness.go b/pkg/discovery/test_harness.go index ab81ceb..d2b0a09 100644 --- a/pkg/discovery/test_harness.go +++ b/pkg/discovery/test_harness.go @@ -1,6 +1,7 @@ package discovery import ( + "context" "encoding/json" "fmt" "io" @@ -38,7 +39,7 @@ func NewTestDiscoverer(t *testing.T, opts ...DiscovererOption) *TestDiscoverer { } // SetMDNSQueryFunc facilitates overriding the mDNS query function for testing. -func (td *TestDiscoverer) SetMDNSQueryFunc(q func(*mdns.QueryParam) error) { +func (td *TestDiscoverer) SetMDNSQueryFunc(q func(context.Context, *mdns.QueryParam) error) { td.mdnsQueryFunc = q } diff --git a/pkg/promserver/server.go b/pkg/promserver/server.go index e3152bc..5409b06 100644 --- a/pkg/promserver/server.go +++ b/pkg/promserver/server.go @@ -266,9 +266,11 @@ func (s *Server) Collect(ch chan<- prometheus.Metric) { } l.Debug().Dur("duration", duration).Msg("finished all collection") }() + l.Debug().Msg("starting discovery") if _, err := s.discoverer.Search(s.ctx); err != nil { l.Err(err).Msg("finding new devices") } + l.Debug().Dur("duration", time.Since(start)).Msg("finished discovery") var wg sync.WaitGroup defer wg.Wait() concurrencyLimit := make(chan struct{}, s.concurrency)