Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic hooks for testing #6938

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ VISIBILITY_DB ?= temporal_visibility
# Always use "protolegacy" tag to allow disabling utf-8 validation on proto messages
# during proto library transition.
ALL_BUILD_TAGS := protolegacy,$(BUILD_TAG)
ALL_TEST_TAGS := $(ALL_BUILD_TAGS),$(TEST_TAG)
ALL_TEST_TAGS := $(ALL_BUILD_TAGS),test_dep,$(TEST_TAG)
BUILD_TAG_FLAG := -tags $(ALL_BUILD_TAGS)
TEST_TAG_FLAG := -tags $(ALL_TEST_TAGS)

Expand Down Expand Up @@ -331,7 +331,7 @@ lint-actions: $(ACTIONLINT)

lint-code: $(GOLANGCI_LINT)
@printf $(COLOR) "Linting code..."
@$(GOLANGCI_LINT) run --verbose --timeout 10m --fix=$(GOLANGCI_LINT_FIX) --new-from-rev=$(GOLANGCI_LINT_BASE_REV) --config=.golangci.yml
@$(GOLANGCI_LINT) run --verbose --build-tags $(ALL_TEST_TAGS) --timeout 10m --fix=$(GOLANGCI_LINT_FIX) --new-from-rev=$(GOLANGCI_LINT_BASE_REV) --config=.golangci.yml

fmt-imports: $(GCI) # Don't get confused, there is a single linter called gci, which is a part of the mega linter we use is called golangci-lint.
@printf $(COLOR) "Formatting imports..."
Expand Down
9 changes: 5 additions & 4 deletions client/client_factory_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion client/clientfactory.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import (
"go.temporal.io/server/common/metrics"
"go.temporal.io/server/common/namespace"
"go.temporal.io/server/common/primitives"
"go.temporal.io/server/common/testing/testhooks"
"google.golang.org/grpc"
)

Expand All @@ -65,6 +66,7 @@ type (
monitor membership.Monitor,
metricsHandler metrics.Handler,
dc *dynamicconfig.Collection,
testHooks testhooks.TestHooks,
numberOfHistoryShards int32,
logger log.Logger,
throttledLogger log.Logger,
Expand All @@ -79,6 +81,7 @@ type (
monitor membership.Monitor
metricsHandler metrics.Handler
dynConfig *dynamicconfig.Collection
testHooks testhooks.TestHooks
numberOfHistoryShards int32
logger log.Logger
throttledLogger log.Logger
Expand All @@ -103,6 +106,7 @@ func (p *factoryProviderImpl) NewFactory(
monitor membership.Monitor,
metricsHandler metrics.Handler,
dc *dynamicconfig.Collection,
testHooks testhooks.TestHooks,
numberOfHistoryShards int32,
logger log.Logger,
throttledLogger log.Logger,
Expand All @@ -112,6 +116,7 @@ func (p *factoryProviderImpl) NewFactory(
monitor: monitor,
metricsHandler: metricsHandler,
dynConfig: dc,
testHooks: testHooks,
numberOfHistoryShards: numberOfHistoryShards,
logger: logger,
throttledLogger: throttledLogger,
Expand Down Expand Up @@ -159,7 +164,7 @@ func (cf *rpcClientFactory) NewMatchingClientWithTimeout(
common.NewClientCache(keyResolver, clientProvider),
cf.metricsHandler,
cf.logger,
matching.NewLoadBalancer(namespaceIDToName, cf.dynConfig),
matching.NewLoadBalancer(namespaceIDToName, cf.dynConfig, cf.testHooks),
)

if cf.metricsHandler != nil {
Expand Down
51 changes: 31 additions & 20 deletions client/matching/loadbalancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (

"go.temporal.io/server/common/dynamicconfig"
"go.temporal.io/server/common/namespace"
"go.temporal.io/server/common/testing/testhooks"
"go.temporal.io/server/common/tqid"
)

Expand Down Expand Up @@ -57,11 +58,10 @@ type (
}

defaultLoadBalancer struct {
namespaceIDToName func(id namespace.ID) (namespace.Name, error)
nReadPartitions dynamicconfig.IntPropertyFnWithTaskQueueFilter
nWritePartitions dynamicconfig.IntPropertyFnWithTaskQueueFilter
forceReadPartition dynamicconfig.IntPropertyFn
forceWritePartition dynamicconfig.IntPropertyFn
namespaceIDToName func(id namespace.ID) (namespace.Name, error)
nReadPartitions dynamicconfig.IntPropertyFnWithTaskQueueFilter
nWritePartitions dynamicconfig.IntPropertyFnWithTaskQueueFilter
testHooks testhooks.TestHooks

lock sync.RWMutex
taskQueueLBs map[tqid.TaskQueue]*tqLoadBalancer
Expand All @@ -85,23 +85,22 @@ type (
func NewLoadBalancer(
namespaceIDToName func(id namespace.ID) (namespace.Name, error),
dc *dynamicconfig.Collection,
testHooks testhooks.TestHooks,
) LoadBalancer {
lb := &defaultLoadBalancer{
namespaceIDToName: namespaceIDToName,
nReadPartitions: dynamicconfig.MatchingNumTaskqueueReadPartitions.Get(dc),
nWritePartitions: dynamicconfig.MatchingNumTaskqueueWritePartitions.Get(dc),
forceReadPartition: dynamicconfig.TestMatchingLBForceReadPartition.Get(dc),
forceWritePartition: dynamicconfig.TestMatchingLBForceWritePartition.Get(dc),
lock: sync.RWMutex{},
taskQueueLBs: make(map[tqid.TaskQueue]*tqLoadBalancer),
namespaceIDToName: namespaceIDToName,
nReadPartitions: dynamicconfig.MatchingNumTaskqueueReadPartitions.Get(dc),
nWritePartitions: dynamicconfig.MatchingNumTaskqueueWritePartitions.Get(dc),
testHooks: testHooks,
taskQueueLBs: make(map[tqid.TaskQueue]*tqLoadBalancer),
}
return lb
}

func (lb *defaultLoadBalancer) PickWritePartition(
taskQueue *tqid.TaskQueue,
) *tqid.NormalPartition {
if n := lb.forceWritePartition(); n >= 0 {
if n, ok := testhooks.Get[int](lb.testHooks, testhooks.MatchingLBForceWritePartition); ok {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did we discuss it?
I really dislike the idea of having this kind of dependency, and having this kind of code in the main code path.
I would suggest to extract functionality into something like "PartitionPicker", and provide different implementations in functional tests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're discussing it now...

  1. If you look at what it was doing before, it was abusing dynamic config to hook in here, so this is already a strict improvement (reduces runtime overhead, and makes it clearer that this is a hook for testing).

  2. We only want this hook in some tests, only some of the time. So even in tests, most of the time we want the standard behavior. So we'd need a test LoadBalancer that can be set/unset to a mode with fixed behavior, otherwise falls back to the default. I think that's worse:

    1. First it's just a lot more code.
    2. Second, that means the mechanism to poke the test implementation is specific to each object, and tests will have to do their own cleanup. This generic mechanism is simpler for test writers, you just s.InjectHook and it's automatically cleaned up.
  3. How do we do that for the other two examples here, forcing async match, and injecting a racing call in the middle of an update-with-start sequence? The alternative implementation method doesn't work there, as far as I can see.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another good use will be the case when test needs to continue only after specific code line is reached somewhere deep inside server (like in almost all Update tests, I need to make sure that Update actually reached the server and added to the registry before moving forward). In this case, hook can unblock the channel which is awaited in test code.

return taskQueue.NormalPartition(n)
}

Expand Down Expand Up @@ -130,7 +129,11 @@ func (lb *defaultLoadBalancer) PickReadPartition(
partitionCount = lb.nReadPartitions(string(namespaceName), taskQueue.Name(), taskQueue.TaskType())
}

return tqlb.pickReadPartition(partitionCount, lb.forceReadPartition())
if n, ok := testhooks.Get[int](lb.testHooks, testhooks.MatchingLBForceWritePartition); ok {
return tqlb.forceReadPartition(partitionCount, n)
}

return tqlb.pickReadPartition(partitionCount)
}

func (lb *defaultLoadBalancer) getTaskQueueLoadBalancer(tq *tqid.TaskQueue) *tqLoadBalancer {
Expand All @@ -157,18 +160,26 @@ func newTaskQueueLoadBalancer(tq *tqid.TaskQueue) *tqLoadBalancer {
}
}

func (b *tqLoadBalancer) pickReadPartition(partitionCount int, forcedPartition int) *pollToken {
func (b *tqLoadBalancer) pickReadPartition(partitionCount int) *pollToken {
b.lock.Lock()
defer b.lock.Unlock()

// ensure we reflect dynamic config change if it ever happens
b.ensurePartitionCountLocked(max(partitionCount, forcedPartition+1))
b.ensurePartitionCountLocked(partitionCount)
partitionID := b.pickReadPartitionWithFewestPolls(partitionCount)

partitionID := forcedPartition
b.pollerCounts[partitionID]++

if partitionID < 0 {
partitionID = b.pickReadPartitionWithFewestPolls(partitionCount)
return &pollToken{
TQPartition: b.taskQueue.NormalPartition(partitionID),
balancer: b,
}
}

func (b *tqLoadBalancer) forceReadPartition(partitionCount, partitionID int) *pollToken {
b.lock.Lock()
defer b.lock.Unlock()

b.ensurePartitionCountLocked(max(partitionCount, partitionID+1))

b.pollerCounts[partitionID]++

Expand Down
54 changes: 27 additions & 27 deletions client/matching/loadbalancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,22 @@ func TestTQLoadBalancer(t *testing.T) {
tqlb := newTaskQueueLoadBalancer(f.TaskQueue(enumspb.TASK_QUEUE_TYPE_ACTIVITY))

// pick 4 times, each partition picked would have one poller
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 1, maxPollerCount(tqlb))
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 1, maxPollerCount(tqlb))
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 1, maxPollerCount(tqlb))
p3 := tqlb.pickReadPartition(partitionCount, -1)
p3 := tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 1, maxPollerCount(tqlb))

// release one, and pick one, the newly picked one should have one poller
p3.Release()
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 1, maxPollerCount(tqlb))

// pick one again, this time it should have 2 pollers
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 2, maxPollerCount(tqlb))
}

Expand All @@ -89,27 +89,27 @@ func TestTQLoadBalancerForce(t *testing.T) {
tqlb := newTaskQueueLoadBalancer(f.TaskQueue(enumspb.TASK_QUEUE_TYPE_ACTIVITY))

// pick 4 times, each partition picked would have one poller
p1 := tqlb.pickReadPartition(partitionCount, 1)
p1 := tqlb.forceReadPartition(partitionCount, 1)
assert.Equal(t, 1, p1.TQPartition.PartitionId())
assert.Equal(t, 1, maxPollerCount(tqlb))
tqlb.pickReadPartition(partitionCount, 1)
tqlb.forceReadPartition(partitionCount, 1)
assert.Equal(t, 2, maxPollerCount(tqlb))

// when we don't force it should balance out
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount)
tqlb.pickReadPartition(partitionCount)
tqlb.pickReadPartition(partitionCount)
tqlb.pickReadPartition(partitionCount)
tqlb.pickReadPartition(partitionCount)
tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 2, maxPollerCount(tqlb))

// releasing the forced one and adding another should still be balanced
p1.Release()
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 2, maxPollerCount(tqlb))

tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 3, maxPollerCount(tqlb))
}

Expand All @@ -125,7 +125,7 @@ func TestLoadBalancerConcurrent(t *testing.T) {
for i := 0; i < concurrentCount; i++ {
go func() {
defer wg.Done()
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount)
}()
}
wg.Wait()
Expand All @@ -142,23 +142,23 @@ func TestLoadBalancer_ReducedPartitionCount(t *testing.T) {
f, err := tqid.NewTaskQueueFamily("fake-namespace-id", "fake-taskqueue")
assert.NoError(t, err)
tqlb := newTaskQueueLoadBalancer(f.TaskQueue(enumspb.TASK_QUEUE_TYPE_ACTIVITY))
p1 := tqlb.pickReadPartition(partitionCount, -1)
p2 := tqlb.pickReadPartition(partitionCount, -1)
p1 := tqlb.pickReadPartition(partitionCount)
p2 := tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 1, maxPollerCount(tqlb))
assert.Equal(t, 1, maxPollerCount(tqlb))

partitionCount += 2 // increase partition count
p3 := tqlb.pickReadPartition(partitionCount, -1)
p4 := tqlb.pickReadPartition(partitionCount, -1)
p3 := tqlb.pickReadPartition(partitionCount)
p4 := tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 1, maxPollerCount(tqlb))
assert.Equal(t, 1, maxPollerCount(tqlb))

partitionCount -= 2 // reduce partition count
p5 := tqlb.pickReadPartition(partitionCount, -1)
p6 := tqlb.pickReadPartition(partitionCount, -1)
p5 := tqlb.pickReadPartition(partitionCount)
p6 := tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 2, maxPollerCount(tqlb))
assert.Equal(t, 2, maxPollerCount(tqlb))
p7 := tqlb.pickReadPartition(partitionCount, -1)
p7 := tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 3, maxPollerCount(tqlb))

// release all of them and it should be ok.
Expand All @@ -170,11 +170,11 @@ func TestLoadBalancer_ReducedPartitionCount(t *testing.T) {
p6.Release()
p7.Release()

tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount)
tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 1, maxPollerCount(tqlb))
assert.Equal(t, 1, maxPollerCount(tqlb))
tqlb.pickReadPartition(partitionCount, -1)
tqlb.pickReadPartition(partitionCount)
assert.Equal(t, 2, maxPollerCount(tqlb))
}

Expand Down
17 changes: 0 additions & 17 deletions common/dynamicconfig/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -1234,23 +1234,6 @@ these log lines can be noisy, we want to be able to turn on and sample selective
1000,
`MatchingMaxTaskQueuesInDeployment represents the maximum number of task-queues that can be registed in a single deployment`,
)
// for matching testing only:

TestMatchingDisableSyncMatch = NewGlobalBoolSetting(
"test.matching.disableSyncMatch",
false,
`TestMatchingDisableSyncMatch forces tasks to go through the db once`,
)
TestMatchingLBForceReadPartition = NewGlobalIntSetting(
"test.matching.lbForceReadPartition",
-1,
`TestMatchingLBForceReadPartition forces polls to go to a specific partition`,
)
TestMatchingLBForceWritePartition = NewGlobalIntSetting(
"test.matching.lbForceWritePartition",
-1,
`TestMatchingLBForceWritePartition forces adds to go to a specific partition`,
)

// keys for history

Expand Down
4 changes: 4 additions & 0 deletions common/resource/fx.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ import (
"go.temporal.io/server/common/sdk"
"go.temporal.io/server/common/searchattribute"
"go.temporal.io/server/common/telemetry"
"go.temporal.io/server/common/testing/testhooks"
"go.temporal.io/server/common/utf8validator"
"go.uber.org/fx"
"google.golang.org/grpc"
Expand Down Expand Up @@ -129,6 +130,7 @@ var Module = fx.Options(
deadlock.Module,
config.Module,
utf8validator.Module,
testhooks.Module,
fx.Invoke(func(*utf8validator.Validator) {}), // force this to be constructed even if not referenced elsewhere
)

Expand Down Expand Up @@ -227,6 +229,7 @@ func ClientFactoryProvider(
membershipMonitor membership.Monitor,
metricsHandler metrics.Handler,
dynamicCollection *dynamicconfig.Collection,
testHooks testhooks.TestHooks,
persistenceConfig *config.Persistence,
logger log.SnTaggedLogger,
throttledLogger log.ThrottledLogger,
Expand All @@ -236,6 +239,7 @@ func ClientFactoryProvider(
membershipMonitor,
metricsHandler,
dynamicCollection,
testHooks,
persistenceConfig.NumHistoryShards,
logger,
throttledLogger,
Expand Down
Loading
Loading