diff --git a/pkg/rsqueue/queue/queue.go b/pkg/rsqueue/queue/queue.go index bc76dc9..5b4d8a7 100644 --- a/pkg/rsqueue/queue/queue.go +++ b/pkg/rsqueue/queue/queue.go @@ -81,11 +81,17 @@ type Queue interface { type QueueSupportedTypes interface { Enabled() []uint64 SetEnabled(typeId uint64, enabled bool) + SetEnabledConditional(typeId uint64, enabled func() bool) DisableAll() } +type Enabled struct { + Always bool + Conditional func() bool +} + type DefaultQueueSupportedTypes struct { - types map[uint64]bool + types map[uint64]Enabled mutex sync.RWMutex } @@ -94,7 +100,7 @@ func (d *DefaultQueueSupportedTypes) Enabled() []uint64 { defer d.mutex.RUnlock() results := make([]uint64, 0) for i, enabled := range d.types { - if enabled { + if enabled.Always || (enabled.Conditional != nil && enabled.Conditional()) { results = append(results, i) } } @@ -103,18 +109,27 @@ func (d *DefaultQueueSupportedTypes) Enabled() []uint64 { func (d *DefaultQueueSupportedTypes) SetEnabled(typeId uint64, enabled bool) { if d.types == nil { - d.types = make(map[uint64]bool) + d.types = make(map[uint64]Enabled) + } + d.mutex.Lock() + defer d.mutex.Unlock() + d.types[typeId] = Enabled{Always: enabled} +} + +func (d *DefaultQueueSupportedTypes) SetEnabledConditional(typeId uint64, enabled func() bool) { + if d.types == nil { + d.types = make(map[uint64]Enabled) } d.mutex.Lock() defer d.mutex.Unlock() - d.types[typeId] = enabled + d.types[typeId] = Enabled{Conditional: enabled} } func (d *DefaultQueueSupportedTypes) DisableAll() { d.mutex.Lock() defer d.mutex.Unlock() for i := range d.types { - d.types[i] = false + d.types[i] = Enabled{Always: false} } } diff --git a/pkg/rsqueue/runnerfactory/runnerfactory.go b/pkg/rsqueue/runnerfactory/runnerfactory.go index e8133fb..a343d34 100644 --- a/pkg/rsqueue/runnerfactory/runnerfactory.go +++ b/pkg/rsqueue/runnerfactory/runnerfactory.go @@ -32,17 +32,24 @@ func (r *RunnerFactory) Add(workType uint64, runner queue.WorkRunner) { r.types.SetEnabled(workType, true) } +func (r *RunnerFactory) AddConditional(workType uint64, enabled func() bool, runner queue.WorkRunner) { + r.runners[workType] = runner + r.types.SetEnabledConditional(workType, enabled) +} + +// Run runs work if the work type is configured. Note that this doesn't check to +// see if the work type is enabled (in r.types). func (r *RunnerFactory) Run(work queue.RecursableWork) error { runner, ok := r.runners[work.WorkType] if !ok { - return fmt.Errorf("Invalid work type %d", work.WorkType) + return fmt.Errorf("invalid work type %d", work.WorkType) } return runner.Run(work) } -// Stops all the runners in the factory. After each runner is stopped, +// Stop all the runners in the factory. After each runner is stopped, // it is marked as disabled so that we won't attempt to grab future // work for that runner from the queue. func (r *RunnerFactory) Stop(timeout time.Duration) error { diff --git a/pkg/rsqueue/runnerfactory/runnerfactory_test.go b/pkg/rsqueue/runnerfactory/runnerfactory_test.go index 10284be..063cb50 100644 --- a/pkg/rsqueue/runnerfactory/runnerfactory_test.go +++ b/pkg/rsqueue/runnerfactory/runnerfactory_test.go @@ -93,7 +93,31 @@ func (s *RunnerFactorySuite) TestNewRunner(c *check.C) { Work: []byte{}, WorkType: 2, }) - c.Assert(err, check.ErrorMatches, "Invalid work type 2") + c.Assert(err, check.ErrorMatches, "invalid work type 2") +} + +func (s *RunnerFactorySuite) TestRunnerConditional(c *check.C) { + types := &queue.DefaultQueueSupportedTypes{} + r := NewRunnerFactory(RunnerFactoryConfig{SupportedTypes: types}) + c.Check(r, check.DeepEquals, &RunnerFactory{ + runners: make(map[uint64]queue.WorkRunner), + types: types, + }) + + yes := func() bool { + return true + } + no := func() bool { + return false + } + + // Add a runner + fr1 := &FakeRunnerOne{} + fr2 := &FakeRunnerTwo{} + r.AddConditional(0, yes, fr1) + r.AddConditional(1, no, fr2) + c.Check(r.runners, check.HasLen, 2) + c.Check(r.types.Enabled(), check.DeepEquals, []uint64{0}) } func (s *RunnerFactorySuite) TestStop(c *check.C) {