Skip to content

Commit

Permalink
fix(database/gdb): issue where the Count/Value/Array query logic wa…
Browse files Browse the repository at this point in the history
…s incompatible with the old version when users extended the returned result fields using the `Select` Hook (#3995)
  • Loading branch information
gqcn authored Dec 1, 2024
1 parent 42eae41 commit 2c916f8
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 55 deletions.
34 changes: 33 additions & 1 deletion contrib/drivers/mysql/mysql_z_unit_issue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"testing"
"time"

"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
Expand Down Expand Up @@ -1283,12 +1284,12 @@ func Test_Issue3754(t *testing.T) {
func Test_Issue3626(t *testing.T) {
table := "issue3626"
array := gstr.SplitAndTrim(gtest.DataContent(`issue3626.sql`), ";")
defer dropTable(table)
for _, v := range array {
if _, err := db.Exec(ctx, v); err != nil {
gtest.Error(err)
}
}
defer dropTable(table)

// Insert.
gtest.C(t, func(t *gtest.T) {
Expand Down Expand Up @@ -1377,3 +1378,34 @@ func Test_Issue3932(t *testing.T) {
t.Assert(one["id"], 10)
})
}

// https://github.com/gogf/gf/issues/3968
func Test_Issue3968(t *testing.T) {
table := createInitTable()
defer dropTable(table)

gtest.C(t, func(t *gtest.T) {
var hook = gdb.HookHandler{
Select: func(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result, err error) {
result, err = in.Next(ctx)
if err != nil {
return nil, err
}
if result != nil {
for i, _ := range result {
result[i]["location"] = gvar.New("ny")
}
}
return
},
}
var (
count int
result gdb.Result
)
err := db.Model(table).Hook(hook).ScanAndCount(&result, &count, false)
t.AssertNil(err)
t.Assert(count, 10)
t.Assert(len(result), 10)
})
}
21 changes: 11 additions & 10 deletions database/gdb/gdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -396,12 +396,13 @@ const (
linkPattern = `(\w+):([\w\-\$]*):(.*?)@(\w+?)\((.+?)\)/{0,1}([^\?]*)\?{0,1}(.*)`
)

type queryType int
type SelectType int

const (
queryTypeNormal queryType = iota
queryTypeCount
queryTypeValue
SelectTypeDefault SelectType = iota
SelectTypeCount
SelectTypeValue
SelectTypeArray
)

type joinOperator string
Expand Down Expand Up @@ -700,21 +701,21 @@ func getConfigNodeByWeight(cg ConfigGroup) *ConfigNode {
}
// Exclude the right border value.
var (
min = 0
max = 0
random = grand.N(0, total-1)
minWeight = 0
maxWeight = 0
random = grand.N(0, total-1)
)
for i := 0; i < len(cg); i++ {
max = min + cg[i].Weight*100
if random >= min && random < max {
maxWeight = minWeight + cg[i].Weight*100
if random >= minWeight && random < maxWeight {
// ====================================================
// Return a COPY of the ConfigNode.
// ====================================================
node := ConfigNode{}
node = cg[i]
return &node
}
min = max
minWeight = maxWeight
}
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion database/gdb/gdb_core.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ func (c *Core) doUnion(ctx context.Context, unionType int, unions ...*Model) *Mo
unionTypeStr = "UNION"
}
for _, v := range unions {
sqlWithHolder, holderArgs := v.getFormattedSqlAndArgs(ctx, queryTypeNormal, false)
sqlWithHolder, holderArgs := v.getFormattedSqlAndArgs(ctx, SelectTypeDefault, false)
if composedSqlStr == "" {
composedSqlStr += fmt.Sprintf(`(%s)`, sqlWithHolder)
} else {
Expand Down
2 changes: 0 additions & 2 deletions database/gdb/gdb_core_ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ type internalCtxData struct {
}

// column stores column data in ctx for internal usage purpose.
// Deprecated.
// TODO remove this usage in future.
type internalColumnData struct {
// The first column in result response from database server.
// This attribute is used for Value/Count selection statement purpose,
Expand Down
9 changes: 5 additions & 4 deletions database/gdb/gdb_model_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (m *Model) getSelectResultFromCache(ctx context.Context, sql string, args .
}

func (m *Model) saveSelectResultToCache(
ctx context.Context, queryType queryType, result Result, sql string, args ...interface{},
ctx context.Context, selectType SelectType, result Result, sql string, args ...interface{},
) (err error) {
if !m.cacheEnabled || m.tx != nil {
return
Expand All @@ -108,18 +108,19 @@ func (m *Model) saveSelectResultToCache(
// Special handler for Value/Count operations result.
if len(result) > 0 {
var core = m.db.GetCore()
switch queryType {
case queryTypeValue, queryTypeCount:
switch selectType {
case SelectTypeValue, SelectTypeArray, SelectTypeCount:
if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil {
if result[0][internalData.FirstResultColumn].IsEmpty() {
result = nil
}
}
default:
}
}

// In case of Cache Penetration.
if result.IsEmpty() {
if result != nil && result.IsEmpty() {
if m.cacheOption.Force {
result = Result{}
} else {
Expand Down
11 changes: 6 additions & 5 deletions database/gdb/gdb_model_hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,12 @@ type internalParamHookDelete struct {
// which is usually not be interesting for upper business hook handler.
type HookSelectInput struct {
internalParamHookSelect
Model *Model // Current operation Model.
Table string // The table name that to be used. Update this attribute to change target table name.
Schema string // The schema name that to be used. Update this attribute to change target schema name.
Sql string // The sql string that to be committed.
Args []interface{} // The arguments of sql.
Model *Model // Current operation Model.
Table string // The table name that to be used. Update this attribute to change target table name.
Schema string // The schema name that to be used. Update this attribute to change target schema name.
Sql string // The sql string that to be committed.
Args []interface{} // The arguments of sql.
SelectType SelectType // The type of this SELECT operation.
}

// HookInsertInput holds the parameters for insert hook operation.
Expand Down
108 changes: 76 additions & 32 deletions database/gdb/gdb_model_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
// see Model.Where.
func (m *Model) All(where ...interface{}) (Result, error) {
var ctx = m.GetCtx()
return m.doGetAll(ctx, false, where...)
return m.doGetAll(ctx, SelectTypeDefault, false, where...)
}

// AllAndCount retrieves all records and the total count of records from the model.
Expand Down Expand Up @@ -69,7 +69,7 @@ func (m *Model) AllAndCount(useFieldForCount bool) (result Result, totalCount in
}

// Retrieve all records
result, err = m.doGetAll(m.GetCtx(), false)
result, err = m.doGetAll(m.GetCtx(), SelectTypeDefault, false)
return
}

Expand Down Expand Up @@ -110,7 +110,7 @@ func (m *Model) One(where ...interface{}) (Record, error) {
if len(where) > 0 {
return m.Where(where[0], where[1:]...).One()
}
all, err := m.doGetAll(ctx, true)
all, err := m.doGetAll(ctx, SelectTypeDefault, true)
if err != nil {
return nil, err
}
Expand All @@ -136,24 +136,41 @@ func (m *Model) Array(fieldsAndWhere ...interface{}) ([]Value, error) {
return m.Fields(gconv.String(fieldsAndWhere[0])).Array()
}
}
all, err := m.All()

var (
field string
core = m.db.GetCore()
ctx = core.injectInternalColumn(m.GetCtx())
)
all, err := m.doGetAll(ctx, SelectTypeArray, false)
if err != nil {
return nil, err
}
var field string
if len(all) > 0 {
var recordFields = m.getRecordFields(all[0])
if len(recordFields) > 1 {
// it returns error if there are multiple fields in the result record.
return nil, gerror.NewCodef(
gcode.CodeInvalidParameter,
`invalid fields for "Array" operation, result fields number "%d"%s, but expect one`,
len(recordFields),
gjson.MustEncodeString(recordFields),
internalData := core.getInternalColumnFromCtx(ctx)
if internalData == nil {
return nil, gerror.NewCode(
gcode.CodeInternalError,
`query count error: the internal context data is missing. there's internal issue should be fixed`,
)
}
if len(recordFields) == 1 {
field = recordFields[0]
// If FirstResultColumn present, it returns the value of the first record of the first field.
// It means it use no cache mechanism, while cache mechanism makes `internalData` missing.
field = internalData.FirstResultColumn
if field == "" {
// Fields number check.
var recordFields = m.getRecordFields(all[0])
if len(recordFields) == 1 {
field = recordFields[0]
} else {
// it returns error if there are multiple fields in the result record.
return nil, gerror.NewCodef(
gcode.CodeInvalidParameter,
`invalid fields for "Array" operation, result fields number "%d"%s, but expect one`,
len(recordFields),
gjson.MustEncodeString(recordFields),
)
}
}
}
return all.Array(field), nil
Expand Down Expand Up @@ -398,13 +415,26 @@ func (m *Model) Value(fieldsAndWhere ...interface{}) (Value, error) {
}
}
var (
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, queryTypeValue, true)
all, err = m.doGetAllBySql(ctx, queryTypeValue, sqlWithHolder, holderArgs...)
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, SelectTypeValue, true)
all, err = m.doGetAllBySql(ctx, SelectTypeValue, sqlWithHolder, holderArgs...)
)
if err != nil {
return nil, err
}
if len(all) > 0 {
internalData := core.getInternalColumnFromCtx(ctx)
if internalData == nil {
return nil, gerror.NewCode(
gcode.CodeInternalError,
`query count error: the internal context data is missing. there's internal issue should be fixed`,
)
}
// If FirstResultColumn present, it returns the value of the first record of the first field.
// It means it use no cache mechanism, while cache mechanism makes `internalData` missing.
if v, ok := all[0][internalData.FirstResultColumn]; ok {
return v, nil
}
// Fields number check.
var recordFields = m.getRecordFields(all[0])
if len(recordFields) == 1 {
for _, v := range all[0] {
Expand Down Expand Up @@ -445,13 +475,26 @@ func (m *Model) Count(where ...interface{}) (int, error) {
return m.Where(where[0], where[1:]...).Count()
}
var (
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, queryTypeCount, false)
all, err = m.doGetAllBySql(ctx, queryTypeCount, sqlWithHolder, holderArgs...)
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, SelectTypeCount, false)
all, err = m.doGetAllBySql(ctx, SelectTypeCount, sqlWithHolder, holderArgs...)
)
if err != nil {
return 0, err
}
if len(all) > 0 {
internalData := core.getInternalColumnFromCtx(ctx)
if internalData == nil {
return 0, gerror.NewCode(
gcode.CodeInternalError,
`query count error: the internal context data is missing. there's internal issue should be fixed`,
)
}
// If FirstResultColumn present, it returns the value of the first record of the first field.
// It means it use no cache mechanism, while cache mechanism makes `internalData` missing.
if v, ok := all[0][internalData.FirstResultColumn]; ok {
return v.Int(), nil
}
// Fields number check.
var recordFields = m.getRecordFields(all[0])
if len(recordFields) == 1 {
for _, v := range all[0] {
Expand Down Expand Up @@ -616,17 +659,17 @@ func (m *Model) Having(having interface{}, args ...interface{}) *Model {
// The parameter `limit1` specifies whether limits querying only one record if m.limit is not set.
// The optional parameter `where` is the same as the parameter of Model.Where function,
// see Model.Where.
func (m *Model) doGetAll(ctx context.Context, limit1 bool, where ...interface{}) (Result, error) {
func (m *Model) doGetAll(ctx context.Context, selectType SelectType, limit1 bool, where ...interface{}) (Result, error) {
if len(where) > 0 {
return m.Where(where[0], where[1:]...).All()
}
sqlWithHolder, holderArgs := m.getFormattedSqlAndArgs(ctx, queryTypeNormal, limit1)
return m.doGetAllBySql(ctx, queryTypeNormal, sqlWithHolder, holderArgs...)
sqlWithHolder, holderArgs := m.getFormattedSqlAndArgs(ctx, selectType, limit1)
return m.doGetAllBySql(ctx, selectType, sqlWithHolder, holderArgs...)
}

// doGetAllBySql does the select statement on the database.
func (m *Model) doGetAllBySql(
ctx context.Context, queryType queryType, sql string, args ...interface{},
ctx context.Context, selectType SelectType, sql string, args ...interface{},
) (result Result, err error) {
if result, err = m.getSelectResultFromCache(ctx, sql, args...); err != nil || result != nil {
return
Expand All @@ -639,24 +682,25 @@ func (m *Model) doGetAllBySql(
},
handler: m.hookHandler.Select,
},
Model: m,
Table: m.tables,
Sql: sql,
Args: m.mergeArguments(args),
Model: m,
Table: m.tables,
Sql: sql,
Args: m.mergeArguments(args),
SelectType: selectType,
}
if result, err = in.Next(ctx); err != nil {
return
}

err = m.saveSelectResultToCache(ctx, queryType, result, sql, args...)
err = m.saveSelectResultToCache(ctx, selectType, result, sql, args...)
return
}

func (m *Model) getFormattedSqlAndArgs(
ctx context.Context, queryType queryType, limit1 bool,
ctx context.Context, selectType SelectType, limit1 bool,
) (sqlWithHolder string, holderArgs []interface{}) {
switch queryType {
case queryTypeCount:
switch selectType {
case SelectTypeCount:
queryFields := "COUNT(1)"
if len(m.fields) > 0 {
// DO NOT quote the m.fields here, in case of fields like:
Expand Down Expand Up @@ -698,7 +742,7 @@ func (m *Model) getFormattedSqlAndArgs(

func (m *Model) getHolderAndArgsAsSubModel(ctx context.Context) (holder string, args []interface{}) {
holder, args = m.getFormattedSqlAndArgs(
ctx, queryTypeNormal, false,
ctx, SelectTypeDefault, false,
)
args = m.mergeArguments(args)
return
Expand Down

0 comments on commit 2c916f8

Please sign in to comment.