diff --git a/contrib/drivers/mysql/mysql_z_unit_issue_test.go b/contrib/drivers/mysql/mysql_z_unit_issue_test.go index b3329a8fd91..58ceeb0a5f3 100644 --- a/contrib/drivers/mysql/mysql_z_unit_issue_test.go +++ b/contrib/drivers/mysql/mysql_z_unit_issue_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/gogf/gf/v2/container/gvar" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/gtime" @@ -1283,12 +1284,12 @@ func Test_Issue3754(t *testing.T) { func Test_Issue3626(t *testing.T) { table := "issue3626" array := gstr.SplitAndTrim(gtest.DataContent(`issue3626.sql`), ";") + defer dropTable(table) for _, v := range array { if _, err := db.Exec(ctx, v); err != nil { gtest.Error(err) } } - defer dropTable(table) // Insert. gtest.C(t, func(t *gtest.T) { @@ -1377,3 +1378,34 @@ func Test_Issue3932(t *testing.T) { t.Assert(one["id"], 10) }) } + +// https://github.com/gogf/gf/issues/3968 +func Test_Issue3968(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + var hook = gdb.HookHandler{ + Select: func(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result, err error) { + result, err = in.Next(ctx) + if err != nil { + return nil, err + } + if result != nil { + for i, _ := range result { + result[i]["location"] = gvar.New("ny") + } + } + return + }, + } + var ( + count int + result gdb.Result + ) + err := db.Model(table).Hook(hook).ScanAndCount(&result, &count, false) + t.AssertNil(err) + t.Assert(count, 10) + t.Assert(len(result), 10) + }) +} diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index 80fc265648e..74d76e5c8a1 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -396,12 +396,13 @@ const ( linkPattern = `(\w+):([\w\-\$]*):(.*?)@(\w+?)\((.+?)\)/{0,1}([^\?]*)\?{0,1}(.*)` ) -type queryType int +type SelectType int const ( - queryTypeNormal queryType = iota - queryTypeCount - queryTypeValue + SelectTypeDefault SelectType = iota + SelectTypeCount + SelectTypeValue + SelectTypeArray ) type joinOperator string @@ -700,13 +701,13 @@ func getConfigNodeByWeight(cg ConfigGroup) *ConfigNode { } // Exclude the right border value. var ( - min = 0 - max = 0 - random = grand.N(0, total-1) + minWeight = 0 + maxWeight = 0 + random = grand.N(0, total-1) ) for i := 0; i < len(cg); i++ { - max = min + cg[i].Weight*100 - if random >= min && random < max { + maxWeight = minWeight + cg[i].Weight*100 + if random >= minWeight && random < maxWeight { // ==================================================== // Return a COPY of the ConfigNode. // ==================================================== @@ -714,7 +715,7 @@ func getConfigNodeByWeight(cg ConfigGroup) *ConfigNode { node = cg[i] return &node } - min = max + minWeight = maxWeight } return nil } diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index 3bce0512f1b..ec14ffd1be2 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -278,7 +278,7 @@ func (c *Core) doUnion(ctx context.Context, unionType int, unions ...*Model) *Mo unionTypeStr = "UNION" } for _, v := range unions { - sqlWithHolder, holderArgs := v.getFormattedSqlAndArgs(ctx, queryTypeNormal, false) + sqlWithHolder, holderArgs := v.getFormattedSqlAndArgs(ctx, SelectTypeDefault, false) if composedSqlStr == "" { composedSqlStr += fmt.Sprintf(`(%s)`, sqlWithHolder) } else { diff --git a/database/gdb/gdb_core_ctx.go b/database/gdb/gdb_core_ctx.go index d77e33a02b1..ea8ca4b665a 100644 --- a/database/gdb/gdb_core_ctx.go +++ b/database/gdb/gdb_core_ctx.go @@ -23,8 +23,6 @@ type internalCtxData struct { } // column stores column data in ctx for internal usage purpose. -// Deprecated. -// TODO remove this usage in future. type internalColumnData struct { // The first column in result response from database server. // This attribute is used for Value/Count selection statement purpose, diff --git a/database/gdb/gdb_model_cache.go b/database/gdb/gdb_model_cache.go index c1b6ca82041..76592cdec29 100644 --- a/database/gdb/gdb_model_cache.go +++ b/database/gdb/gdb_model_cache.go @@ -90,7 +90,7 @@ func (m *Model) getSelectResultFromCache(ctx context.Context, sql string, args . } func (m *Model) saveSelectResultToCache( - ctx context.Context, queryType queryType, result Result, sql string, args ...interface{}, + ctx context.Context, selectType SelectType, result Result, sql string, args ...interface{}, ) (err error) { if !m.cacheEnabled || m.tx != nil { return @@ -108,18 +108,19 @@ func (m *Model) saveSelectResultToCache( // Special handler for Value/Count operations result. if len(result) > 0 { var core = m.db.GetCore() - switch queryType { - case queryTypeValue, queryTypeCount: + switch selectType { + case SelectTypeValue, SelectTypeArray, SelectTypeCount: if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil { if result[0][internalData.FirstResultColumn].IsEmpty() { result = nil } } + default: } } // In case of Cache Penetration. - if result.IsEmpty() { + if result != nil && result.IsEmpty() { if m.cacheOption.Force { result = Result{} } else { diff --git a/database/gdb/gdb_model_hook.go b/database/gdb/gdb_model_hook.go index 7dd00740578..3428990a73c 100644 --- a/database/gdb/gdb_model_hook.go +++ b/database/gdb/gdb_model_hook.go @@ -66,11 +66,12 @@ type internalParamHookDelete struct { // which is usually not be interesting for upper business hook handler. type HookSelectInput struct { internalParamHookSelect - Model *Model // Current operation Model. - Table string // The table name that to be used. Update this attribute to change target table name. - Schema string // The schema name that to be used. Update this attribute to change target schema name. - Sql string // The sql string that to be committed. - Args []interface{} // The arguments of sql. + Model *Model // Current operation Model. + Table string // The table name that to be used. Update this attribute to change target table name. + Schema string // The schema name that to be used. Update this attribute to change target schema name. + Sql string // The sql string that to be committed. + Args []interface{} // The arguments of sql. + SelectType SelectType // The type of this SELECT operation. } // HookInsertInput holds the parameters for insert hook operation. diff --git a/database/gdb/gdb_model_select.go b/database/gdb/gdb_model_select.go index a369f71ade2..8c6433abf02 100644 --- a/database/gdb/gdb_model_select.go +++ b/database/gdb/gdb_model_select.go @@ -28,7 +28,7 @@ import ( // see Model.Where. func (m *Model) All(where ...interface{}) (Result, error) { var ctx = m.GetCtx() - return m.doGetAll(ctx, false, where...) + return m.doGetAll(ctx, SelectTypeDefault, false, where...) } // AllAndCount retrieves all records and the total count of records from the model. @@ -69,7 +69,7 @@ func (m *Model) AllAndCount(useFieldForCount bool) (result Result, totalCount in } // Retrieve all records - result, err = m.doGetAll(m.GetCtx(), false) + result, err = m.doGetAll(m.GetCtx(), SelectTypeDefault, false) return } @@ -110,7 +110,7 @@ func (m *Model) One(where ...interface{}) (Record, error) { if len(where) > 0 { return m.Where(where[0], where[1:]...).One() } - all, err := m.doGetAll(ctx, true) + all, err := m.doGetAll(ctx, SelectTypeDefault, true) if err != nil { return nil, err } @@ -136,24 +136,41 @@ func (m *Model) Array(fieldsAndWhere ...interface{}) ([]Value, error) { return m.Fields(gconv.String(fieldsAndWhere[0])).Array() } } - all, err := m.All() + + var ( + field string + core = m.db.GetCore() + ctx = core.injectInternalColumn(m.GetCtx()) + ) + all, err := m.doGetAll(ctx, SelectTypeArray, false) if err != nil { return nil, err } - var field string if len(all) > 0 { - var recordFields = m.getRecordFields(all[0]) - if len(recordFields) > 1 { - // it returns error if there are multiple fields in the result record. - return nil, gerror.NewCodef( - gcode.CodeInvalidParameter, - `invalid fields for "Array" operation, result fields number "%d"%s, but expect one`, - len(recordFields), - gjson.MustEncodeString(recordFields), + internalData := core.getInternalColumnFromCtx(ctx) + if internalData == nil { + return nil, gerror.NewCode( + gcode.CodeInternalError, + `query count error: the internal context data is missing. there's internal issue should be fixed`, ) } - if len(recordFields) == 1 { - field = recordFields[0] + // If FirstResultColumn present, it returns the value of the first record of the first field. + // It means it use no cache mechanism, while cache mechanism makes `internalData` missing. + field = internalData.FirstResultColumn + if field == "" { + // Fields number check. + var recordFields = m.getRecordFields(all[0]) + if len(recordFields) == 1 { + field = recordFields[0] + } else { + // it returns error if there are multiple fields in the result record. + return nil, gerror.NewCodef( + gcode.CodeInvalidParameter, + `invalid fields for "Array" operation, result fields number "%d"%s, but expect one`, + len(recordFields), + gjson.MustEncodeString(recordFields), + ) + } } } return all.Array(field), nil @@ -398,13 +415,26 @@ func (m *Model) Value(fieldsAndWhere ...interface{}) (Value, error) { } } var ( - sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, queryTypeValue, true) - all, err = m.doGetAllBySql(ctx, queryTypeValue, sqlWithHolder, holderArgs...) + sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, SelectTypeValue, true) + all, err = m.doGetAllBySql(ctx, SelectTypeValue, sqlWithHolder, holderArgs...) ) if err != nil { return nil, err } if len(all) > 0 { + internalData := core.getInternalColumnFromCtx(ctx) + if internalData == nil { + return nil, gerror.NewCode( + gcode.CodeInternalError, + `query count error: the internal context data is missing. there's internal issue should be fixed`, + ) + } + // If FirstResultColumn present, it returns the value of the first record of the first field. + // It means it use no cache mechanism, while cache mechanism makes `internalData` missing. + if v, ok := all[0][internalData.FirstResultColumn]; ok { + return v, nil + } + // Fields number check. var recordFields = m.getRecordFields(all[0]) if len(recordFields) == 1 { for _, v := range all[0] { @@ -445,13 +475,26 @@ func (m *Model) Count(where ...interface{}) (int, error) { return m.Where(where[0], where[1:]...).Count() } var ( - sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, queryTypeCount, false) - all, err = m.doGetAllBySql(ctx, queryTypeCount, sqlWithHolder, holderArgs...) + sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, SelectTypeCount, false) + all, err = m.doGetAllBySql(ctx, SelectTypeCount, sqlWithHolder, holderArgs...) ) if err != nil { return 0, err } if len(all) > 0 { + internalData := core.getInternalColumnFromCtx(ctx) + if internalData == nil { + return 0, gerror.NewCode( + gcode.CodeInternalError, + `query count error: the internal context data is missing. there's internal issue should be fixed`, + ) + } + // If FirstResultColumn present, it returns the value of the first record of the first field. + // It means it use no cache mechanism, while cache mechanism makes `internalData` missing. + if v, ok := all[0][internalData.FirstResultColumn]; ok { + return v.Int(), nil + } + // Fields number check. var recordFields = m.getRecordFields(all[0]) if len(recordFields) == 1 { for _, v := range all[0] { @@ -616,17 +659,17 @@ func (m *Model) Having(having interface{}, args ...interface{}) *Model { // The parameter `limit1` specifies whether limits querying only one record if m.limit is not set. // The optional parameter `where` is the same as the parameter of Model.Where function, // see Model.Where. -func (m *Model) doGetAll(ctx context.Context, limit1 bool, where ...interface{}) (Result, error) { +func (m *Model) doGetAll(ctx context.Context, selectType SelectType, limit1 bool, where ...interface{}) (Result, error) { if len(where) > 0 { return m.Where(where[0], where[1:]...).All() } - sqlWithHolder, holderArgs := m.getFormattedSqlAndArgs(ctx, queryTypeNormal, limit1) - return m.doGetAllBySql(ctx, queryTypeNormal, sqlWithHolder, holderArgs...) + sqlWithHolder, holderArgs := m.getFormattedSqlAndArgs(ctx, selectType, limit1) + return m.doGetAllBySql(ctx, selectType, sqlWithHolder, holderArgs...) } // doGetAllBySql does the select statement on the database. func (m *Model) doGetAllBySql( - ctx context.Context, queryType queryType, sql string, args ...interface{}, + ctx context.Context, selectType SelectType, sql string, args ...interface{}, ) (result Result, err error) { if result, err = m.getSelectResultFromCache(ctx, sql, args...); err != nil || result != nil { return @@ -639,24 +682,25 @@ func (m *Model) doGetAllBySql( }, handler: m.hookHandler.Select, }, - Model: m, - Table: m.tables, - Sql: sql, - Args: m.mergeArguments(args), + Model: m, + Table: m.tables, + Sql: sql, + Args: m.mergeArguments(args), + SelectType: selectType, } if result, err = in.Next(ctx); err != nil { return } - err = m.saveSelectResultToCache(ctx, queryType, result, sql, args...) + err = m.saveSelectResultToCache(ctx, selectType, result, sql, args...) return } func (m *Model) getFormattedSqlAndArgs( - ctx context.Context, queryType queryType, limit1 bool, + ctx context.Context, selectType SelectType, limit1 bool, ) (sqlWithHolder string, holderArgs []interface{}) { - switch queryType { - case queryTypeCount: + switch selectType { + case SelectTypeCount: queryFields := "COUNT(1)" if len(m.fields) > 0 { // DO NOT quote the m.fields here, in case of fields like: @@ -698,7 +742,7 @@ func (m *Model) getFormattedSqlAndArgs( func (m *Model) getHolderAndArgsAsSubModel(ctx context.Context) (holder string, args []interface{}) { holder, args = m.getFormattedSqlAndArgs( - ctx, queryTypeNormal, false, + ctx, SelectTypeDefault, false, ) args = m.mergeArguments(args) return