diff --git a/internal/internal_workflow_testsuite.go b/internal/internal_workflow_testsuite.go index 87bf91379..bc00ae314 100644 --- a/internal/internal_workflow_testsuite.go +++ b/internal/internal_workflow_testsuite.go @@ -160,6 +160,7 @@ type ( updateResult struct { success interface{} err error + mu *sync.Mutex } // testWorkflowEnvironmentShared is the shared data between parent workflow and child workflow test environments @@ -235,7 +236,8 @@ type ( queryHandler func(string, *commonpb.Payloads, *commonpb.Header) (*commonpb.Payloads, error) updateHandler func(name string, id string, input *commonpb.Payloads, header *commonpb.Header, resp UpdateCallbacks) updateMap map[string]updateResult - startedHandler func(r WorkflowExecution, e error) + // updateMapLock map[string]sync.Mutex + startedHandler func(r WorkflowExecution, e error) isWorkflowCompleted bool testResult converter.EncodedValue @@ -2236,7 +2238,6 @@ func (env *testWorkflowEnvironmentImpl) RegisterUpdateHandler( handler func(name string, id string, input *commonpb.Payloads, header *commonpb.Header, resp UpdateCallbacks), ) { env.updateHandler = handler - env.updateMap = make(map[string]updateResult) } func (env *testWorkflowEnvironmentImpl) RegisterQueryHandler( @@ -2925,19 +2926,28 @@ func (env *testWorkflowEnvironmentImpl) updateWorkflow(name string, id string, u panic(err) } + if env.updateMap == nil { + env.updateMap = make(map[string]updateResult) + } + var ucWrapper = updateCallbacksWrapper{uc: uc, env: env, updateID: id} // check for duplicate update ID if result, ok := env.updateMap[id]; ok { + result.mu.Lock() // return cached result env.postCallback(func() { ucWrapper.uc.Accept() ucWrapper.uc.Complete(result.success, result.err) + defer result.mu.Unlock() }, false) } else { + env.updateMap[id] = updateResult{nil, nil, &sync.Mutex{}} + env.updateMap[id].mu.Lock() env.postCallback(func() { // Do not send any headers on test invocations env.updateHandler(name, id, data, nil, ucWrapper) + defer env.updateMap[id].mu.Unlock() }, true) } @@ -3120,7 +3130,13 @@ func (uc updateCallbacksWrapper) Complete(success interface{}, err error) { if uc.env == nil { panic("env is needed in updateCallback to cache update results for deduping purposes") } - uc.env.updateMap[uc.updateID] = updateResult{success, err} + if result, ok := uc.env.updateMap[uc.updateID]; ok { + result.success = success + result.err = err + uc.env.updateMap[uc.updateID] = result + } else { + panic("updateMap[updateID] should already be created from updateWorkflow()") + } uc.uc.Complete(success, err) }