diff --git a/cmd/client/cli/run.go b/cmd/client/cli/run.go index a96788b..cf25b38 100644 --- a/cmd/client/cli/run.go +++ b/cmd/client/cli/run.go @@ -26,7 +26,7 @@ func parseJobsFlag(cmd *cobra.Command, name string) int { logrus.Fatalf("Could not get jobs number: %v", err) } if jobs < 0 { - logrus.Fatal("run: job count should be non-negavtive") + logrus.Fatal("run: job count should be non-negative") } return jobs } @@ -42,6 +42,13 @@ func NewRun(cmd *cobra.Command, _ []string, cfg *client.Config) NeoCLI { jobs := parseJobsFlag(cmd, "jobs") endlessJobs := parseJobsFlag(cmd, "endless-jobs") + timeoutScaleTarget, err := cmd.Flags().GetFloat64("timeout-autoscale-target") + if err != nil { + logrus.Fatalf("Could not get timeout-autoscale-target flag: %v", err) + } + if timeoutScaleTarget < 0 { + logrus.Fatalf("timeout-autoscale-target should be non-negative") + } neocli.Weight = jobs cli.sender = joblogger.NewRemoteSender(neocli) @@ -49,6 +56,7 @@ func NewRun(cmd *cobra.Command, _ []string, cfg *client.Config) NeoCLI { cli.ClientID(), jobs, endlessJobs, + timeoutScaleTarget, cfg, neocli, cli.sender, diff --git a/cmd/client/cmd/run.go b/cmd/client/cmd/run.go index 5452142..c7980ef 100644 --- a/cmd/client/cmd/run.go +++ b/cmd/client/cmd/run.go @@ -29,4 +29,10 @@ func init() { rootCmd.AddCommand(runCmd) runCmd.Flags().IntP("jobs", "j", runtime.NumCPU()*cli.JobsPerCPU, "workers to run") runCmd.Flags().IntP("endless-jobs", "e", 0, "workers to run for endless mode. Default is 0 for no endless mode") + runCmd.Flags().Float64( + "timeout-autoscale-target", + 1.5, + "target upper bound for recurrent exploit worker utilization by scaling timeouts."+ + " Setting this to 0 disables scaling", + ) } diff --git a/internal/exploit/metrics.go b/internal/exploit/metrics.go index 26cc8de..f5179a0 100644 --- a/internal/exploit/metrics.go +++ b/internal/exploit/metrics.go @@ -3,18 +3,16 @@ package exploit import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" - "github.com/samber/lo" ) type Metrics struct { FlagsSubmitted *prometheus.CounterVec Teams prometheus.Gauge + Queue *prometheus.GaugeVec } func NewMetrics(namespace string) *Metrics { const subsystem = "exploit_runner" - targetLabels := []string{"target_id", "target_ip"} - exploitLabels := []string{"exploit_id", "exploit_version", "exploit_type"} return &Metrics{ FlagsSubmitted: promauto.NewCounterVec( @@ -24,7 +22,13 @@ func NewMetrics(namespace string) *Metrics { Name: "flags_submitted_total", Help: "Number of exploits finished", }, - lo.Union(targetLabels, exploitLabels), + []string{ + "target_id", + "target_ip", + "exploit_id", + "exploit_version", + "exploit_type", + }, ), Teams: promauto.NewGauge(prometheus.GaugeOpts{ @@ -33,5 +37,15 @@ func NewMetrics(namespace string) *Metrics { Name: "teams", Help: "Number of teams scheduled for the current runner", }), + + Queue: promauto.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "queue", + Help: "Number of exploits in the queue", + }, + []string{"type"}, + ), } } diff --git a/internal/exploit/models.go b/internal/exploit/models.go index 3edf3fd..d74df74 100644 --- a/internal/exploit/models.go +++ b/internal/exploit/models.go @@ -42,15 +42,16 @@ func (r *FullResult) MetricLabels() prometheus.Labels { } type State struct { - ID string - Version int64 - Dir string - Path string - Disabled bool - Endless bool - RunEvery time.Duration - LastRun time.Time - Timeout time.Duration + ID string + Version int64 + Dir string + Path string + Disabled bool + Endless bool + RunEvery time.Duration + LastRun time.Time + ScaledTimeout time.Duration + Timeout time.Duration } func (s *State) ExploitType() models.ExploitType { diff --git a/internal/exploit/runner.go b/internal/exploit/runner.go index 7a2cec1..0493931 100644 --- a/internal/exploit/runner.go +++ b/internal/exploit/runner.go @@ -31,19 +31,23 @@ var ( func NewRunner( clientID string, maxJobs, maxEndlessJobs int, + timeoutScaleTarget float64, clientConfig *client.Config, c *client.Client, logSender joblogger.Sender, ) *Runner { return &Runner{ - storage: NewStorage(NewCache(), clientConfig.ExploitDir, c), - cfg: &config.ExploitsConfig{}, - client: c, - maxJobs: maxJobs, - maxEndlessJobs: maxEndlessJobs, - singleRuns: make(chan *epb.SingleRunSubscribeResponse), - restarts: make(chan struct{}, 1), - logSender: logSender, + storage: NewStorage(NewCache(), clientConfig.ExploitDir, c), + cfg: &config.ExploitsConfig{}, + client: c, + + maxJobs: maxJobs, + maxEndlessJobs: maxEndlessJobs, + timeoutScaleTarget: timeoutScaleTarget, + + singleRuns: make(chan *epb.SingleRunSubscribeResponse), + restarts: make(chan struct{}, 1), + logSender: logSender, metricsPusher: push. New(clientConfig.MetricsHost, "neo_runner"). Grouping("client_id", clientID). @@ -63,8 +67,9 @@ type Runner struct { metricsPusher *push.Pusher metrics *Metrics - maxJobs int - maxEndlessJobs int + maxJobs int + maxEndlessJobs int + timeoutScaleTarget float64 simpleLoop *submitLoop endlessLoop *submitLoop @@ -371,6 +376,10 @@ func (r *Runner) onServerStateUpdate(ctx context.Context, state *epb.ServerState } if r.storage.UpdateExploits(ctx, state.Exploits) { + if r.timeoutScaleTarget > 0 { + r.storage.ScaleTimeouts(r.maxJobs, len(r.teams), r.timeoutScaleTarget) + } + r.logger.Info("Exploits changed, scheduling loops restart") r.restartLoops() } @@ -400,7 +409,7 @@ func CreateExploitJobs( ex.Path, ex.Dir, environ, - ex.Timeout, + ex.ScaledTimeout, joblogger.New(ex.ID, ex.Version, ip, sender), )) } diff --git a/internal/exploit/storage.go b/internal/exploit/storage.go index 8206d9e..b2397cd 100644 --- a/internal/exploit/storage.go +++ b/internal/exploit/storage.go @@ -61,6 +61,46 @@ func (s *Storage) UpdateExploits(ctx context.Context, exs []*epb.ExploitState) b return true } +func (s *Storage) ScaleTimeouts(workers, teams int, target float64) { + // Alpha is a worker usage coefficient. + // For example, an exploit with timeout 10s and run every 20s + // Uses half of the worker's time for each team, so if teams = 4, + // exploit will use 2 full workers. + // Alpha in the case above will be 10/20 = 0.5 after the loop, + // if workers = 2 its final value will be 0.5 * 4 / 2 = 1, + // which means full worker utilization. + // If it's smaller, we could increase the timeouts, if larger -- + // decrease them proportionally to their original values. + // Target allows to specify the desired Alpha value, + // as in most cases exploits finish before timeout, + // and "safe" case with target = 1 leads to + // suboptimal worker utilization. + // NB 1: endless exploits are not scaled. + // NB 2: timeouts are rounded down to nearest second. + alpha := 0.0 + + for _, ex := range s.cache.Exploits() { + if ex.Endless { + continue + } + alpha += ex.Timeout.Seconds() / ex.RunEvery.Seconds() + } + alpha = alpha * float64(teams) / float64(workers) + logrus.Infof("Scaling timeouts: alpha = %.2f, target = %.2f", alpha, target) + for _, ex := range s.cache.Exploits() { + if ex.Endless { + continue + } + newTimeout := time.Duration(float64(ex.Timeout) * target / alpha) + + // Round down to nearest second. + newTimeout -= newTimeout % time.Second + + logrus.Infof("Scaling timeout for exploit %s: %s -> %s", ex.ID, ex.ScaledTimeout, newTimeout) + ex.ScaledTimeout = newTimeout + } +} + func (s *Storage) updateExploit(ctx context.Context, exploitID string) (*State, error) { // Download the current exploit state. resp, err := s.client.Exploit(ctx, exploitID) @@ -115,14 +155,15 @@ func (s *Storage) updateExploit(ctx context.Context, exploitID string) (*State, } res := &State{ - ID: state.ExploitId, - Version: state.Version, - Dir: "", - Path: entryPath, - Disabled: state.Config.Disabled, - Endless: state.Config.Endless, - RunEvery: state.Config.RunEvery.AsDuration(), - Timeout: state.Config.Timeout.AsDuration(), + ID: state.ExploitId, + Version: state.Version, + Dir: "", + Path: entryPath, + Disabled: state.Config.Disabled, + Endless: state.Config.Endless, + RunEvery: state.Config.RunEvery.AsDuration(), + ScaledTimeout: state.Config.Timeout.AsDuration(), + Timeout: state.Config.Timeout.AsDuration(), } if state.Config.IsArchive { res.Dir = oPath diff --git a/internal/exploit/storage_test.go b/internal/exploit/storage_test.go index 20aa70c..5fb8c86 100644 --- a/internal/exploit/storage_test.go +++ b/internal/exploit/storage_test.go @@ -7,6 +7,7 @@ import ( "path" "strings" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -168,3 +169,69 @@ func Test_prepareEntry(t *testing.T) { // Check that file is executable. require.NotZero(t, fi.Mode()&0111) } + +func TestStorage_Scale(t *testing.T) { + st, cleanup := mockStorage() + defer func() { + require.NoError(t, cleanup()) + }() + + // This exploit's timeout should be halved, as teams = 2 * workers. + st.cache.Update([]*State{ + { + ID: "1", + Version: 1, + RunEvery: time.Minute, + ScaledTimeout: time.Minute, + Timeout: time.Minute, + }, + }) + st.ScaleTimeouts(10, 20, 1) + + res, ok := st.Exploit("1") + require.True(t, ok) + require.EqualValues(t, 1, res.Version) + require.EqualValues(t, time.Minute, res.RunEvery) + require.EqualValues(t, 30*time.Second, res.ScaledTimeout) + + // Now it should be doubled, as workers = 2 * teams. + st.ScaleTimeouts(20, 10, 1) + + res, ok = st.Exploit("1") + require.True(t, ok) + require.EqualValues(t, time.Minute, res.RunEvery) + require.EqualValues(t, 2*time.Minute, res.ScaledTimeout) + + // Add another exploit, expect scale to work proportionally to original timeouts. + st.cache.Update([]*State{ + { + ID: "2", + Version: 1, + RunEvery: time.Minute, + ScaledTimeout: time.Minute, + Timeout: time.Minute, + }, + }) + st.ScaleTimeouts(20, 10, 1) + + res, ok = st.Exploit("1") + require.True(t, ok) + require.EqualValues(t, time.Minute, res.RunEvery) + require.EqualValues(t, time.Minute, res.ScaledTimeout) + + res, ok = st.Exploit("2") + require.True(t, ok) + require.EqualValues(t, time.Minute, res.RunEvery) + require.EqualValues(t, time.Minute, res.ScaledTimeout) + + // Scale with target = 2, expect exploit timeouts to scale up. + st.ScaleTimeouts(20, 10, 2) + + res, ok = st.Exploit("1") + require.True(t, ok) + require.EqualValues(t, 2*time.Minute, res.ScaledTimeout) + + res, ok = st.Exploit("2") + require.True(t, ok) + require.EqualValues(t, 2*time.Minute, res.ScaledTimeout) +} diff --git a/internal/exploit/submit_loop.go b/internal/exploit/submit_loop.go index 6f0d86a..852a0a2 100644 --- a/internal/exploit/submit_loop.go +++ b/internal/exploit/submit_loop.go @@ -126,6 +126,7 @@ func (l *submitLoop) Start(ctx context.Context) { } case <-t.C: flush() + l.metrics.Queue.WithLabelValues(string(l.q.Type())).Set(float64(l.q.Size())) case <-ctx.Done(): return } diff --git a/internal/exploit/submit_loop_test.go b/internal/exploit/submit_loop_test.go index 1554654..b45b7b6 100644 --- a/internal/exploit/submit_loop_test.go +++ b/internal/exploit/submit_loop_test.go @@ -330,6 +330,10 @@ func (m *mockQueue) Type() queue.Type { return "mock" } +func (m *mockQueue) Size() int { + return len(m.in) +} + func (m *mockQueue) Start(ctx context.Context) { <-ctx.Done() } diff --git a/internal/queue/endless.go b/internal/queue/endless.go index e0bff0e..73aea72 100644 --- a/internal/queue/endless.go +++ b/internal/queue/endless.go @@ -50,6 +50,10 @@ func (q *endlessQueue) Type() Type { return TypeEndless } +func (q *endlessQueue) Size() int { + return len(q.c) +} + // Start is synchronous. // Cancel the start's context to stop the queue. func (q *endlessQueue) Start(ctx context.Context) { diff --git a/internal/queue/queue.go b/internal/queue/queue.go index 59b84c2..7ff30ba 100644 --- a/internal/queue/queue.go +++ b/internal/queue/queue.go @@ -31,6 +31,7 @@ type Queue interface { Add(*Job) error Results() <-chan *Output Type() Type + Size() int fmt.Stringer } diff --git a/internal/queue/simple.go b/internal/queue/simple.go index 5c16dc1..28bf2a8 100644 --- a/internal/queue/simple.go +++ b/internal/queue/simple.go @@ -49,6 +49,10 @@ func (q *simpleQueue) Type() Type { return TypeSimple } +func (q *simpleQueue) Size() int { + return len(q.c) +} + // Start is synchronous. // Cancel the start's context to stop the queue. func (q *simpleQueue) Start(ctx context.Context) {