diff --git a/internal/util/importutilv2/parquet/field_reader.go b/internal/util/importutilv2/parquet/field_reader.go index 5fb60ba458286..e8545223056d8 100644 --- a/internal/util/importutilv2/parquet/field_reader.go +++ b/internal/util/importutilv2/parquet/field_reader.go @@ -223,12 +223,18 @@ func ReadNullableBoolData(pcr *FieldReader, count int64) (any, []bool, error) { dataNums := chunk.Data().Len() boolReader, ok := chunk.(*array.Boolean) if !ok { - return nil, nil, WrapTypeErr("bool", chunk.DataType().Name(), pcr.field) - } - validData = append(validData, bytesToBoolArray(dataNums, boolReader.NullBitmapBytes())...) - - for i := 0; i < dataNums; i++ { - data = append(data, boolReader.Value(i)) + // the chunk type may be *array.Null if the data in chunk is all null + _, ok := chunk.(*array.Null) + if !ok { + return nil, nil, WrapTypeErr("bool|null", chunk.DataType().Name(), pcr.field) + } + validData = append(validData, make([]bool, dataNums)...) + data = append(data, make([]bool, dataNums)...) + } else { + validData = append(validData, bytesToBoolArray(dataNums, boolReader.NullBitmapBytes())...) + for i := 0; i < dataNums; i++ { + data = append(data, boolReader.Value(i)) + } } } if len(data) == 0 { @@ -353,8 +359,12 @@ func ReadNullableIntegerOrFloatData[T constraints.Integer | constraints.Float](p for i := 0; i < dataNums; i++ { data = append(data, T(float64Reader.Value(i))) } + case arrow.NULL: + // the chunk type may be *array.Null if the data in chunk is all null + validData = append(validData, make([]bool, dataNums)...) + data = append(data, make([]T, dataNums)...) default: - return nil, nil, WrapTypeErr("integer|float", chunk.DataType().Name(), pcr.field) + return nil, nil, WrapTypeErr("integer|float|null", chunk.DataType().Name(), pcr.field) } } if len(data) == 0 { @@ -402,15 +412,22 @@ func ReadNullableStringData(pcr *FieldReader, count int64) (any, []bool, error) dataNums := chunk.Data().Len() stringReader, ok := chunk.(*array.String) if !ok { - return nil, nil, WrapTypeErr("string", chunk.DataType().Name(), pcr.field) - } - validData = append(validData, bytesToBoolArray(dataNums, stringReader.NullBitmapBytes())...) - for i := 0; i < dataNums; i++ { - if stringReader.IsNull(i) { - data = append(data, "") - continue + // the chunk type may be *array.Null if the data in chunk is all null + _, ok := chunk.(*array.Null) + if !ok { + return nil, nil, WrapTypeErr("string|null", chunk.DataType().Name(), pcr.field) + } + validData = append(validData, make([]bool, dataNums)...) + data = append(data, make([]string, dataNums)...) + } else { + validData = append(validData, bytesToBoolArray(dataNums, stringReader.NullBitmapBytes())...) + for i := 0; i < dataNums; i++ { + if stringReader.IsNull(i) { + data = append(data, "") + continue + } + data = append(data, stringReader.ValueStr(i)) } - data = append(data, stringReader.ValueStr(i)) } } if len(data) == 0 { @@ -469,18 +486,25 @@ func ReadNullableVarcharData(pcr *FieldReader, count int64) (any, []bool, error) dataNums := chunk.Data().Len() stringReader, ok := chunk.(*array.String) if !ok { - return nil, nil, WrapTypeErr("string", chunk.DataType().Name(), pcr.field) - } - validData = append(validData, bytesToBoolArray(dataNums, stringReader.NullBitmapBytes())...) - for i := 0; i < dataNums; i++ { - if stringReader.IsNull(i) { - data = append(data, "") - continue + // the chunk type may be *array.Null if the data in chunk is all null + _, ok := chunk.(*array.Null) + if !ok { + return nil, nil, WrapTypeErr("string|null", chunk.DataType().Name(), pcr.field) } - if err = common.CheckVarcharLength(stringReader.Value(i), maxLength); err != nil { - return nil, nil, err + validData = append(validData, make([]bool, dataNums)...) + data = append(data, make([]string, dataNums)...) + } else { + validData = append(validData, bytesToBoolArray(dataNums, stringReader.NullBitmapBytes())...) + for i := 0; i < dataNums; i++ { + if stringReader.IsNull(i) { + data = append(data, "") + continue + } + if err = common.CheckVarcharLength(stringReader.Value(i), maxLength); err != nil { + return nil, nil, err + } + data = append(data, stringReader.ValueStr(i)) } - data = append(data, stringReader.ValueStr(i)) } } if len(data) == 0 { @@ -686,25 +710,33 @@ func ReadNullableBoolArrayData(pcr *FieldReader, count int64) (any, []bool, erro for _, chunk := range chunked.Chunks() { listReader, ok := chunk.(*array.List) if !ok { - return nil, nil, WrapTypeErr("list", chunk.DataType().Name(), pcr.field) - } - boolReader, ok := listReader.ListValues().(*array.Boolean) - if !ok { - return nil, nil, WrapTypeErr("boolArray", chunk.DataType().Name(), pcr.field) - } - offsets := listReader.Offsets() - for i := 1; i < len(offsets); i++ { - start, end := offsets[i-1], offsets[i] - elementData := make([]bool, 0, end-start) - for j := start; j < end; j++ { - elementData = append(elementData, boolReader.Value(int(j))) + // the chunk type may be *array.Null if the data in chunk is all null + _, ok := chunk.(*array.Null) + if !ok { + return nil, nil, WrapTypeErr("list|null", chunk.DataType().Name(), pcr.field) } - data = append(data, elementData) - elementDataValid := true - if start == end { - elementDataValid = false + dataNums := chunk.Data().Len() + validData = append(validData, make([]bool, dataNums)...) + data = append(data, make([][]bool, dataNums)...) + } else { + boolReader, ok := listReader.ListValues().(*array.Boolean) + if !ok { + return nil, nil, WrapTypeErr("boolArray", chunk.DataType().Name(), pcr.field) + } + offsets := listReader.Offsets() + for i := 1; i < len(offsets); i++ { + start, end := offsets[i-1], offsets[i] + elementData := make([]bool, 0, end-start) + for j := start; j < end; j++ { + elementData = append(elementData, boolReader.Value(int(j))) + } + data = append(data, elementData) + elementDataValid := true + if start == end { + elementDataValid = false + } + validData = append(validData, elementDataValid) } - validData = append(validData, elementDataValid) } } if len(data) == 0 { @@ -813,49 +845,57 @@ func ReadNullableIntegerOrFloatArrayData[T constraints.Integer | constraints.Flo for _, chunk := range chunked.Chunks() { listReader, ok := chunk.(*array.List) if !ok { - return nil, nil, WrapTypeErr("list", chunk.DataType().Name(), pcr.field) - } - offsets := listReader.Offsets() - dataType := pcr.field.GetDataType() - if typeutil.IsVectorType(dataType) { - if err = checkVectorAligned(offsets, pcr.dim, dataType); err != nil { - return nil, nil, merr.WrapErrImportFailed(fmt.Sprintf("length of vector is not aligned: %s, data type: %s", err.Error(), dataType.String())) + // the chunk type may be *array.Null if the data in chunk is all null + _, ok := chunk.(*array.Null) + if !ok { + return nil, nil, WrapTypeErr("list|null", chunk.DataType().Name(), pcr.field) + } + dataNums := chunk.Data().Len() + validData = append(validData, make([]bool, dataNums)...) + data = append(data, make([][]T, dataNums)...) + } else { + offsets := listReader.Offsets() + dataType := pcr.field.GetDataType() + if typeutil.IsVectorType(dataType) { + if err = checkVectorAligned(offsets, pcr.dim, dataType); err != nil { + return nil, nil, merr.WrapErrImportFailed(fmt.Sprintf("length of vector is not aligned: %s, data type: %s", err.Error(), dataType.String())) + } + } + valueReader := listReader.ListValues() + switch valueReader.DataType().ID() { + case arrow.INT8: + int8Reader := valueReader.(*array.Int8) + getDataFunc(offsets, func(i int) T { + return T(int8Reader.Value(i)) + }) + case arrow.INT16: + int16Reader := valueReader.(*array.Int16) + getDataFunc(offsets, func(i int) T { + return T(int16Reader.Value(i)) + }) + case arrow.INT32: + int32Reader := valueReader.(*array.Int32) + getDataFunc(offsets, func(i int) T { + return T(int32Reader.Value(i)) + }) + case arrow.INT64: + int64Reader := valueReader.(*array.Int64) + getDataFunc(offsets, func(i int) T { + return T(int64Reader.Value(i)) + }) + case arrow.FLOAT32: + float32Reader := valueReader.(*array.Float32) + getDataFunc(offsets, func(i int) T { + return T(float32Reader.Value(i)) + }) + case arrow.FLOAT64: + float64Reader := valueReader.(*array.Float64) + getDataFunc(offsets, func(i int) T { + return T(float64Reader.Value(i)) + }) + default: + return nil, nil, WrapTypeErr("integerArray|floatArray", chunk.DataType().Name(), pcr.field) } - } - valueReader := listReader.ListValues() - switch valueReader.DataType().ID() { - case arrow.INT8: - int8Reader := valueReader.(*array.Int8) - getDataFunc(offsets, func(i int) T { - return T(int8Reader.Value(i)) - }) - case arrow.INT16: - int16Reader := valueReader.(*array.Int16) - getDataFunc(offsets, func(i int) T { - return T(int16Reader.Value(i)) - }) - case arrow.INT32: - int32Reader := valueReader.(*array.Int32) - getDataFunc(offsets, func(i int) T { - return T(int32Reader.Value(i)) - }) - case arrow.INT64: - int64Reader := valueReader.(*array.Int64) - getDataFunc(offsets, func(i int) T { - return T(int64Reader.Value(i)) - }) - case arrow.FLOAT32: - float32Reader := valueReader.(*array.Float32) - getDataFunc(offsets, func(i int) T { - return T(float32Reader.Value(i)) - }) - case arrow.FLOAT64: - float64Reader := valueReader.(*array.Float64) - getDataFunc(offsets, func(i int) T { - return T(float64Reader.Value(i)) - }) - default: - return nil, nil, WrapTypeErr("integerArray|floatArray", chunk.DataType().Name(), pcr.field) } } if len(data) == 0 { @@ -908,25 +948,33 @@ func ReadNullableStringArrayData(pcr *FieldReader, count int64) (any, []bool, er for _, chunk := range chunked.Chunks() { listReader, ok := chunk.(*array.List) if !ok { - return nil, nil, WrapTypeErr("list", chunk.DataType().Name(), pcr.field) - } - stringReader, ok := listReader.ListValues().(*array.String) - if !ok { - return nil, nil, WrapTypeErr("stringArray", chunk.DataType().Name(), pcr.field) - } - offsets := listReader.Offsets() - for i := 1; i < len(offsets); i++ { - start, end := offsets[i-1], offsets[i] - elementData := make([]string, 0, end-start) - for j := start; j < end; j++ { - elementData = append(elementData, stringReader.Value(int(j))) + // the chunk type may be *array.Null if the data in chunk is all null + _, ok := chunk.(*array.Null) + if !ok { + return nil, nil, WrapTypeErr("list|null", chunk.DataType().Name(), pcr.field) } - data = append(data, elementData) - elementDataValid := true - if start == end { - elementDataValid = false + dataNums := chunk.Data().Len() + validData = append(validData, make([]bool, dataNums)...) + data = append(data, make([][]string, dataNums)...) + } else { + stringReader, ok := listReader.ListValues().(*array.String) + if !ok { + return nil, nil, WrapTypeErr("stringArray", chunk.DataType().Name(), pcr.field) + } + offsets := listReader.Offsets() + for i := 1; i < len(offsets); i++ { + start, end := offsets[i-1], offsets[i] + elementData := make([]string, 0, end-start) + for j := start; j < end; j++ { + elementData = append(elementData, stringReader.Value(int(j))) + } + data = append(data, elementData) + elementDataValid := true + if start == end { + elementDataValid = false + } + validData = append(validData, elementDataValid) } - validData = append(validData, elementDataValid) } } if len(data) == 0 { diff --git a/internal/util/importutilv2/parquet/reader_test.go b/internal/util/importutilv2/parquet/reader_test.go index 6db2b2668e4a4..8bb886a80831c 100644 --- a/internal/util/importutilv2/parquet/reader_test.go +++ b/internal/util/importutilv2/parquet/reader_test.go @@ -68,8 +68,12 @@ func randomString(length int) string { return string(b) } -func writeParquet(w io.Writer, schema *schemapb.CollectionSchema, numRows int) (*storage.InsertData, error) { - pqSchema, err := ConvertToArrowSchema(schema) +func writeParquet(w io.Writer, schema *schemapb.CollectionSchema, numRows int, nullPercent int) (*storage.InsertData, error) { + useNullType := false + if nullPercent == 100 { + useNullType = true + } + pqSchema, err := ConvertToArrowSchema(schema, useNullType) if err != nil { return nil, err } @@ -79,12 +83,11 @@ func writeParquet(w io.Writer, schema *schemapb.CollectionSchema, numRows int) ( } defer fw.Close() - insertData, err := testutil.CreateInsertData(schema, numRows) + insertData, err := testutil.CreateInsertData(schema, numRows, nullPercent) if err != nil { return nil, err } - - columns, err := testutil.BuildArrayData(schema, insertData) + columns, err := testutil.BuildArrayData(schema, insertData, useNullType) if err != nil { return nil, err } @@ -98,7 +101,7 @@ func writeParquet(w io.Writer, schema *schemapb.CollectionSchema, numRows int) ( return insertData, nil } -func (s *ReaderSuite) run(dataType schemapb.DataType, elemType schemapb.DataType, nullable bool) { +func (s *ReaderSuite) run(dataType schemapb.DataType, elemType schemapb.DataType, nullable bool, nullPercent int) { schema := &schemapb.CollectionSchema{ Fields: []*schemapb.FieldSchema{ { @@ -148,7 +151,7 @@ func (s *ReaderSuite) run(dataType schemapb.DataType, elemType schemapb.DataType defer os.Remove(filePath) wf, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE, 0o666) assert.NoError(s.T(), err) - insertData, err := writeParquet(wf, schema, s.numRows) + insertData, err := writeParquet(wf, schema, s.numRows, nullPercent) assert.NoError(s.T(), err) ctx := context.Background() @@ -250,7 +253,7 @@ func (s *ReaderSuite) failRun(dt schemapb.DataType, isDynamic bool) { defer os.Remove(filePath) wf, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE, 0o666) assert.NoError(s.T(), err) - _, err = writeParquet(wf, schema, s.numRows) + _, err = writeParquet(wf, schema, s.numRows, 50) assert.NoError(s.T(), err) ctx := context.Background() @@ -265,66 +268,85 @@ func (s *ReaderSuite) failRun(dt schemapb.DataType, isDynamic bool) { } func (s *ReaderSuite) TestReadScalarFields() { - s.run(schemapb.DataType_Bool, schemapb.DataType_None, false) - s.run(schemapb.DataType_Int8, schemapb.DataType_None, false) - s.run(schemapb.DataType_Int16, schemapb.DataType_None, false) - s.run(schemapb.DataType_Int32, schemapb.DataType_None, false) - s.run(schemapb.DataType_Int64, schemapb.DataType_None, false) - s.run(schemapb.DataType_Float, schemapb.DataType_None, false) - s.run(schemapb.DataType_Double, schemapb.DataType_None, false) - s.run(schemapb.DataType_String, schemapb.DataType_None, false) - s.run(schemapb.DataType_VarChar, schemapb.DataType_None, false) - s.run(schemapb.DataType_JSON, schemapb.DataType_None, false) - - s.run(schemapb.DataType_Array, schemapb.DataType_Bool, false) - s.run(schemapb.DataType_Array, schemapb.DataType_Int8, false) - s.run(schemapb.DataType_Array, schemapb.DataType_Int16, false) - s.run(schemapb.DataType_Array, schemapb.DataType_Int32, false) - s.run(schemapb.DataType_Array, schemapb.DataType_Int64, false) - s.run(schemapb.DataType_Array, schemapb.DataType_Float, false) - s.run(schemapb.DataType_Array, schemapb.DataType_Double, false) - s.run(schemapb.DataType_Array, schemapb.DataType_String, false) - - s.run(schemapb.DataType_Bool, schemapb.DataType_None, true) - s.run(schemapb.DataType_Int8, schemapb.DataType_None, true) - s.run(schemapb.DataType_Int16, schemapb.DataType_None, true) - s.run(schemapb.DataType_Int32, schemapb.DataType_None, true) - s.run(schemapb.DataType_Int64, schemapb.DataType_None, true) - s.run(schemapb.DataType_Float, schemapb.DataType_None, true) - s.run(schemapb.DataType_Double, schemapb.DataType_None, true) - s.run(schemapb.DataType_String, schemapb.DataType_None, true) - s.run(schemapb.DataType_VarChar, schemapb.DataType_None, true) - s.run(schemapb.DataType_JSON, schemapb.DataType_None, true) - - s.run(schemapb.DataType_Array, schemapb.DataType_Bool, true) - s.run(schemapb.DataType_Array, schemapb.DataType_Int8, true) - s.run(schemapb.DataType_Array, schemapb.DataType_Int16, true) - s.run(schemapb.DataType_Array, schemapb.DataType_Int32, true) - s.run(schemapb.DataType_Array, schemapb.DataType_Int64, true) - s.run(schemapb.DataType_Array, schemapb.DataType_Float, true) - s.run(schemapb.DataType_Array, schemapb.DataType_Double, true) - s.run(schemapb.DataType_Array, schemapb.DataType_String, true) + s.run(schemapb.DataType_Bool, schemapb.DataType_None, false, 0) + s.run(schemapb.DataType_Int8, schemapb.DataType_None, false, 0) + s.run(schemapb.DataType_Int16, schemapb.DataType_None, false, 0) + s.run(schemapb.DataType_Int32, schemapb.DataType_None, false, 0) + s.run(schemapb.DataType_Int64, schemapb.DataType_None, false, 0) + s.run(schemapb.DataType_Float, schemapb.DataType_None, false, 0) + s.run(schemapb.DataType_Double, schemapb.DataType_None, false, 0) + s.run(schemapb.DataType_String, schemapb.DataType_None, false, 0) + s.run(schemapb.DataType_VarChar, schemapb.DataType_None, false, 0) + s.run(schemapb.DataType_JSON, schemapb.DataType_None, false, 0) + + s.run(schemapb.DataType_Array, schemapb.DataType_Bool, false, 0) + s.run(schemapb.DataType_Array, schemapb.DataType_Int8, false, 0) + s.run(schemapb.DataType_Array, schemapb.DataType_Int16, false, 0) + s.run(schemapb.DataType_Array, schemapb.DataType_Int32, false, 0) + s.run(schemapb.DataType_Array, schemapb.DataType_Int64, false, 0) + s.run(schemapb.DataType_Array, schemapb.DataType_Float, false, 0) + s.run(schemapb.DataType_Array, schemapb.DataType_Double, false, 0) + s.run(schemapb.DataType_Array, schemapb.DataType_String, false, 0) + + s.run(schemapb.DataType_Bool, schemapb.DataType_None, true, 50) + s.run(schemapb.DataType_Int8, schemapb.DataType_None, true, 50) + s.run(schemapb.DataType_Int16, schemapb.DataType_None, true, 50) + s.run(schemapb.DataType_Int32, schemapb.DataType_None, true, 50) + s.run(schemapb.DataType_Int64, schemapb.DataType_None, true, 50) + s.run(schemapb.DataType_Float, schemapb.DataType_None, true, 50) + s.run(schemapb.DataType_String, schemapb.DataType_None, true, 50) + s.run(schemapb.DataType_VarChar, schemapb.DataType_None, true, 50) + s.run(schemapb.DataType_JSON, schemapb.DataType_None, true, 50) + + s.run(schemapb.DataType_Array, schemapb.DataType_Bool, true, 50) + s.run(schemapb.DataType_Array, schemapb.DataType_Int8, true, 50) + s.run(schemapb.DataType_Array, schemapb.DataType_Int16, true, 50) + s.run(schemapb.DataType_Array, schemapb.DataType_Int32, true, 50) + s.run(schemapb.DataType_Array, schemapb.DataType_Int64, true, 50) + s.run(schemapb.DataType_Array, schemapb.DataType_Float, true, 50) + s.run(schemapb.DataType_Array, schemapb.DataType_Double, true, 50) + s.run(schemapb.DataType_Array, schemapb.DataType_String, true, 50) + + s.run(schemapb.DataType_Bool, schemapb.DataType_None, true, 100) + s.run(schemapb.DataType_Int8, schemapb.DataType_None, true, 100) + s.run(schemapb.DataType_Int16, schemapb.DataType_None, true, 100) + s.run(schemapb.DataType_Int32, schemapb.DataType_None, true, 100) + s.run(schemapb.DataType_Int64, schemapb.DataType_None, true, 100) + s.run(schemapb.DataType_Float, schemapb.DataType_None, true, 100) + s.run(schemapb.DataType_String, schemapb.DataType_None, true, 100) + s.run(schemapb.DataType_VarChar, schemapb.DataType_None, true, 100) + s.run(schemapb.DataType_JSON, schemapb.DataType_None, true, 100) + + s.run(schemapb.DataType_Array, schemapb.DataType_Bool, true, 100) + s.run(schemapb.DataType_Array, schemapb.DataType_Int8, true, 100) + s.run(schemapb.DataType_Array, schemapb.DataType_Int16, true, 100) + s.run(schemapb.DataType_Array, schemapb.DataType_Int32, true, 100) + s.run(schemapb.DataType_Array, schemapb.DataType_Int64, true, 100) + s.run(schemapb.DataType_Array, schemapb.DataType_Float, true, 100) + s.run(schemapb.DataType_Array, schemapb.DataType_Double, true, 100) + s.run(schemapb.DataType_Array, schemapb.DataType_String, true, 100) s.failRun(schemapb.DataType_JSON, true) } func (s *ReaderSuite) TestStringPK() { s.pkDataType = schemapb.DataType_VarChar - s.run(schemapb.DataType_Int32, schemapb.DataType_None, false) - s.run(schemapb.DataType_Int32, schemapb.DataType_None, true) + s.run(schemapb.DataType_Int32, schemapb.DataType_None, false, 0) + s.run(schemapb.DataType_Int32, schemapb.DataType_None, true, 50) + s.run(schemapb.DataType_Int32, schemapb.DataType_None, true, 100) } func (s *ReaderSuite) TestVector() { s.vecDataType = schemapb.DataType_BinaryVector - s.run(schemapb.DataType_Int32, schemapb.DataType_None, false) + s.run(schemapb.DataType_Int32, schemapb.DataType_None, false, 0) s.vecDataType = schemapb.DataType_FloatVector - s.run(schemapb.DataType_Int32, schemapb.DataType_None, false) + s.run(schemapb.DataType_Int32, schemapb.DataType_None, false, 0) s.vecDataType = schemapb.DataType_Float16Vector - s.run(schemapb.DataType_Int32, schemapb.DataType_None, false) + s.run(schemapb.DataType_Int32, schemapb.DataType_None, false, 0) s.vecDataType = schemapb.DataType_BFloat16Vector - s.run(schemapb.DataType_Int32, schemapb.DataType_None, false) + s.run(schemapb.DataType_Int32, schemapb.DataType_None, false, 0) s.vecDataType = schemapb.DataType_SparseFloatVector - s.run(schemapb.DataType_Int32, schemapb.DataType_None, false) + s.run(schemapb.DataType_Int32, schemapb.DataType_None, false, 0) } func TestUtil(t *testing.T) { diff --git a/internal/util/importutilv2/parquet/util.go b/internal/util/importutilv2/parquet/util.go index 451436c34c96e..8b5d1b1987acc 100644 --- a/internal/util/importutilv2/parquet/util.go +++ b/internal/util/importutilv2/parquet/util.go @@ -115,7 +115,7 @@ func isArrowArithmeticType(dataType arrow.Type) bool { return isArrowIntegerType(dataType) || isArrowFloatingType(dataType) } -func isArrowDataTypeConvertible(src arrow.DataType, dst arrow.DataType) bool { +func isArrowDataTypeConvertible(src arrow.DataType, dst arrow.DataType, nullable bool) bool { srcType := src.ID() dstType := dst.ID() switch srcType { @@ -142,7 +142,9 @@ func isArrowDataTypeConvertible(src arrow.DataType, dst arrow.DataType) bool { case arrow.BINARY: return dstType == arrow.LIST && dst.(*arrow.ListType).Elem().ID() == arrow.UINT8 case arrow.LIST: - return dstType == arrow.LIST && isArrowDataTypeConvertible(src.(*arrow.ListType).Elem(), dst.(*arrow.ListType).Elem()) + return dstType == arrow.LIST && isArrowDataTypeConvertible(src.(*arrow.ListType).Elem(), dst.(*arrow.ListType).Elem(), nullable) + case arrow.NULL: + return nullable default: return false } @@ -204,7 +206,7 @@ func convertToArrowDataType(field *schemapb.FieldSchema, isArray bool) (arrow.Da } } -func ConvertToArrowSchema(schema *schemapb.CollectionSchema) (*arrow.Schema, error) { +func ConvertToArrowSchema(schema *schemapb.CollectionSchema, useNullType bool) (*arrow.Schema, error) { arrFields := make([]arrow.Field, 0) for _, field := range schema.GetFields() { if typeutil.IsAutoPKField(field) { @@ -214,10 +216,13 @@ func ConvertToArrowSchema(schema *schemapb.CollectionSchema) (*arrow.Schema, err if err != nil { return nil, err } + if field.GetNullable() && useNullType { + arrDataType = arrow.Null + } arrFields = append(arrFields, arrow.Field{ Name: field.GetName(), Type: arrDataType, - Nullable: true, + Nullable: field.GetNullable(), Metadata: arrow.Metadata{}, }) } @@ -243,7 +248,7 @@ func isSchemaEqual(schema *schemapb.CollectionSchema, arrSchema *arrow.Schema) e if err != nil { return err } - if !isArrowDataTypeConvertible(arrField.Type, toArrDataType) { + if !isArrowDataTypeConvertible(arrField.Type, toArrDataType, field.GetNullable()) { return merr.WrapErrImportFailed(fmt.Sprintf("field '%s' type mis-match, milvus data type '%s', arrow data type get '%s'", field.Name, field.DataType.String(), arrField.Type.String())) } diff --git a/internal/util/testutil/test_util.go b/internal/util/testutil/test_util.go index 7b827bd6f5ae9..3e89b33d155de 100644 --- a/internal/util/testutil/test_util.go +++ b/internal/util/testutil/test_util.go @@ -15,6 +15,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/testutils" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -102,7 +103,7 @@ func randomString(length int) string { return string(b) } -func CreateInsertData(schema *schemapb.CollectionSchema, rows int) (*storage.InsertData, error) { +func CreateInsertData(schema *schemapb.CollectionSchema, rows int, nullPercent ...int) (*storage.InsertData, error) { insertData, err := storage.NewInsertData(schema) if err != nil { return nil, err @@ -193,13 +194,22 @@ func CreateInsertData(schema *schemapb.CollectionSchema, rows int) (*storage.Ins panic(fmt.Sprintf("unsupported data type: %s", f.GetDataType().String())) } if f.GetNullable() { - insertData.Data[f.FieldID].AppendValidDataRows(testutils.GenerateBoolArray(rows)) + if len(nullPercent) > 1 { + return nil, merr.WrapErrParameterInvalidMsg("the length of nullPercent is wrong") + } + if len(nullPercent) == 0 || nullPercent[0] == 50 { + insertData.Data[f.FieldID].AppendValidDataRows(testutils.GenerateBoolArray(rows)) + } else if len(nullPercent) == 1 && nullPercent[0] == 100 { + insertData.Data[f.FieldID].AppendValidDataRows(make([]bool, rows)) + } else { + return nil, merr.WrapErrParameterInvalidMsg("not support the number of nullPercent") + } } } return insertData, nil } -func BuildArrayData(schema *schemapb.CollectionSchema, insertData *storage.InsertData) ([]arrow.Array, error) { +func BuildArrayData(schema *schemapb.CollectionSchema, insertData *storage.InsertData, useNullType bool) ([]arrow.Array, error) { mem := memory.NewGoAllocator() columns := make([]arrow.Array, 0, len(schema.Fields)) for _, field := range schema.Fields { @@ -209,6 +219,10 @@ func BuildArrayData(schema *schemapb.CollectionSchema, insertData *storage.Inser fieldID := field.GetFieldID() dataType := field.GetDataType() elementType := field.GetElementType() + if field.GetNullable() && useNullType { + columns = append(columns, array.NewNull(insertData.Data[fieldID].RowNum())) + continue + } switch dataType { case schemapb.DataType_Bool: builder := array.NewBooleanBuilder(mem) diff --git a/tests/integration/import/util_test.go b/tests/integration/import/util_test.go index d8add9e57d445..5eb63a847493b 100644 --- a/tests/integration/import/util_test.go +++ b/tests/integration/import/util_test.go @@ -68,7 +68,7 @@ func GenerateParquetFileAndReturnInsertData(filePath string, schema *schemapb.Co return nil, err } - pqSchema, err := pq.ConvertToArrowSchema(schema) + pqSchema, err := pq.ConvertToArrowSchema(schema, false) if err != nil { return nil, err } @@ -83,7 +83,7 @@ func GenerateParquetFileAndReturnInsertData(filePath string, schema *schemapb.Co return nil, err } - columns, err := testutil.BuildArrayData(schema, insertData) + columns, err := testutil.BuildArrayData(schema, insertData, false) if err != nil { return nil, err }