diff --git a/internal/util/importutilv2/common/util.go b/internal/util/importutilv2/common/util.go index ba26bd5f91994..bb6e86d730152 100644 --- a/internal/util/importutilv2/common/util.go +++ b/internal/util/importutilv2/common/util.go @@ -84,8 +84,15 @@ func EstimateReadCountPerBatch(bufferSize int, schema *schemapb.CollectionSchema if err != nil { return 0, err } + if sizePerRecord <= 0 || bufferSize <= 0 { + return 0, fmt.Errorf("invalid size, sizePerRecord=%d, bufferSize=%d", sizePerRecord, bufferSize) + } if 1000*sizePerRecord <= bufferSize { return 1000, nil } - return int64(bufferSize) / int64(sizePerRecord), nil + ret := int64(bufferSize) / int64(sizePerRecord) + if ret <= 0 { + return 1, nil + } + return ret, nil } diff --git a/internal/util/importutilv2/common/util_test.go b/internal/util/importutilv2/common/util_test.go index efbb32cbdb201..cc7f14b531070 100644 --- a/internal/util/importutilv2/common/util_test.go +++ b/internal/util/importutilv2/common/util_test.go @@ -66,3 +66,43 @@ func TestUtil_EstimateReadCountPerBatch(t *testing.T) { _, err = EstimateReadCountPerBatch(16*1024*1024, schema) assert.Error(t, err) } + +func TestUtil_EstimateReadCountPerBatch_InvalidBufferSize(t *testing.T) { + schema := &schemapb.CollectionSchema{} + count, err := EstimateReadCountPerBatch(16*1024*1024, schema) + assert.Error(t, err) + assert.Equal(t, int64(0), count) + t.Logf("err=%v", err) + + schema = &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + DataType: schemapb.DataType_Int64, + }, + }, + } + count, err = EstimateReadCountPerBatch(0, schema) + assert.Error(t, err) + assert.Equal(t, int64(0), count) + t.Logf("err=%v", err) +} + +func TestUtil_EstimateReadCountPerBatch_LargeSchema(t *testing.T) { + schema := &schemapb.CollectionSchema{} + for i := 0; i < 100; i++ { + schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ + FieldID: int64(i), + DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxLengthKey, + Value: "10000000", + }, + }, + }) + } + count, err := EstimateReadCountPerBatch(16*1024*1024, schema) + assert.NoError(t, err) + assert.Equal(t, int64(1), count) +}