Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add upsert test cases #614

Merged
merged 1 commit into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions test/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ func GenLongString(n int) string {
return builder.String()
}

// ColumnIndexFunc generate column index
func ColumnIndexFunc(data []entity.Column, fieldName string) int {
for index, column := range data {
if column.Name() == fieldName {
return index
}
}
return -1
}

// --- common utils ---

// --- gen fields ---
Expand Down
3 changes: 1 addition & 2 deletions test/testcases/insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ func TestInsertAutoIdPkData(t *testing.T) {
// insert
pkColumn, floatColumn, vecColumn := common.GenDefaultColumnData(0, common.DefaultNb, common.DefaultDim)
_, errInsert := mc.Insert(ctx, collName, "", pkColumn, floatColumn, vecColumn)
//TODO change to check error code
common.CheckErr(t, errInsert, false, "invalid parameter") //, "can not assign primary field data when auto id enabled")
common.CheckErr(t, errInsert, false, "the length of passed fields is equal to needed: expected=2, actual=3: invalid parameter")

// flush and check row count
errFlush := mc.Flush(ctx, collName, false)
Expand Down
114 changes: 74 additions & 40 deletions test/testcases/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,27 +242,14 @@ func createVarcharCollectionWithDataIndex(ctx context.Context, t *testing.T, mc
return collName, ids
}

type CollectionFieldsType string

const (
Int64FloatVec CollectionFieldsType = "PkInt64FloatVec" // int64 + float + floatVec
Int64BinaryVec CollectionFieldsType = "Int64BinaryVec" // int64 + float + binaryVec
VarcharBinaryVec CollectionFieldsType = "PkVarcharBinaryVec" // varchar + binaryVec
Int64FloatVecJSON CollectionFieldsType = "PkInt64FloatVecJson" // int64 + float + floatVec + json
AllFields CollectionFieldsType = "AllFields" // all scalar fields + floatVec
CustomerFields CollectionFieldsType = "CustomerFields" // customer fields
)

type CollectionParams struct {
CollectionFieldsType CollectionFieldsType // collection fields type
AutoID bool // autoId
EnableDynamicField bool // enable dynamic field
ShardsNum int32
Fields []*entity.Field
Dim int64
MaxLength int64
}

func createCollection(ctx context.Context, t *testing.T, mc *base.MilvusClient, cp CollectionParams, opts ...client.CreateCollectionOption) string {
collName := common.GenRandomString(4)
var fields []*entity.Field
Expand All @@ -282,8 +269,6 @@ func createCollection(ctx context.Context, t *testing.T, mc *base.MilvusClient,
fields = append(fields, jsonField)
case AllFields:
fields = common.GenAllFields()
case CustomerFields:
fields = cp.Fields
}

// schema
Expand All @@ -300,19 +285,7 @@ func createCollection(ctx context.Context, t *testing.T, mc *base.MilvusClient,
return collName
}

type DataParams struct {
CollectionName string // insert data into which collection
PartitionName string
CollectionFieldsType CollectionFieldsType // collection fields type
start int // start
nb int // insert how many data
dim int64
EnableDynamicField bool // whether insert dynamic field data
WithRows bool
Data []entity.Column
Rows []interface{}
}

// insert nb data
func insertData(ctx context.Context, t *testing.T, mc *base.MilvusClient, dp DataParams) (entity.Column, error) {
// todo autoid
// prepare data
Expand Down Expand Up @@ -360,12 +333,6 @@ func insertData(ctx context.Context, t *testing.T, mc *base.MilvusClient, dp Dat
rows = common.GenAllFieldsRows(dp.start, dp.nb, dp.dim, dp.EnableDynamicField)
}
data = common.GenAllFieldsData(dp.start, dp.nb, dp.dim)
case CustomerFields:
if dp.WithRows {
rows = dp.Rows
} else {
data = dp.Data
}
}

if dp.EnableDynamicField && !dp.WithRows {
Expand Down Expand Up @@ -446,12 +413,6 @@ func createCollectionAllFields(ctx context.Context, t *testing.T, mc *base.Milvu
return collName, ids
}

type HelpPartitionColumns struct {
PartitionName string
IdsColumn entity.Column
VectorColumn entity.Column
}

func createInsertTwoPartitions(ctx context.Context, t *testing.T, mc *base.MilvusClient, collName string, nb int) (partitionName string, defaultPartition HelpPartitionColumns, newPartition HelpPartitionColumns) {
// create new partition
partitionName = "new"
Expand Down Expand Up @@ -486,6 +447,79 @@ func createInsertTwoPartitions(ctx context.Context, t *testing.T, mc *base.Milvu
return partitionName, defaultPartition, newPartition
}

// prepare collection, maybe data index and load
func prepareCollection(ctx context.Context, t *testing.T, mc *base.MilvusClient, collParam CollectionParams, opts ...PrepareCollectionOption) string {
// default insert nb entities with 0 start
defaultDp := DataParams{DoInsert: true, CollectionName: "", PartitionName: "", CollectionFieldsType: collParam.CollectionFieldsType,
start: 0, nb: common.DefaultNb, dim: collParam.Dim, EnableDynamicField: collParam.EnableDynamicField, WithRows: false}

// default do flush
defaultFp := FlushParams{DoFlush: true, PartitionNames: []string{}, async: false}

// default build index
idx, err := entity.NewIndexHNSW(entity.L2, 8, 96)
common.CheckErr(t, err, true)
defaultIndexParams := IndexParams{BuildIndex: true, Index: idx, FieldName: common.DefaultFloatVecFieldName, async: false}

// default load collection
defaultLp := LoadParams{DoLoad: true, async: false}
opt := &ClientParamsOption{
DataParams: defaultDp,
FlushParams: defaultFp,
IndexParams: defaultIndexParams,
LoadParams: defaultLp,
}
for _, o := range opts {
o(opt)
}
// create collection
collName := createCollection(ctx, t, mc, collParam, opt.CreateOpts)

// insert
if opt.DataParams.DoInsert {
if opt.DataParams.EnableDynamicField != collParam.EnableDynamicField {
t.Fatalf("The EnableDynamicField of CollectionParams and DataParams should be equal.")
}
opt.DataParams.CollectionName = collName
opt.DataParams.CollectionFieldsType = collParam.CollectionFieldsType
insertData(ctx, t, mc, opt.DataParams)
}

// flush
if opt.FlushParams.DoFlush {
err := mc.Flush(ctx, collName, opt.FlushParams.async)
common.CheckErr(t, err, true)
}

// index
if opt.IndexParams.BuildIndex {
var err error
if opt.IndexOpts == nil {
err = mc.CreateIndex(ctx, collName, opt.IndexParams.FieldName, opt.IndexParams.Index, opt.IndexParams.async)
} else {
err = mc.CreateIndex(ctx, collName, opt.IndexParams.FieldName, opt.IndexParams.Index, opt.IndexParams.async, opt.IndexOpts)
}
common.CheckErr(t, err, true)
}

// load
if opt.LoadParams.DoLoad {
var err error
if len(opt.LoadParams.PartitionNames) > 0 {
err = mc.LoadPartitions(ctx, collName, opt.LoadParams.PartitionNames, opt.LoadParams.async)
common.CheckErr(t, err, true)
} else {
if opt.LoadOpts != nil {
err = mc.LoadCollection(ctx, collName, opt.LoadParams.async, opt.LoadOpts)
} else {
err = mc.LoadCollection(ctx, collName, opt.LoadParams.async)
}
common.CheckErr(t, err, true)
}
}
return collName
}

func TestMain(m *testing.M) {
flag.Parse()
log.Printf("parse addr=%s", *addr)
Expand Down
116 changes: 116 additions & 0 deletions test/testcases/option.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package testcases

import (
"github.com/milvus-io/milvus-sdk-go/v2/client"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
)

type HelpPartitionColumns struct {
PartitionName string
IdsColumn entity.Column
VectorColumn entity.Column
}

type CollectionFieldsType string

type CollectionParams struct {
CollectionFieldsType CollectionFieldsType // collection fields type
AutoID bool // autoId
EnableDynamicField bool // enable dynamic field
ShardsNum int32
Dim int64
MaxLength int64
}

type DataParams struct {
CollectionName string // insert data into which collection
PartitionName string
CollectionFieldsType CollectionFieldsType // collection fields type
start int // start
nb int // insert how many data
dim int64
EnableDynamicField bool // whether insert dynamic field data
WithRows bool
DoInsert bool
}

func (d DataParams) IsEmpty() bool {
return d.CollectionName == "" || d.nb == 0
}

type FlushParams struct {
DoFlush bool
PartitionNames []string
async bool
}

type IndexParams struct {
BuildIndex bool
Index entity.Index
FieldName string
async bool
}

func (i IndexParams) IsEmpty() bool {
return i.Index == nil || i.FieldName == ""
}

type LoadParams struct {
DoLoad bool
PartitionNames []string
async bool
}

type ClientParamsOption struct {
DataParams DataParams
FlushParams FlushParams
IndexParams IndexParams
LoadParams LoadParams
CreateOpts client.CreateCollectionOption
IndexOpts client.IndexOption
LoadOpts client.LoadCollectionOption
}

type PrepareCollectionOption func(opt *ClientParamsOption)

func WithDataParams(dp DataParams) PrepareCollectionOption {
return func(opt *ClientParamsOption) {
opt.DataParams = dp
}
}

func WithFlushParams(fp FlushParams) PrepareCollectionOption {
return func(opt *ClientParamsOption) {
opt.FlushParams = fp
}
}

func WithIndexParams(ip IndexParams) PrepareCollectionOption {
return func(opt *ClientParamsOption) {
opt.IndexParams = ip
}
}

func WithLoadParams(lp LoadParams) PrepareCollectionOption {
return func(opt *ClientParamsOption) {
opt.LoadParams = lp
}
}

func WithCreateOption(createOpts client.CreateCollectionOption) PrepareCollectionOption {
return func(opt *ClientParamsOption) {
opt.CreateOpts = createOpts
}
}

func WithIndexOption(indexOpts client.IndexOption) PrepareCollectionOption {
return func(opt *ClientParamsOption) {
opt.IndexOpts = indexOpts
}
}

func WithLoadOption(loadOpts client.LoadCollectionOption) PrepareCollectionOption {
return func(opt *ClientParamsOption) {
opt.LoadOpts = loadOpts
}
}
10 changes: 5 additions & 5 deletions test/testcases/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ func TestQueryOutputInvalidOutputFieldCount(t *testing.T) {
}

// test query count* after insert -> delete -> upsert -> compact
func TestQueryCountAfterDml(t *testing.T) {
func TestQueryCountAfterDml(t *testing.T) {
ctx := createContext(t, time.Second*common.DefaultTimeout)
// connect
mc := createMilvusClient(ctx, t)
Expand Down Expand Up @@ -872,7 +872,7 @@ func TestQueryCountAfterDml(t *testing.T) {
start: common.DefaultNb, nb: insertNb, dim: common.DefaultDim, EnableDynamicField: true}
insertData(ctx, t, mc, dpInsert)
countAfterInsert, _ := mc.Query(ctx, collName, []string{common.DefaultPartition}, "", []string{common.QueryCountFieldName})
require.Equal(t, int64(common.DefaultNb + insertNb), countAfterInsert.GetColumn(common.QueryCountFieldName).(*entity.ColumnInt64).Data()[0])
require.Equal(t, int64(common.DefaultNb+insertNb), countAfterInsert.GetColumn(common.QueryCountFieldName).(*entity.ColumnInt64).Data()[0])

// delete 1000 entities -> count*
mc.Delete(ctx, collName, common.DefaultPartition, fmt.Sprintf("%s < 1000 ", common.DefaultIntFieldName))
Expand All @@ -885,20 +885,20 @@ func TestQueryCountAfterDml(t *testing.T) {
jsonColumn := common.GenDefaultJSONData(common.DefaultJSONFieldName, 0, upsertNb)
mc.Upsert(ctx, collName, "", intColumn, floatColumn, vecColumn, jsonColumn)
countAfterUpsert, _ := mc.Query(ctx, collName, []string{common.DefaultPartition}, "", []string{common.QueryCountFieldName})
require.Equal(t, int64(common.DefaultNb + upsertNb), countAfterUpsert.GetColumn(common.QueryCountFieldName).(*entity.ColumnInt64).Data()[0])
require.Equal(t, int64(common.DefaultNb+upsertNb), countAfterUpsert.GetColumn(common.QueryCountFieldName).(*entity.ColumnInt64).Data()[0])

// upsert existed 100 entities -> count*
intColumn, floatColumn, vecColumn = common.GenDefaultColumnData(common.DefaultNb, upsertNb, common.DefaultDim)
jsonColumn = common.GenDefaultJSONData(common.DefaultJSONFieldName, common.DefaultNb, upsertNb)
mc.Upsert(ctx, collName, "", intColumn, floatColumn, vecColumn, jsonColumn)
countAfterUpsert2, _ := mc.Query(ctx, collName, []string{common.DefaultPartition}, "", []string{common.QueryCountFieldName})
require.Equal(t, int64(common.DefaultNb + upsertNb), countAfterUpsert2.GetColumn(common.QueryCountFieldName).(*entity.ColumnInt64).Data()[0])
require.Equal(t, int64(common.DefaultNb+upsertNb), countAfterUpsert2.GetColumn(common.QueryCountFieldName).(*entity.ColumnInt64).Data()[0])

// compact -> count(*)
_, err := mc.Compact(ctx, collName, time.Second*60)
common.CheckErr(t, err, true)
countAfterCompact, _ := mc.Query(ctx, collName, []string{common.DefaultPartition}, "", []string{common.QueryCountFieldName})
require.Equal(t, int64(common.DefaultNb + upsertNb), countAfterCompact.GetColumn(common.QueryCountFieldName).(*entity.ColumnInt64).Data()[0])
require.Equal(t, int64(common.DefaultNb+upsertNb), countAfterCompact.GetColumn(common.QueryCountFieldName).(*entity.ColumnInt64).Data()[0])
}

// TODO offset and limit
Expand Down
Loading