diff --git a/pkg/common/morpc/backend.go b/pkg/common/morpc/backend.go index ae82cad27d72c..95a2e485f0650 100644 --- a/pkg/common/morpc/backend.go +++ b/pkg/common/morpc/backend.go @@ -631,7 +631,9 @@ func (rb *remoteBackend) fetch(messages []*Future, maxFetchCount int) ([]*Future // If the connect needs to be reset, then all futures in the waiting response state will never // get the response and need to be notified of an error immediately. rb.makeAllWaitingFutureFailed() - rb.handleResetConn() + if err := rb.handleResetConn(); err != nil { + return nil, true + } case <-rb.stopWriteC: return rb.fetchN(messages, math.MaxInt), true } @@ -687,11 +689,13 @@ func (rb *remoteBackend) makeAllWaitingFutureFailed() { } } -func (rb *remoteBackend) handleResetConn() { +func (rb *remoteBackend) handleResetConn() error { if err := rb.resetConn(); err != nil { rb.logger.Error("fail to reset backend connection", zap.Error(err)) rb.inactive() + return err } + return nil } func (rb *remoteBackend) doClose() { diff --git a/pkg/frontend/computation_wrapper.go b/pkg/frontend/computation_wrapper.go index 48aecb3f49452..080daff8c1e7c 100644 --- a/pkg/frontend/computation_wrapper.go +++ b/pkg/frontend/computation_wrapper.go @@ -268,20 +268,7 @@ func (cwft *TxnComputationWrapper) Compile(any any, fill func(*batch.Batch) erro stats := statistic.StatsInfoFromContext(execCtx.reqCtx) stats.CompileStart() defer stats.CompileEnd() - cwft.compile = compile.NewCompile( - addr, - cwft.ses.GetDatabaseName(), - cwft.ses.GetSql(), - tenant, - cwft.ses.GetUserName(), - execCtx.reqCtx, - cwft.ses.GetTxnHandler().GetStorage(), - cwft.proc, - cwft.stmt, - cwft.ses.GetIsInternal(), - deepcopy.Copy(cwft.ses.getCNLabels()).(map[string]string), - getStatementStartAt(execCtx.reqCtx), - ) + cwft.compile = compile.NewCompile(addr, cwft.ses.GetDatabaseName(), cwft.ses.GetSql(), tenant, cwft.ses.GetUserName(), cwft.ses.GetTxnHandler().GetStorage(), cwft.proc, cwft.stmt, cwft.ses.GetIsInternal(), deepcopy.Copy(cwft.ses.getCNLabels()).(map[string]string), getStatementStartAt(execCtx.reqCtx)) defer func() { if err != nil { cwft.compile.Release() diff --git a/pkg/sql/colexec/types.go b/pkg/sql/colexec/types.go index b4c433646fb3a..9683df71ca07c 100644 --- a/pkg/sql/colexec/types.go +++ b/pkg/sql/colexec/types.go @@ -42,10 +42,7 @@ type ReceiveInfo struct { Uuid uuid.UUID } -// Server used to support cn2s3 directly, for more info, refer to docs about it type Server struct { - sync.Mutex - hakeeper logservice.CNHAKeeperClient uuidCsChanMap UuidProcMap //txn's local segments. diff --git a/pkg/sql/compile/compile.go b/pkg/sql/compile/compile.go index c5e2a44cbc946..8404593e8e198 100644 --- a/pkg/sql/compile/compile.go +++ b/pkg/sql/compile/compile.go @@ -104,24 +104,13 @@ var ( ) // NewCompile is used to new an object of compile -func NewCompile( - addr, db, sql, tenant, uid string, - ctx context.Context, - e engine.Engine, - proc *process.Process, - stmt tree.Statement, - isInternal bool, - cnLabel map[string]string, - startAt time.Time, -) *Compile { - c := reuse.Alloc[Compile](nil) +func NewCompile(addr, db, sql, tenant, uid string, e engine.Engine, proc *process.Process, stmt tree.Statement, isInternal bool, cnLabel map[string]string, startAt time.Time) *Compile { + c := GetCompileService().getCompile(proc) c.e = e c.db = db - c.ctx = ctx c.tenant = tenant c.uid = uid c.sql = sql - c.proc = proc c.proc.MessageBoard = c.MessageBoard c.stmt = stmt c.addr = addr @@ -145,7 +134,7 @@ func (c *Compile) Release() { if c == nil { return } - reuse.Free[Compile](c, nil) + _, _ = GetCompileService().putCompile(c) } func (c Compile) TypeName() string { @@ -582,7 +571,7 @@ func (c *Compile) prepareRetry(defChanged bool) (*Compile, error) { // improved to refresh expression in the future. var e error - runC := NewCompile(c.addr, c.db, c.sql, c.tenant, c.uid, c.proc.Ctx, c.e, c.proc, c.stmt, c.isInternal, c.cnLabel, c.startAt) + runC := NewCompile(c.addr, c.db, c.sql, c.tenant, c.uid, c.e, c.proc, c.stmt, c.isInternal, c.cnLabel, c.startAt) defer func() { if e != nil { runC.Release() diff --git a/pkg/sql/compile/compileService.go b/pkg/sql/compile/compileService.go new file mode 100644 index 0000000000000..75060f9a09b5a --- /dev/null +++ b/pkg/sql/compile/compileService.go @@ -0,0 +1,159 @@ +// Copyright 2024 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compile + +import ( + "context" + "github.com/matrixorigin/matrixone/pkg/common/reuse" + txnClient "github.com/matrixorigin/matrixone/pkg/txn/client" + "github.com/matrixorigin/matrixone/pkg/vm/process" + "sync" +) + +// todo: Move it to a CN level structure next day. +var compileService *ServiceOfCompile + +func init() { + compileService = InitCompileService() + txnClient.SetRunningPipelineManagement(compileService) +} + +func GetCompileService() *ServiceOfCompile { + return compileService +} + +// ServiceOfCompile is used to manage the lifecycle of Compile structures, +// including their creation and deletion. +// +// It also tracks the currently active complies within a single CN. +type ServiceOfCompile struct { + sync.RWMutex + + // ongoing compiles with additional information. + aliveCompiles map[*Compile]compileAdditionalInformation +} + +// compileAdditionalInformation holds additional information for one compile. +// to help control one compile. +type compileAdditionalInformation struct { + // mustReturnError holds an error that must be returned if set. + mustReturnError error + + // queryCancel is a method to cancel an ongoing query. + queryCancel context.CancelFunc + // queryDone is a waiter that checks if this query has been completed or not. + queryDone queryDoneWaiter +} + +// kill one query and block until it was completed. +func (info *compileAdditionalInformation) kill(errResult error) { + info.queryCancel() + info.queryDone.checkCompleted() + info.mustReturnError = errResult +} + +type queryDoneWaiter chan bool + +func newQueryDoneWaiter() queryDoneWaiter { + return make(chan bool, 1) +} + +func (waiter queryDoneWaiter) noticeQueryCompleted() { + waiter <- true +} + +func (waiter queryDoneWaiter) checkCompleted() { + <-waiter + waiter <- true +} + +func (waiter queryDoneWaiter) clear() { + for len(waiter) > 0 { + <-waiter + } +} + +func InitCompileService() *ServiceOfCompile { + srv := &ServiceOfCompile{ + aliveCompiles: make(map[*Compile]compileAdditionalInformation, 1024), + } + return srv +} + +func (srv *ServiceOfCompile) getCompile( + proc *process.Process) *Compile { + // make sure the process has a cancel function. + if proc.Cancel == nil { + proc.Ctx, proc.Cancel = context.WithCancel(proc.Ctx) + } + + runningCompile := reuse.Alloc[Compile](nil) + runningCompile.ctx = proc.Ctx + runningCompile.proc = proc + + if runningCompile.queryStatus == nil { + runningCompile.queryStatus = newQueryDoneWaiter() + } else { + runningCompile.queryStatus.clear() + } + + srv.Lock() + srv.aliveCompiles[runningCompile] = compileAdditionalInformation{ + mustReturnError: nil, + queryCancel: proc.Cancel, + queryDone: runningCompile.queryStatus, + } + srv.Unlock() + + return runningCompile +} + +func (srv *ServiceOfCompile) putCompile(c *Compile) (mustReturnError bool, err error) { + c.queryStatus.noticeQueryCompleted() + + srv.Lock() + + if item, ok := srv.aliveCompiles[c]; ok { + err = item.mustReturnError + } + delete(srv.aliveCompiles, c) + c.queryStatus.clear() + srv.Unlock() + + reuse.Free[Compile](c, nil) + + return err != nil, err +} + +func (srv *ServiceOfCompile) aliveCompile() int { + srv.Lock() + defer srv.Unlock() + + return len(srv.aliveCompiles) +} + +func (srv *ServiceOfCompile) PauseService() { + srv.Lock() +} + +func (srv *ServiceOfCompile) ResumeService() { + srv.Unlock() +} + +func (srv *ServiceOfCompile) KillAllQueriesWithError(err error) { + for _, v := range srv.aliveCompiles { + v.kill(err) + } +} diff --git a/pkg/sql/compile/compileService_test.go b/pkg/sql/compile/compileService_test.go new file mode 100644 index 0000000000000..06c36aa6fe58e --- /dev/null +++ b/pkg/sql/compile/compileService_test.go @@ -0,0 +1,72 @@ +// Copyright 2024 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compile + +import ( + "context" + "github.com/matrixorigin/matrixone/pkg/vm/process" + "github.com/stretchr/testify/require" + "sync" + "sync/atomic" + "testing" +) + +func generateRunningProc(n int) []*process.Process { + rs := make([]*process.Process, n) + for i := range rs { + ctx, cancel := context.WithCancel(context.TODO()) + + rs[i] = &process.Process{ + Ctx: ctx, + Cancel: cancel, + } + } + return rs +} + +func TestCompileService(t *testing.T) { + service := InitCompileService() + + doneRoutine := atomic.Int32{} + doneRoutine.Store(0) + wg := sync.WaitGroup{} + + // 1. service should count running Compile in correct. + inputs := generateRunningProc(10) + for _, p := range inputs { + wg.Add(1) + + c := service.getCompile(p) + go func(cc *Compile) { + <-cc.ctx.Done() + + doneRoutine.Add(1) + _, _ = service.putCompile(cc) + wg.Done() + }(c) + } + require.Equal(t, 10, service.aliveCompile()) + + // 2. kill all running Compile. + service.PauseService() + service.KillAllQueriesWithError(nil) + service.ResumeService() + + require.Equal(t, int32(10), doneRoutine.Load()) + + // after all, alive compile should be 0. + wg.Wait() + require.Equal(t, 0, service.aliveCompile()) +} diff --git a/pkg/sql/compile/compile_test.go b/pkg/sql/compile/compile_test.go index 91f6bfa6fcebe..cf27af26549ee 100644 --- a/pkg/sql/compile/compile_test.go +++ b/pkg/sql/compile/compile_test.go @@ -139,7 +139,7 @@ func TestCompile(t *testing.T) { tc.proc.TxnClient = txnCli tc.proc.TxnOperator = txnOp tc.proc.Ctx = ctx - c := NewCompile("test", "test", tc.sql, "", "", ctx, tc.e, tc.proc, tc.stmt, false, nil, time.Now()) + c := NewCompile("test", "test", tc.sql, "", "", tc.e, tc.proc, tc.stmt, false, nil, time.Now()) err := c.Compile(ctx, tc.pn, testPrint) require.NoError(t, err) c.getAffectedRows() @@ -168,7 +168,7 @@ func TestCompileWithFaults(t *testing.T) { tc.proc.TxnClient = txnCli tc.proc.TxnOperator = txnOp tc.proc.Ctx = ctx - c := NewCompile("test", "test", tc.sql, "", "", ctx, tc.e, tc.proc, nil, false, nil, time.Now()) + c := NewCompile("test", "test", tc.sql, "", "", tc.e, tc.proc, nil, false, nil, time.Now()) err := c.Compile(ctx, tc.pn, testPrint) require.NoError(t, err) c.getAffectedRows() diff --git a/pkg/sql/compile/remoterunServer.go b/pkg/sql/compile/remoterunServer.go index 0834c746900a4..b7af5361372c8 100644 --- a/pkg/sql/compile/remoterunServer.go +++ b/pkg/sql/compile/remoterunServer.go @@ -403,15 +403,14 @@ func (receiver *messageReceiverOnServer) newCompile() (*Compile, error) { proc.StmtProfile = process.NewStmtProfile(uuid.UUID(txnId), pHelper.StmtId) } - c := reuse.Alloc[Compile](nil) - c.proc = proc + c := GetCompileService().getCompile(proc) c.proc.MessageBoard = c.MessageBoard c.e = cnInfo.storeEngine c.anal = newAnaylze() c.anal.analInfos = proc.AnalInfos c.addr = receiver.cnInformation.cnAddr c.proc.Ctx = perfcounter.WithCounterSet(c.proc.Ctx, c.counterSet) - c.ctx = defines.AttachAccountId(c.proc.Ctx, pHelper.accountId) + c.ctx = defines.AttachAccountId(c.ctx, pHelper.accountId) // a method to send back. c.fill = func(b *batch.Batch) error { diff --git a/pkg/sql/compile/scope_test.go b/pkg/sql/compile/scope_test.go index 63a123d3466a3..6bab316ba6510 100644 --- a/pkg/sql/compile/scope_test.go +++ b/pkg/sql/compile/scope_test.go @@ -86,7 +86,7 @@ func generateScopeCases(t *testing.T, testCases []string) []*Scope { qry, err := opt.Optimize(stmts[0], false) require.NoError(t1, err) proc.Ctx = ctx - c := NewCompile("test", "test", sql, "", "", context.Background(), e, proc, nil, false, nil, time.Now()) + c := NewCompile("test", "test", sql, "", "", e, proc, nil, false, nil, time.Now()) err = c.Compile(ctx, &plan.Plan{Plan: &plan.Plan_Query{Query: qry}}, func(batch *batch.Batch) error { return nil }) diff --git a/pkg/sql/compile/sql_executor.go b/pkg/sql/compile/sql_executor.go index 287f0ef3f3fc8..3142bb34a5416 100644 --- a/pkg/sql/compile/sql_executor.go +++ b/pkg/sql/compile/sql_executor.go @@ -310,7 +310,7 @@ func (exec *txnExecutor) Exec( return executor.Result{}, err } - c := NewCompile(exec.s.addr, exec.getDatabase(), sql, "", "", exec.ctx, exec.s.eng, proc, stmts[0], false, nil, receiveAt) + c := NewCompile(exec.s.addr, exec.getDatabase(), sql, "", "", exec.s.eng, proc, stmts[0], false, nil, receiveAt) defer c.Release() c.disableRetry = exec.opts.DisableIncrStatement() c.SetBuildPlanFunc(func() (*plan.Plan, error) { diff --git a/pkg/sql/compile/types.go b/pkg/sql/compile/types.go index bb1651270daa2..2e2556e808483 100644 --- a/pkg/sql/compile/types.go +++ b/pkg/sql/compile/types.go @@ -247,6 +247,9 @@ type Compile struct { sql string originSQL string + // queryStatus is a structure to record query has done. + queryStatus queryDoneWaiter + anal *anaylze // e db engine instance. e engine.Engine diff --git a/pkg/txn/client/client.go b/pkg/txn/client/client.go index 68106f0e642ba..5b7c39ecd0ebf 100644 --- a/pkg/txn/client/client.go +++ b/pkg/txn/client/client.go @@ -571,21 +571,38 @@ func (client *txnClient) Resume() { } } +// NodeRunningPipelineManager to avoid packages import cycles. +type NodeRunningPipelineManager interface { + PauseService() + KillAllQueriesWithError(err error) + ResumeService() +} + +var runningPipelines NodeRunningPipelineManager + +func SetRunningPipelineManagement(m NodeRunningPipelineManager) { + runningPipelines = m +} + func (client *txnClient) AbortAllRunningTxn() { client.mu.Lock() + runningPipelines.PauseService() + ops := make([]*txnOperator, 0, len(client.mu.activeTxns)) for _, op := range client.mu.activeTxns { ops = append(ops, op) } waitOps := append(([]*txnOperator)(nil), client.mu.waitActiveTxns...) client.mu.waitActiveTxns = client.mu.waitActiveTxns[:0] - client.mu.Unlock() - if client.timestampWaiter != nil { // Cancel all waiters, means that all waiters do not need to wait for // the newer timestamp from logtail consumer. client.timestampWaiter.Pause() } + runningPipelines.KillAllQueriesWithError(nil) + + client.mu.Unlock() + runningPipelines.ResumeService() for _, op := range ops { op.cannotCleanWorkspace = true diff --git a/pkg/txn/client/client_test.go b/pkg/txn/client/client_test.go index a3c645e0a798d..4ab341cae2cb1 100644 --- a/pkg/txn/client/client_test.go +++ b/pkg/txn/client/client_test.go @@ -135,7 +135,15 @@ func TestNewTxnWithSnapshotTS(t *testing.T) { assert.Equal(t, txn.TxnStatus_Active, txnMeta.Status) } +type fakeRunningPipelinesManager struct{} + +func (m *fakeRunningPipelinesManager) PauseService() {} +func (m *fakeRunningPipelinesManager) KillAllQueriesWithError(_ error) {} +func (m *fakeRunningPipelinesManager) ResumeService() {} + func TestTxnClientAbortAllRunningTxn(t *testing.T) { + SetRunningPipelineManagement(&fakeRunningPipelinesManager{}) + rt := runtime.NewRuntime(metadata.ServiceType_CN, "", logutil.GetPanicLogger(), runtime.WithClock(clock.NewHLCClock(func() int64 { diff --git a/pkg/txn/service/service.go b/pkg/txn/service/service.go index 27646785ada10..5c22565e31495 100644 --- a/pkg/txn/service/service.go +++ b/pkg/txn/service/service.go @@ -104,7 +104,9 @@ func (s *service) Start() error { if err := s.storage.Start(); err != nil { return err } + s.logger.Info("start txn recovery") s.startRecovery() + s.logger.Info("end txn recovery") return nil }