From 28626dc6523ab80e07b6c6f31790f053f24532c8 Mon Sep 17 00:00:00 2001 From: David Reiss Date: Thu, 5 Dec 2024 14:43:07 +0000 Subject: [PATCH 01/13] initial --- Makefile | 2 +- common/errorinjector/noop_impl.go | 22 +++++++++++++ common/errorinjector/test_impl.go | 51 +++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 1 deletion(-) create mode 100644 common/errorinjector/noop_impl.go create mode 100644 common/errorinjector/test_impl.go diff --git a/Makefile b/Makefile index b11a5cf07ba..3a0a4db71ba 100644 --- a/Makefile +++ b/Makefile @@ -41,7 +41,7 @@ VISIBILITY_DB ?= temporal_visibility # Always use "protolegacy" tag to allow disabling utf-8 validation on proto messages # during proto library transition. ALL_BUILD_TAGS := protolegacy,$(BUILD_TAG) -ALL_TEST_TAGS := $(ALL_BUILD_TAGS),$(TEST_TAG) +ALL_TEST_TAGS := $(ALL_BUILD_TAGS),errorinjector,$(TEST_TAG) BUILD_TAG_FLAG := -tags $(ALL_BUILD_TAGS) TEST_TAG_FLAG := -tags $(ALL_TEST_TAGS) diff --git a/common/errorinjector/noop_impl.go b/common/errorinjector/noop_impl.go new file mode 100644 index 00000000000..f9a2de4fb30 --- /dev/null +++ b/common/errorinjector/noop_impl.go @@ -0,0 +1,22 @@ +//go:build !errorinjector + +package errorinjector + +import "go.uber.org/fx" + +var Module = fx.Options( + fx.Provide(func() ErrorInjector { return nil }), +) + +type ( + ErrorInjector interface { + } +) + +func Get[T any](ei ErrorInjector, key string) (T, bool) { + var zero T + return zero, false +} + +func Set[T any](ei ErrorInjector, key string, val T) { +} diff --git a/common/errorinjector/test_impl.go b/common/errorinjector/test_impl.go new file mode 100644 index 00000000000..3ec3f1460da --- /dev/null +++ b/common/errorinjector/test_impl.go @@ -0,0 +1,51 @@ +//go:build errorinjector + +package errorinjector + +import ( + "sync" + + "go.uber.org/fx" +) + +var Module = fx.Options( + fx.Provide(func() ErrorInjector { return newTestErrorInjector() }), +) + +type ( + ErrorInjector interface { + // private accessors; access must go through package-level Get/Set + get(string) (any, bool) + set(string, any) + } + + errorInjectorImpl struct { + m sync.Map + } +) + +func Get[T any](ei ErrorInjector, key string) (T, bool) { + if val, ok := ei.get(key); ok { + // this is only used in test so we want to panic on type mismatch: + return val.(T), ok + } + var zero T + return zero, false +} + +func Set[T any](ei ErrorInjector, key string, val T) { + ei.set(key, val) +} + +func newTestErrorInjector() *errorInjectorImpl { + return &errorInjectorImpl{} +} + +func (ei *errorInjectorImpl) get(key string) (any, bool) { + val, ok := ei.m.Load(key) + return val, ok +} + +func (ei *errorInjectorImpl) set(key string, val any) { + ei.m.Store(key, val) +} From 45d620d1bd2b1f432637f6c2c1a078c1992986d8 Mon Sep 17 00:00:00 2001 From: David Reiss Date: Thu, 5 Dec 2024 15:07:31 +0000 Subject: [PATCH 02/13] use in matching --- client/clientfactory.go | 7 +++- client/matching/loadbalancer.go | 41 +++++++++---------- common/dynamicconfig/constants.go | 17 -------- common/errorinjector/constants.go | 7 ++++ common/errorinjector/test_impl.go | 12 ++++-- common/resource/fx.go | 4 ++ service/matching/config.go | 4 -- service/matching/handler.go | 3 ++ service/matching/matching_engine.go | 4 ++ .../matching/physical_task_queue_manager.go | 3 +- tests/testcore/functional_test_base.go | 13 ++++-- tests/testcore/onebox.go | 22 ++++++++-- 12 files changed, 83 insertions(+), 54 deletions(-) create mode 100644 common/errorinjector/constants.go diff --git a/client/clientfactory.go b/client/clientfactory.go index e3c7838378e..7c3eb974f81 100644 --- a/client/clientfactory.go +++ b/client/clientfactory.go @@ -39,6 +39,7 @@ import ( "go.temporal.io/server/client/matching" "go.temporal.io/server/common" "go.temporal.io/server/common/dynamicconfig" + "go.temporal.io/server/common/errorinjector" "go.temporal.io/server/common/log" "go.temporal.io/server/common/membership" "go.temporal.io/server/common/metrics" @@ -65,6 +66,7 @@ type ( monitor membership.Monitor, metricsHandler metrics.Handler, dc *dynamicconfig.Collection, + errorInjector errorinjector.ErrorInjector, numberOfHistoryShards int32, logger log.Logger, throttledLogger log.Logger, @@ -79,6 +81,7 @@ type ( monitor membership.Monitor metricsHandler metrics.Handler dynConfig *dynamicconfig.Collection + errorInjector errorinjector.ErrorInjector numberOfHistoryShards int32 logger log.Logger throttledLogger log.Logger @@ -103,6 +106,7 @@ func (p *factoryProviderImpl) NewFactory( monitor membership.Monitor, metricsHandler metrics.Handler, dc *dynamicconfig.Collection, + errorInjector errorinjector.ErrorInjector, numberOfHistoryShards int32, logger log.Logger, throttledLogger log.Logger, @@ -112,6 +116,7 @@ func (p *factoryProviderImpl) NewFactory( monitor: monitor, metricsHandler: metricsHandler, dynConfig: dc, + errorInjector: errorInjector, numberOfHistoryShards: numberOfHistoryShards, logger: logger, throttledLogger: throttledLogger, @@ -159,7 +164,7 @@ func (cf *rpcClientFactory) NewMatchingClientWithTimeout( common.NewClientCache(keyResolver, clientProvider), cf.metricsHandler, cf.logger, - matching.NewLoadBalancer(namespaceIDToName, cf.dynConfig), + matching.NewLoadBalancer(namespaceIDToName, cf.dynConfig, cf.errorInjector), ) if cf.metricsHandler != nil { diff --git a/client/matching/loadbalancer.go b/client/matching/loadbalancer.go index ddc7095316e..32de7362284 100644 --- a/client/matching/loadbalancer.go +++ b/client/matching/loadbalancer.go @@ -29,6 +29,7 @@ import ( "sync" "go.temporal.io/server/common/dynamicconfig" + "go.temporal.io/server/common/errorinjector" "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/tqid" ) @@ -57,11 +58,10 @@ type ( } defaultLoadBalancer struct { - namespaceIDToName func(id namespace.ID) (namespace.Name, error) - nReadPartitions dynamicconfig.IntPropertyFnWithTaskQueueFilter - nWritePartitions dynamicconfig.IntPropertyFnWithTaskQueueFilter - forceReadPartition dynamicconfig.IntPropertyFn - forceWritePartition dynamicconfig.IntPropertyFn + namespaceIDToName func(id namespace.ID) (namespace.Name, error) + nReadPartitions dynamicconfig.IntPropertyFnWithTaskQueueFilter + nWritePartitions dynamicconfig.IntPropertyFnWithTaskQueueFilter + errorInjector errorinjector.ErrorInjector lock sync.RWMutex taskQueueLBs map[tqid.TaskQueue]*tqLoadBalancer @@ -85,15 +85,14 @@ type ( func NewLoadBalancer( namespaceIDToName func(id namespace.ID) (namespace.Name, error), dc *dynamicconfig.Collection, + errorInjector errorinjector.ErrorInjector, ) LoadBalancer { lb := &defaultLoadBalancer{ - namespaceIDToName: namespaceIDToName, - nReadPartitions: dynamicconfig.MatchingNumTaskqueueReadPartitions.Get(dc), - nWritePartitions: dynamicconfig.MatchingNumTaskqueueWritePartitions.Get(dc), - forceReadPartition: dynamicconfig.TestMatchingLBForceReadPartition.Get(dc), - forceWritePartition: dynamicconfig.TestMatchingLBForceWritePartition.Get(dc), - lock: sync.RWMutex{}, - taskQueueLBs: make(map[tqid.TaskQueue]*tqLoadBalancer), + namespaceIDToName: namespaceIDToName, + nReadPartitions: dynamicconfig.MatchingNumTaskqueueReadPartitions.Get(dc), + nWritePartitions: dynamicconfig.MatchingNumTaskqueueWritePartitions.Get(dc), + errorInjector: errorInjector, + taskQueueLBs: make(map[tqid.TaskQueue]*tqLoadBalancer), } return lb } @@ -101,7 +100,7 @@ func NewLoadBalancer( func (lb *defaultLoadBalancer) PickWritePartition( taskQueue *tqid.TaskQueue, ) *tqid.NormalPartition { - if n := lb.forceWritePartition(); n >= 0 { + if n, ok := errorinjector.Get[int](lb.errorInjector, errorinjector.MatchingLBForceWritePartition); ok { return taskQueue.NormalPartition(n) } @@ -130,7 +129,7 @@ func (lb *defaultLoadBalancer) PickReadPartition( partitionCount = lb.nReadPartitions(string(namespaceName), taskQueue.Name(), taskQueue.TaskType()) } - return tqlb.pickReadPartition(partitionCount, lb.forceReadPartition()) + return tqlb.pickReadPartition(partitionCount, lb.errorInjector) } func (lb *defaultLoadBalancer) getTaskQueueLoadBalancer(tq *tqid.TaskQueue) *tqLoadBalancer { @@ -157,16 +156,16 @@ func newTaskQueueLoadBalancer(tq *tqid.TaskQueue) *tqLoadBalancer { } } -func (b *tqLoadBalancer) pickReadPartition(partitionCount int, forcedPartition int) *pollToken { +func (b *tqLoadBalancer) pickReadPartition(partitionCount int, ei errorinjector.ErrorInjector) *pollToken { b.lock.Lock() defer b.lock.Unlock() - // ensure we reflect dynamic config change if it ever happens - b.ensurePartitionCountLocked(max(partitionCount, forcedPartition+1)) - - partitionID := forcedPartition - - if partitionID < 0 { + var partitionID int + if n, ok := errorinjector.Get[int](ei, errorinjector.MatchingLBForceWritePartition); ok { + b.ensurePartitionCountLocked(max(partitionCount, n+1)) // allow n to be >= partitionCount + partitionID = n + } else { + b.ensurePartitionCountLocked(partitionCount) partitionID = b.pickReadPartitionWithFewestPolls(partitionCount) } diff --git a/common/dynamicconfig/constants.go b/common/dynamicconfig/constants.go index 780c014f02b..83b9907ce7e 100644 --- a/common/dynamicconfig/constants.go +++ b/common/dynamicconfig/constants.go @@ -1234,23 +1234,6 @@ these log lines can be noisy, we want to be able to turn on and sample selective 1000, `MatchingMaxTaskQueuesInDeployment represents the maximum number of task-queues that can be registed in a single deployment`, ) - // for matching testing only: - - TestMatchingDisableSyncMatch = NewGlobalBoolSetting( - "test.matching.disableSyncMatch", - false, - `TestMatchingDisableSyncMatch forces tasks to go through the db once`, - ) - TestMatchingLBForceReadPartition = NewGlobalIntSetting( - "test.matching.lbForceReadPartition", - -1, - `TestMatchingLBForceReadPartition forces polls to go to a specific partition`, - ) - TestMatchingLBForceWritePartition = NewGlobalIntSetting( - "test.matching.lbForceWritePartition", - -1, - `TestMatchingLBForceWritePartition forces adds to go to a specific partition`, - ) // keys for history diff --git a/common/errorinjector/constants.go b/common/errorinjector/constants.go new file mode 100644 index 00000000000..0c26955f375 --- /dev/null +++ b/common/errorinjector/constants.go @@ -0,0 +1,7 @@ +package errorinjector + +const ( + MatchingDisableSyncMatch = "matching.disableSyncMatch" + MatchingLBForceReadPartition = "matching.lbForceReadPartition" + MatchingLBForceWritePartition = "matching.lbForceWritePartition" +) diff --git a/common/errorinjector/test_impl.go b/common/errorinjector/test_impl.go index 3ec3f1460da..8289ff7a77d 100644 --- a/common/errorinjector/test_impl.go +++ b/common/errorinjector/test_impl.go @@ -9,7 +9,7 @@ import ( ) var Module = fx.Options( - fx.Provide(func() ErrorInjector { return newTestErrorInjector() }), + fx.Provide(NewTestErrorInjector), ) type ( @@ -17,6 +17,7 @@ type ( // private accessors; access must go through package-level Get/Set get(string) (any, bool) set(string, any) + del(string) } errorInjectorImpl struct { @@ -33,11 +34,12 @@ func Get[T any](ei ErrorInjector, key string) (T, bool) { return zero, false } -func Set[T any](ei ErrorInjector, key string, val T) { +func Set[T any](ei ErrorInjector, key string, val T) func() { ei.set(key, val) + return func() { ei.del(key) } } -func newTestErrorInjector() *errorInjectorImpl { +func NewTestErrorInjector() ErrorInjector { return &errorInjectorImpl{} } @@ -49,3 +51,7 @@ func (ei *errorInjectorImpl) get(key string) (any, bool) { func (ei *errorInjectorImpl) set(key string, val any) { ei.m.Store(key, val) } + +func (ei *errorInjectorImpl) del(key string) { + ei.m.Delete(key) +} diff --git a/common/resource/fx.go b/common/resource/fx.go index 117d661a469..662b3e2b16c 100644 --- a/common/resource/fx.go +++ b/common/resource/fx.go @@ -46,6 +46,7 @@ import ( "go.temporal.io/server/common/config" "go.temporal.io/server/common/deadlock" "go.temporal.io/server/common/dynamicconfig" + "go.temporal.io/server/common/errorinjector" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/membership" @@ -129,6 +130,7 @@ var Module = fx.Options( deadlock.Module, config.Module, utf8validator.Module, + errorinjector.Module, fx.Invoke(func(*utf8validator.Validator) {}), // force this to be constructed even if not referenced elsewhere ) @@ -227,6 +229,7 @@ func ClientFactoryProvider( membershipMonitor membership.Monitor, metricsHandler metrics.Handler, dynamicCollection *dynamicconfig.Collection, + errorInjector errorinjector.ErrorInjector, persistenceConfig *config.Persistence, logger log.SnTaggedLogger, throttledLogger log.ThrottledLogger, @@ -236,6 +239,7 @@ func ClientFactoryProvider( membershipMonitor, metricsHandler, dynamicCollection, + errorInjector, persistenceConfig.NumHistoryShards, logger, throttledLogger, diff --git a/service/matching/config.go b/service/matching/config.go index b155bb03475..476004890f7 100644 --- a/service/matching/config.go +++ b/service/matching/config.go @@ -47,7 +47,6 @@ type ( PersistenceDynamicRateLimitingParams dynamicconfig.TypedPropertyFn[dynamicconfig.DynamicRateLimitingParams] PersistenceQPSBurstRatio dynamicconfig.FloatPropertyFn SyncMatchWaitDuration dynamicconfig.DurationPropertyFnWithTaskQueueFilter - TestDisableSyncMatch dynamicconfig.BoolPropertyFn RPS dynamicconfig.IntPropertyFn OperatorRPSRatio dynamicconfig.FloatPropertyFn AlignMembershipChange dynamicconfig.DurationPropertyFn @@ -132,7 +131,6 @@ type ( BacklogNegligibleAge func() time.Duration MaxWaitForPollerBeforeFwd func() time.Duration QueryPollerUnavailableWindow func() time.Duration - TestDisableSyncMatch func() bool // Time to hold a poll request before returning an empty response if there are no tasks LongPollExpirationInterval func() time.Duration RangeSize int64 @@ -211,7 +209,6 @@ func NewConfig( PersistenceDynamicRateLimitingParams: dynamicconfig.MatchingPersistenceDynamicRateLimitingParams.Get(dc), PersistenceQPSBurstRatio: dynamicconfig.PersistenceQPSBurstRatio.Get(dc), SyncMatchWaitDuration: dynamicconfig.MatchingSyncMatchWaitDuration.Get(dc), - TestDisableSyncMatch: dynamicconfig.TestMatchingDisableSyncMatch.Get(dc), LoadUserData: dynamicconfig.MatchingLoadUserData.Get(dc), HistoryMaxPageSize: dynamicconfig.MatchingHistoryMaxPageSize.Get(dc), EnableDeployments: dynamicconfig.EnableDeployments.Get(dc), @@ -303,7 +300,6 @@ func newTaskQueueConfig(tq *tqid.TaskQueue, config *Config, ns namespace.Name) * return config.MaxWaitForPollerBeforeFwd(ns.String(), taskQueueName, taskType) }, QueryPollerUnavailableWindow: config.QueryPollerUnavailableWindow, - TestDisableSyncMatch: config.TestDisableSyncMatch, LongPollExpirationInterval: func() time.Duration { return config.LongPollExpirationInterval(ns.String(), taskQueueName, taskType) }, diff --git a/service/matching/handler.go b/service/matching/handler.go index 0b9bc7f641d..6bb2a404e80 100644 --- a/service/matching/handler.go +++ b/service/matching/handler.go @@ -35,6 +35,7 @@ import ( "go.temporal.io/server/api/matchingservice/v1" "go.temporal.io/server/common" "go.temporal.io/server/common/cluster" + "go.temporal.io/server/common/errorinjector" "go.temporal.io/server/common/log" "go.temporal.io/server/common/membership" "go.temporal.io/server/common/metrics" @@ -88,6 +89,7 @@ func NewHandler( namespaceReplicationQueue persistence.NamespaceReplicationQueue, visibilityManager manager.VisibilityManager, nexusEndpointManager persistence.NexusEndpointManager, + errorInjector errorinjector.ErrorInjector, ) *Handler { handler := &Handler{ config: config, @@ -110,6 +112,7 @@ func NewHandler( namespaceReplicationQueue, visibilityManager, nexusEndpointManager, + errorInjector, ), namespaceRegistry: namespaceRegistry, } diff --git a/service/matching/matching_engine.go b/service/matching/matching_engine.go index affa38585e8..dc4ff71e6cc 100644 --- a/service/matching/matching_engine.go +++ b/service/matching/matching_engine.go @@ -56,6 +56,7 @@ import ( hlc "go.temporal.io/server/common/clock/hybrid_logical_clock" "go.temporal.io/server/common/cluster" "go.temporal.io/server/common/collection" + "go.temporal.io/server/common/errorinjector" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/membership" @@ -143,6 +144,7 @@ type ( partitions map[tqid.PartitionKey]taskQueuePartitionManager gaugeMetrics gaugeMetrics // per-namespace task queue counters config *Config + errorInjector errorinjector.ErrorInjector // queryResults maps query TaskID (which is a UUID generated in QueryWorkflow() call) to a channel // that QueryWorkflow() will block on. The channel is unblocked either by worker sending response through // RespondQueryTaskCompleted() or through an internal service error causing temporal to be unable to dispatch @@ -203,6 +205,7 @@ func NewEngine( namespaceReplicationQueue persistence.NamespaceReplicationQueue, visibilityManager manager.VisibilityManager, nexusEndpointManager persistence.NexusEndpointManager, + errorInjector errorinjector.ErrorInjector, ) Engine { scopedMetricsHandler := metricsHandler.WithTags(metrics.OperationTag(metrics.MatchingEngineScope)) e := &matchingEngineImpl{ @@ -233,6 +236,7 @@ func NewEngine( loadedPhysicalTaskQueueCount: make(map[taskQueueCounterKey]int), }, config: config, + errorInjector: errorInjector, queryResults: collection.NewSyncMap[string, chan *queryResult](), nexusResults: collection.NewSyncMap[string, chan *nexusResult](), outstandingPollers: collection.NewSyncMap[string, context.CancelFunc](), diff --git a/service/matching/physical_task_queue_manager.go b/service/matching/physical_task_queue_manager.go index 9e9c31991c2..91df31a5d5e 100644 --- a/service/matching/physical_task_queue_manager.go +++ b/service/matching/physical_task_queue_manager.go @@ -46,6 +46,7 @@ import ( "go.temporal.io/server/common/clock" "go.temporal.io/server/common/cluster" "go.temporal.io/server/common/debug" + "go.temporal.io/server/common/errorinjector" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/metrics" @@ -529,7 +530,7 @@ func (c *physicalTaskQueueManagerImpl) TrySyncMatch(ctx context.Context, task *i // request sent by history service c.liveness.markAlive() c.tasksAddedInIntervals.incrementTaskCount() - if c.config.TestDisableSyncMatch() { + if disable, _ := errorinjector.Get[bool](c.partitionMgr.engine.errorInjector, errorinjector.MatchingDisableSyncMatch); disable { return false, nil } } diff --git a/tests/testcore/functional_test_base.go b/tests/testcore/functional_test_base.go index 95d640cfaca..87f37d3eaf0 100644 --- a/tests/testcore/functional_test_base.go +++ b/tests/testcore/functional_test_base.go @@ -48,6 +48,7 @@ import ( persistencespb "go.temporal.io/server/api/persistence/v1" "go.temporal.io/server/common" "go.temporal.io/server/common/dynamicconfig" + "go.temporal.io/server/common/errorinjector" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/payloads" @@ -478,6 +479,10 @@ func (s *FunctionalTestBase) OverrideDynamicConfig(setting dynamicconfig.Generic return s.testCluster.host.overrideDynamicConfig(s.T(), setting.Key(), value) } +func (s *FunctionalTestBase) InjectError(key string, value any) (cleanup func()) { + return s.testCluster.host.injectError(s.T(), key, value) +} + func (s *FunctionalTestBase) GetNamespaceID(namespace string) string { namespaceResp, err := s.FrontendClient().DescribeNamespace(NewContext(), &workflowservice.DescribeNamespaceRequest{ Namespace: namespace, @@ -510,20 +515,20 @@ func (s *FunctionalTestBase) RunTestWithMatchingBehavior(subtest func()) { name, func() { if forceTaskForward { s.OverrideDynamicConfig(dynamicconfig.MatchingNumTaskqueueWritePartitions, 13) - s.OverrideDynamicConfig(dynamicconfig.TestMatchingLBForceWritePartition, 11) + s.InjectError(errorinjector.MatchingLBForceWritePartition, 11) } else { s.OverrideDynamicConfig(dynamicconfig.MatchingNumTaskqueueWritePartitions, 1) } if forcePollForward { s.OverrideDynamicConfig(dynamicconfig.MatchingNumTaskqueueReadPartitions, 13) - s.OverrideDynamicConfig(dynamicconfig.TestMatchingLBForceReadPartition, 5) + s.InjectError(errorinjector.MatchingLBForceReadPartition, 5) } else { s.OverrideDynamicConfig(dynamicconfig.MatchingNumTaskqueueReadPartitions, 1) } if forceAsync { - s.OverrideDynamicConfig(dynamicconfig.TestMatchingDisableSyncMatch, true) + s.InjectError(errorinjector.MatchingDisableSyncMatch, true) } else { - s.OverrideDynamicConfig(dynamicconfig.TestMatchingDisableSyncMatch, false) + s.InjectError(errorinjector.MatchingDisableSyncMatch, false) } subtest() diff --git a/tests/testcore/onebox.go b/tests/testcore/onebox.go index 95a5f1f9be3..ffc29bad059 100644 --- a/tests/testcore/onebox.go +++ b/tests/testcore/onebox.go @@ -52,6 +52,7 @@ import ( "go.temporal.io/server/common/cluster" "go.temporal.io/server/common/config" "go.temporal.io/server/common/dynamicconfig" + "go.temporal.io/server/common/errorinjector" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/membership" @@ -99,6 +100,7 @@ type ( matchingClient matchingservice.MatchingServiceClient dcClient *dynamicconfig.MemoryClient + errorInjector errorinjector.ErrorInjector logger log.Logger clusterMetadataConfig *cluster.Config persistenceConfig config.Persistence @@ -258,9 +260,11 @@ func newTemporal(t *testing.T, params *TemporalParams) *TemporalImpl { tlsConfigProvider: params.TLSConfigProvider, captureMetricsHandler: params.CaptureMetricsHandler, dcClient: dynamicconfig.NewMemoryClient(), - serviceFxOptions: params.ServiceFxOptions, - taskCategoryRegistry: params.TaskCategoryRegistry, - hostsByProtocolByService: params.HostsByProtocolByService, + // If this doesn't build, make sure you're building with tags 'errorinjector': + errorInjector: errorinjector.NewTestErrorInjector(), + serviceFxOptions: params.ServiceFxOptions, + taskCategoryRegistry: params.TaskCategoryRegistry, + hostsByProtocolByService: params.HostsByProtocolByService, } for k, v := range staticOverrides { @@ -410,6 +414,7 @@ func (c *TemporalImpl) startFrontend() { fx.Provide(func() persistenceClient.AbstractDataStoreFactory { return c.abstractDataStoreFactory }), fx.Provide(func() visibility.VisibilityStoreFactory { return c.visibilityStoreFactory }), fx.Provide(func() dynamicconfig.Client { return c.dcClient }), + fx.Decorate(func() errorinjector.ErrorInjector { return c.errorInjector }), fx.Provide(resource.DefaultSnTaggedLoggerProvider), fx.Provide(func() *esclient.Config { return c.esConfig }), fx.Provide(func() esclient.Client { return c.esClient }), @@ -481,6 +486,7 @@ func (c *TemporalImpl) startHistory() { fx.Provide(func() persistenceClient.AbstractDataStoreFactory { return c.abstractDataStoreFactory }), fx.Provide(func() visibility.VisibilityStoreFactory { return c.visibilityStoreFactory }), fx.Provide(func() dynamicconfig.Client { return c.dcClient }), + fx.Decorate(func() errorinjector.ErrorInjector { return c.errorInjector }), fx.Provide(resource.DefaultSnTaggedLoggerProvider), fx.Provide(func() *esclient.Config { return c.esConfig }), fx.Provide(func() esclient.Client { return c.esClient }), @@ -534,6 +540,7 @@ func (c *TemporalImpl) startMatching() { fx.Provide(func() persistenceClient.AbstractDataStoreFactory { return c.abstractDataStoreFactory }), fx.Provide(func() visibility.VisibilityStoreFactory { return c.visibilityStoreFactory }), fx.Provide(func() dynamicconfig.Client { return c.dcClient }), + fx.Decorate(func() errorinjector.ErrorInjector { return c.errorInjector }), fx.Provide(func() *esclient.Config { return c.esConfig }), fx.Provide(func() esclient.Client { return c.esClient }), fx.Provide(c.GetTLSConfigProvider), @@ -597,6 +604,7 @@ func (c *TemporalImpl) startWorker() { fx.Provide(func() persistenceClient.AbstractDataStoreFactory { return c.abstractDataStoreFactory }), fx.Provide(func() visibility.VisibilityStoreFactory { return c.visibilityStoreFactory }), fx.Provide(func() dynamicconfig.Client { return c.dcClient }), + fx.Decorate(func() errorinjector.ErrorInjector { return c.errorInjector }), fx.Provide(resource.DefaultSnTaggedLoggerProvider), fx.Provide(func() esclient.Client { return c.esClient }), fx.Provide(func() *esclient.Config { return c.esConfig }), @@ -772,6 +780,7 @@ func (p *clientFactoryProvider) NewFactory( monitor membership.Monitor, metricsHandler metrics.Handler, dc *dynamicconfig.Collection, + errorInjector errorinjector.ErrorInjector, numberOfHistoryShards int32, logger log.Logger, throttledLogger log.Logger, @@ -781,6 +790,7 @@ func (p *clientFactoryProvider) NewFactory( monitor, metricsHandler, dc, + errorInjector, numberOfHistoryShards, logger, throttledLogger, @@ -894,6 +904,12 @@ func (c *TemporalImpl) overrideDynamicConfig(t *testing.T, name dynamicconfig.Ke return cleanup } +func (c *TemporalImpl) injectError(t *testing.T, key string, value any) func() { + cleanup := errorinjector.Set(c.errorInjector, key, value) + t.Cleanup(cleanup) + return cleanup +} + func mustPortFromAddress(addr string) httpPort { _, port, err := net.SplitHostPort(addr) if err != nil { From 18a174d15ba568b5d387cd7d6ba6dfcf5980a04d Mon Sep 17 00:00:00 2001 From: David Reiss Date: Thu, 5 Dec 2024 15:11:05 +0000 Subject: [PATCH 03/13] simplify noop --- common/errorinjector/noop_impl.go | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/common/errorinjector/noop_impl.go b/common/errorinjector/noop_impl.go index f9a2de4fb30..7c2b1e29aa4 100644 --- a/common/errorinjector/noop_impl.go +++ b/common/errorinjector/noop_impl.go @@ -5,18 +5,14 @@ package errorinjector import "go.uber.org/fx" var Module = fx.Options( - fx.Provide(func() ErrorInjector { return nil }), + fx.Provide(func() (ei ErrorInjector) { return }), ) type ( - ErrorInjector interface { - } + ErrorInjector struct{} ) func Get[T any](ei ErrorInjector, key string) (T, bool) { var zero T return zero, false } - -func Set[T any](ei ErrorInjector, key string, val T) { -} From c4efeacc2c1b0dc9db161109018c611c73c041a5 Mon Sep 17 00:00:00 2001 From: David Reiss Date: Thu, 5 Dec 2024 19:42:26 +0000 Subject: [PATCH 04/13] mock --- client/client_factory_mock.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/client/client_factory_mock.go b/client/client_factory_mock.go index 00788b6554b..31910b11f71 100644 --- a/client/client_factory_mock.go +++ b/client/client_factory_mock.go @@ -43,6 +43,7 @@ import ( matchingservice "go.temporal.io/server/api/matchingservice/v1" common "go.temporal.io/server/common" dynamicconfig "go.temporal.io/server/common/dynamicconfig" + errorinjector "go.temporal.io/server/common/errorinjector" log "go.temporal.io/server/common/log" membership "go.temporal.io/server/common/membership" metrics "go.temporal.io/server/common/metrics" @@ -187,15 +188,15 @@ func (m *MockFactoryProvider) EXPECT() *MockFactoryProviderMockRecorder { } // NewFactory mocks base method. -func (m *MockFactoryProvider) NewFactory(rpcFactory common.RPCFactory, monitor membership.Monitor, metricsHandler metrics.Handler, dc *dynamicconfig.Collection, numberOfHistoryShards int32, logger, throttledLogger log.Logger) Factory { +func (m *MockFactoryProvider) NewFactory(rpcFactory common.RPCFactory, monitor membership.Monitor, metricsHandler metrics.Handler, dc *dynamicconfig.Collection, errorInjector errorinjector.ErrorInjector, numberOfHistoryShards int32, logger, throttledLogger log.Logger) Factory { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NewFactory", rpcFactory, monitor, metricsHandler, dc, numberOfHistoryShards, logger, throttledLogger) + ret := m.ctrl.Call(m, "NewFactory", rpcFactory, monitor, metricsHandler, dc, errorInjector, numberOfHistoryShards, logger, throttledLogger) ret0, _ := ret[0].(Factory) return ret0 } // NewFactory indicates an expected call of NewFactory. -func (mr *MockFactoryProviderMockRecorder) NewFactory(rpcFactory, monitor, metricsHandler, dc, numberOfHistoryShards, logger, throttledLogger any) *gomock.Call { +func (mr *MockFactoryProviderMockRecorder) NewFactory(rpcFactory, monitor, metricsHandler, dc, errorInjector, numberOfHistoryShards, logger, throttledLogger any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewFactory", reflect.TypeOf((*MockFactoryProvider)(nil).NewFactory), rpcFactory, monitor, metricsHandler, dc, numberOfHistoryShards, logger, throttledLogger) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewFactory", reflect.TypeOf((*MockFactoryProvider)(nil).NewFactory), rpcFactory, monitor, metricsHandler, dc, errorInjector, numberOfHistoryShards, logger, throttledLogger) } From 1a9615a78c81a95cabd642d285a3fde41a65d476 Mon Sep 17 00:00:00 2001 From: David Reiss Date: Thu, 5 Dec 2024 19:42:37 +0000 Subject: [PATCH 05/13] fix matching test --- client/matching/loadbalancer.go | 30 +++++++++++----- client/matching/loadbalancer_test.go | 54 ++++++++++++++-------------- 2 files changed, 48 insertions(+), 36 deletions(-) diff --git a/client/matching/loadbalancer.go b/client/matching/loadbalancer.go index 32de7362284..90474e5271c 100644 --- a/client/matching/loadbalancer.go +++ b/client/matching/loadbalancer.go @@ -129,7 +129,11 @@ func (lb *defaultLoadBalancer) PickReadPartition( partitionCount = lb.nReadPartitions(string(namespaceName), taskQueue.Name(), taskQueue.TaskType()) } - return tqlb.pickReadPartition(partitionCount, lb.errorInjector) + if n, ok := errorinjector.Get[int](lb.errorInjector, errorinjector.MatchingLBForceWritePartition); ok { + return tqlb.forceReadPartition(partitionCount, n) + } else { + return tqlb.pickReadPartition(partitionCount) + } } func (lb *defaultLoadBalancer) getTaskQueueLoadBalancer(tq *tqid.TaskQueue) *tqLoadBalancer { @@ -156,18 +160,26 @@ func newTaskQueueLoadBalancer(tq *tqid.TaskQueue) *tqLoadBalancer { } } -func (b *tqLoadBalancer) pickReadPartition(partitionCount int, ei errorinjector.ErrorInjector) *pollToken { +func (b *tqLoadBalancer) pickReadPartition(partitionCount int) *pollToken { b.lock.Lock() defer b.lock.Unlock() - var partitionID int - if n, ok := errorinjector.Get[int](ei, errorinjector.MatchingLBForceWritePartition); ok { - b.ensurePartitionCountLocked(max(partitionCount, n+1)) // allow n to be >= partitionCount - partitionID = n - } else { - b.ensurePartitionCountLocked(partitionCount) - partitionID = b.pickReadPartitionWithFewestPolls(partitionCount) + b.ensurePartitionCountLocked(partitionCount) + partitionID := b.pickReadPartitionWithFewestPolls(partitionCount) + + b.pollerCounts[partitionID]++ + + return &pollToken{ + TQPartition: b.taskQueue.NormalPartition(partitionID), + balancer: b, } +} + +func (b *tqLoadBalancer) forceReadPartition(partitionCount, partitionID int) *pollToken { + b.lock.Lock() + defer b.lock.Unlock() + + b.ensurePartitionCountLocked(max(partitionCount, partitionID+1)) b.pollerCounts[partitionID]++ diff --git a/client/matching/loadbalancer_test.go b/client/matching/loadbalancer_test.go index 65a2ae6421f..76a84253a16 100644 --- a/client/matching/loadbalancer_test.go +++ b/client/matching/loadbalancer_test.go @@ -63,22 +63,22 @@ func TestTQLoadBalancer(t *testing.T) { tqlb := newTaskQueueLoadBalancer(f.TaskQueue(enumspb.TASK_QUEUE_TYPE_ACTIVITY)) // pick 4 times, each partition picked would have one poller - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 1, maxPollerCount(tqlb)) - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 1, maxPollerCount(tqlb)) - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 1, maxPollerCount(tqlb)) - p3 := tqlb.pickReadPartition(partitionCount, -1) + p3 := tqlb.pickReadPartition(partitionCount) assert.Equal(t, 1, maxPollerCount(tqlb)) // release one, and pick one, the newly picked one should have one poller p3.Release() - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 1, maxPollerCount(tqlb)) // pick one again, this time it should have 2 pollers - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 2, maxPollerCount(tqlb)) } @@ -89,27 +89,27 @@ func TestTQLoadBalancerForce(t *testing.T) { tqlb := newTaskQueueLoadBalancer(f.TaskQueue(enumspb.TASK_QUEUE_TYPE_ACTIVITY)) // pick 4 times, each partition picked would have one poller - p1 := tqlb.pickReadPartition(partitionCount, 1) + p1 := tqlb.forceReadPartition(partitionCount, 1) assert.Equal(t, 1, p1.TQPartition.PartitionId()) assert.Equal(t, 1, maxPollerCount(tqlb)) - tqlb.pickReadPartition(partitionCount, 1) + tqlb.forceReadPartition(partitionCount, 1) assert.Equal(t, 2, maxPollerCount(tqlb)) // when we don't force it should balance out - tqlb.pickReadPartition(partitionCount, -1) - tqlb.pickReadPartition(partitionCount, -1) - tqlb.pickReadPartition(partitionCount, -1) - tqlb.pickReadPartition(partitionCount, -1) - tqlb.pickReadPartition(partitionCount, -1) - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) + tqlb.pickReadPartition(partitionCount) + tqlb.pickReadPartition(partitionCount) + tqlb.pickReadPartition(partitionCount) + tqlb.pickReadPartition(partitionCount) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 2, maxPollerCount(tqlb)) // releasing the forced one and adding another should still be balanced p1.Release() - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 2, maxPollerCount(tqlb)) - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 3, maxPollerCount(tqlb)) } @@ -125,7 +125,7 @@ func TestLoadBalancerConcurrent(t *testing.T) { for i := 0; i < concurrentCount; i++ { go func() { defer wg.Done() - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) }() } wg.Wait() @@ -142,23 +142,23 @@ func TestLoadBalancer_ReducedPartitionCount(t *testing.T) { f, err := tqid.NewTaskQueueFamily("fake-namespace-id", "fake-taskqueue") assert.NoError(t, err) tqlb := newTaskQueueLoadBalancer(f.TaskQueue(enumspb.TASK_QUEUE_TYPE_ACTIVITY)) - p1 := tqlb.pickReadPartition(partitionCount, -1) - p2 := tqlb.pickReadPartition(partitionCount, -1) + p1 := tqlb.pickReadPartition(partitionCount) + p2 := tqlb.pickReadPartition(partitionCount) assert.Equal(t, 1, maxPollerCount(tqlb)) assert.Equal(t, 1, maxPollerCount(tqlb)) partitionCount += 2 // increase partition count - p3 := tqlb.pickReadPartition(partitionCount, -1) - p4 := tqlb.pickReadPartition(partitionCount, -1) + p3 := tqlb.pickReadPartition(partitionCount) + p4 := tqlb.pickReadPartition(partitionCount) assert.Equal(t, 1, maxPollerCount(tqlb)) assert.Equal(t, 1, maxPollerCount(tqlb)) partitionCount -= 2 // reduce partition count - p5 := tqlb.pickReadPartition(partitionCount, -1) - p6 := tqlb.pickReadPartition(partitionCount, -1) + p5 := tqlb.pickReadPartition(partitionCount) + p6 := tqlb.pickReadPartition(partitionCount) assert.Equal(t, 2, maxPollerCount(tqlb)) assert.Equal(t, 2, maxPollerCount(tqlb)) - p7 := tqlb.pickReadPartition(partitionCount, -1) + p7 := tqlb.pickReadPartition(partitionCount) assert.Equal(t, 3, maxPollerCount(tqlb)) // release all of them and it should be ok. @@ -170,11 +170,11 @@ func TestLoadBalancer_ReducedPartitionCount(t *testing.T) { p6.Release() p7.Release() - tqlb.pickReadPartition(partitionCount, -1) - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 1, maxPollerCount(tqlb)) assert.Equal(t, 1, maxPollerCount(tqlb)) - tqlb.pickReadPartition(partitionCount, -1) + tqlb.pickReadPartition(partitionCount) assert.Equal(t, 2, maxPollerCount(tqlb)) } From 525924c7c7f7e0f757a0d50dfcbb361e68ee7aff Mon Sep 17 00:00:00 2001 From: David Reiss Date: Thu, 5 Dec 2024 19:59:27 +0000 Subject: [PATCH 06/13] use test tags for lint --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 3a0a4db71ba..dd2a804adeb 100644 --- a/Makefile +++ b/Makefile @@ -331,7 +331,7 @@ lint-actions: $(ACTIONLINT) lint-code: $(GOLANGCI_LINT) @printf $(COLOR) "Linting code..." - @$(GOLANGCI_LINT) run --verbose --timeout 10m --fix=$(GOLANGCI_LINT_FIX) --new-from-rev=$(GOLANGCI_LINT_BASE_REV) --config=.golangci.yml + @$(GOLANGCI_LINT) run --verbose --build-tags $(ALL_TEST_TAGS) --timeout 10m --fix=$(GOLANGCI_LINT_FIX) --new-from-rev=$(GOLANGCI_LINT_BASE_REV) --config=.golangci.yml fmt-imports: $(GCI) # Don't get confused, there is a single linter called gci, which is a part of the mega linter we use is called golangci-lint. @printf $(COLOR) "Formatting imports..." From ffeb3e5d2835cf86a834d542b55a8513d6d66780 Mon Sep 17 00:00:00 2001 From: David Reiss Date: Thu, 5 Dec 2024 20:32:54 +0000 Subject: [PATCH 07/13] rename to testhooks --- Makefile | 2 +- client/client_factory_mock.go | 10 ++-- client/clientfactory.go | 12 ++-- client/matching/loadbalancer.go | 16 +++--- common/errorinjector/noop_impl.go | 18 ------ common/errorinjector/test_impl.go | 57 ------------------- common/resource/fx.go | 8 +-- .../testhooks}/constants.go | 2 +- common/testing/testhooks/noop_impl.go | 18 ++++++ common/testing/testhooks/test_impl.go | 57 +++++++++++++++++++ service/matching/handler.go | 6 +- service/matching/matching_engine.go | 8 +-- .../matching/physical_task_queue_manager.go | 4 +- tests/testcore/functional_test_base.go | 10 ++-- tests/testcore/onebox.go | 22 +++---- 15 files changed, 125 insertions(+), 125 deletions(-) delete mode 100644 common/errorinjector/noop_impl.go delete mode 100644 common/errorinjector/test_impl.go rename common/{errorinjector => testing/testhooks}/constants.go (90%) create mode 100644 common/testing/testhooks/noop_impl.go create mode 100644 common/testing/testhooks/test_impl.go diff --git a/Makefile b/Makefile index dd2a804adeb..cd3c927d959 100644 --- a/Makefile +++ b/Makefile @@ -41,7 +41,7 @@ VISIBILITY_DB ?= temporal_visibility # Always use "protolegacy" tag to allow disabling utf-8 validation on proto messages # during proto library transition. ALL_BUILD_TAGS := protolegacy,$(BUILD_TAG) -ALL_TEST_TAGS := $(ALL_BUILD_TAGS),errorinjector,$(TEST_TAG) +ALL_TEST_TAGS := $(ALL_BUILD_TAGS),testhooks,$(TEST_TAG) BUILD_TAG_FLAG := -tags $(ALL_BUILD_TAGS) TEST_TAG_FLAG := -tags $(ALL_TEST_TAGS) diff --git a/client/client_factory_mock.go b/client/client_factory_mock.go index 31910b11f71..93246d97d40 100644 --- a/client/client_factory_mock.go +++ b/client/client_factory_mock.go @@ -43,10 +43,10 @@ import ( matchingservice "go.temporal.io/server/api/matchingservice/v1" common "go.temporal.io/server/common" dynamicconfig "go.temporal.io/server/common/dynamicconfig" - errorinjector "go.temporal.io/server/common/errorinjector" log "go.temporal.io/server/common/log" membership "go.temporal.io/server/common/membership" metrics "go.temporal.io/server/common/metrics" + testhooks "go.temporal.io/server/common/testing/testhooks" gomock "go.uber.org/mock/gomock" grpc "google.golang.org/grpc" ) @@ -188,15 +188,15 @@ func (m *MockFactoryProvider) EXPECT() *MockFactoryProviderMockRecorder { } // NewFactory mocks base method. -func (m *MockFactoryProvider) NewFactory(rpcFactory common.RPCFactory, monitor membership.Monitor, metricsHandler metrics.Handler, dc *dynamicconfig.Collection, errorInjector errorinjector.ErrorInjector, numberOfHistoryShards int32, logger, throttledLogger log.Logger) Factory { +func (m *MockFactoryProvider) NewFactory(rpcFactory common.RPCFactory, monitor membership.Monitor, metricsHandler metrics.Handler, dc *dynamicconfig.Collection, testHooks testhooks.TestHooks, numberOfHistoryShards int32, logger, throttledLogger log.Logger) Factory { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NewFactory", rpcFactory, monitor, metricsHandler, dc, errorInjector, numberOfHistoryShards, logger, throttledLogger) + ret := m.ctrl.Call(m, "NewFactory", rpcFactory, monitor, metricsHandler, dc, testHooks, numberOfHistoryShards, logger, throttledLogger) ret0, _ := ret[0].(Factory) return ret0 } // NewFactory indicates an expected call of NewFactory. -func (mr *MockFactoryProviderMockRecorder) NewFactory(rpcFactory, monitor, metricsHandler, dc, errorInjector, numberOfHistoryShards, logger, throttledLogger any) *gomock.Call { +func (mr *MockFactoryProviderMockRecorder) NewFactory(rpcFactory, monitor, metricsHandler, dc, testHooks, numberOfHistoryShards, logger, throttledLogger any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewFactory", reflect.TypeOf((*MockFactoryProvider)(nil).NewFactory), rpcFactory, monitor, metricsHandler, dc, errorInjector, numberOfHistoryShards, logger, throttledLogger) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewFactory", reflect.TypeOf((*MockFactoryProvider)(nil).NewFactory), rpcFactory, monitor, metricsHandler, dc, testHooks, numberOfHistoryShards, logger, throttledLogger) } diff --git a/client/clientfactory.go b/client/clientfactory.go index 7c3eb974f81..85366422b91 100644 --- a/client/clientfactory.go +++ b/client/clientfactory.go @@ -39,12 +39,12 @@ import ( "go.temporal.io/server/client/matching" "go.temporal.io/server/common" "go.temporal.io/server/common/dynamicconfig" - "go.temporal.io/server/common/errorinjector" "go.temporal.io/server/common/log" "go.temporal.io/server/common/membership" "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/primitives" + "go.temporal.io/server/common/testing/testhooks" "google.golang.org/grpc" ) @@ -66,7 +66,7 @@ type ( monitor membership.Monitor, metricsHandler metrics.Handler, dc *dynamicconfig.Collection, - errorInjector errorinjector.ErrorInjector, + testHooks testhooks.TestHooks, numberOfHistoryShards int32, logger log.Logger, throttledLogger log.Logger, @@ -81,7 +81,7 @@ type ( monitor membership.Monitor metricsHandler metrics.Handler dynConfig *dynamicconfig.Collection - errorInjector errorinjector.ErrorInjector + testHooks testhooks.TestHooks numberOfHistoryShards int32 logger log.Logger throttledLogger log.Logger @@ -106,7 +106,7 @@ func (p *factoryProviderImpl) NewFactory( monitor membership.Monitor, metricsHandler metrics.Handler, dc *dynamicconfig.Collection, - errorInjector errorinjector.ErrorInjector, + testHooks testhooks.TestHooks, numberOfHistoryShards int32, logger log.Logger, throttledLogger log.Logger, @@ -116,7 +116,7 @@ func (p *factoryProviderImpl) NewFactory( monitor: monitor, metricsHandler: metricsHandler, dynConfig: dc, - errorInjector: errorInjector, + testHooks: testHooks, numberOfHistoryShards: numberOfHistoryShards, logger: logger, throttledLogger: throttledLogger, @@ -164,7 +164,7 @@ func (cf *rpcClientFactory) NewMatchingClientWithTimeout( common.NewClientCache(keyResolver, clientProvider), cf.metricsHandler, cf.logger, - matching.NewLoadBalancer(namespaceIDToName, cf.dynConfig, cf.errorInjector), + matching.NewLoadBalancer(namespaceIDToName, cf.dynConfig, cf.testHooks), ) if cf.metricsHandler != nil { diff --git a/client/matching/loadbalancer.go b/client/matching/loadbalancer.go index 90474e5271c..ee08ee6fd31 100644 --- a/client/matching/loadbalancer.go +++ b/client/matching/loadbalancer.go @@ -29,8 +29,8 @@ import ( "sync" "go.temporal.io/server/common/dynamicconfig" - "go.temporal.io/server/common/errorinjector" "go.temporal.io/server/common/namespace" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/tqid" ) @@ -61,7 +61,7 @@ type ( namespaceIDToName func(id namespace.ID) (namespace.Name, error) nReadPartitions dynamicconfig.IntPropertyFnWithTaskQueueFilter nWritePartitions dynamicconfig.IntPropertyFnWithTaskQueueFilter - errorInjector errorinjector.ErrorInjector + testHooks testhooks.TestHooks lock sync.RWMutex taskQueueLBs map[tqid.TaskQueue]*tqLoadBalancer @@ -85,13 +85,13 @@ type ( func NewLoadBalancer( namespaceIDToName func(id namespace.ID) (namespace.Name, error), dc *dynamicconfig.Collection, - errorInjector errorinjector.ErrorInjector, + testHooks testhooks.TestHooks, ) LoadBalancer { lb := &defaultLoadBalancer{ namespaceIDToName: namespaceIDToName, nReadPartitions: dynamicconfig.MatchingNumTaskqueueReadPartitions.Get(dc), nWritePartitions: dynamicconfig.MatchingNumTaskqueueWritePartitions.Get(dc), - errorInjector: errorInjector, + testHooks: testHooks, taskQueueLBs: make(map[tqid.TaskQueue]*tqLoadBalancer), } return lb @@ -100,7 +100,7 @@ func NewLoadBalancer( func (lb *defaultLoadBalancer) PickWritePartition( taskQueue *tqid.TaskQueue, ) *tqid.NormalPartition { - if n, ok := errorinjector.Get[int](lb.errorInjector, errorinjector.MatchingLBForceWritePartition); ok { + if n, ok := testhooks.Get[int](lb.testHooks, testhooks.MatchingLBForceWritePartition); ok { return taskQueue.NormalPartition(n) } @@ -129,11 +129,11 @@ func (lb *defaultLoadBalancer) PickReadPartition( partitionCount = lb.nReadPartitions(string(namespaceName), taskQueue.Name(), taskQueue.TaskType()) } - if n, ok := errorinjector.Get[int](lb.errorInjector, errorinjector.MatchingLBForceWritePartition); ok { + if n, ok := testhooks.Get[int](lb.testHooks, testhooks.MatchingLBForceWritePartition); ok { return tqlb.forceReadPartition(partitionCount, n) - } else { - return tqlb.pickReadPartition(partitionCount) } + + return tqlb.pickReadPartition(partitionCount) } func (lb *defaultLoadBalancer) getTaskQueueLoadBalancer(tq *tqid.TaskQueue) *tqLoadBalancer { diff --git a/common/errorinjector/noop_impl.go b/common/errorinjector/noop_impl.go deleted file mode 100644 index 7c2b1e29aa4..00000000000 --- a/common/errorinjector/noop_impl.go +++ /dev/null @@ -1,18 +0,0 @@ -//go:build !errorinjector - -package errorinjector - -import "go.uber.org/fx" - -var Module = fx.Options( - fx.Provide(func() (ei ErrorInjector) { return }), -) - -type ( - ErrorInjector struct{} -) - -func Get[T any](ei ErrorInjector, key string) (T, bool) { - var zero T - return zero, false -} diff --git a/common/errorinjector/test_impl.go b/common/errorinjector/test_impl.go deleted file mode 100644 index 8289ff7a77d..00000000000 --- a/common/errorinjector/test_impl.go +++ /dev/null @@ -1,57 +0,0 @@ -//go:build errorinjector - -package errorinjector - -import ( - "sync" - - "go.uber.org/fx" -) - -var Module = fx.Options( - fx.Provide(NewTestErrorInjector), -) - -type ( - ErrorInjector interface { - // private accessors; access must go through package-level Get/Set - get(string) (any, bool) - set(string, any) - del(string) - } - - errorInjectorImpl struct { - m sync.Map - } -) - -func Get[T any](ei ErrorInjector, key string) (T, bool) { - if val, ok := ei.get(key); ok { - // this is only used in test so we want to panic on type mismatch: - return val.(T), ok - } - var zero T - return zero, false -} - -func Set[T any](ei ErrorInjector, key string, val T) func() { - ei.set(key, val) - return func() { ei.del(key) } -} - -func NewTestErrorInjector() ErrorInjector { - return &errorInjectorImpl{} -} - -func (ei *errorInjectorImpl) get(key string) (any, bool) { - val, ok := ei.m.Load(key) - return val, ok -} - -func (ei *errorInjectorImpl) set(key string, val any) { - ei.m.Store(key, val) -} - -func (ei *errorInjectorImpl) del(key string) { - ei.m.Delete(key) -} diff --git a/common/resource/fx.go b/common/resource/fx.go index 662b3e2b16c..eb33b1883f4 100644 --- a/common/resource/fx.go +++ b/common/resource/fx.go @@ -46,7 +46,6 @@ import ( "go.temporal.io/server/common/config" "go.temporal.io/server/common/deadlock" "go.temporal.io/server/common/dynamicconfig" - "go.temporal.io/server/common/errorinjector" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/membership" @@ -64,6 +63,7 @@ import ( "go.temporal.io/server/common/sdk" "go.temporal.io/server/common/searchattribute" "go.temporal.io/server/common/telemetry" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/utf8validator" "go.uber.org/fx" "google.golang.org/grpc" @@ -130,7 +130,7 @@ var Module = fx.Options( deadlock.Module, config.Module, utf8validator.Module, - errorinjector.Module, + testhooks.Module, fx.Invoke(func(*utf8validator.Validator) {}), // force this to be constructed even if not referenced elsewhere ) @@ -229,7 +229,7 @@ func ClientFactoryProvider( membershipMonitor membership.Monitor, metricsHandler metrics.Handler, dynamicCollection *dynamicconfig.Collection, - errorInjector errorinjector.ErrorInjector, + testHooks testhooks.TestHooks, persistenceConfig *config.Persistence, logger log.SnTaggedLogger, throttledLogger log.ThrottledLogger, @@ -239,7 +239,7 @@ func ClientFactoryProvider( membershipMonitor, metricsHandler, dynamicCollection, - errorInjector, + testHooks, persistenceConfig.NumHistoryShards, logger, throttledLogger, diff --git a/common/errorinjector/constants.go b/common/testing/testhooks/constants.go similarity index 90% rename from common/errorinjector/constants.go rename to common/testing/testhooks/constants.go index 0c26955f375..338ddfaa187 100644 --- a/common/errorinjector/constants.go +++ b/common/testing/testhooks/constants.go @@ -1,4 +1,4 @@ -package errorinjector +package testhooks const ( MatchingDisableSyncMatch = "matching.disableSyncMatch" diff --git a/common/testing/testhooks/noop_impl.go b/common/testing/testhooks/noop_impl.go new file mode 100644 index 00000000000..4cb3901bcca --- /dev/null +++ b/common/testing/testhooks/noop_impl.go @@ -0,0 +1,18 @@ +//go:build !testhooks + +package testhooks + +import "go.uber.org/fx" + +var Module = fx.Options( + fx.Provide(func() (_ TestHooks) { return }), +) + +type ( + TestHooks struct{} +) + +func Get[T any](_ TestHooks, key string) (T, bool) { + var zero T + return zero, false +} diff --git a/common/testing/testhooks/test_impl.go b/common/testing/testhooks/test_impl.go new file mode 100644 index 00000000000..52d9be6f673 --- /dev/null +++ b/common/testing/testhooks/test_impl.go @@ -0,0 +1,57 @@ +//go:build testhooks + +package testhooks + +import ( + "sync" + + "go.uber.org/fx" +) + +var Module = fx.Options( + fx.Provide(NewTestHooksImpl), +) + +type ( + TestHooks interface { + // private accessors; access must go through package-level Get/Set + get(string) (any, bool) + set(string, any) + del(string) + } + + testHooksImpl struct { + m sync.Map + } +) + +func Get[T any](th TestHooks, key string) (T, bool) { + if val, ok := th.get(key); ok { + // this is only used in test so we want to panic on type mismatch: + return val.(T), ok // nolint:revive + } + var zero T + return zero, false +} + +func Set[T any](th TestHooks, key string, val T) func() { + th.set(key, val) + return func() { th.del(key) } +} + +func NewTestHooksImpl() TestHooks { + return &testHooksImpl{} +} + +func (th *testHooksImpl) get(key string) (any, bool) { + val, ok := th.m.Load(key) + return val, ok +} + +func (th *testHooksImpl) set(key string, val any) { + th.m.Store(key, val) +} + +func (th *testHooksImpl) del(key string) { + th.m.Delete(key) +} diff --git a/service/matching/handler.go b/service/matching/handler.go index 6bb2a404e80..5e9f13be783 100644 --- a/service/matching/handler.go +++ b/service/matching/handler.go @@ -35,7 +35,6 @@ import ( "go.temporal.io/server/api/matchingservice/v1" "go.temporal.io/server/common" "go.temporal.io/server/common/cluster" - "go.temporal.io/server/common/errorinjector" "go.temporal.io/server/common/log" "go.temporal.io/server/common/membership" "go.temporal.io/server/common/metrics" @@ -43,6 +42,7 @@ import ( "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/persistence/visibility/manager" "go.temporal.io/server/common/resource" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/tqid" "go.temporal.io/server/service/worker/deployment" "google.golang.org/protobuf/proto" @@ -89,7 +89,7 @@ func NewHandler( namespaceReplicationQueue persistence.NamespaceReplicationQueue, visibilityManager manager.VisibilityManager, nexusEndpointManager persistence.NexusEndpointManager, - errorInjector errorinjector.ErrorInjector, + testHooks testhooks.TestHooks, ) *Handler { handler := &Handler{ config: config, @@ -112,7 +112,7 @@ func NewHandler( namespaceReplicationQueue, visibilityManager, nexusEndpointManager, - errorInjector, + testHooks, ), namespaceRegistry: namespaceRegistry, } diff --git a/service/matching/matching_engine.go b/service/matching/matching_engine.go index dc4ff71e6cc..364474db443 100644 --- a/service/matching/matching_engine.go +++ b/service/matching/matching_engine.go @@ -56,7 +56,6 @@ import ( hlc "go.temporal.io/server/common/clock/hybrid_logical_clock" "go.temporal.io/server/common/cluster" "go.temporal.io/server/common/collection" - "go.temporal.io/server/common/errorinjector" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/membership" @@ -70,6 +69,7 @@ import ( "go.temporal.io/server/common/resource" serviceerrors "go.temporal.io/server/common/serviceerror" "go.temporal.io/server/common/tasktoken" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/tqid" "go.temporal.io/server/common/util" "go.temporal.io/server/common/worker_versioning" @@ -144,7 +144,7 @@ type ( partitions map[tqid.PartitionKey]taskQueuePartitionManager gaugeMetrics gaugeMetrics // per-namespace task queue counters config *Config - errorInjector errorinjector.ErrorInjector + testHooks testhooks.TestHooks // queryResults maps query TaskID (which is a UUID generated in QueryWorkflow() call) to a channel // that QueryWorkflow() will block on. The channel is unblocked either by worker sending response through // RespondQueryTaskCompleted() or through an internal service error causing temporal to be unable to dispatch @@ -205,7 +205,7 @@ func NewEngine( namespaceReplicationQueue persistence.NamespaceReplicationQueue, visibilityManager manager.VisibilityManager, nexusEndpointManager persistence.NexusEndpointManager, - errorInjector errorinjector.ErrorInjector, + testHooks testhooks.TestHooks, ) Engine { scopedMetricsHandler := metricsHandler.WithTags(metrics.OperationTag(metrics.MatchingEngineScope)) e := &matchingEngineImpl{ @@ -236,7 +236,7 @@ func NewEngine( loadedPhysicalTaskQueueCount: make(map[taskQueueCounterKey]int), }, config: config, - errorInjector: errorInjector, + testHooks: testHooks, queryResults: collection.NewSyncMap[string, chan *queryResult](), nexusResults: collection.NewSyncMap[string, chan *nexusResult](), outstandingPollers: collection.NewSyncMap[string, context.CancelFunc](), diff --git a/service/matching/physical_task_queue_manager.go b/service/matching/physical_task_queue_manager.go index 91df31a5d5e..25ae86ee5d9 100644 --- a/service/matching/physical_task_queue_manager.go +++ b/service/matching/physical_task_queue_manager.go @@ -46,11 +46,11 @@ import ( "go.temporal.io/server/common/clock" "go.temporal.io/server/common/cluster" "go.temporal.io/server/common/debug" - "go.temporal.io/server/common/errorinjector" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/namespace" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/worker_versioning" "go.temporal.io/server/service/worker/deployment" "google.golang.org/protobuf/types/known/durationpb" @@ -530,7 +530,7 @@ func (c *physicalTaskQueueManagerImpl) TrySyncMatch(ctx context.Context, task *i // request sent by history service c.liveness.markAlive() c.tasksAddedInIntervals.incrementTaskCount() - if disable, _ := errorinjector.Get[bool](c.partitionMgr.engine.errorInjector, errorinjector.MatchingDisableSyncMatch); disable { + if disable, _ := testhooks.Get[bool](c.partitionMgr.engine.testHooks, testhooks.MatchingDisableSyncMatch); disable { return false, nil } } diff --git a/tests/testcore/functional_test_base.go b/tests/testcore/functional_test_base.go index 87f37d3eaf0..197b83159cf 100644 --- a/tests/testcore/functional_test_base.go +++ b/tests/testcore/functional_test_base.go @@ -48,7 +48,6 @@ import ( persistencespb "go.temporal.io/server/api/persistence/v1" "go.temporal.io/server/common" "go.temporal.io/server/common/dynamicconfig" - "go.temporal.io/server/common/errorinjector" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/payloads" @@ -56,6 +55,7 @@ import ( "go.temporal.io/server/common/primitives" "go.temporal.io/server/common/primitives/timestamp" "go.temporal.io/server/common/rpc" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/environment" "go.uber.org/fx" "google.golang.org/protobuf/types/known/durationpb" @@ -515,20 +515,20 @@ func (s *FunctionalTestBase) RunTestWithMatchingBehavior(subtest func()) { name, func() { if forceTaskForward { s.OverrideDynamicConfig(dynamicconfig.MatchingNumTaskqueueWritePartitions, 13) - s.InjectError(errorinjector.MatchingLBForceWritePartition, 11) + s.InjectError(testhooks.MatchingLBForceWritePartition, 11) } else { s.OverrideDynamicConfig(dynamicconfig.MatchingNumTaskqueueWritePartitions, 1) } if forcePollForward { s.OverrideDynamicConfig(dynamicconfig.MatchingNumTaskqueueReadPartitions, 13) - s.InjectError(errorinjector.MatchingLBForceReadPartition, 5) + s.InjectError(testhooks.MatchingLBForceReadPartition, 5) } else { s.OverrideDynamicConfig(dynamicconfig.MatchingNumTaskqueueReadPartitions, 1) } if forceAsync { - s.InjectError(errorinjector.MatchingDisableSyncMatch, true) + s.InjectError(testhooks.MatchingDisableSyncMatch, true) } else { - s.InjectError(errorinjector.MatchingDisableSyncMatch, false) + s.InjectError(testhooks.MatchingDisableSyncMatch, false) } subtest() diff --git a/tests/testcore/onebox.go b/tests/testcore/onebox.go index ffc29bad059..90ea411ebf8 100644 --- a/tests/testcore/onebox.go +++ b/tests/testcore/onebox.go @@ -52,7 +52,6 @@ import ( "go.temporal.io/server/common/cluster" "go.temporal.io/server/common/config" "go.temporal.io/server/common/dynamicconfig" - "go.temporal.io/server/common/errorinjector" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/membership" @@ -71,6 +70,7 @@ import ( "go.temporal.io/server/common/rpc/encryption" "go.temporal.io/server/common/sdk" "go.temporal.io/server/common/searchattribute" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service/frontend" "go.temporal.io/server/service/history" "go.temporal.io/server/service/history/replication" @@ -100,7 +100,7 @@ type ( matchingClient matchingservice.MatchingServiceClient dcClient *dynamicconfig.MemoryClient - errorInjector errorinjector.ErrorInjector + testHooks testhooks.TestHooks logger log.Logger clusterMetadataConfig *cluster.Config persistenceConfig config.Persistence @@ -260,8 +260,8 @@ func newTemporal(t *testing.T, params *TemporalParams) *TemporalImpl { tlsConfigProvider: params.TLSConfigProvider, captureMetricsHandler: params.CaptureMetricsHandler, dcClient: dynamicconfig.NewMemoryClient(), - // If this doesn't build, make sure you're building with tags 'errorinjector': - errorInjector: errorinjector.NewTestErrorInjector(), + // If this doesn't build, make sure you're building with tags 'testhooks': + testHooks: testhooks.NewTestHooksImpl(), serviceFxOptions: params.ServiceFxOptions, taskCategoryRegistry: params.TaskCategoryRegistry, hostsByProtocolByService: params.HostsByProtocolByService, @@ -414,7 +414,7 @@ func (c *TemporalImpl) startFrontend() { fx.Provide(func() persistenceClient.AbstractDataStoreFactory { return c.abstractDataStoreFactory }), fx.Provide(func() visibility.VisibilityStoreFactory { return c.visibilityStoreFactory }), fx.Provide(func() dynamicconfig.Client { return c.dcClient }), - fx.Decorate(func() errorinjector.ErrorInjector { return c.errorInjector }), + fx.Decorate(func() testhooks.TestHooks { return c.testHooks }), fx.Provide(resource.DefaultSnTaggedLoggerProvider), fx.Provide(func() *esclient.Config { return c.esConfig }), fx.Provide(func() esclient.Client { return c.esClient }), @@ -486,7 +486,7 @@ func (c *TemporalImpl) startHistory() { fx.Provide(func() persistenceClient.AbstractDataStoreFactory { return c.abstractDataStoreFactory }), fx.Provide(func() visibility.VisibilityStoreFactory { return c.visibilityStoreFactory }), fx.Provide(func() dynamicconfig.Client { return c.dcClient }), - fx.Decorate(func() errorinjector.ErrorInjector { return c.errorInjector }), + fx.Decorate(func() testhooks.TestHooks { return c.testHooks }), fx.Provide(resource.DefaultSnTaggedLoggerProvider), fx.Provide(func() *esclient.Config { return c.esConfig }), fx.Provide(func() esclient.Client { return c.esClient }), @@ -540,7 +540,7 @@ func (c *TemporalImpl) startMatching() { fx.Provide(func() persistenceClient.AbstractDataStoreFactory { return c.abstractDataStoreFactory }), fx.Provide(func() visibility.VisibilityStoreFactory { return c.visibilityStoreFactory }), fx.Provide(func() dynamicconfig.Client { return c.dcClient }), - fx.Decorate(func() errorinjector.ErrorInjector { return c.errorInjector }), + fx.Decorate(func() testhooks.TestHooks { return c.testHooks }), fx.Provide(func() *esclient.Config { return c.esConfig }), fx.Provide(func() esclient.Client { return c.esClient }), fx.Provide(c.GetTLSConfigProvider), @@ -604,7 +604,7 @@ func (c *TemporalImpl) startWorker() { fx.Provide(func() persistenceClient.AbstractDataStoreFactory { return c.abstractDataStoreFactory }), fx.Provide(func() visibility.VisibilityStoreFactory { return c.visibilityStoreFactory }), fx.Provide(func() dynamicconfig.Client { return c.dcClient }), - fx.Decorate(func() errorinjector.ErrorInjector { return c.errorInjector }), + fx.Decorate(func() testhooks.TestHooks { return c.testHooks }), fx.Provide(resource.DefaultSnTaggedLoggerProvider), fx.Provide(func() esclient.Client { return c.esClient }), fx.Provide(func() *esclient.Config { return c.esConfig }), @@ -780,7 +780,7 @@ func (p *clientFactoryProvider) NewFactory( monitor membership.Monitor, metricsHandler metrics.Handler, dc *dynamicconfig.Collection, - errorInjector errorinjector.ErrorInjector, + testHooks testhooks.TestHooks, numberOfHistoryShards int32, logger log.Logger, throttledLogger log.Logger, @@ -790,7 +790,7 @@ func (p *clientFactoryProvider) NewFactory( monitor, metricsHandler, dc, - errorInjector, + testHooks, numberOfHistoryShards, logger, throttledLogger, @@ -905,7 +905,7 @@ func (c *TemporalImpl) overrideDynamicConfig(t *testing.T, name dynamicconfig.Ke } func (c *TemporalImpl) injectError(t *testing.T, key string, value any) func() { - cleanup := errorinjector.Set(c.errorInjector, key, value) + cleanup := testhooks.Set(c.testHooks, key, value) t.Cleanup(cleanup) return cleanup } From 64e85cb8db039b8f1e4c6d1122e738898f28be19 Mon Sep 17 00:00:00 2001 From: David Reiss Date: Thu, 5 Dec 2024 22:53:20 +0000 Subject: [PATCH 08/13] copyright --- common/testing/testhooks/constants.go | 22 ++++++++++++++++++++++ common/testing/testhooks/noop_impl.go | 22 ++++++++++++++++++++++ common/testing/testhooks/test_impl.go | 22 ++++++++++++++++++++++ 3 files changed, 66 insertions(+) diff --git a/common/testing/testhooks/constants.go b/common/testing/testhooks/constants.go index 338ddfaa187..3e65be1cb41 100644 --- a/common/testing/testhooks/constants.go +++ b/common/testing/testhooks/constants.go @@ -1,3 +1,25 @@ +// The MIT License +// +// Copyright (c) 2024 Temporal Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + package testhooks const ( diff --git a/common/testing/testhooks/noop_impl.go b/common/testing/testhooks/noop_impl.go index 4cb3901bcca..6a6d292341d 100644 --- a/common/testing/testhooks/noop_impl.go +++ b/common/testing/testhooks/noop_impl.go @@ -1,3 +1,25 @@ +// The MIT License +// +// Copyright (c) 2024 Temporal Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + //go:build !testhooks package testhooks diff --git a/common/testing/testhooks/test_impl.go b/common/testing/testhooks/test_impl.go index 52d9be6f673..8dbc990c83c 100644 --- a/common/testing/testhooks/test_impl.go +++ b/common/testing/testhooks/test_impl.go @@ -1,3 +1,25 @@ +// The MIT License +// +// Copyright (c) 2024 Temporal Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + //go:build testhooks package testhooks From 2a5d4c51136a06cb099c6183b20baf2752219522 Mon Sep 17 00:00:00 2001 From: David Reiss Date: Thu, 5 Dec 2024 23:00:41 +0000 Subject: [PATCH 09/13] change flag, comments --- Makefile | 2 +- common/testing/testhooks/noop_impl.go | 6 +++++- common/testing/testhooks/test_impl.go | 14 +++++++++++--- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/Makefile b/Makefile index cd3c927d959..758083dd44b 100644 --- a/Makefile +++ b/Makefile @@ -41,7 +41,7 @@ VISIBILITY_DB ?= temporal_visibility # Always use "protolegacy" tag to allow disabling utf-8 validation on proto messages # during proto library transition. ALL_BUILD_TAGS := protolegacy,$(BUILD_TAG) -ALL_TEST_TAGS := $(ALL_BUILD_TAGS),testhooks,$(TEST_TAG) +ALL_TEST_TAGS := $(ALL_BUILD_TAGS),test_dep,$(TEST_TAG) BUILD_TAG_FLAG := -tags $(ALL_BUILD_TAGS) TEST_TAG_FLAG := -tags $(ALL_TEST_TAGS) diff --git a/common/testing/testhooks/noop_impl.go b/common/testing/testhooks/noop_impl.go index 6a6d292341d..d68134bf172 100644 --- a/common/testing/testhooks/noop_impl.go +++ b/common/testing/testhooks/noop_impl.go @@ -20,7 +20,7 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -//go:build !testhooks +//go:build !test_dep package testhooks @@ -31,9 +31,13 @@ var Module = fx.Options( ) type ( + // TestHooks (in production mode) is an empty struct just so the build works. + // See TestHooks in test_impl.go. TestHooks struct{} ) +// Get gets the value of a test hook. In production mode it always returns the zero value and +// false, which hopefully the compiler will inline and remove the hook as dead code. func Get[T any](_ TestHooks, key string) (T, bool) { var zero T return zero, false diff --git a/common/testing/testhooks/test_impl.go b/common/testing/testhooks/test_impl.go index 8dbc990c83c..5d86d99673a 100644 --- a/common/testing/testhooks/test_impl.go +++ b/common/testing/testhooks/test_impl.go @@ -20,7 +20,7 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -//go:build testhooks +//go:build test_dep package testhooks @@ -35,6 +35,8 @@ var Module = fx.Options( ) type ( + // TestHooks holds a registry of active test hooks. It should be obtained through fx and + // used with Get and Set. TestHooks interface { // private accessors; access must go through package-level Get/Set get(string) (any, bool) @@ -42,11 +44,13 @@ type ( del(string) } + // testHooksImpl is an implementation of TestHooks. testHooksImpl struct { m sync.Map } ) +// Get gets the value of a test hook from the registry. func Get[T any](th TestHooks, key string) (T, bool) { if val, ok := th.get(key); ok { // this is only used in test so we want to panic on type mismatch: @@ -56,18 +60,22 @@ func Get[T any](th TestHooks, key string) (T, bool) { return zero, false } +// Set sets a test hook to a value and returns a cleanup function to unset it. +// Calls to Set and the cleanup functions should form a stack. func Set[T any](th TestHooks, key string, val T) func() { th.set(key, val) return func() { th.del(key) } } +// NewTestHooksImpl returns a new instance of a test hook registry. This is provided and used +// in the main "resource" module as a default, but in functional tests, it's overridden by an +// explicitly constructed instance. func NewTestHooksImpl() TestHooks { return &testHooksImpl{} } func (th *testHooksImpl) get(key string) (any, bool) { - val, ok := th.m.Load(key) - return val, ok + return th.m.Load(key) } func (th *testHooksImpl) set(key string, val any) { From c2be9e0be5641e2a45c1afe18ddda99a12169e2c Mon Sep 17 00:00:00 2001 From: Stephan Behnke Date: Thu, 5 Dec 2024 07:58:46 -0800 Subject: [PATCH 10/13] Add UpdateWithStart test using hook --- common/testing/testhooks/constants.go | 2 ++ service/history/api/multioperation/api.go | 10 ++++-- service/history/history_engine.go | 5 +++ service/history/history_engine_factory.go | 3 ++ tests/update_workflow_test.go | 37 +++++++++++++++++++++++ 5 files changed, 54 insertions(+), 3 deletions(-) diff --git a/common/testing/testhooks/constants.go b/common/testing/testhooks/constants.go index 3e65be1cb41..3aaf375c24b 100644 --- a/common/testing/testhooks/constants.go +++ b/common/testing/testhooks/constants.go @@ -26,4 +26,6 @@ const ( MatchingDisableSyncMatch = "matching.disableSyncMatch" MatchingLBForceReadPartition = "matching.lbForceReadPartition" MatchingLBForceWritePartition = "matching.lbForceWritePartition" + + UpdateWithStartInBetweenLockAndStart = "history.updateWithStartInBetweenLockAndStart" ) diff --git a/service/history/api/multioperation/api.go b/service/history/api/multioperation/api.go index 2a3a7b3be28..4712d9ad661 100644 --- a/service/history/api/multioperation/api.go +++ b/service/history/api/multioperation/api.go @@ -38,6 +38,7 @@ import ( "go.temporal.io/server/common/definition" "go.temporal.io/server/common/locks" "go.temporal.io/server/common/persistence/visibility/manager" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service/history/api" "go.temporal.io/server/service/history/api/startworkflow" "go.temporal.io/server/service/history/api/updateworkflow" @@ -45,9 +46,7 @@ import ( "go.temporal.io/server/service/history/workflow" ) -var ( - multiOpAbortedErr = serviceerror.NewMultiOperationAborted("Operation was aborted.") -) +var multiOpAbortedErr = serviceerror.NewMultiOperationAborted("Operation was aborted.") type ( // updateError is a wrapper to distinguish an update error from a start error. @@ -64,6 +63,7 @@ func Invoke( tokenSerializer common.TaskTokenSerializer, visibilityManager manager.VisibilityManager, matchingClient matchingservice.MatchingServiceClient, + testHooks testhooks.TestHooks, ) (*historyservice.ExecuteMultiOperationResponse, error) { if len(req.Operations) != 2 { return nil, serviceerror.NewInvalidArgument("expected exactly 2 operations") @@ -205,6 +205,10 @@ func Invoke( } } + if hook, ok := testhooks.Get[func()](testHooks, testhooks.UpdateWithStartInBetweenLockAndStart); ok { + hook() + } + // workflow hasn't been started yet: start and then apply update resp, err := startAndUpdateWorkflow(ctx, starter, updater) var noStartErr *noStartError diff --git a/service/history/history_engine.go b/service/history/history_engine.go index 24fd4c240c0..774c132d0c2 100644 --- a/service/history/history_engine.go +++ b/service/history/history_engine.go @@ -55,6 +55,7 @@ import ( "go.temporal.io/server/common/primitives/timestamp" "go.temporal.io/server/common/sdk" "go.temporal.io/server/common/searchattribute" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service/history/api" "go.temporal.io/server/service/history/api/addtasks" "go.temporal.io/server/service/history/api/deleteworkflow" @@ -157,6 +158,7 @@ type ( replicationProgressCache replication.ProgressCache syncStateRetriever replication.SyncStateRetriever outboundQueueCBPool *circuitbreakerpool.OutboundQueueCircuitBreakerPool + testHooks testhooks.TestHooks } ) @@ -183,6 +185,7 @@ func NewEngineWithShardContext( dlqWriter replication.DLQWriter, commandHandlerRegistry *workflow.CommandHandlerRegistry, outboundQueueCBPool *circuitbreakerpool.OutboundQueueCircuitBreakerPool, + testHooks testhooks.TestHooks, ) shard.Engine { currentClusterName := shard.GetClusterMetadata().GetCurrentClusterName() @@ -232,6 +235,7 @@ func NewEngineWithShardContext( replicationProgressCache: replicationProgressCache, syncStateRetriever: syncStateRetriever, outboundQueueCBPool: outboundQueueCBPool, + testHooks: testHooks, } historyEngImpl.queueProcessors = make(map[tasks.Category]queues.Queue) @@ -429,6 +433,7 @@ func (e *historyEngineImpl) ExecuteMultiOperation( e.tokenSerializer, e.persistenceVisibilityMgr, e.matchingClient, + e.testHooks, ) } diff --git a/service/history/history_engine_factory.go b/service/history/history_engine_factory.go index de90aa4fb3c..a005be63929 100644 --- a/service/history/history_engine_factory.go +++ b/service/history/history_engine_factory.go @@ -32,6 +32,7 @@ import ( "go.temporal.io/server/common/persistence/visibility/manager" "go.temporal.io/server/common/resource" "go.temporal.io/server/common/sdk" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/service/history/api" "go.temporal.io/server/service/history/circuitbreakerpool" "go.temporal.io/server/service/history/configs" @@ -68,6 +69,7 @@ type ( ReplicationDLQWriter replication.DLQWriter CommandHandlerRegistry *workflow.CommandHandlerRegistry OutboundQueueCBPool *circuitbreakerpool.OutboundQueueCircuitBreakerPool + TestHooks testhooks.TestHooks } historyEngineFactory struct { @@ -108,5 +110,6 @@ func (f *historyEngineFactory) CreateEngine( f.ReplicationDLQWriter, f.CommandHandlerRegistry, f.OutboundQueueCBPool, + f.TestHooks, ) } diff --git a/tests/update_workflow_test.go b/tests/update_workflow_test.go index fd4c4e5af46..c6f6e9538f1 100644 --- a/tests/update_workflow_test.go +++ b/tests/update_workflow_test.go @@ -51,6 +51,7 @@ import ( "go.temporal.io/server/common/metrics/metricstest" "go.temporal.io/server/common/testing/protoutils" "go.temporal.io/server/common/testing/taskpoller" + "go.temporal.io/server/common/testing/testhooks" "go.temporal.io/server/common/testing/testvars" "go.temporal.io/server/tests/testcore" "google.golang.org/protobuf/types/known/durationpb" @@ -5319,6 +5320,42 @@ func (s *UpdateWorkflowSuite) TestUpdateWithStart() { }) }) + s.Run("workflow start conflict", func() { + + s.Run("workflow id conflict policy fail: use-existing", func() { + tv := testvars.New(s.T()) + + startReq := startWorkflowReq(tv) + startReq.WorkflowIdConflictPolicy = enumspb.WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING + updateReq := s.updateWorkflowRequest(tv, + &updatepb.WaitPolicy{LifecycleStage: enumspb.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_COMPLETED}, "1") + + // simulate a race condition + s.InjectError(testhooks.UpdateWithStartInBetweenLockAndStart, func() { + _, err := s.FrontendClient().StartWorkflowExecution(testcore.NewContext(), startReq) + s.NoError(err) + }) + + uwsCh := sendUpdateWithStart(testcore.NewContext(), startReq, updateReq) + + _, err := s.TaskPoller.PollAndHandleWorkflowTask(tv, + func(task *workflowservice.PollWorkflowTaskQueueResponse) (*workflowservice.RespondWorkflowTaskCompletedRequest, error) { + return &workflowservice.RespondWorkflowTaskCompletedRequest{}, nil + }) + s.NoError(err) + + _, err = s.TaskPoller.PollAndHandleWorkflowTask(tv, + func(task *workflowservice.PollWorkflowTaskQueueResponse) (*workflowservice.RespondWorkflowTaskCompletedRequest, error) { + return &workflowservice.RespondWorkflowTaskCompletedRequest{ + Messages: s.UpdateAcceptCompleteMessages(tv, task.Messages[0], "1"), + }, nil + }) + s.NoError(err) + + <-uwsCh + }) + }) + s.Run("return update rate limit error", func() { // lower maximum total number of updates for testing purposes maxTotalUpdates := 0 From a933ed6bcb65e0bb9a0ba5cb144342117b49804f Mon Sep 17 00:00:00 2001 From: David Reiss Date: Thu, 5 Dec 2024 23:13:26 +0000 Subject: [PATCH 11/13] rename --- tests/testcore/functional_test_base.go | 12 ++++++------ tests/testcore/onebox.go | 4 ++-- tests/update_workflow_test.go | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/testcore/functional_test_base.go b/tests/testcore/functional_test_base.go index 197b83159cf..a45b3fb3846 100644 --- a/tests/testcore/functional_test_base.go +++ b/tests/testcore/functional_test_base.go @@ -479,8 +479,8 @@ func (s *FunctionalTestBase) OverrideDynamicConfig(setting dynamicconfig.Generic return s.testCluster.host.overrideDynamicConfig(s.T(), setting.Key(), value) } -func (s *FunctionalTestBase) InjectError(key string, value any) (cleanup func()) { - return s.testCluster.host.injectError(s.T(), key, value) +func (s *FunctionalTestBase) InjectHook(key string, value any) (cleanup func()) { + return s.testCluster.host.injectHook(s.T(), key, value) } func (s *FunctionalTestBase) GetNamespaceID(namespace string) string { @@ -515,20 +515,20 @@ func (s *FunctionalTestBase) RunTestWithMatchingBehavior(subtest func()) { name, func() { if forceTaskForward { s.OverrideDynamicConfig(dynamicconfig.MatchingNumTaskqueueWritePartitions, 13) - s.InjectError(testhooks.MatchingLBForceWritePartition, 11) + s.InjectHook(testhooks.MatchingLBForceWritePartition, 11) } else { s.OverrideDynamicConfig(dynamicconfig.MatchingNumTaskqueueWritePartitions, 1) } if forcePollForward { s.OverrideDynamicConfig(dynamicconfig.MatchingNumTaskqueueReadPartitions, 13) - s.InjectError(testhooks.MatchingLBForceReadPartition, 5) + s.InjectHook(testhooks.MatchingLBForceReadPartition, 5) } else { s.OverrideDynamicConfig(dynamicconfig.MatchingNumTaskqueueReadPartitions, 1) } if forceAsync { - s.InjectError(testhooks.MatchingDisableSyncMatch, true) + s.InjectHook(testhooks.MatchingDisableSyncMatch, true) } else { - s.InjectError(testhooks.MatchingDisableSyncMatch, false) + s.InjectHook(testhooks.MatchingDisableSyncMatch, false) } subtest() diff --git a/tests/testcore/onebox.go b/tests/testcore/onebox.go index 90ea411ebf8..855e8f75083 100644 --- a/tests/testcore/onebox.go +++ b/tests/testcore/onebox.go @@ -260,7 +260,7 @@ func newTemporal(t *testing.T, params *TemporalParams) *TemporalImpl { tlsConfigProvider: params.TLSConfigProvider, captureMetricsHandler: params.CaptureMetricsHandler, dcClient: dynamicconfig.NewMemoryClient(), - // If this doesn't build, make sure you're building with tags 'testhooks': + // If this doesn't build, make sure you're building with tags 'test_dep': testHooks: testhooks.NewTestHooksImpl(), serviceFxOptions: params.ServiceFxOptions, taskCategoryRegistry: params.TaskCategoryRegistry, @@ -904,7 +904,7 @@ func (c *TemporalImpl) overrideDynamicConfig(t *testing.T, name dynamicconfig.Ke return cleanup } -func (c *TemporalImpl) injectError(t *testing.T, key string, value any) func() { +func (c *TemporalImpl) injectHook(t *testing.T, key string, value any) func() { cleanup := testhooks.Set(c.testHooks, key, value) t.Cleanup(cleanup) return cleanup diff --git a/tests/update_workflow_test.go b/tests/update_workflow_test.go index c6f6e9538f1..ee4d41753c1 100644 --- a/tests/update_workflow_test.go +++ b/tests/update_workflow_test.go @@ -5331,7 +5331,7 @@ func (s *UpdateWorkflowSuite) TestUpdateWithStart() { &updatepb.WaitPolicy{LifecycleStage: enumspb.UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_COMPLETED}, "1") // simulate a race condition - s.InjectError(testhooks.UpdateWithStartInBetweenLockAndStart, func() { + s.InjectHook(testhooks.UpdateWithStartInBetweenLockAndStart, func() { _, err := s.FrontendClient().StartWorkflowExecution(testcore.NewContext(), startReq) s.NoError(err) }) From 6194d2c86268609ae3913d0832b3e6c5476c029e Mon Sep 17 00:00:00 2001 From: David Reiss Date: Thu, 5 Dec 2024 23:17:16 +0000 Subject: [PATCH 12/13] simplify common case --- common/testing/testhooks/noop_impl.go | 4 ++++ common/testing/testhooks/test_impl.go | 7 +++++++ service/history/api/multioperation/api.go | 4 +--- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/common/testing/testhooks/noop_impl.go b/common/testing/testhooks/noop_impl.go index d68134bf172..675e4a29fcf 100644 --- a/common/testing/testhooks/noop_impl.go +++ b/common/testing/testhooks/noop_impl.go @@ -42,3 +42,7 @@ func Get[T any](_ TestHooks, key string) (T, bool) { var zero T return zero, false } + +// Call calls a func() hook if present. +func Call(_ TestHooks, key string) { +} diff --git a/common/testing/testhooks/test_impl.go b/common/testing/testhooks/test_impl.go index 5d86d99673a..fc7651c321f 100644 --- a/common/testing/testhooks/test_impl.go +++ b/common/testing/testhooks/test_impl.go @@ -60,6 +60,13 @@ func Get[T any](th TestHooks, key string) (T, bool) { return zero, false } +// Call calls a func() hook if present. +func Call(th TestHooks, key string) { + if hook, ok := Get[func()](th, key); ok { + hook() + } +} + // Set sets a test hook to a value and returns a cleanup function to unset it. // Calls to Set and the cleanup functions should form a stack. func Set[T any](th TestHooks, key string, val T) func() { diff --git a/service/history/api/multioperation/api.go b/service/history/api/multioperation/api.go index 4712d9ad661..29bdf7db214 100644 --- a/service/history/api/multioperation/api.go +++ b/service/history/api/multioperation/api.go @@ -205,9 +205,7 @@ func Invoke( } } - if hook, ok := testhooks.Get[func()](testHooks, testhooks.UpdateWithStartInBetweenLockAndStart); ok { - hook() - } + testhooks.Call(testHooks, testhooks.UpdateWithStartInBetweenLockAndStart) // workflow hasn't been started yet: start and then apply update resp, err := startAndUpdateWorkflow(ctx, starter, updater) From 6025df71d6220680b815480af421dbc0873b0c54 Mon Sep 17 00:00:00 2001 From: David Reiss Date: Fri, 6 Dec 2024 01:27:05 +0000 Subject: [PATCH 13/13] use int keys --- common/testing/testhooks/constants.go | 11 ++++++----- common/testing/testhooks/noop_impl.go | 4 ++-- common/testing/testhooks/test_impl.go | 18 +++++++++--------- tests/testcore/functional_test_base.go | 2 +- tests/testcore/onebox.go | 2 +- 5 files changed, 19 insertions(+), 18 deletions(-) diff --git a/common/testing/testhooks/constants.go b/common/testing/testhooks/constants.go index 3aaf375c24b..eb8f32a10a5 100644 --- a/common/testing/testhooks/constants.go +++ b/common/testing/testhooks/constants.go @@ -22,10 +22,11 @@ package testhooks -const ( - MatchingDisableSyncMatch = "matching.disableSyncMatch" - MatchingLBForceReadPartition = "matching.lbForceReadPartition" - MatchingLBForceWritePartition = "matching.lbForceWritePartition" +type Key int - UpdateWithStartInBetweenLockAndStart = "history.updateWithStartInBetweenLockAndStart" +const ( + MatchingDisableSyncMatch Key = iota + MatchingLBForceReadPartition + MatchingLBForceWritePartition + UpdateWithStartInBetweenLockAndStart ) diff --git a/common/testing/testhooks/noop_impl.go b/common/testing/testhooks/noop_impl.go index 675e4a29fcf..cdeecd20209 100644 --- a/common/testing/testhooks/noop_impl.go +++ b/common/testing/testhooks/noop_impl.go @@ -38,11 +38,11 @@ type ( // Get gets the value of a test hook. In production mode it always returns the zero value and // false, which hopefully the compiler will inline and remove the hook as dead code. -func Get[T any](_ TestHooks, key string) (T, bool) { +func Get[T any](_ TestHooks, key Key) (T, bool) { var zero T return zero, false } // Call calls a func() hook if present. -func Call(_ TestHooks, key string) { +func Call(_ TestHooks, key Key) { } diff --git a/common/testing/testhooks/test_impl.go b/common/testing/testhooks/test_impl.go index fc7651c321f..f08197e5828 100644 --- a/common/testing/testhooks/test_impl.go +++ b/common/testing/testhooks/test_impl.go @@ -39,9 +39,9 @@ type ( // used with Get and Set. TestHooks interface { // private accessors; access must go through package-level Get/Set - get(string) (any, bool) - set(string, any) - del(string) + get(Key) (any, bool) + set(Key, any) + del(Key) } // testHooksImpl is an implementation of TestHooks. @@ -51,7 +51,7 @@ type ( ) // Get gets the value of a test hook from the registry. -func Get[T any](th TestHooks, key string) (T, bool) { +func Get[T any](th TestHooks, key Key) (T, bool) { if val, ok := th.get(key); ok { // this is only used in test so we want to panic on type mismatch: return val.(T), ok // nolint:revive @@ -61,7 +61,7 @@ func Get[T any](th TestHooks, key string) (T, bool) { } // Call calls a func() hook if present. -func Call(th TestHooks, key string) { +func Call(th TestHooks, key Key) { if hook, ok := Get[func()](th, key); ok { hook() } @@ -69,7 +69,7 @@ func Call(th TestHooks, key string) { // Set sets a test hook to a value and returns a cleanup function to unset it. // Calls to Set and the cleanup functions should form a stack. -func Set[T any](th TestHooks, key string, val T) func() { +func Set[T any](th TestHooks, key Key, val T) func() { th.set(key, val) return func() { th.del(key) } } @@ -81,14 +81,14 @@ func NewTestHooksImpl() TestHooks { return &testHooksImpl{} } -func (th *testHooksImpl) get(key string) (any, bool) { +func (th *testHooksImpl) get(key Key) (any, bool) { return th.m.Load(key) } -func (th *testHooksImpl) set(key string, val any) { +func (th *testHooksImpl) set(key Key, val any) { th.m.Store(key, val) } -func (th *testHooksImpl) del(key string) { +func (th *testHooksImpl) del(key Key) { th.m.Delete(key) } diff --git a/tests/testcore/functional_test_base.go b/tests/testcore/functional_test_base.go index a45b3fb3846..3f71d8421d8 100644 --- a/tests/testcore/functional_test_base.go +++ b/tests/testcore/functional_test_base.go @@ -479,7 +479,7 @@ func (s *FunctionalTestBase) OverrideDynamicConfig(setting dynamicconfig.Generic return s.testCluster.host.overrideDynamicConfig(s.T(), setting.Key(), value) } -func (s *FunctionalTestBase) InjectHook(key string, value any) (cleanup func()) { +func (s *FunctionalTestBase) InjectHook(key testhooks.Key, value any) (cleanup func()) { return s.testCluster.host.injectHook(s.T(), key, value) } diff --git a/tests/testcore/onebox.go b/tests/testcore/onebox.go index 855e8f75083..13629ee3735 100644 --- a/tests/testcore/onebox.go +++ b/tests/testcore/onebox.go @@ -904,7 +904,7 @@ func (c *TemporalImpl) overrideDynamicConfig(t *testing.T, name dynamicconfig.Ke return cleanup } -func (c *TemporalImpl) injectHook(t *testing.T, key string, value any) func() { +func (c *TemporalImpl) injectHook(t *testing.T, key testhooks.Key, value any) func() { cleanup := testhooks.Set(c.testHooks, key, value) t.Cleanup(cleanup) return cleanup