diff --git a/Makefile b/Makefile index 877498cafa6ac..8ad4e63e52e98 100644 --- a/Makefile +++ b/Makefile @@ -457,6 +457,9 @@ generate-mockery-kv: getdeps $(INSTALL_PATH)/mockery --name=SnapShotKV --dir=$(PWD)/internal/kv --output=$(PWD)/internal/kv/mocks --filename=snapshot_kv.go --with-expecter $(INSTALL_PATH)/mockery --name=Predicate --dir=$(PWD)/internal/kv/predicates --output=$(PWD)/internal/kv/predicates --filename=mock_predicate.go --with-expecter --inpackage +generate-mockery-chunk-manager: getdeps + $(INSTALL_PATH)/mockery --name=ChunkManager --dir=$(PWD)/internal/storage --output=$(PWD)/internal/mocks --filename=mock_chunk_manager.go --with-expecter + generate-mockery-pkg: $(MAKE) -C pkg generate-mockery diff --git a/internal/mocks/mock_chunk_manager.go b/internal/mocks/mock_chunk_manager.go index 7e8cc7a5d82f6..ef20e5547baff 100644 --- a/internal/mocks/mock_chunk_manager.go +++ b/internal/mocks/mock_chunk_manager.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.14.0. DO NOT EDIT. +// Code generated by mockery v2.32.4. DO NOT EDIT. package mocks @@ -32,13 +32,16 @@ func (_m *ChunkManager) Exist(ctx context.Context, filePath string) (bool, error ret := _m.Called(ctx, filePath) var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (bool, error)); ok { + return rf(ctx, filePath) + } if rf, ok := ret.Get(0).(func(context.Context, string) bool); ok { r0 = rf(ctx, filePath) } else { r0 = ret.Get(0).(bool) } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { r1 = rf(ctx, filePath) } else { @@ -72,11 +75,21 @@ func (_c *ChunkManager_Exist_Call) Return(_a0 bool, _a1 error) *ChunkManager_Exi return _c } +func (_c *ChunkManager_Exist_Call) RunAndReturn(run func(context.Context, string) (bool, error)) *ChunkManager_Exist_Call { + _c.Call.Return(run) + return _c +} + // ListWithPrefix provides a mock function with given fields: ctx, prefix, recursive func (_m *ChunkManager) ListWithPrefix(ctx context.Context, prefix string, recursive bool) ([]string, []time.Time, error) { ret := _m.Called(ctx, prefix, recursive) var r0 []string + var r1 []time.Time + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, string, bool) ([]string, []time.Time, error)); ok { + return rf(ctx, prefix, recursive) + } if rf, ok := ret.Get(0).(func(context.Context, string, bool) []string); ok { r0 = rf(ctx, prefix, recursive) } else { @@ -85,7 +98,6 @@ func (_m *ChunkManager) ListWithPrefix(ctx context.Context, prefix string, recur } } - var r1 []time.Time if rf, ok := ret.Get(1).(func(context.Context, string, bool) []time.Time); ok { r1 = rf(ctx, prefix, recursive) } else { @@ -94,7 +106,6 @@ func (_m *ChunkManager) ListWithPrefix(ctx context.Context, prefix string, recur } } - var r2 error if rf, ok := ret.Get(2).(func(context.Context, string, bool) error); ok { r2 = rf(ctx, prefix, recursive) } else { @@ -129,11 +140,20 @@ func (_c *ChunkManager_ListWithPrefix_Call) Return(_a0 []string, _a1 []time.Time return _c } +func (_c *ChunkManager_ListWithPrefix_Call) RunAndReturn(run func(context.Context, string, bool) ([]string, []time.Time, error)) *ChunkManager_ListWithPrefix_Call { + _c.Call.Return(run) + return _c +} + // Mmap provides a mock function with given fields: ctx, filePath func (_m *ChunkManager) Mmap(ctx context.Context, filePath string) (*mmap.ReaderAt, error) { ret := _m.Called(ctx, filePath) var r0 *mmap.ReaderAt + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*mmap.ReaderAt, error)); ok { + return rf(ctx, filePath) + } if rf, ok := ret.Get(0).(func(context.Context, string) *mmap.ReaderAt); ok { r0 = rf(ctx, filePath) } else { @@ -142,7 +162,6 @@ func (_m *ChunkManager) Mmap(ctx context.Context, filePath string) (*mmap.Reader } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { r1 = rf(ctx, filePath) } else { @@ -176,11 +195,20 @@ func (_c *ChunkManager_Mmap_Call) Return(_a0 *mmap.ReaderAt, _a1 error) *ChunkMa return _c } +func (_c *ChunkManager_Mmap_Call) RunAndReturn(run func(context.Context, string) (*mmap.ReaderAt, error)) *ChunkManager_Mmap_Call { + _c.Call.Return(run) + return _c +} + // MultiRead provides a mock function with given fields: ctx, filePaths func (_m *ChunkManager) MultiRead(ctx context.Context, filePaths []string) ([][]byte, error) { ret := _m.Called(ctx, filePaths) var r0 [][]byte + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, []string) ([][]byte, error)); ok { + return rf(ctx, filePaths) + } if rf, ok := ret.Get(0).(func(context.Context, []string) [][]byte); ok { r0 = rf(ctx, filePaths) } else { @@ -189,7 +217,6 @@ func (_m *ChunkManager) MultiRead(ctx context.Context, filePaths []string) ([][] } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, []string) error); ok { r1 = rf(ctx, filePaths) } else { @@ -223,6 +250,11 @@ func (_c *ChunkManager_MultiRead_Call) Return(_a0 [][]byte, _a1 error) *ChunkMan return _c } +func (_c *ChunkManager_MultiRead_Call) RunAndReturn(run func(context.Context, []string) ([][]byte, error)) *ChunkManager_MultiRead_Call { + _c.Call.Return(run) + return _c +} + // MultiRemove provides a mock function with given fields: ctx, filePaths func (_m *ChunkManager) MultiRemove(ctx context.Context, filePaths []string) error { ret := _m.Called(ctx, filePaths) @@ -261,6 +293,11 @@ func (_c *ChunkManager_MultiRemove_Call) Return(_a0 error) *ChunkManager_MultiRe return _c } +func (_c *ChunkManager_MultiRemove_Call) RunAndReturn(run func(context.Context, []string) error) *ChunkManager_MultiRemove_Call { + _c.Call.Return(run) + return _c +} + // MultiWrite provides a mock function with given fields: ctx, contents func (_m *ChunkManager) MultiWrite(ctx context.Context, contents map[string][]byte) error { ret := _m.Called(ctx, contents) @@ -299,18 +336,26 @@ func (_c *ChunkManager_MultiWrite_Call) Return(_a0 error) *ChunkManager_MultiWri return _c } +func (_c *ChunkManager_MultiWrite_Call) RunAndReturn(run func(context.Context, map[string][]byte) error) *ChunkManager_MultiWrite_Call { + _c.Call.Return(run) + return _c +} + // Path provides a mock function with given fields: ctx, filePath func (_m *ChunkManager) Path(ctx context.Context, filePath string) (string, error) { ret := _m.Called(ctx, filePath) var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (string, error)); ok { + return rf(ctx, filePath) + } if rf, ok := ret.Get(0).(func(context.Context, string) string); ok { r0 = rf(ctx, filePath) } else { r0 = ret.Get(0).(string) } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { r1 = rf(ctx, filePath) } else { @@ -344,11 +389,20 @@ func (_c *ChunkManager_Path_Call) Return(_a0 string, _a1 error) *ChunkManager_Pa return _c } +func (_c *ChunkManager_Path_Call) RunAndReturn(run func(context.Context, string) (string, error)) *ChunkManager_Path_Call { + _c.Call.Return(run) + return _c +} + // Read provides a mock function with given fields: ctx, filePath func (_m *ChunkManager) Read(ctx context.Context, filePath string) ([]byte, error) { ret := _m.Called(ctx, filePath) var r0 []byte + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) ([]byte, error)); ok { + return rf(ctx, filePath) + } if rf, ok := ret.Get(0).(func(context.Context, string) []byte); ok { r0 = rf(ctx, filePath) } else { @@ -357,7 +411,6 @@ func (_m *ChunkManager) Read(ctx context.Context, filePath string) ([]byte, erro } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { r1 = rf(ctx, filePath) } else { @@ -391,11 +444,20 @@ func (_c *ChunkManager_Read_Call) Return(_a0 []byte, _a1 error) *ChunkManager_Re return _c } +func (_c *ChunkManager_Read_Call) RunAndReturn(run func(context.Context, string) ([]byte, error)) *ChunkManager_Read_Call { + _c.Call.Return(run) + return _c +} + // ReadAt provides a mock function with given fields: ctx, filePath, off, length func (_m *ChunkManager) ReadAt(ctx context.Context, filePath string, off int64, length int64) ([]byte, error) { ret := _m.Called(ctx, filePath, off, length) var r0 []byte + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, int64, int64) ([]byte, error)); ok { + return rf(ctx, filePath, off, length) + } if rf, ok := ret.Get(0).(func(context.Context, string, int64, int64) []byte); ok { r0 = rf(ctx, filePath, off, length) } else { @@ -404,7 +466,6 @@ func (_m *ChunkManager) ReadAt(ctx context.Context, filePath string, off int64, } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string, int64, int64) error); ok { r1 = rf(ctx, filePath, off, length) } else { @@ -440,11 +501,21 @@ func (_c *ChunkManager_ReadAt_Call) Return(p []byte, err error) *ChunkManager_Re return _c } +func (_c *ChunkManager_ReadAt_Call) RunAndReturn(run func(context.Context, string, int64, int64) ([]byte, error)) *ChunkManager_ReadAt_Call { + _c.Call.Return(run) + return _c +} + // ReadWithPrefix provides a mock function with given fields: ctx, prefix func (_m *ChunkManager) ReadWithPrefix(ctx context.Context, prefix string) ([]string, [][]byte, error) { ret := _m.Called(ctx, prefix) var r0 []string + var r1 [][]byte + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, string) ([]string, [][]byte, error)); ok { + return rf(ctx, prefix) + } if rf, ok := ret.Get(0).(func(context.Context, string) []string); ok { r0 = rf(ctx, prefix) } else { @@ -453,7 +524,6 @@ func (_m *ChunkManager) ReadWithPrefix(ctx context.Context, prefix string) ([]st } } - var r1 [][]byte if rf, ok := ret.Get(1).(func(context.Context, string) [][]byte); ok { r1 = rf(ctx, prefix) } else { @@ -462,7 +532,6 @@ func (_m *ChunkManager) ReadWithPrefix(ctx context.Context, prefix string) ([]st } } - var r2 error if rf, ok := ret.Get(2).(func(context.Context, string) error); ok { r2 = rf(ctx, prefix) } else { @@ -496,11 +565,20 @@ func (_c *ChunkManager_ReadWithPrefix_Call) Return(_a0 []string, _a1 [][]byte, _ return _c } +func (_c *ChunkManager_ReadWithPrefix_Call) RunAndReturn(run func(context.Context, string) ([]string, [][]byte, error)) *ChunkManager_ReadWithPrefix_Call { + _c.Call.Return(run) + return _c +} + // Reader provides a mock function with given fields: ctx, filePath func (_m *ChunkManager) Reader(ctx context.Context, filePath string) (storage.FileReader, error) { ret := _m.Called(ctx, filePath) var r0 storage.FileReader + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (storage.FileReader, error)); ok { + return rf(ctx, filePath) + } if rf, ok := ret.Get(0).(func(context.Context, string) storage.FileReader); ok { r0 = rf(ctx, filePath) } else { @@ -509,7 +587,6 @@ func (_m *ChunkManager) Reader(ctx context.Context, filePath string) (storage.Fi } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { r1 = rf(ctx, filePath) } else { @@ -543,6 +620,11 @@ func (_c *ChunkManager_Reader_Call) Return(_a0 storage.FileReader, _a1 error) *C return _c } +func (_c *ChunkManager_Reader_Call) RunAndReturn(run func(context.Context, string) (storage.FileReader, error)) *ChunkManager_Reader_Call { + _c.Call.Return(run) + return _c +} + // Remove provides a mock function with given fields: ctx, filePath func (_m *ChunkManager) Remove(ctx context.Context, filePath string) error { ret := _m.Called(ctx, filePath) @@ -581,6 +663,11 @@ func (_c *ChunkManager_Remove_Call) Return(_a0 error) *ChunkManager_Remove_Call return _c } +func (_c *ChunkManager_Remove_Call) RunAndReturn(run func(context.Context, string) error) *ChunkManager_Remove_Call { + _c.Call.Return(run) + return _c +} + // RemoveWithPrefix provides a mock function with given fields: ctx, prefix func (_m *ChunkManager) RemoveWithPrefix(ctx context.Context, prefix string) error { ret := _m.Called(ctx, prefix) @@ -619,6 +706,11 @@ func (_c *ChunkManager_RemoveWithPrefix_Call) Return(_a0 error) *ChunkManager_Re return _c } +func (_c *ChunkManager_RemoveWithPrefix_Call) RunAndReturn(run func(context.Context, string) error) *ChunkManager_RemoveWithPrefix_Call { + _c.Call.Return(run) + return _c +} + // RootPath provides a mock function with given fields: func (_m *ChunkManager) RootPath() string { ret := _m.Called() @@ -655,18 +747,26 @@ func (_c *ChunkManager_RootPath_Call) Return(_a0 string) *ChunkManager_RootPath_ return _c } +func (_c *ChunkManager_RootPath_Call) RunAndReturn(run func() string) *ChunkManager_RootPath_Call { + _c.Call.Return(run) + return _c +} + // Size provides a mock function with given fields: ctx, filePath func (_m *ChunkManager) Size(ctx context.Context, filePath string) (int64, error) { ret := _m.Called(ctx, filePath) var r0 int64 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (int64, error)); ok { + return rf(ctx, filePath) + } if rf, ok := ret.Get(0).(func(context.Context, string) int64); ok { r0 = rf(ctx, filePath) } else { r0 = ret.Get(0).(int64) } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { r1 = rf(ctx, filePath) } else { @@ -700,6 +800,11 @@ func (_c *ChunkManager_Size_Call) Return(_a0 int64, _a1 error) *ChunkManager_Siz return _c } +func (_c *ChunkManager_Size_Call) RunAndReturn(run func(context.Context, string) (int64, error)) *ChunkManager_Size_Call { + _c.Call.Return(run) + return _c +} + // Write provides a mock function with given fields: ctx, filePath, content func (_m *ChunkManager) Write(ctx context.Context, filePath string, content []byte) error { ret := _m.Called(ctx, filePath, content) @@ -739,13 +844,17 @@ func (_c *ChunkManager_Write_Call) Return(_a0 error) *ChunkManager_Write_Call { return _c } -type mockConstructorTestingTNewChunkManager interface { - mock.TestingT - Cleanup(func()) +func (_c *ChunkManager_Write_Call) RunAndReturn(run func(context.Context, string, []byte) error) *ChunkManager_Write_Call { + _c.Call.Return(run) + return _c } // NewChunkManager creates a new instance of ChunkManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewChunkManager(t mockConstructorTestingTNewChunkManager) *ChunkManager { +// The first argument is typically a *testing.T value. +func NewChunkManager(t interface { + mock.TestingT + Cleanup(func()) +}) *ChunkManager { mock := &ChunkManager{} mock.Mock.Test(t) diff --git a/internal/storage/azure_object_storage.go b/internal/storage/azure_object_storage.go index 286733229c8db..a773601295336 100644 --- a/internal/storage/azure_object_storage.go +++ b/internal/storage/azure_object_storage.go @@ -101,7 +101,7 @@ func (AzureObjectStorage *AzureObjectStorage) GetObject(ctx context.Context, buc if err != nil { return nil, checkObjectStorageError(objectName, err) } - return object.Body, nil + return NewAzureFile(object.Body), nil } func (AzureObjectStorage *AzureObjectStorage) PutObject(ctx context.Context, bucketName, objectName string, reader io.Reader, objectSize int64) error { diff --git a/internal/storage/file.go b/internal/storage/file.go new file mode 100644 index 0000000000000..d27c1fcd85598 --- /dev/null +++ b/internal/storage/file.go @@ -0,0 +1,117 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "io" + + "github.com/cockroachdb/errors" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/pkg/log" +) + +var errInvalid = errors.New("invalid argument") + +// MemoryFile implements the FileReader interface +type MemoryFile struct { + data []byte + position int +} + +// NewMemoryFile creates a new instance of MemoryFile +func NewMemoryFile(data []byte) *MemoryFile { + return &MemoryFile{data: data} +} + +// ReadAt implements the ReadAt method of the io.ReaderAt interface +func (mf *MemoryFile) ReadAt(p []byte, off int64) (n int, err error) { + if off < 0 || int64(int(off)) < off { + return 0, errInvalid + } + if off > int64(len(mf.data)) { + return 0, io.EOF + } + n = copy(p, mf.data[off:]) + mf.position += n + if n < len(p) { + return n, io.EOF + } + return n, nil +} + +// Seek implements the Seek method of the io.Seeker interface +func (mf *MemoryFile) Seek(offset int64, whence int) (int64, error) { + var newOffset int64 + switch whence { + case io.SeekStart: + newOffset = offset + case io.SeekCurrent: + newOffset = int64(mf.position) + offset + case io.SeekEnd: + newOffset = int64(len(mf.data)) + offset + default: + return 0, errInvalid + } + if newOffset < 0 { + return 0, errInvalid + } + mf.position = int(newOffset) + return newOffset, nil +} + +// Read implements the Read method of the io.Reader interface +func (mf *MemoryFile) Read(p []byte) (n int, err error) { + if mf.position >= len(mf.data) { + return 0, io.EOF + } + n = copy(p, mf.data[mf.position:]) + mf.position += n + return n, nil +} + +// Write implements the Write method of the io.Writer interface +func (mf *MemoryFile) Write(p []byte) (n int, err error) { + // Write data to memory + mf.data = append(mf.data, p...) + return len(p), nil +} + +// Close implements the Close method of the io.Closer interface +func (mf *MemoryFile) Close() error { + // Memory file does not need a close operation + return nil +} + +type AzureFile struct { + *MemoryFile +} + +func NewAzureFile(body io.ReadCloser) *AzureFile { + data, err := io.ReadAll(body) + defer body.Close() + if err != nil && err != io.EOF { + log.Warn("create azure file failed, read data failed", zap.Error(err)) + return &AzureFile{ + NewMemoryFile(nil), + } + } + + return &AzureFile{ + NewMemoryFile(data), + } +} diff --git a/internal/storage/file_test.go b/internal/storage/file_test.go new file mode 100644 index 0000000000000..64a8b45095259 --- /dev/null +++ b/internal/storage/file_test.go @@ -0,0 +1,88 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAzureFile(t *testing.T) { + t.Run("Read", func(t *testing.T) { + data := []byte("Test data for Read.") + azureFile := NewAzureFile(io.NopCloser(bytes.NewReader(data))) + buffer := make([]byte, 4) + n, err := azureFile.Read(buffer) + assert.NoError(t, err) + assert.Equal(t, 4, n) + assert.Equal(t, "Test", string(buffer)) + + buffer = make([]byte, 6) + n, err = azureFile.Read(buffer) + assert.NoError(t, err) + assert.Equal(t, 6, n) + assert.Equal(t, " data ", string(buffer)) + }) + + t.Run("ReadAt", func(t *testing.T) { + data := []byte("Test data for ReadAt.") + azureFile := NewAzureFile(io.NopCloser(bytes.NewReader(data))) + buffer := make([]byte, 4) + n, err := azureFile.ReadAt(buffer, 5) + assert.NoError(t, err) + assert.Equal(t, 4, n) + assert.Equal(t, "data", string(buffer)) + }) + + t.Run("Seek start", func(t *testing.T) { + data := []byte("Test data for Seek.") + azureFile := NewAzureFile(io.NopCloser(bytes.NewReader(data))) + offset, err := azureFile.Seek(10, io.SeekStart) + assert.NoError(t, err) + assert.Equal(t, int64(10), offset) + buffer := make([]byte, 4) + + n, err := azureFile.Read(buffer) + assert.NoError(t, err) + assert.Equal(t, 4, n) + assert.Equal(t, "for ", string(buffer)) + }) + + t.Run("Seek current", func(t *testing.T) { + data := []byte("Test data for Seek.") + azureFile := NewAzureFile(io.NopCloser(bytes.NewReader(data))) + + buffer := make([]byte, 4) + n, err := azureFile.Read(buffer) + assert.NoError(t, err) + assert.Equal(t, 4, n) + assert.Equal(t, "Test", string(buffer)) + + offset, err := azureFile.Seek(10, io.SeekCurrent) + assert.NoError(t, err) + assert.Equal(t, int64(14), offset) + + buffer = make([]byte, 4) + n, err = azureFile.Read(buffer) + assert.NoError(t, err) + assert.Equal(t, 4, n) + assert.Equal(t, "Seek", string(buffer)) + }) +} diff --git a/internal/storage/types.go b/internal/storage/types.go index b12ac8ebc0752..aa9ec9ff81da4 100644 --- a/internal/storage/types.go +++ b/internal/storage/types.go @@ -37,6 +37,8 @@ func (s StatsLogType) LogIdx() string { type FileReader interface { io.Reader io.Closer + io.ReaderAt + io.Seeker } // ChunkManager is to manager chunks. diff --git a/internal/util/importutil/import_util.go b/internal/util/importutil/import_util.go index e543830b9bd25..56b629c2661b3 100644 --- a/internal/util/importutil/import_util.go +++ b/internal/util/importutil/import_util.go @@ -30,6 +30,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" @@ -815,7 +816,7 @@ func pkToShard(pk interface{}, shardNum uint32) (uint32, error) { } else { intPK, ok := pk.(int64) if !ok { - log.Warn("Numpy parser: primary key field must be int64 or varchar") + log.Warn("parser: primary key field must be int64 or varchar") return 0, merr.WrapErrImportFailed("primary key field must be int64 or varchar") } hash, _ := typeutil.Hash32Int64(intPK) @@ -843,3 +844,270 @@ func UpdateKVInfo(infos *[]*commonpb.KeyValuePair, k string, v string) error { return nil } + +// appendFunc defines the methods to append data to storage.FieldData +func appendFunc(schema *schemapb.FieldSchema) func(src storage.FieldData, n int, target storage.FieldData) error { + switch schema.DataType { + case schemapb.DataType_Bool: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.BoolFieldData) + arr.Data = append(arr.Data, src.GetRow(n).(bool)) + return nil + } + case schemapb.DataType_Float: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.FloatFieldData) + arr.Data = append(arr.Data, src.GetRow(n).(float32)) + return nil + } + case schemapb.DataType_Double: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.DoubleFieldData) + arr.Data = append(arr.Data, src.GetRow(n).(float64)) + return nil + } + case schemapb.DataType_Int8: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.Int8FieldData) + arr.Data = append(arr.Data, src.GetRow(n).(int8)) + return nil + } + case schemapb.DataType_Int16: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.Int16FieldData) + arr.Data = append(arr.Data, src.GetRow(n).(int16)) + return nil + } + case schemapb.DataType_Int32: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.Int32FieldData) + arr.Data = append(arr.Data, src.GetRow(n).(int32)) + return nil + } + case schemapb.DataType_Int64: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.Int64FieldData) + arr.Data = append(arr.Data, src.GetRow(n).(int64)) + return nil + } + case schemapb.DataType_BinaryVector: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.BinaryVectorFieldData) + arr.Data = append(arr.Data, src.GetRow(n).([]byte)...) + return nil + } + case schemapb.DataType_FloatVector: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.FloatVectorFieldData) + arr.Data = append(arr.Data, src.GetRow(n).([]float32)...) + return nil + } + case schemapb.DataType_String, schemapb.DataType_VarChar: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.StringFieldData) + arr.Data = append(arr.Data, src.GetRow(n).(string)) + return nil + } + case schemapb.DataType_JSON: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.JSONFieldData) + arr.Data = append(arr.Data, src.GetRow(n).([]byte)) + return nil + } + case schemapb.DataType_Array: + return func(src storage.FieldData, n int, target storage.FieldData) error { + arr := target.(*storage.ArrayFieldData) + arr.Data = append(arr.Data, src.GetRow(n).(*schemapb.ScalarField)) + return nil + } + + default: + return nil + } +} + +func prepareAppendFunctions(collectionInfo *CollectionInfo) (map[string]func(src storage.FieldData, n int, target storage.FieldData) error, error) { + appendFunctions := make(map[string]func(src storage.FieldData, n int, target storage.FieldData) error) + for i := 0; i < len(collectionInfo.Schema.Fields); i++ { + schema := collectionInfo.Schema.Fields[i] + appendFuncErr := appendFunc(schema) + if appendFuncErr == nil { + log.Warn("parser: unsupported field data type") + return nil, fmt.Errorf("unsupported field data type: %d", schema.GetDataType()) + } + appendFunctions[schema.GetName()] = appendFuncErr + } + return appendFunctions, nil +} + +// checkRowCount check row count of each field, all fields row count must be equal +func checkRowCount(collectionInfo *CollectionInfo, fieldsData BlockData) (int, error) { + rowCount := 0 + rowCounter := make(map[string]int) + for i := 0; i < len(collectionInfo.Schema.Fields); i++ { + schema := collectionInfo.Schema.Fields[i] + if !schema.GetAutoID() { + v, ok := fieldsData[schema.GetFieldID()] + if !ok { + if schema.GetIsDynamic() { + // user might not provide numpy file for dynamic field, skip it, will auto-generate later + continue + } + log.Warn("field not provided", zap.String("fieldName", schema.GetName())) + return 0, fmt.Errorf("field '%s' not provided", schema.GetName()) + } + rowCounter[schema.GetName()] = v.RowNum() + if v.RowNum() > rowCount { + rowCount = v.RowNum() + } + } + } + + for name, count := range rowCounter { + if count != rowCount { + log.Warn("field row count is not equal to other fields row count", zap.String("fieldName", name), + zap.Int("rowCount", count), zap.Int("otherRowCount", rowCount)) + return 0, fmt.Errorf("field '%s' row count %d is not equal to other fields row count: %d", name, count, rowCount) + } + } + + return rowCount, nil +} + +// hashToPartition hash partition key to get an partition ID, return the first partition ID if no partition key exist +// CollectionInfo ensures only one partition ID in the PartitionIDs if no partition key exist +func hashToPartition(collectionInfo *CollectionInfo, fieldsData BlockData, rowNumber int) (int64, error) { + if collectionInfo.PartitionKey == nil { + // no partition key, directly return the target partition id + if len(collectionInfo.PartitionIDs) != 1 { + return 0, fmt.Errorf("collection '%s' partition list is empty", collectionInfo.Schema.Name) + } + return collectionInfo.PartitionIDs[0], nil + } + + partitionKeyID := collectionInfo.PartitionKey.GetFieldID() + fieldData := fieldsData[partitionKeyID] + value := fieldData.GetRow(rowNumber) + index, err := pkToShard(value, uint32(len(collectionInfo.PartitionIDs))) + if err != nil { + return 0, err + } + + return collectionInfo.PartitionIDs[index], nil +} + +// splitFieldsData is to split the in-memory data(parsed from column-based files) into shards +func splitFieldsData(collectionInfo *CollectionInfo, fieldsData BlockData, shards []ShardData, rowIDAllocator *allocator.IDAllocator) ([]int64, error) { + if len(fieldsData) == 0 { + log.Warn("fields data to split is empty") + return nil, fmt.Errorf("fields data to split is empty") + } + + if len(shards) != int(collectionInfo.ShardNum) { + log.Warn("block count is not equal to collection shard number", zap.Int("shardsLen", len(shards)), + zap.Int32("shardNum", collectionInfo.ShardNum)) + return nil, fmt.Errorf("block count %d is not equal to collection shard number %d", len(shards), collectionInfo.ShardNum) + } + + rowCount, err := checkRowCount(collectionInfo, fieldsData) + if err != nil { + return nil, err + } + + // generate auto id for primary key and rowid field + rowIDBegin, rowIDEnd, err := rowIDAllocator.Alloc(uint32(rowCount)) + if err != nil { + log.Warn("failed to alloc row ID", zap.Int("rowCount", rowCount), zap.Error(err)) + return nil, fmt.Errorf("failed to alloc %d rows ID, error: %w", rowCount, err) + } + + rowIDField, ok := fieldsData[common.RowIDField] + if !ok { + rowIDField = &storage.Int64FieldData{ + Data: make([]int64, 0), + } + fieldsData[common.RowIDField] = rowIDField + } + rowIDFieldArr := rowIDField.(*storage.Int64FieldData) + for i := rowIDBegin; i < rowIDEnd; i++ { + rowIDFieldArr.Data = append(rowIDFieldArr.Data, i) + } + + // reset the primary keys, as we know, only int64 pk can be auto-generated + primaryKey := collectionInfo.PrimaryKey + autoIDRange := make([]int64, 0) + if primaryKey.GetAutoID() { + log.Info("generating auto-id", zap.Int("rowCount", rowCount), zap.Int64("rowIDBegin", rowIDBegin)) + if primaryKey.GetDataType() != schemapb.DataType_Int64 { + log.Warn("primary key field is auto-generated but the field type is not int64") + return nil, fmt.Errorf("primary key field is auto-generated but the field type is not int64") + } + + primaryDataArr := &storage.Int64FieldData{ + Data: make([]int64, 0, rowCount), + } + for i := rowIDBegin; i < rowIDEnd; i++ { + primaryDataArr.Data = append(primaryDataArr.Data, i) + } + + fieldsData[primaryKey.GetFieldID()] = primaryDataArr + autoIDRange = append(autoIDRange, rowIDBegin, rowIDEnd) + } + + // if the primary key is not auto-gernerate and user doesn't provide, return error + primaryData, ok := fieldsData[primaryKey.GetFieldID()] + if !ok || primaryData.RowNum() <= 0 { + log.Warn("primary key field is not provided", zap.String("keyName", primaryKey.GetName())) + return nil, fmt.Errorf("primary key '%s' field data is not provided", primaryKey.GetName()) + } + + // prepare append functions + appendFunctions, err := prepareAppendFunctions(collectionInfo) + if err != nil { + return nil, err + } + + // split data into shards + for i := 0; i < rowCount; i++ { + // hash to a shard number and partition + pk := primaryData.GetRow(i) + shard, err := pkToShard(pk, uint32(collectionInfo.ShardNum)) + if err != nil { + return nil, err + } + + partitionID, err := hashToPartition(collectionInfo, fieldsData, i) + if err != nil { + return nil, err + } + + // set rowID field + rowIDField := shards[shard][partitionID][common.RowIDField].(*storage.Int64FieldData) + rowIDField.Data = append(rowIDField.Data, rowIDFieldArr.GetRow(i).(int64)) + + // append row to shard + for k := 0; k < len(collectionInfo.Schema.Fields); k++ { + schema := collectionInfo.Schema.Fields[k] + srcData := fieldsData[schema.GetFieldID()] + targetData := shards[shard][partitionID][schema.GetFieldID()] + if srcData == nil && schema.GetIsDynamic() { + // user might not provide numpy file for dynamic field, skip it, will auto-generate later + continue + } + if srcData == nil || targetData == nil { + log.Warn("cannot append data since source or target field data is nil", + zap.String("FieldName", schema.GetName()), + zap.Bool("sourceNil", srcData == nil), zap.Bool("targetNil", targetData == nil)) + return nil, fmt.Errorf("cannot append data for field '%s', possibly no any fields corresponding to this numpy file, or a required numpy file is not provided", + schema.GetName()) + } + appendFunc := appendFunctions[schema.GetName()] + err := appendFunc(srcData, i, targetData) + if err != nil { + return nil, err + } + } + } + + return autoIDRange, nil +} diff --git a/internal/util/importutil/import_wrapper.go b/internal/util/importutil/import_wrapper.go index 2c5f410169b65..838bc821e58d6 100644 --- a/internal/util/importutil/import_wrapper.go +++ b/internal/util/importutil/import_wrapper.go @@ -37,8 +37,9 @@ import ( ) const ( - JSONFileExt = ".json" - NumpyFileExt = ".npy" + JSONFileExt = ".json" + NumpyFileExt = ".npy" + ParquetFileExt = ".parquet" // parsers read JSON/Numpy/CSV files buffer by buffer, this limitation is to define the buffer size. ReadBufferSize = 16 * 1024 * 1024 // 16MB @@ -188,7 +189,7 @@ func (p *ImportWrapper) fileValidation(filePaths []string) (bool, error) { name, fileType := GetFileNameAndExt(filePath) // only allow json file, numpy file and csv file - if fileType != JSONFileExt && fileType != NumpyFileExt { + if fileType != JSONFileExt && fileType != NumpyFileExt && fileType != ParquetFileExt { log.Warn("import wrapper: unsupported file type", zap.String("filePath", filePath)) return false, merr.WrapErrImportFailed(fmt.Sprintf("unsupported file type: '%s'", filePath)) } @@ -206,7 +207,7 @@ func (p *ImportWrapper) fileValidation(filePaths []string) (bool, error) { return rowBased, merr.WrapErrImportFailed(fmt.Sprintf("unsupported file type for row-based mode: '%s'", filePath)) } } else { - if fileType != NumpyFileExt { + if fileType != NumpyFileExt && fileType != ParquetFileExt { log.Warn("import wrapper: unsupported file type for column-based mode", zap.String("filePath", filePath)) return rowBased, merr.WrapErrImportFailed(fmt.Sprintf("unsupported file type for column-based mode: '%s'", filePath)) } @@ -292,18 +293,34 @@ func (p *ImportWrapper) Import(filePaths []string, options ImportOptions) error printFieldsDataInfo(fields, "import wrapper: prepare to flush binlog data", filePaths) return p.flushFunc(fields, shardID, partitionID) } - parser, err := NewNumpyParser(p.ctx, p.collectionInfo, p.rowIDAllocator, p.binlogSize, - p.chunkManager, flushFunc, p.updateProgressPercent) - if err != nil { - return err - } + _, fileType := GetFileNameAndExt(filePaths[0]) + if fileType == NumpyFileExt { + parser, err := NewNumpyParser(p.ctx, p.collectionInfo, p.rowIDAllocator, p.binlogSize, + p.chunkManager, flushFunc, p.updateProgressPercent) + if err != nil { + return err + } - err = parser.Parse(filePaths) - if err != nil { - return err - } + err = parser.Parse(filePaths) + if err != nil { + return err + } - p.importResult.AutoIds = append(p.importResult.AutoIds, parser.IDRange()...) + p.importResult.AutoIds = append(p.importResult.AutoIds, parser.IDRange()...) + } else if fileType == ParquetFileExt { + parser, err := NewParquetParser(p.ctx, p.collectionInfo, p.rowIDAllocator, p.binlogSize, + p.chunkManager, filePaths[0], flushFunc, p.updateProgressPercent) + if err != nil { + return err + } + + err = parser.Parse() + if err != nil { + return err + } + + p.importResult.AutoIds = append(p.importResult.AutoIds, parser.IDRange()...) + } // trigger after parse finished triggerGC() diff --git a/internal/util/importutil/import_wrapper_test.go b/internal/util/importutil/import_wrapper_test.go index ec0cc15fef065..017ea2479e704 100644 --- a/internal/util/importutil/import_wrapper_test.go +++ b/internal/util/importutil/import_wrapper_test.go @@ -28,6 +28,7 @@ import ( "testing" "time" + "github.com/apache/arrow/go/v12/parquet" "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "golang.org/x/exp/mmap" @@ -144,6 +145,10 @@ func (mc *MockChunkManager) RemoveWithPrefix(ctx context.Context, prefix string) return nil } +func (mc *MockChunkManager) NewParquetReaderAtSeeker(fileName string) (parquet.ReaderAtSeeker, error) { + panic("implement me") +} + type rowCounterTest struct { rowCount int callTime int diff --git a/internal/util/importutil/numpy_parser.go b/internal/util/importutil/numpy_parser.go index 9eb7510f21445..0d282e3644ccc 100644 --- a/internal/util/importutil/numpy_parser.go +++ b/internal/util/importutil/numpy_parser.go @@ -26,7 +26,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/storage" - "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/timerecord" @@ -434,7 +433,7 @@ func (p *NumpyParser) consume(columnReaders []*NumpyColumnReader) error { updateProgress(totalRead) tr.Record("readData") // split data to shards - err = p.splitFieldsData(segmentData, shards) + p.autoIDRange, err = splitFieldsData(p.collectionInfo, segmentData, shards, p.rowIDAllocator) if err != nil { return err } @@ -631,262 +630,3 @@ func (p *NumpyParser) readData(columnReader *NumpyColumnReader, rowCount int) (s columnReader.fieldName)) } } - -// appendFunc defines the methods to append data to storage.FieldData -func (p *NumpyParser) appendFunc(schema *schemapb.FieldSchema) func(src storage.FieldData, n int, target storage.FieldData) error { - switch schema.DataType { - case schemapb.DataType_Bool: - return func(src storage.FieldData, n int, target storage.FieldData) error { - arr := target.(*storage.BoolFieldData) - arr.Data = append(arr.Data, src.GetRow(n).(bool)) - return nil - } - case schemapb.DataType_Float: - return func(src storage.FieldData, n int, target storage.FieldData) error { - arr := target.(*storage.FloatFieldData) - arr.Data = append(arr.Data, src.GetRow(n).(float32)) - return nil - } - case schemapb.DataType_Double: - return func(src storage.FieldData, n int, target storage.FieldData) error { - arr := target.(*storage.DoubleFieldData) - arr.Data = append(arr.Data, src.GetRow(n).(float64)) - return nil - } - case schemapb.DataType_Int8: - return func(src storage.FieldData, n int, target storage.FieldData) error { - arr := target.(*storage.Int8FieldData) - arr.Data = append(arr.Data, src.GetRow(n).(int8)) - return nil - } - case schemapb.DataType_Int16: - return func(src storage.FieldData, n int, target storage.FieldData) error { - arr := target.(*storage.Int16FieldData) - arr.Data = append(arr.Data, src.GetRow(n).(int16)) - return nil - } - case schemapb.DataType_Int32: - return func(src storage.FieldData, n int, target storage.FieldData) error { - arr := target.(*storage.Int32FieldData) - arr.Data = append(arr.Data, src.GetRow(n).(int32)) - return nil - } - case schemapb.DataType_Int64: - return func(src storage.FieldData, n int, target storage.FieldData) error { - arr := target.(*storage.Int64FieldData) - arr.Data = append(arr.Data, src.GetRow(n).(int64)) - return nil - } - case schemapb.DataType_BinaryVector: - return func(src storage.FieldData, n int, target storage.FieldData) error { - arr := target.(*storage.BinaryVectorFieldData) - arr.Data = append(arr.Data, src.GetRow(n).([]byte)...) - return nil - } - case schemapb.DataType_FloatVector: - return func(src storage.FieldData, n int, target storage.FieldData) error { - arr := target.(*storage.FloatVectorFieldData) - arr.Data = append(arr.Data, src.GetRow(n).([]float32)...) - return nil - } - case schemapb.DataType_String, schemapb.DataType_VarChar: - return func(src storage.FieldData, n int, target storage.FieldData) error { - arr := target.(*storage.StringFieldData) - arr.Data = append(arr.Data, src.GetRow(n).(string)) - return nil - } - case schemapb.DataType_JSON: - return func(src storage.FieldData, n int, target storage.FieldData) error { - arr := target.(*storage.JSONFieldData) - arr.Data = append(arr.Data, src.GetRow(n).([]byte)) - return nil - } - default: - return nil - } -} - -func (p *NumpyParser) prepareAppendFunctions() (map[string]func(src storage.FieldData, n int, target storage.FieldData) error, error) { - appendFunctions := make(map[string]func(src storage.FieldData, n int, target storage.FieldData) error) - for i := 0; i < len(p.collectionInfo.Schema.Fields); i++ { - schema := p.collectionInfo.Schema.Fields[i] - appendFuncErr := p.appendFunc(schema) - if appendFuncErr == nil { - log.Warn("Numpy parser: unsupported field data type") - return nil, merr.WrapErrImportFailed(fmt.Sprintf("unsupported field data type: %d", schema.GetDataType())) - } - appendFunctions[schema.GetName()] = appendFuncErr - } - return appendFunctions, nil -} - -// checkRowCount check row count of each field, all fields row count must be equal -func (p *NumpyParser) checkRowCount(fieldsData BlockData) (int, error) { - rowCount := 0 - rowCounter := make(map[string]int) - for i := 0; i < len(p.collectionInfo.Schema.Fields); i++ { - schema := p.collectionInfo.Schema.Fields[i] - if !schema.GetAutoID() { - v, ok := fieldsData[schema.GetFieldID()] - if !ok { - if schema.GetIsDynamic() { - // user might not provide numpy file for dynamic field, skip it, will auto-generate later - continue - } - log.Warn("Numpy parser: field not provided", zap.String("fieldName", schema.GetName())) - return 0, merr.WrapErrImportFailed(fmt.Sprintf("field '%s' not provided", schema.GetName())) - } - rowCounter[schema.GetName()] = v.RowNum() - if v.RowNum() > rowCount { - rowCount = v.RowNum() - } - } - } - - for name, count := range rowCounter { - if count != rowCount { - log.Warn("Numpy parser: field row count is not equal to other fields row count", zap.String("fieldName", name), - zap.Int("rowCount", count), zap.Int("otherRowCount", rowCount)) - return 0, merr.WrapErrImportFailed(fmt.Sprintf("field '%s' row count %d is not equal to other fields row count: %d", name, count, rowCount)) - } - } - - return rowCount, nil -} - -// splitFieldsData is to split the in-memory data(parsed from column-based files) into shards -func (p *NumpyParser) splitFieldsData(fieldsData BlockData, shards []ShardData) error { - if len(fieldsData) == 0 { - log.Warn("Numpy parser: fields data to split is empty") - return merr.WrapErrImportFailed("fields data to split is empty") - } - - if len(shards) != int(p.collectionInfo.ShardNum) { - log.Warn("Numpy parser: block count is not equal to collection shard number", zap.Int("shardsLen", len(shards)), - zap.Int32("shardNum", p.collectionInfo.ShardNum)) - return merr.WrapErrImportFailed(fmt.Sprintf("block count %d is not equal to collection shard number %d", len(shards), p.collectionInfo.ShardNum)) - } - - rowCount, err := p.checkRowCount(fieldsData) - if err != nil { - return err - } - - // generate auto id for primary key and rowid field - rowIDBegin, rowIDEnd, err := p.rowIDAllocator.Alloc(uint32(rowCount)) - if err != nil { - log.Warn("Numpy parser: failed to alloc row ID", zap.Int("rowCount", rowCount), zap.Error(err)) - return merr.WrapErrImportFailed(fmt.Sprintf("failed to alloc %d rows ID, error: %v", rowCount, err)) - } - - rowIDField, ok := fieldsData[common.RowIDField] - if !ok { - rowIDField = &storage.Int64FieldData{ - Data: make([]int64, 0), - } - fieldsData[common.RowIDField] = rowIDField - } - rowIDFieldArr := rowIDField.(*storage.Int64FieldData) - for i := rowIDBegin; i < rowIDEnd; i++ { - rowIDFieldArr.Data = append(rowIDFieldArr.Data, i) - } - - // reset the primary keys, as we know, only int64 pk can be auto-generated - primaryKey := p.collectionInfo.PrimaryKey - if primaryKey.GetAutoID() { - log.Info("Numpy parser: generating auto-id", zap.Int("rowCount", rowCount), zap.Int64("rowIDBegin", rowIDBegin)) - if primaryKey.GetDataType() != schemapb.DataType_Int64 { - log.Warn("Numpy parser: primary key field is auto-generated but the field type is not int64") - return merr.WrapErrImportFailed("primary key field is auto-generated but the field type is not int64") - } - - primaryDataArr := &storage.Int64FieldData{ - Data: make([]int64, 0, rowCount), - } - for i := rowIDBegin; i < rowIDEnd; i++ { - primaryDataArr.Data = append(primaryDataArr.Data, i) - } - - fieldsData[primaryKey.GetFieldID()] = primaryDataArr - p.autoIDRange = append(p.autoIDRange, rowIDBegin, rowIDEnd) - } - - // if the primary key is not auto-gernerate and user doesn't provide, return error - primaryData, ok := fieldsData[primaryKey.GetFieldID()] - if !ok || primaryData.RowNum() <= 0 { - log.Warn("Numpy parser: primary key field is not provided", zap.String("keyName", primaryKey.GetName())) - return merr.WrapErrImportFailed(fmt.Sprintf("primary key '%s' field data is not provided", primaryKey.GetName())) - } - - // prepare append functions - appendFunctions, err := p.prepareAppendFunctions() - if err != nil { - return err - } - - // split data into shards - for i := 0; i < rowCount; i++ { - // hash to a shard number and partition - pk := primaryData.GetRow(i) - shard, err := pkToShard(pk, uint32(p.collectionInfo.ShardNum)) - if err != nil { - return err - } - - partitionID, err := p.hashToPartition(fieldsData, i) - if err != nil { - return err - } - - // set rowID field - rowIDField := shards[shard][partitionID][common.RowIDField].(*storage.Int64FieldData) - rowIDField.Data = append(rowIDField.Data, rowIDFieldArr.GetRow(i).(int64)) - - // append row to shard - for k := 0; k < len(p.collectionInfo.Schema.Fields); k++ { - schema := p.collectionInfo.Schema.Fields[k] - srcData := fieldsData[schema.GetFieldID()] - targetData := shards[shard][partitionID][schema.GetFieldID()] - if srcData == nil && schema.GetIsDynamic() { - // user might not provide numpy file for dynamic field, skip it, will auto-generate later - continue - } - if srcData == nil || targetData == nil { - log.Warn("Numpy parser: cannot append data since source or target field data is nil", - zap.String("FieldName", schema.GetName()), - zap.Bool("sourceNil", srcData == nil), zap.Bool("targetNil", targetData == nil)) - return merr.WrapErrImportFailed(fmt.Sprintf("cannot append data for field '%s', possibly no any fields corresponding to this numpy file, or a required numpy file is not provided", - schema.GetName())) - } - appendFunc := appendFunctions[schema.GetName()] - err := appendFunc(srcData, i, targetData) - if err != nil { - return err - } - } - } - - return nil -} - -// hashToPartition hash partition key to get a partition ID, return the first partition ID if no partition key exist -// CollectionInfo ensures only one partition ID in the PartitionIDs if no partition key exist -func (p *NumpyParser) hashToPartition(fieldsData BlockData, rowNumber int) (int64, error) { - if p.collectionInfo.PartitionKey == nil { - // no partition key, directly return the target partition id - if len(p.collectionInfo.PartitionIDs) != 1 { - return 0, merr.WrapErrImportFailed(fmt.Sprintf("collection '%s' partition list is empty", p.collectionInfo.Schema.Name)) - } - return p.collectionInfo.PartitionIDs[0], nil - } - - partitionKeyID := p.collectionInfo.PartitionKey.GetFieldID() - fieldData := fieldsData[partitionKeyID] - value := fieldData.GetRow(rowNumber) - index, err := pkToShard(value, uint32(len(p.collectionInfo.PartitionIDs))) - if err != nil { - return 0, err - } - - return p.collectionInfo.PartitionIDs[index], nil -} diff --git a/internal/util/importutil/numpy_parser_test.go b/internal/util/importutil/numpy_parser_test.go index 4e20130274d9d..62b89fa39b5df 100644 --- a/internal/util/importutil/numpy_parser_test.go +++ b/internal/util/importutil/numpy_parser_test.go @@ -670,7 +670,7 @@ func Test_NumpyParserPrepareAppendFunctions(t *testing.T) { parser := createNumpyParser(t) // succeed - appendFuncs, err := parser.prepareAppendFunctions() + appendFuncs, err := prepareAppendFunctions(parser.collectionInfo) assert.NoError(t, err) assert.Equal(t, len(createNumpySchema().Fields), len(appendFuncs)) @@ -694,7 +694,7 @@ func Test_NumpyParserPrepareAppendFunctions(t *testing.T) { }, } parser.collectionInfo.resetSchema(schema) - appendFuncs, err = parser.prepareAppendFunctions() + appendFuncs, err = prepareAppendFunctions(parser.collectionInfo) assert.Error(t, err) assert.Nil(t, appendFuncs) } @@ -720,13 +720,13 @@ func Test_NumpyParserCheckRowCount(t *testing.T) { segmentData[reader.fieldID] = fieldData } - rowCount, err := parser.checkRowCount(segmentData) + rowCount, err := checkRowCount(parser.collectionInfo, segmentData) assert.NoError(t, err) assert.Equal(t, 5, rowCount) // field data missed delete(segmentData, 102) - rowCount, err = parser.checkRowCount(segmentData) + rowCount, err = checkRowCount(parser.collectionInfo, segmentData) assert.Error(t, err) assert.Zero(t, rowCount) @@ -759,7 +759,7 @@ func Test_NumpyParserCheckRowCount(t *testing.T) { } parser.collectionInfo.resetSchema(schema) - rowCount, err = parser.checkRowCount(segmentData) + rowCount, err = checkRowCount(parser.collectionInfo, segmentData) assert.Error(t, err) assert.Zero(t, rowCount) @@ -790,7 +790,7 @@ func Test_NumpyParserCheckRowCount(t *testing.T) { } parser.collectionInfo.resetSchema(schema) - rowCount, err = parser.checkRowCount(segmentData) + rowCount, err = checkRowCount(parser.collectionInfo, segmentData) assert.NoError(t, err) assert.Equal(t, 3, rowCount) } @@ -804,7 +804,7 @@ func Test_NumpyParserSplitFieldsData(t *testing.T) { parser := createNumpyParser(t) t.Run("segemnt data is empty", func(t *testing.T) { - err = parser.splitFieldsData(make(BlockData), nil) + parser.autoIDRange, err = splitFieldsData(parser.collectionInfo, make(BlockData), nil, parser.rowIDAllocator) assert.Error(t, err) }) @@ -827,7 +827,7 @@ func Test_NumpyParserSplitFieldsData(t *testing.T) { fieldsData := createFieldsData(sampleSchema(), 0) shards := createShardsData(sampleSchema(), fieldsData, 1, []int64{1}) segmentData := genFieldsDataFunc() - err = parser.splitFieldsData(segmentData, shards) + parser.autoIDRange, err = splitFieldsData(parser.collectionInfo, segmentData, shards, parser.rowIDAllocator) assert.Error(t, err) }) @@ -863,7 +863,7 @@ func Test_NumpyParserSplitFieldsData(t *testing.T) { parser.collectionInfo.ShardNum = 2 fieldsData := createFieldsData(schema, 0) shards := createShardsData(schema, fieldsData, 2, []int64{1}) - err = parser.splitFieldsData(segmentData, shards) + parser.autoIDRange, err = splitFieldsData(parser.collectionInfo, segmentData, shards, parser.rowIDAllocator) assert.Error(t, err) }) @@ -874,7 +874,7 @@ func Test_NumpyParserSplitFieldsData(t *testing.T) { fieldsData := createFieldsData(sampleSchema(), 0) shards := createShardsData(sampleSchema(), fieldsData, 2, []int64{1}) segmentData := genFieldsDataFunc() - err = parser.splitFieldsData(segmentData, shards) + parser.autoIDRange, err = splitFieldsData(parser.collectionInfo, segmentData, shards, parser.rowIDAllocator) assert.Error(t, err) parser.rowIDAllocator = newIDAllocator(ctx, t, nil) }) @@ -888,7 +888,7 @@ func Test_NumpyParserSplitFieldsData(t *testing.T) { fieldsData := createFieldsData(sampleSchema(), 0) shards := createShardsData(sampleSchema(), fieldsData, 2, []int64{partitionID}) segmentData := genFieldsDataFunc() - err = parser.splitFieldsData(segmentData, shards) + parser.autoIDRange, err = splitFieldsData(parser.collectionInfo, segmentData, shards, parser.rowIDAllocator) assert.NoError(t, err) assert.NotEmpty(t, parser.autoIDRange) @@ -900,7 +900,7 @@ func Test_NumpyParserSplitFieldsData(t *testing.T) { // target field data is nil shards[0][partitionID][105] = nil - err = parser.splitFieldsData(segmentData, shards) + parser.autoIDRange, err = splitFieldsData(parser.collectionInfo, segmentData, shards, parser.rowIDAllocator) assert.Error(t, err) schema.AutoID = false @@ -935,7 +935,7 @@ func Test_NumpyParserSplitFieldsData(t *testing.T) { segmentData[101] = &storage.Int64FieldData{ Data: []int64{1, 2, 4}, } - err = parser.splitFieldsData(segmentData, shards) + parser.autoIDRange, err = splitFieldsData(parser.collectionInfo, segmentData, shards, parser.rowIDAllocator) assert.NoError(t, err) }) } @@ -1203,14 +1203,14 @@ func Test_NumpyParserHashToPartition(t *testing.T) { // no partition key, partition ID list greater than 1, return error parser.collectionInfo.PartitionIDs = []int64{1, 2} - partID, err := parser.hashToPartition(blockData, 1) + partID, err := hashToPartition(parser.collectionInfo, blockData, 1) assert.Error(t, err) assert.Zero(t, partID) // no partition key, return the only one partition ID partitionID := int64(5) parser.collectionInfo.PartitionIDs = []int64{partitionID} - partID, err = parser.hashToPartition(blockData, 1) + partID, err = hashToPartition(parser.collectionInfo, blockData, 1) assert.NoError(t, err) assert.Equal(t, partitionID, partID) @@ -1219,7 +1219,7 @@ func Test_NumpyParserHashToPartition(t *testing.T) { err = parser.collectionInfo.resetSchema(schema) assert.NoError(t, err) partitionIDs := []int64{3, 4, 5, 6} - partID, err = parser.hashToPartition(blockData, 1) + partID, err = hashToPartition(parser.collectionInfo, blockData, 1) assert.NoError(t, err) assert.Contains(t, partitionIDs, partID) @@ -1227,7 +1227,7 @@ func Test_NumpyParserHashToPartition(t *testing.T) { blockData[102] = &storage.FloatFieldData{ Data: []float32{1, 2, 3, 4, 5}, } - partID, err = parser.hashToPartition(blockData, 1) + partID, err = hashToPartition(parser.collectionInfo, blockData, 1) assert.Error(t, err) assert.Zero(t, partID) } diff --git a/internal/util/importutil/parquet_column_reader.go b/internal/util/importutil/parquet_column_reader.go new file mode 100644 index 0000000000000..70e816ca18324 --- /dev/null +++ b/internal/util/importutil/parquet_column_reader.go @@ -0,0 +1,79 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importutil + +import ( + "fmt" + + "github.com/apache/arrow/go/v12/arrow" + "github.com/apache/arrow/go/v12/arrow/array" + "github.com/apache/arrow/go/v12/parquet/pqarrow" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +type ParquetColumnReader struct { + fieldName string + fieldID int64 + columnIndex int + // columnSchema *parquet.SchemaElement + dataType schemapb.DataType + elementType schemapb.DataType + columnReader *pqarrow.ColumnReader + dimension int +} + +func ReadData[T any](pcr *ParquetColumnReader, count int64, getDataFunc func(chunk arrow.Array) ([]T, error)) ([]T, error) { + chunked, err := pcr.columnReader.NextBatch(count) + if err != nil { + return nil, err + } + data := make([]T, 0, count) + for _, chunk := range chunked.Chunks() { + chunkData, err := getDataFunc(chunk) + if err != nil { + return nil, err + } + data = append(data, chunkData...) + } + return data, nil +} + +func ReadArrayData[T any](pcr *ParquetColumnReader, count int64, getArrayData func(offsets []int32, array arrow.Array) ([][]T, error)) ([][]T, error) { + chunked, err := pcr.columnReader.NextBatch(count) + if err != nil { + return nil, err + } + arrayData := make([][]T, 0, count) + for _, chunk := range chunked.Chunks() { + listReader, ok := chunk.(*array.List) + if !ok { + log.Warn("the column data in parquet is not array", zap.String("fieldName", pcr.fieldName)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("the column data in parquet is not array of field: %s", pcr.fieldName)) + } + offsets := listReader.Offsets() + chunkData, err := getArrayData(offsets, listReader.ListValues()) + if err != nil { + return nil, err + } + arrayData = append(arrayData, chunkData...) + } + return arrayData, nil +} diff --git a/internal/util/importutil/parquet_parser.go b/internal/util/importutil/parquet_parser.go new file mode 100644 index 0000000000000..cf0951660e5b1 --- /dev/null +++ b/internal/util/importutil/parquet_parser.go @@ -0,0 +1,932 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importutil + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/apache/arrow/go/v12/arrow" + "github.com/apache/arrow/go/v12/arrow/array" + "github.com/apache/arrow/go/v12/arrow/memory" + "github.com/apache/arrow/go/v12/parquet/file" + "github.com/apache/arrow/go/v12/parquet/pqarrow" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/timerecord" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +// ParquetParser is analogous to the ParquetColumnReader, but for Parquet files +type ParquetParser struct { + ctx context.Context // for canceling parse process + collectionInfo *CollectionInfo // collection details including schema + rowIDAllocator *allocator.IDAllocator // autoid allocator + blockSize int64 // maximum size of a read block(unit:byte) + chunkManager storage.ChunkManager // storage interfaces to browse/read the files + autoIDRange []int64 // auto-generated id range, for example: [1, 10, 20, 25] means id from 1 to 10 and 20 to 25 + callFlushFunc ImportFlushFunc // call back function to flush segment + updateProgressFunc func(percent int64) // update working progress percent value + columnMap map[string]*ParquetColumnReader + reader *file.Reader + fileReader *pqarrow.FileReader +} + +// NewParquetParser is helper function to create a ParquetParser +func NewParquetParser(ctx context.Context, + collectionInfo *CollectionInfo, + idAlloc *allocator.IDAllocator, + blockSize int64, + chunkManager storage.ChunkManager, + filePath string, + flushFunc ImportFlushFunc, + updateProgressFunc func(percent int64), +) (*ParquetParser, error) { + if collectionInfo == nil { + log.Warn("Parquet parser: collection schema is nil") + return nil, merr.WrapErrImportFailed("collection schema is nil") + } + + if idAlloc == nil { + log.Warn("Parquet parser: id allocator is nil") + return nil, merr.WrapErrImportFailed("id allocator is nil") + } + + if chunkManager == nil { + log.Warn("Parquet parser: chunk manager pointer is nil") + return nil, merr.WrapErrImportFailed("chunk manager pointer is nil") + } + + if flushFunc == nil { + log.Warn("Parquet parser: flush function is nil") + return nil, merr.WrapErrImportFailed("flush function is nil") + } + + cmReader, err := chunkManager.Reader(ctx, filePath) + if err != nil { + log.Warn("create chunk manager reader failed") + return nil, err + } + + reader, err := file.NewParquetReader(cmReader) + if err != nil { + log.Warn("create parquet reader failed", zap.Error(err)) + return nil, err + } + + fileReader, err := pqarrow.NewFileReader(reader, pqarrow.ArrowReadProperties{BatchSize: 1}, memory.DefaultAllocator) + if err != nil { + log.Warn("create arrow parquet file reader failed", zap.Error(err)) + return nil, err + } + + parser := &ParquetParser{ + ctx: ctx, + collectionInfo: collectionInfo, + rowIDAllocator: idAlloc, + blockSize: blockSize, + chunkManager: chunkManager, + autoIDRange: make([]int64, 0), + callFlushFunc: flushFunc, + updateProgressFunc: updateProgressFunc, + columnMap: make(map[string]*ParquetColumnReader), + fileReader: fileReader, + reader: reader, + } + + return parser, nil +} + +func (p *ParquetParser) IDRange() []int64 { + return p.autoIDRange +} + +// Parse is the function entry +func (p *ParquetParser) Parse() error { + err := p.createReaders() + defer p.Close() + if err != nil { + return err + } + + // read all data from the Parquet files + err = p.consume() + if err != nil { + return err + } + + return nil +} + +func (p *ParquetParser) createReaders() error { + schema, err := p.fileReader.Schema() + if err != nil { + log.Warn("can't schema from file", zap.Error(err)) + return err + } + for _, field := range p.collectionInfo.Schema.GetFields() { + dim, _ := getFieldDimension(field) + parquetColumnReader := &ParquetColumnReader{ + fieldName: field.GetName(), + fieldID: field.GetFieldID(), + dataType: field.GetDataType(), + elementType: field.GetElementType(), + dimension: dim, + } + fields, exist := schema.FieldsByName(field.GetName()) + if !exist { + if !(field.GetIsPrimaryKey() && field.GetAutoID()) && !field.GetIsDynamic() { + log.Warn("there is no field in parquet file", zap.String("fieldName", field.GetName())) + return merr.WrapErrImportFailed(fmt.Sprintf("there is no field: %s in parquet file", field.GetName())) + } + } else { + if len(fields) != 1 { + log.Warn("there is multi field of fieldName", zap.String("fieldName", field.GetName()), zap.Any("file fields", fields)) + return merr.WrapErrImportFailed(fmt.Sprintf("there is multi field of fieldName: %s", field.GetName())) + } + if !verifyFieldSchema(field.GetDataType(), field.GetElementType(), fields[0]) { + log.Warn("field schema is not match", + zap.String("collection schema", field.GetDataType().String()), + zap.String("file schema", fields[0].Type.Name())) + return merr.WrapErrImportFailed(fmt.Sprintf("field schema is not match, collection field dataType: %s, file field dataType:%s", field.GetDataType().String(), fields[0].Type.Name())) + } + indices := schema.FieldIndices(field.GetName()) + if len(indices) != 1 { + log.Warn("field is not match", zap.String("fieldName", field.GetName()), zap.Ints("indices", indices)) + return merr.WrapErrImportFailed(fmt.Sprintf("there is %d indices of fieldName: %s", len(indices), field.GetName())) + } + parquetColumnReader.columnIndex = indices[0] + columnReader, err := p.fileReader.GetColumn(p.ctx, parquetColumnReader.columnIndex) + if err != nil { + log.Warn("get column reader failed", zap.String("fieldName", field.GetName()), zap.Error(err)) + return err + } + parquetColumnReader.columnReader = columnReader + p.columnMap[field.GetName()] = parquetColumnReader + } + } + return nil +} + +func verifyFieldSchema(dataType, elementType schemapb.DataType, fileField arrow.Field) bool { + switch fileField.Type.ID() { + case arrow.BOOL: + return dataType == schemapb.DataType_Bool + case arrow.INT8: + return dataType == schemapb.DataType_Int8 + case arrow.INT16: + return dataType == schemapb.DataType_Int16 + case arrow.INT32: + return dataType == schemapb.DataType_Int32 + case arrow.INT64: + return dataType == schemapb.DataType_Int64 + case arrow.FLOAT32: + return dataType == schemapb.DataType_Float + case arrow.FLOAT64: + return dataType == schemapb.DataType_Double + case arrow.STRING: + return dataType == schemapb.DataType_VarChar || dataType == schemapb.DataType_String || dataType == schemapb.DataType_JSON + case arrow.LIST: + if dataType != schemapb.DataType_Array && dataType != schemapb.DataType_FloatVector && + dataType != schemapb.DataType_Float16Vector && dataType != schemapb.DataType_BinaryVector { + return false + } + if dataType == schemapb.DataType_Array { + return verifyFieldSchema(elementType, schemapb.DataType_None, fileField.Type.(*arrow.ListType).ElemField()) + } + return true + } + return false +} + +// Close closes the parquet file reader +func (p *ParquetParser) Close() { + p.reader.Close() +} + +// calcRowCountPerBlock calculates a proper value for a batch row count to read file +func (p *ParquetParser) calcRowCountPerBlock() (int64, error) { + sizePerRecord, err := typeutil.EstimateSizePerRecord(p.collectionInfo.Schema) + if err != nil { + log.Warn("Parquet parser: failed to estimate size of each row", zap.Error(err)) + return 0, merr.WrapErrImportFailed(fmt.Sprintf("failed to estimate size of each row: %s", err.Error())) + } + + if sizePerRecord <= 0 { + log.Warn("Parquet parser: failed to estimate size of each row, the collection schema might be empty") + return 0, merr.WrapErrImportFailed("failed to estimate size of each row: the collection schema might be empty") + } + + // the sizePerRecord is estimate value, if the schema contains varchar field, the value is not accurate + // we will read data block by block, by default, each block size is 16MB + // rowCountPerBlock is the estimated row count for a block + rowCountPerBlock := p.blockSize / int64(sizePerRecord) + if rowCountPerBlock <= 0 { + rowCountPerBlock = 1 // make sure the value is positive + } + + log.Info("Parquet parser: calculate row count per block to read file", zap.Int64("rowCountPerBlock", rowCountPerBlock), + zap.Int64("blockSize", p.blockSize), zap.Int("sizePerRecord", sizePerRecord)) + return rowCountPerBlock, nil +} + +// consume method reads Parquet data section into a storage.FieldData +// please note it will require a large memory block(the memory size is almost equal to Parquet file size) +func (p *ParquetParser) consume() error { + rowCountPerBlock, err := p.calcRowCountPerBlock() + if err != nil { + return err + } + + updateProgress := func(readRowCount int64) { + if p.updateProgressFunc != nil && p.reader != nil && p.reader.NumRows() > 0 { + percent := (readRowCount * ProgressValueForPersist) / p.reader.NumRows() + log.Info("Parquet parser: working progress", zap.Int64("readRowCount", readRowCount), + zap.Int64("totalRowCount", p.reader.NumRows()), zap.Int64("percent", percent)) + p.updateProgressFunc(percent) + } + } + + // prepare shards + shards := make([]ShardData, 0, p.collectionInfo.ShardNum) + for i := 0; i < int(p.collectionInfo.ShardNum); i++ { + shardData := initShardData(p.collectionInfo.Schema, p.collectionInfo.PartitionIDs) + if shardData == nil { + log.Warn("Parquet parser: failed to initialize FieldData list") + return merr.WrapErrImportFailed("failed to initialize FieldData list") + } + shards = append(shards, shardData) + } + tr := timerecord.NewTimeRecorder("consume performance") + defer tr.Elapse("end") + // read data from files, batch by batch + totalRead := 0 + for { + readRowCount := 0 + segmentData := make(BlockData) + for _, reader := range p.columnMap { + fieldData, err := p.readData(reader, rowCountPerBlock) + if err != nil { + return err + } + if readRowCount == 0 { + readRowCount = fieldData.RowNum() + } else if readRowCount != fieldData.RowNum() { + log.Warn("Parquet parser: data block's row count mismatch", zap.Int("firstBlockRowCount", readRowCount), + zap.Int("thisBlockRowCount", fieldData.RowNum()), zap.Int64("rowCountPerBlock", rowCountPerBlock), + zap.String("current field", reader.fieldName)) + return merr.WrapErrImportFailed(fmt.Sprintf("data block's row count mismatch: %d vs %d", readRowCount, fieldData.RowNum())) + } + + segmentData[reader.fieldID] = fieldData + } + + // nothing to read + if readRowCount == 0 { + break + } + totalRead += readRowCount + updateProgress(int64(totalRead)) + tr.Record("readData") + // split data to shards + p.autoIDRange, err = splitFieldsData(p.collectionInfo, segmentData, shards, p.rowIDAllocator) + if err != nil { + return err + } + tr.Record("splitFieldsData") + // when the estimated size is close to blockSize, save to binlog + err = tryFlushBlocks(p.ctx, shards, p.collectionInfo.Schema, p.callFlushFunc, p.blockSize, MaxTotalSizeInMemory, false) + if err != nil { + return err + } + tr.Record("tryFlushBlocks") + } + + // force flush at the end + return tryFlushBlocks(p.ctx, shards, p.collectionInfo.Schema, p.callFlushFunc, p.blockSize, MaxTotalSizeInMemory, true) +} + +// readData method reads Parquet data section into a storage.FieldData +func (p *ParquetParser) readData(columnReader *ParquetColumnReader, rowCount int64) (storage.FieldData, error) { + switch columnReader.dataType { + case schemapb.DataType_Bool: + data, err := ReadData(columnReader, rowCount, func(chunk arrow.Array) ([]bool, error) { + boolReader, ok := chunk.(*array.Boolean) + boolData := make([]bool, 0) + if !ok { + log.Warn("the column data in parquet is not bool", zap.String("fieldName", columnReader.fieldName)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("the column data in parquet is not bool of field: %s", columnReader.fieldName)) + } + for i := 0; i < boolReader.Data().Len(); i++ { + boolData = append(boolData, boolReader.Value(i)) + } + return boolData, nil + }) + if err != nil { + log.Warn("Parquet parser: failed to read bool array", zap.Error(err)) + return nil, err + } + + return &storage.BoolFieldData{ + Data: data, + }, nil + case schemapb.DataType_Int8: + data, err := ReadData(columnReader, rowCount, func(chunk arrow.Array) ([]int8, error) { + int8Reader, ok := chunk.(*array.Int8) + int8Data := make([]int8, 0) + if !ok { + log.Warn("the column data in parquet is not int8", zap.String("fieldName", columnReader.fieldName)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("the column data in parquet is not int8 of field: %s", columnReader.fieldName)) + } + for i := 0; i < int8Reader.Data().Len(); i++ { + int8Data = append(int8Data, int8Reader.Value(i)) + } + return int8Data, nil + }) + if err != nil { + log.Warn("Parquet parser: failed to read int8 array", zap.Error(err)) + return nil, err + } + + return &storage.Int8FieldData{ + Data: data, + }, nil + case schemapb.DataType_Int16: + data, err := ReadData(columnReader, rowCount, func(chunk arrow.Array) ([]int16, error) { + int16Reader, ok := chunk.(*array.Int16) + int16Data := make([]int16, 0) + if !ok { + log.Warn("the column data in parquet is not int16", zap.String("fieldName", columnReader.fieldName)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("the column data in parquet is not int16 of field: %s", columnReader.fieldName)) + } + for i := 0; i < int16Reader.Data().Len(); i++ { + int16Data = append(int16Data, int16Reader.Value(i)) + } + return int16Data, nil + }) + if err != nil { + log.Warn("Parquet parser: failed to int16 array", zap.Error(err)) + return nil, err + } + + return &storage.Int16FieldData{ + Data: data, + }, nil + case schemapb.DataType_Int32: + data, err := ReadData(columnReader, rowCount, func(chunk arrow.Array) ([]int32, error) { + int32Reader, ok := chunk.(*array.Int32) + int32Data := make([]int32, 0) + if !ok { + log.Warn("the column data in parquet is not int32", zap.String("fieldName", columnReader.fieldName)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("the column data in parquet is not int32 of field: %s", columnReader.fieldName)) + } + for i := 0; i < int32Reader.Data().Len(); i++ { + int32Data = append(int32Data, int32Reader.Value(i)) + } + return int32Data, nil + }) + if err != nil { + log.Warn("Parquet parser: failed to read int32 array", zap.Error(err)) + return nil, err + } + + return &storage.Int32FieldData{ + Data: data, + }, nil + case schemapb.DataType_Int64: + data, err := ReadData(columnReader, rowCount, func(chunk arrow.Array) ([]int64, error) { + int64Reader, ok := chunk.(*array.Int64) + int64Data := make([]int64, 0) + if !ok { + log.Warn("the column data in parquet is not int64", zap.String("fieldName", columnReader.fieldName)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("the column data in parquet is not int64 of field: %s", columnReader.fieldName)) + } + for i := 0; i < int64Reader.Data().Len(); i++ { + int64Data = append(int64Data, int64Reader.Value(i)) + } + return int64Data, nil + }) + if err != nil { + log.Warn("Parquet parser: failed to read int64 array", zap.Error(err)) + return nil, err + } + + return &storage.Int64FieldData{ + Data: data, + }, nil + case schemapb.DataType_Float: + data, err := ReadData(columnReader, rowCount, func(chunk arrow.Array) ([]float32, error) { + float32Reader, ok := chunk.(*array.Float32) + float32Data := make([]float32, 0) + if !ok { + log.Warn("the column data in parquet is not float", zap.String("fieldName", columnReader.fieldName)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("the column data in parquet is not float of field: %s", columnReader.fieldName)) + } + for i := 0; i < float32Reader.Data().Len(); i++ { + float32Data = append(float32Data, float32Reader.Value(i)) + } + return float32Data, nil + }) + if err != nil { + log.Warn("Parquet parser: failed to read float array", zap.Error(err)) + return nil, err + } + + err = typeutil.VerifyFloats32(data) + if err != nil { + log.Warn("Parquet parser: illegal value in float array", zap.Error(err)) + return nil, err + } + + return &storage.FloatFieldData{ + Data: data, + }, nil + case schemapb.DataType_Double: + data, err := ReadData(columnReader, rowCount, func(chunk arrow.Array) ([]float64, error) { + float64Reader, ok := chunk.(*array.Float64) + float64Data := make([]float64, 0) + if !ok { + log.Warn("the column data in parquet is not double", zap.String("fieldName", columnReader.fieldName)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("the column data in parquet is not double of field: %s", columnReader.fieldName)) + } + for i := 0; i < float64Reader.Data().Len(); i++ { + float64Data = append(float64Data, float64Reader.Value(i)) + } + return float64Data, nil + }) + if err != nil { + log.Warn("Parquet parser: failed to read double array", zap.Error(err)) + return nil, err + } + + err = typeutil.VerifyFloats64(data) + if err != nil { + log.Warn("Parquet parser: illegal value in double array", zap.Error(err)) + return nil, err + } + + return &storage.DoubleFieldData{ + Data: data, + }, nil + case schemapb.DataType_VarChar, schemapb.DataType_String: + data, err := ReadData(columnReader, rowCount, func(chunk arrow.Array) ([]string, error) { + stringReader, ok := chunk.(*array.String) + stringData := make([]string, 0) + if !ok { + log.Warn("the column data in parquet is not string", zap.String("fieldName", columnReader.fieldName)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("the column data in parquet is not string of field: %s", columnReader.fieldName)) + } + for i := 0; i < stringReader.Data().Len(); i++ { + stringData = append(stringData, stringReader.Value(i)) + } + return stringData, nil + }) + if err != nil { + log.Warn("Parquet parser: failed to read varchar array", zap.Error(err)) + return nil, err + } + + return &storage.StringFieldData{ + Data: data, + }, nil + case schemapb.DataType_JSON: + // JSON field read data from string array Parquet + data, err := ReadData(columnReader, rowCount, func(chunk arrow.Array) ([]string, error) { + stringReader, ok := chunk.(*array.String) + stringData := make([]string, 0) + if !ok { + log.Warn("the column data in parquet is not json string", zap.String("fieldName", columnReader.fieldName)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("the column data in parquet is not json string of field: %s", columnReader.fieldName)) + } + for i := 0; i < stringReader.Data().Len(); i++ { + stringData = append(stringData, stringReader.Value(i)) + } + return stringData, nil + }) + if err != nil { + log.Warn("Parquet parser: failed to read json string array", zap.Error(err)) + return nil, err + } + + byteArr := make([][]byte, 0) + for _, str := range data { + var dummy interface{} + err := json.Unmarshal([]byte(str), &dummy) + if err != nil { + log.Warn("Parquet parser: illegal string value for JSON field", + zap.String("value", str), zap.String("fieldName", columnReader.fieldName), zap.Error(err)) + return nil, err + } + byteArr = append(byteArr, []byte(str)) + } + + return &storage.JSONFieldData{ + Data: byteArr, + }, nil + case schemapb.DataType_BinaryVector: + data, err := ReadArrayData(columnReader, rowCount, func(offsets []int32, reader arrow.Array) ([][]uint8, error) { + arrayData := make([][]uint8, 0) + uint8Reader, ok := reader.(*array.Uint8) + if !ok { + log.Warn("the column element data of array in parquet is not binary", zap.String("fieldName", columnReader.fieldName)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("the column element data of array in parquet is not binary: %s", columnReader.fieldName)) + } + for i := 1; i < len(offsets); i++ { + start, end := offsets[i-1], offsets[i] + elementData := make([]uint8, 0) + for j := start; j < end; j++ { + elementData = append(elementData, uint8Reader.Value(int(j))) + } + arrayData = append(arrayData, elementData) + } + return arrayData, nil + }) + if err != nil { + log.Warn("Parquet parser: failed to read binary vector array", zap.Error(err)) + return nil, err + } + binaryData := make([]byte, 0) + for _, arr := range data { + binaryData = append(binaryData, arr...) + } + + if len(binaryData) != len(data)*columnReader.dimension/8 { + log.Warn("Parquet parser: binary vector is irregular", zap.Int("actual num", len(binaryData)), + zap.Int("expect num", len(data)*columnReader.dimension/8)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("binary vector is irregular, expect num = %d,"+ + " actual num = %d", len(data)*columnReader.dimension/8, len(binaryData))) + } + + return &storage.BinaryVectorFieldData{ + Data: binaryData, + Dim: columnReader.dimension, + }, nil + case schemapb.DataType_FloatVector: + data := make([]float32, 0) + rowNum := 0 + if columnReader.columnReader.Field().Type.(*arrow.ListType).Elem().ID() == arrow.FLOAT32 { + arrayData, err := ReadArrayData(columnReader, rowCount, func(offsets []int32, reader arrow.Array) ([][]float32, error) { + arrayData := make([][]float32, 0) + float32Reader, ok := reader.(*array.Float32) + if !ok { + log.Warn("the column element data of array in parquet is not float", zap.String("fieldName", columnReader.fieldName)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("the column element data of array in parquet is not float: %s", columnReader.fieldName)) + } + for i := 1; i < len(offsets); i++ { + start, end := offsets[i-1], offsets[i] + elementData := make([]float32, 0) + for j := start; j < end; j++ { + elementData = append(elementData, float32Reader.Value(int(j))) + } + arrayData = append(arrayData, elementData) + } + return arrayData, nil + }) + if err != nil { + log.Warn("Parquet parser: failed to read float vector array", zap.Error(err)) + return nil, err + } + for _, arr := range arrayData { + data = append(data, arr...) + } + err = typeutil.VerifyFloats32(data) + if err != nil { + log.Warn("Parquet parser: illegal value in float vector array", zap.Error(err)) + return nil, err + } + rowNum = len(arrayData) + } else if columnReader.columnReader.Field().Type.(*arrow.ListType).Elem().ID() == arrow.FLOAT64 { + arrayData, err := ReadArrayData(columnReader, rowCount, func(offsets []int32, reader arrow.Array) ([][]float64, error) { + arrayData := make([][]float64, 0) + float64Reader, ok := reader.(*array.Float64) + if !ok { + log.Warn("the column element data of array in parquet is not double", zap.String("fieldName", columnReader.fieldName)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("the column element data of array in parquet is not double: %s", columnReader.fieldName)) + } + for i := 1; i < len(offsets); i++ { + start, end := offsets[i-1], offsets[i] + elementData := make([]float64, 0) + for j := start; j < end; j++ { + elementData = append(elementData, float64Reader.Value(int(j))) + } + arrayData = append(arrayData, elementData) + } + return arrayData, nil + }) + if err != nil { + log.Warn("Parquet parser: failed to read float vector array", zap.Error(err)) + return nil, err + } + for _, arr := range arrayData { + for _, f64 := range arr { + err = typeutil.VerifyFloat(f64) + if err != nil { + log.Warn("Parquet parser: illegal value in float vector array", zap.Error(err)) + return nil, err + } + data = append(data, float32(f64)) + } + } + rowNum = len(arrayData) + } else { + log.Warn("Parquet parser: FloatVector type is not float", zap.String("fieldName", columnReader.fieldName)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("FloatVector type is not float, is: %s", + columnReader.columnReader.Field().Type.(*arrow.ListType).Elem().ID().String())) + } + + if len(data) != rowNum*columnReader.dimension { + log.Warn("Parquet parser: float vector is irregular", zap.Int("actual num", len(data)), + zap.Int("expect num", rowNum*columnReader.dimension)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("float vector is irregular, expect num = %d,"+ + " actual num = %d", rowNum*columnReader.dimension, len(data))) + } + + return &storage.FloatVectorFieldData{ + Data: data, + Dim: columnReader.dimension, + }, nil + + case schemapb.DataType_Array: + data := make([]*schemapb.ScalarField, 0) + switch columnReader.elementType { + case schemapb.DataType_Bool: + boolArray, err := ReadArrayData(columnReader, rowCount, func(offsets []int32, reader arrow.Array) ([][]bool, error) { + arrayData := make([][]bool, 0) + boolReader, ok := reader.(*array.Boolean) + if !ok { + log.Warn("the column element data of array in parquet is not bool", zap.String("fieldName", columnReader.fieldName)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("the column element data of array in parquet is not bool: %s", columnReader.fieldName)) + } + for i := 1; i < len(offsets); i++ { + start, end := offsets[i-1], offsets[i] + elementData := make([]bool, 0) + for j := start; j < end; j++ { + elementData = append(elementData, boolReader.Value(int(j))) + } + arrayData = append(arrayData, elementData) + } + return arrayData, nil + }) + if err != nil { + return nil, err + } + for _, elementArray := range boolArray { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: elementArray, + }, + }, + }) + } + case schemapb.DataType_Int8: + int8Array, err := ReadArrayData(columnReader, rowCount, func(offsets []int32, reader arrow.Array) ([][]int32, error) { + arrayData := make([][]int32, 0) + int8Reader, ok := reader.(*array.Int8) + if !ok { + log.Warn("the column element data of array in parquet is not int8", zap.String("fieldName", columnReader.fieldName)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("the column element data of array in parquet is not int8: %s", columnReader.fieldName)) + } + for i := 1; i < len(offsets); i++ { + start, end := offsets[i-1], offsets[i] + elementData := make([]int32, 0) + for j := start; j < end; j++ { + elementData = append(elementData, int32(int8Reader.Value(int(j)))) + } + arrayData = append(arrayData, elementData) + } + return arrayData, nil + }) + if err != nil { + return nil, err + } + for _, elementArray := range int8Array { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: elementArray, + }, + }, + }) + } + case schemapb.DataType_Int16: + int16Array, err := ReadArrayData(columnReader, rowCount, func(offsets []int32, reader arrow.Array) ([][]int32, error) { + arrayData := make([][]int32, 0) + int16Reader, ok := reader.(*array.Int16) + if !ok { + log.Warn("the column element data of array in parquet is not int16", zap.String("fieldName", columnReader.fieldName)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("the column element data of array in parquet is not int16: %s", columnReader.fieldName)) + } + for i := 1; i < len(offsets); i++ { + start, end := offsets[i-1], offsets[i] + elementData := make([]int32, 0) + for j := start; j < end; j++ { + elementData = append(elementData, int32(int16Reader.Value(int(j)))) + } + arrayData = append(arrayData, elementData) + } + return arrayData, nil + }) + if err != nil { + return nil, err + } + for _, elementArray := range int16Array { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: elementArray, + }, + }, + }) + } + + case schemapb.DataType_Int32: + int32Array, err := ReadArrayData(columnReader, rowCount, func(offsets []int32, reader arrow.Array) ([][]int32, error) { + arrayData := make([][]int32, 0) + int32Reader, ok := reader.(*array.Int32) + if !ok { + log.Warn("the column element data of array in parquet is not int32", zap.String("fieldName", columnReader.fieldName)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("the column element data of array in parquet is not int32: %s", columnReader.fieldName)) + } + for i := 1; i < len(offsets); i++ { + start, end := offsets[i-1], offsets[i] + elementData := make([]int32, 0) + for j := start; j < end; j++ { + elementData = append(elementData, int32Reader.Value(int(j))) + } + arrayData = append(arrayData, elementData) + } + return arrayData, nil + }) + if err != nil { + return nil, err + } + for _, elementArray := range int32Array { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: elementArray, + }, + }, + }) + } + + case schemapb.DataType_Int64: + int64Array, err := ReadArrayData(columnReader, rowCount, func(offsets []int32, reader arrow.Array) ([][]int64, error) { + arrayData := make([][]int64, 0) + int64Reader, ok := reader.(*array.Int64) + if !ok { + log.Warn("the column element data of array in parquet is not int64", zap.String("fieldName", columnReader.fieldName)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("the column element data of array in parquet is not int64: %s", columnReader.fieldName)) + } + for i := 1; i < len(offsets); i++ { + start, end := offsets[i-1], offsets[i] + elementData := make([]int64, 0) + for j := start; j < end; j++ { + elementData = append(elementData, int64Reader.Value(int(j))) + } + arrayData = append(arrayData, elementData) + } + return arrayData, nil + }) + if err != nil { + return nil, err + } + for _, elementArray := range int64Array { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: elementArray, + }, + }, + }) + } + + case schemapb.DataType_Float: + float32Array, err := ReadArrayData(columnReader, rowCount, func(offsets []int32, reader arrow.Array) ([][]float32, error) { + arrayData := make([][]float32, 0) + float32Reader, ok := reader.(*array.Float32) + if !ok { + log.Warn("the column element data of array in parquet is not float", zap.String("fieldName", columnReader.fieldName)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("the column element data of array in parquet is not float: %s", columnReader.fieldName)) + } + for i := 1; i < len(offsets); i++ { + start, end := offsets[i-1], offsets[i] + elementData := make([]float32, 0) + for j := start; j < end; j++ { + elementData = append(elementData, float32Reader.Value(int(j))) + } + arrayData = append(arrayData, elementData) + } + return arrayData, nil + }) + if err != nil { + return nil, err + } + for _, elementArray := range float32Array { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: elementArray, + }, + }, + }) + } + + case schemapb.DataType_Double: + float64Array, err := ReadArrayData(columnReader, rowCount, func(offsets []int32, reader arrow.Array) ([][]float64, error) { + arrayData := make([][]float64, 0) + float64Reader, ok := reader.(*array.Float64) + if !ok { + log.Warn("the column element data of array in parquet is not double", zap.String("fieldName", columnReader.fieldName)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("the column element data of array in parquet is not double: %s", columnReader.fieldName)) + } + for i := 1; i < len(offsets); i++ { + start, end := offsets[i-1], offsets[i] + elementData := make([]float64, 0) + for j := start; j < end; j++ { + elementData = append(elementData, float64Reader.Value(int(j))) + } + arrayData = append(arrayData, elementData) + } + return arrayData, nil + }) + if err != nil { + return nil, err + } + for _, elementArray := range float64Array { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_DoubleData{ + DoubleData: &schemapb.DoubleArray{ + Data: elementArray, + }, + }, + }) + } + + case schemapb.DataType_VarChar, schemapb.DataType_String: + stringArray, err := ReadArrayData(columnReader, rowCount, func(offsets []int32, reader arrow.Array) ([][]string, error) { + arrayData := make([][]string, 0) + stringReader, ok := reader.(*array.String) + if !ok { + log.Warn("the column element data of array in parquet is not string", zap.String("fieldName", columnReader.fieldName)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("the column element data of array in parquet is not string: %s", columnReader.fieldName)) + } + for i := 1; i < len(offsets); i++ { + start, end := offsets[i-1], offsets[i] + elementData := make([]string, 0) + for j := start; j < end; j++ { + elementData = append(elementData, stringReader.Value(int(j))) + } + arrayData = append(arrayData, elementData) + } + return arrayData, nil + }) + if err != nil { + return nil, err + } + for _, elementArray := range stringArray { + data = append(data, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: elementArray, + }, + }, + }) + } + default: + log.Warn("unsupported element type", zap.String("element type", columnReader.elementType.String()), + zap.String("fieldName", columnReader.fieldName)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("unsupported data type: %s of array", columnReader.elementType.String())) + } + return &storage.ArrayFieldData{ + ElementType: columnReader.elementType, + Data: data, + }, nil + default: + log.Warn("Parquet parser: unsupported data type of field", + zap.String("dataType", columnReader.dataType.String()), + zap.String("fieldName", columnReader.fieldName)) + return nil, merr.WrapErrImportFailed(fmt.Sprintf("unsupported data type: %s", columnReader.elementType.String())) + } +} diff --git a/internal/util/importutil/parquet_parser_test.go b/internal/util/importutil/parquet_parser_test.go new file mode 100644 index 0000000000000..1036d532bc24f --- /dev/null +++ b/internal/util/importutil/parquet_parser_test.go @@ -0,0 +1,1022 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importutil + +import ( + "context" + "fmt" + "io" + "math/rand" + "os" + "testing" + + "github.com/apache/arrow/go/v12/arrow" + "github.com/apache/arrow/go/v12/arrow/array" + "github.com/apache/arrow/go/v12/arrow/memory" + "github.com/apache/arrow/go/v12/parquet" + "github.com/apache/arrow/go/v12/parquet/pqarrow" + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" +) + +// parquetSampleSchema() return a schema contains all supported data types with an int64 primary key +func parquetSampleSchema() *schemapb.CollectionSchema { + schema := &schemapb.CollectionSchema{ + Name: "schema", + Description: "schema", + AutoID: true, + EnableDynamicField: true, + Fields: []*schemapb.FieldSchema{ + { + FieldID: 102, + Name: "FieldBool", + IsPrimaryKey: false, + Description: "bool", + DataType: schemapb.DataType_Bool, + }, + { + FieldID: 103, + Name: "FieldInt8", + IsPrimaryKey: false, + Description: "int8", + DataType: schemapb.DataType_Int8, + }, + { + FieldID: 104, + Name: "FieldInt16", + IsPrimaryKey: false, + Description: "int16", + DataType: schemapb.DataType_Int16, + }, + { + FieldID: 105, + Name: "FieldInt32", + IsPrimaryKey: false, + Description: "int32", + DataType: schemapb.DataType_Int32, + }, + { + FieldID: 106, + Name: "FieldInt64", + IsPrimaryKey: true, + AutoID: false, + Description: "int64", + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 107, + Name: "FieldFloat", + IsPrimaryKey: false, + Description: "float", + DataType: schemapb.DataType_Float, + }, + { + FieldID: 108, + Name: "FieldDouble", + IsPrimaryKey: false, + Description: "double", + DataType: schemapb.DataType_Double, + }, + { + FieldID: 109, + Name: "FieldString", + IsPrimaryKey: false, + Description: "string", + DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + {Key: common.MaxLengthKey, Value: "128"}, + }, + }, + { + FieldID: 110, + Name: "FieldBinaryVector", + IsPrimaryKey: false, + Description: "binary_vector", + DataType: schemapb.DataType_BinaryVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: common.DimKey, Value: "32"}, + }, + }, + { + FieldID: 111, + Name: "FieldFloatVector", + IsPrimaryKey: false, + Description: "float_vector", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: common.DimKey, Value: "4"}, + }, + }, + { + FieldID: 112, + Name: "FieldJSON", + IsPrimaryKey: false, + Description: "json", + DataType: schemapb.DataType_JSON, + }, + { + FieldID: 113, + Name: "FieldArrayBool", + IsPrimaryKey: false, + Description: "int16 array", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Bool, + }, + { + FieldID: 114, + Name: "FieldArrayInt8", + IsPrimaryKey: false, + Description: "int16 array", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int8, + }, + { + FieldID: 115, + Name: "FieldArrayInt16", + IsPrimaryKey: false, + Description: "int16 array", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int16, + }, + { + FieldID: 116, + Name: "FieldArrayInt32", + IsPrimaryKey: false, + Description: "int16 array", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int32, + }, + { + FieldID: 117, + Name: "FieldArrayInt64", + IsPrimaryKey: false, + Description: "int16 array", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int64, + }, + { + FieldID: 118, + Name: "FieldArrayFloat", + IsPrimaryKey: false, + Description: "int16 array", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Float, + }, + { + FieldID: 118, + Name: "FieldArrayDouble", + IsPrimaryKey: false, + Description: "int16 array", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Double, + }, + { + FieldID: 120, + Name: "FieldArrayString", + IsPrimaryKey: false, + Description: "string array", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_VarChar, + }, + { + FieldID: 121, + Name: "$meta", + IsPrimaryKey: false, + Description: "dynamic field", + DataType: schemapb.DataType_JSON, + IsDynamic: true, + }, + }, + } + return schema +} + +func milvusDataTypeToArrowType(dataType schemapb.DataType, dim int) arrow.DataType { + switch dataType { + case schemapb.DataType_Bool: + return &arrow.BooleanType{} + case schemapb.DataType_Int8: + return &arrow.Int8Type{} + case schemapb.DataType_Int16: + return &arrow.Int16Type{} + case schemapb.DataType_Int32: + return &arrow.Int32Type{} + case schemapb.DataType_Int64: + return &arrow.Int64Type{} + case schemapb.DataType_Float: + return &arrow.Float32Type{} + case schemapb.DataType_Double: + return &arrow.Float64Type{} + case schemapb.DataType_VarChar, schemapb.DataType_String: + return &arrow.StringType{} + case schemapb.DataType_Array: + return &arrow.ListType{} + case schemapb.DataType_JSON: + return &arrow.StringType{} + case schemapb.DataType_FloatVector: + return arrow.ListOfField(arrow.Field{ + Name: "item", + Type: &arrow.Float32Type{}, + Nullable: true, + Metadata: arrow.Metadata{}, + }) + case schemapb.DataType_BinaryVector: + return arrow.ListOfField(arrow.Field{ + Name: "item", + Type: &arrow.Uint8Type{}, + Nullable: true, + Metadata: arrow.Metadata{}, + }) + case schemapb.DataType_Float16Vector: + return arrow.ListOfField(arrow.Field{ + Name: "item", + Type: &arrow.Float16Type{}, + Nullable: true, + Metadata: arrow.Metadata{}, + }) + default: + panic("unsupported data type") + } +} + +func convertMilvusSchemaToArrowSchema(schema *schemapb.CollectionSchema) *arrow.Schema { + fields := make([]arrow.Field, 0) + for _, field := range schema.GetFields() { + dim, _ := getFieldDimension(field) + if field.GetDataType() == schemapb.DataType_Array { + fields = append(fields, arrow.Field{ + Name: field.GetName(), + Type: arrow.ListOfField(arrow.Field{ + Name: "item", + Type: milvusDataTypeToArrowType(field.GetElementType(), dim), + Nullable: true, + Metadata: arrow.Metadata{}, + }), + Nullable: true, + Metadata: arrow.Metadata{}, + }) + continue + } + fields = append(fields, arrow.Field{ + Name: field.GetName(), + Type: milvusDataTypeToArrowType(field.GetDataType(), dim), + Nullable: true, + Metadata: arrow.Metadata{}, + }) + } + return arrow.NewSchema(fields, nil) +} + +func buildArrayData(dataType, elementType schemapb.DataType, dim, rows int) arrow.Array { + mem := memory.NewGoAllocator() + switch dataType { + case schemapb.DataType_Bool: + builder := array.NewBooleanBuilder(mem) + for i := 0; i < rows; i++ { + builder.Append(i%2 == 0) + } + return builder.NewBooleanArray() + case schemapb.DataType_Int8: + builder := array.NewInt8Builder(mem) + for i := 0; i < rows; i++ { + builder.Append(int8(i)) + } + return builder.NewInt8Array() + case schemapb.DataType_Int16: + builder := array.NewInt16Builder(mem) + for i := 0; i < rows; i++ { + builder.Append(int16(i)) + } + return builder.NewInt16Array() + case schemapb.DataType_Int32: + builder := array.NewInt32Builder(mem) + for i := 0; i < rows; i++ { + builder.Append(int32(i)) + } + return builder.NewInt32Array() + case schemapb.DataType_Int64: + builder := array.NewInt64Builder(mem) + for i := 0; i < rows; i++ { + builder.Append(int64(i)) + } + return builder.NewInt64Array() + case schemapb.DataType_Float: + builder := array.NewFloat32Builder(mem) + for i := 0; i < rows; i++ { + builder.Append(float32(i) * 0.1) + } + return builder.NewFloat32Array() + case schemapb.DataType_Double: + builder := array.NewFloat64Builder(mem) + for i := 0; i < rows; i++ { + builder.Append(float64(i) * 0.02) + } + return builder.NewFloat64Array() + case schemapb.DataType_VarChar, schemapb.DataType_String: + builder := array.NewStringBuilder(mem) + for i := 0; i < rows; i++ { + builder.Append(randomString(10)) + } + return builder.NewStringArray() + case schemapb.DataType_FloatVector: + builder := array.NewListBuilder(mem, &arrow.Float32Type{}) + offsets := make([]int32, 0, rows) + valid := make([]bool, 0, rows) + for i := 0; i < dim*rows; i++ { + builder.ValueBuilder().(*array.Float32Builder).Append(float32(i)) + } + for i := 0; i < rows; i++ { + offsets = append(offsets, int32(i*dim)) + valid = append(valid, true) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_BinaryVector: + builder := array.NewListBuilder(mem, &arrow.Uint8Type{}) + offsets := make([]int32, 0, rows) + valid := make([]bool, 0) + for i := 0; i < dim*rows/8; i++ { + builder.ValueBuilder().(*array.Uint8Builder).Append(uint8(i)) + } + for i := 0; i < rows; i++ { + offsets = append(offsets, int32(dim*i/8)) + valid = append(valid, true) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_JSON: + builder := array.NewStringBuilder(mem) + for i := 0; i < rows; i++ { + builder.Append(fmt.Sprintf("{\"a\": \"%s\", \"b\": %d}", randomString(3), i)) + } + return builder.NewStringArray() + case schemapb.DataType_Array: + offsets := make([]int32, 0, rows) + valid := make([]bool, 0, rows) + index := 0 + for i := 0; i < rows; i++ { + index += i + offsets = append(offsets, int32(index)) + valid = append(valid, true) + } + index += rows + switch elementType { + case schemapb.DataType_Bool: + builder := array.NewListBuilder(mem, &arrow.BooleanType{}) + valueBuilder := builder.ValueBuilder().(*array.BooleanBuilder) + for i := 0; i < index; i++ { + valueBuilder.Append(i%2 == 0) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_Int8: + builder := array.NewListBuilder(mem, &arrow.Int8Type{}) + valueBuilder := builder.ValueBuilder().(*array.Int8Builder) + for i := 0; i < index; i++ { + valueBuilder.Append(int8(i)) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_Int16: + builder := array.NewListBuilder(mem, &arrow.Int16Type{}) + valueBuilder := builder.ValueBuilder().(*array.Int16Builder) + for i := 0; i < index; i++ { + valueBuilder.Append(int16(i)) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_Int32: + builder := array.NewListBuilder(mem, &arrow.Int32Type{}) + valueBuilder := builder.ValueBuilder().(*array.Int32Builder) + for i := 0; i < index; i++ { + valueBuilder.Append(int32(i)) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_Int64: + builder := array.NewListBuilder(mem, &arrow.Int64Type{}) + valueBuilder := builder.ValueBuilder().(*array.Int64Builder) + for i := 0; i < index; i++ { + valueBuilder.Append(int64(i)) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_Float: + builder := array.NewListBuilder(mem, &arrow.Float32Type{}) + valueBuilder := builder.ValueBuilder().(*array.Float32Builder) + for i := 0; i < index; i++ { + valueBuilder.Append(float32(i) * 0.1) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_Double: + builder := array.NewListBuilder(mem, &arrow.Float64Type{}) + valueBuilder := builder.ValueBuilder().(*array.Float64Builder) + for i := 0; i < index; i++ { + valueBuilder.Append(float64(i) * 0.02) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + case schemapb.DataType_VarChar, schemapb.DataType_String: + builder := array.NewListBuilder(mem, &arrow.StringType{}) + valueBuilder := builder.ValueBuilder().(*array.StringBuilder) + for i := 0; i < index; i++ { + valueBuilder.Append(randomString(5) + "-" + fmt.Sprintf("%d", i)) + } + builder.AppendValues(offsets, valid) + return builder.NewListArray() + } + } + return nil +} + +func writeParquet(w io.Writer, milvusSchema *schemapb.CollectionSchema, numRows int) error { + schema := convertMilvusSchemaToArrowSchema(milvusSchema) + columns := make([]arrow.Array, 0, len(milvusSchema.Fields)) + for _, field := range milvusSchema.Fields { + dim, _ := getFieldDimension(field) + columnData := buildArrayData(field.DataType, field.ElementType, dim, numRows) + columns = append(columns, columnData) + } + recordBatch := array.NewRecord(schema, columns, int64(numRows)) + fw, err := pqarrow.NewFileWriter(schema, w, parquet.NewWriterProperties(), pqarrow.DefaultWriterProps()) + if err != nil { + return err + } + defer fw.Close() + + err = fw.Write(recordBatch) + if err != nil { + return err + } + return nil +} + +func randomString(length int) string { + letterRunes := []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + b := make([]rune, length) + for i := range b { + b[i] = letterRunes[rand.Intn(len(letterRunes))] + } + return string(b) +} + +func TestParquetReader(t *testing.T) { + filePath := "/tmp/wp.parquet" + ctx := context.Background() + schema := parquetSampleSchema() + idAllocator := newIDAllocator(ctx, t, nil) + defer os.Remove(filePath) + + writeFile := func() { + wf, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE, 0o666) + assert.NoError(t, err) + err = writeParquet(wf, schema, 100) + assert.NoError(t, err) + } + writeFile() + + t.Run("read file", func(t *testing.T) { + cm := createLocalChunkManager(t) + flushFunc := func(fields BlockData, shardID int, partID int64) error { + return nil + } + collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1}) + assert.NoError(t, err) + + updateProgress := func(percent int64) { + assert.Greater(t, percent, int64(0)) + } + + // parquet schema sizePreRecord = 5296 + parquetParser, err := NewParquetParser(ctx, collectionInfo, idAllocator, 102400, cm, filePath, flushFunc, updateProgress) + assert.NoError(t, err) + defer parquetParser.Close() + err = parquetParser.Parse() + assert.NoError(t, err) + }) + + t.Run("field not exist", func(t *testing.T) { + schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ + FieldID: 200, + Name: "invalid", + Description: "invalid field", + DataType: schemapb.DataType_JSON, + }) + + cm := createLocalChunkManager(t) + flushFunc := func(fields BlockData, shardID int, partID int64) error { + return nil + } + collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1}) + assert.NoError(t, err) + + parquetParser, err := NewParquetParser(ctx, collectionInfo, idAllocator, 10240, cm, filePath, flushFunc, nil) + assert.NoError(t, err) + defer parquetParser.Close() + err = parquetParser.Parse() + assert.Error(t, err) + + // reset schema + schema = parquetSampleSchema() + }) + + t.Run("schema mismatch", func(t *testing.T) { + schema.Fields[0].DataType = schemapb.DataType_JSON + cm := createLocalChunkManager(t) + flushFunc := func(fields BlockData, shardID int, partID int64) error { + return nil + } + collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1}) + assert.NoError(t, err) + + parquetParser, err := NewParquetParser(ctx, collectionInfo, idAllocator, 10240, cm, filePath, flushFunc, nil) + assert.NoError(t, err) + defer parquetParser.Close() + err = parquetParser.Parse() + assert.Error(t, err) + + // reset schema + schema = parquetSampleSchema() + }) + + t.Run("data not match", func(t *testing.T) { + cm := createLocalChunkManager(t) + flushFunc := func(fields BlockData, shardID int, partID int64) error { + return nil + } + collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1}) + assert.NoError(t, err) + + parquetParser, err := NewParquetParser(ctx, collectionInfo, idAllocator, 10240, cm, filePath, flushFunc, nil) + assert.NoError(t, err) + defer parquetParser.Close() + + err = parquetParser.createReaders() + assert.NoError(t, err) + t.Run("read not bool field", func(t *testing.T) { + columnReader := parquetParser.columnMap["FieldInt8"] + columnReader.dataType = schemapb.DataType_Bool + data, err := parquetParser.readData(columnReader, 1024) + assert.Error(t, err) + assert.Nil(t, data) + }) + + t.Run("read not int8 field", func(t *testing.T) { + columnReader := parquetParser.columnMap["FieldInt16"] + columnReader.dataType = schemapb.DataType_Int8 + data, err := parquetParser.readData(columnReader, 1024) + assert.Error(t, err) + assert.Nil(t, data) + }) + + t.Run("read not int16 field", func(t *testing.T) { + columnReader := parquetParser.columnMap["FieldInt32"] + columnReader.dataType = schemapb.DataType_Int16 + data, err := parquetParser.readData(columnReader, 1024) + assert.Error(t, err) + assert.Nil(t, data) + }) + + t.Run("read not int32 field", func(t *testing.T) { + columnReader := parquetParser.columnMap["FieldInt64"] + columnReader.dataType = schemapb.DataType_Int32 + data, err := parquetParser.readData(columnReader, 1024) + assert.Error(t, err) + assert.Nil(t, data) + }) + + t.Run("read not int64 field", func(t *testing.T) { + columnReader := parquetParser.columnMap["FieldFloat"] + columnReader.dataType = schemapb.DataType_Int64 + data, err := parquetParser.readData(columnReader, 1024) + assert.Error(t, err) + assert.Nil(t, data) + }) + + t.Run("read not float field", func(t *testing.T) { + columnReader := parquetParser.columnMap["FieldDouble"] + columnReader.dataType = schemapb.DataType_Float + data, err := parquetParser.readData(columnReader, 1024) + assert.Error(t, err) + assert.Nil(t, data) + }) + + t.Run("read not double field", func(t *testing.T) { + columnReader := parquetParser.columnMap["FieldBool"] + columnReader.dataType = schemapb.DataType_Double + data, err := parquetParser.readData(columnReader, 1024) + assert.Error(t, err) + assert.Nil(t, data) + }) + + t.Run("read not string field", func(t *testing.T) { + columnReader := parquetParser.columnMap["FieldBool"] + columnReader.dataType = schemapb.DataType_VarChar + data, err := parquetParser.readData(columnReader, 1024) + assert.Error(t, err) + assert.Nil(t, data) + }) + + t.Run("read not array field", func(t *testing.T) { + columnReader := parquetParser.columnMap["FieldBool"] + columnReader.dataType = schemapb.DataType_Array + columnReader.elementType = schemapb.DataType_Bool + data, err := parquetParser.readData(columnReader, 1024) + assert.Error(t, err) + assert.Nil(t, data) + }) + + t.Run("read not bool array field", func(t *testing.T) { + columnReader := parquetParser.columnMap["FieldArrayString"] + columnReader.dataType = schemapb.DataType_Array + columnReader.elementType = schemapb.DataType_Bool + data, err := parquetParser.readData(columnReader, 1024) + assert.Error(t, err) + assert.Nil(t, data) + }) + + t.Run("read not int8 array field", func(t *testing.T) { + columnReader := parquetParser.columnMap["FieldArrayString"] + columnReader.dataType = schemapb.DataType_Array + columnReader.elementType = schemapb.DataType_Int8 + data, err := parquetParser.readData(columnReader, 1024) + assert.Error(t, err) + assert.Nil(t, data) + }) + + t.Run("read not int16 array field", func(t *testing.T) { + columnReader := parquetParser.columnMap["FieldArrayString"] + columnReader.dataType = schemapb.DataType_Array + columnReader.elementType = schemapb.DataType_Int16 + data, err := parquetParser.readData(columnReader, 1024) + assert.Error(t, err) + assert.Nil(t, data) + }) + + t.Run("read not int32 array field", func(t *testing.T) { + columnReader := parquetParser.columnMap["FieldArrayString"] + columnReader.dataType = schemapb.DataType_Array + columnReader.elementType = schemapb.DataType_Int32 + data, err := parquetParser.readData(columnReader, 1024) + assert.Error(t, err) + assert.Nil(t, data) + }) + + t.Run("read not int64 array field", func(t *testing.T) { + columnReader := parquetParser.columnMap["FieldArrayString"] + columnReader.dataType = schemapb.DataType_Array + columnReader.elementType = schemapb.DataType_Int64 + data, err := parquetParser.readData(columnReader, 1024) + assert.Error(t, err) + assert.Nil(t, data) + }) + + t.Run("read not float array field", func(t *testing.T) { + columnReader := parquetParser.columnMap["FieldArrayString"] + columnReader.dataType = schemapb.DataType_Array + columnReader.elementType = schemapb.DataType_Float + data, err := parquetParser.readData(columnReader, 1024) + assert.Error(t, err) + assert.Nil(t, data) + }) + + t.Run("read not double array field", func(t *testing.T) { + columnReader := parquetParser.columnMap["FieldArrayString"] + columnReader.dataType = schemapb.DataType_Array + columnReader.elementType = schemapb.DataType_Double + data, err := parquetParser.readData(columnReader, 1024) + assert.Error(t, err) + assert.Nil(t, data) + }) + + t.Run("read not string array field", func(t *testing.T) { + columnReader := parquetParser.columnMap["FieldArrayBool"] + columnReader.dataType = schemapb.DataType_Array + columnReader.elementType = schemapb.DataType_VarChar + data, err := parquetParser.readData(columnReader, 1024) + assert.Error(t, err) + assert.Nil(t, data) + }) + + t.Run("read not float vector field", func(t *testing.T) { + columnReader := parquetParser.columnMap["FieldArrayBool"] + columnReader.dataType = schemapb.DataType_FloatVector + data, err := parquetParser.readData(columnReader, 1024) + assert.Error(t, err) + assert.Nil(t, data) + }) + + t.Run("read irregular float vector", func(t *testing.T) { + columnReader := parquetParser.columnMap["FieldArrayFloat"] + columnReader.dataType = schemapb.DataType_FloatVector + data, err := parquetParser.readData(columnReader, 1024) + assert.Error(t, err) + assert.Nil(t, data) + }) + + t.Run("read irregular float vector", func(t *testing.T) { + columnReader := parquetParser.columnMap["FieldArrayDouble"] + columnReader.dataType = schemapb.DataType_FloatVector + data, err := parquetParser.readData(columnReader, 1024) + assert.Error(t, err) + assert.Nil(t, data) + }) + + t.Run("read not binary vector field", func(t *testing.T) { + columnReader := parquetParser.columnMap["FieldArrayBool"] + columnReader.dataType = schemapb.DataType_BinaryVector + data, err := parquetParser.readData(columnReader, 1024) + assert.Error(t, err) + assert.Nil(t, data) + }) + + t.Run("read not json field", func(t *testing.T) { + columnReader := parquetParser.columnMap["FieldBool"] + columnReader.dataType = schemapb.DataType_JSON + data, err := parquetParser.readData(columnReader, 1024) + assert.Error(t, err) + assert.Nil(t, data) + }) + + t.Run("read illegal json field", func(t *testing.T) { + columnReader := parquetParser.columnMap["FieldString"] + columnReader.dataType = schemapb.DataType_JSON + data, err := parquetParser.readData(columnReader, 1024) + assert.Error(t, err) + assert.Nil(t, data) + }) + + t.Run("read unknown field", func(t *testing.T) { + columnReader := parquetParser.columnMap["FieldString"] + columnReader.dataType = schemapb.DataType_None + data, err := parquetParser.readData(columnReader, 1024) + assert.Error(t, err) + assert.Nil(t, data) + }) + + t.Run("read unsupported array", func(t *testing.T) { + columnReader := parquetParser.columnMap["FieldArrayString"] + columnReader.dataType = schemapb.DataType_Array + columnReader.elementType = schemapb.DataType_JSON + data, err := parquetParser.readData(columnReader, 1024) + assert.Error(t, err) + assert.Nil(t, data) + }) + }) + + t.Run("flush failed", func(t *testing.T) { + cm := createLocalChunkManager(t) + flushFunc := func(fields BlockData, shardID int, partID int64) error { + return fmt.Errorf("mock error") + } + collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1}) + assert.NoError(t, err) + + updateProgress := func(percent int64) { + assert.Greater(t, percent, int64(0)) + } + + // parquet schema sizePreRecord = 5296 + parquetParser, err := NewParquetParser(ctx, collectionInfo, idAllocator, 102400, cm, filePath, flushFunc, updateProgress) + assert.NoError(t, err) + defer parquetParser.Close() + err = parquetParser.Parse() + assert.Error(t, err) + }) +} + +func TestNewParquetParser(t *testing.T) { + ctx := context.Background() + t.Run("nil collectionInfo", func(t *testing.T) { + parquetParser, err := NewParquetParser(ctx, nil, nil, 10240, nil, "", nil, nil) + assert.Error(t, err) + assert.Nil(t, parquetParser) + }) + + t.Run("nil idAlloc", func(t *testing.T) { + collectionInfo, err := NewCollectionInfo(parquetSampleSchema(), 2, []int64{1}) + assert.NoError(t, err) + + parquetParser, err := NewParquetParser(ctx, collectionInfo, nil, 10240, nil, "", nil, nil) + assert.Error(t, err) + assert.Nil(t, parquetParser) + }) + + t.Run("nil chunk manager", func(t *testing.T) { + collectionInfo, err := NewCollectionInfo(parquetSampleSchema(), 2, []int64{1}) + assert.NoError(t, err) + + idAllocator := newIDAllocator(ctx, t, nil) + + parquetParser, err := NewParquetParser(ctx, collectionInfo, idAllocator, 10240, nil, "", nil, nil) + assert.Error(t, err) + assert.Nil(t, parquetParser) + }) + + t.Run("nil flush func", func(t *testing.T) { + collectionInfo, err := NewCollectionInfo(parquetSampleSchema(), 2, []int64{1}) + assert.NoError(t, err) + + idAllocator := newIDAllocator(ctx, t, nil) + cm := createLocalChunkManager(t) + + parquetParser, err := NewParquetParser(ctx, collectionInfo, idAllocator, 10240, cm, "", nil, nil) + assert.Error(t, err) + assert.Nil(t, parquetParser) + }) + // + //t.Run("create reader with closed file", func(t *testing.T) { + // collectionInfo, err := NewCollectionInfo(parquetSampleSchema(), 2, []int64{1}) + // assert.NoError(t, err) + // + // idAllocator := newIDAllocator(ctx, t, nil) + // cm := createLocalChunkManager(t) + // flushFunc := func(fields BlockData, shardID int, partID int64) error { + // return nil + // } + // + // rf, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE, 0o666) + // assert.NoError(t, err) + // r := storage.NewLocalFile(rf) + // + // parquetParser, err := NewParquetParser(ctx, collectionInfo, idAllocator, 10240, cm, filePath, flushFunc, nil) + // assert.Error(t, err) + // assert.Nil(t, parquetParser) + //}) +} + +func TestVerifyFieldSchema(t *testing.T) { + ok := verifyFieldSchema(schemapb.DataType_Bool, schemapb.DataType_None, arrow.Field{Type: &arrow.BooleanType{}}) + assert.True(t, ok) + ok = verifyFieldSchema(schemapb.DataType_Bool, schemapb.DataType_None, arrow.Field{Type: arrow.ListOfField(arrow.Field{Type: &arrow.BooleanType{}})}) + assert.False(t, ok) + + ok = verifyFieldSchema(schemapb.DataType_Int8, schemapb.DataType_None, arrow.Field{Type: &arrow.Int8Type{}}) + assert.True(t, ok) + ok = verifyFieldSchema(schemapb.DataType_Int8, schemapb.DataType_None, arrow.Field{Type: arrow.ListOfField(arrow.Field{Type: &arrow.Int8Type{}})}) + assert.False(t, ok) + + ok = verifyFieldSchema(schemapb.DataType_Int16, schemapb.DataType_None, arrow.Field{Type: &arrow.Int16Type{}}) + assert.True(t, ok) + ok = verifyFieldSchema(schemapb.DataType_Int16, schemapb.DataType_None, arrow.Field{Type: arrow.ListOfField(arrow.Field{Type: &arrow.Int16Type{}})}) + assert.False(t, ok) + + ok = verifyFieldSchema(schemapb.DataType_Int32, schemapb.DataType_None, arrow.Field{Type: &arrow.Int32Type{}}) + assert.True(t, ok) + ok = verifyFieldSchema(schemapb.DataType_Int32, schemapb.DataType_None, arrow.Field{Type: arrow.ListOfField(arrow.Field{Type: &arrow.Int32Type{}})}) + assert.False(t, ok) + + ok = verifyFieldSchema(schemapb.DataType_Int64, schemapb.DataType_None, arrow.Field{Type: &arrow.Int64Type{}}) + assert.True(t, ok) + ok = verifyFieldSchema(schemapb.DataType_Int64, schemapb.DataType_None, arrow.Field{Type: arrow.ListOfField(arrow.Field{Type: &arrow.Int64Type{}})}) + assert.False(t, ok) + + ok = verifyFieldSchema(schemapb.DataType_Float, schemapb.DataType_None, arrow.Field{Type: &arrow.Float32Type{}}) + assert.True(t, ok) + ok = verifyFieldSchema(schemapb.DataType_Float, schemapb.DataType_None, arrow.Field{Type: arrow.ListOfField(arrow.Field{Type: &arrow.Float32Type{}})}) + assert.False(t, ok) + + ok = verifyFieldSchema(schemapb.DataType_Double, schemapb.DataType_None, arrow.Field{Type: &arrow.Float64Type{}}) + assert.True(t, ok) + ok = verifyFieldSchema(schemapb.DataType_Double, schemapb.DataType_None, arrow.Field{Type: arrow.ListOfField(arrow.Field{Type: &arrow.Float64Type{}})}) + assert.False(t, ok) + + ok = verifyFieldSchema(schemapb.DataType_VarChar, schemapb.DataType_None, arrow.Field{Type: &arrow.StringType{}}) + assert.True(t, ok) + ok = verifyFieldSchema(schemapb.DataType_VarChar, schemapb.DataType_None, arrow.Field{Type: arrow.ListOfField(arrow.Field{Type: &arrow.StringType{}})}) + assert.False(t, ok) + + ok = verifyFieldSchema(schemapb.DataType_FloatVector, schemapb.DataType_None, arrow.Field{Type: arrow.ListOfField(arrow.Field{Type: &arrow.Float32Type{}})}) + assert.True(t, ok) + ok = verifyFieldSchema(schemapb.DataType_FloatVector, schemapb.DataType_None, arrow.Field{Type: arrow.ListOfField(arrow.Field{Type: &arrow.Float64Type{}})}) + assert.True(t, ok) + ok = verifyFieldSchema(schemapb.DataType_FloatVector, schemapb.DataType_None, arrow.Field{Type: &arrow.Float32Type{}}) + assert.False(t, ok) + + ok = verifyFieldSchema(schemapb.DataType_BinaryVector, schemapb.DataType_None, arrow.Field{Type: arrow.ListOfField(arrow.Field{Type: &arrow.Uint8Type{}})}) + assert.True(t, ok) + ok = verifyFieldSchema(schemapb.DataType_BinaryVector, schemapb.DataType_None, arrow.Field{Type: &arrow.Uint8Type{}}) + assert.False(t, ok) + + ok = verifyFieldSchema(schemapb.DataType_Array, schemapb.DataType_Bool, arrow.Field{Type: arrow.ListOfField(arrow.Field{Type: &arrow.BooleanType{}})}) + assert.True(t, ok) + + ok = verifyFieldSchema(schemapb.DataType_Array, schemapb.DataType_Int8, arrow.Field{Type: arrow.ListOfField(arrow.Field{Type: &arrow.Int8Type{}})}) + assert.True(t, ok) + + ok = verifyFieldSchema(schemapb.DataType_Array, schemapb.DataType_Int16, arrow.Field{Type: arrow.ListOfField(arrow.Field{Type: &arrow.Int16Type{}})}) + assert.True(t, ok) + + ok = verifyFieldSchema(schemapb.DataType_Array, schemapb.DataType_Int32, arrow.Field{Type: arrow.ListOfField(arrow.Field{Type: &arrow.Int32Type{}})}) + assert.True(t, ok) + + ok = verifyFieldSchema(schemapb.DataType_Array, schemapb.DataType_Int64, arrow.Field{Type: arrow.ListOfField(arrow.Field{Type: &arrow.Int64Type{}})}) + assert.True(t, ok) + + ok = verifyFieldSchema(schemapb.DataType_Array, schemapb.DataType_Float, arrow.Field{Type: arrow.ListOfField(arrow.Field{Type: &arrow.Float32Type{}})}) + assert.True(t, ok) + + ok = verifyFieldSchema(schemapb.DataType_Array, schemapb.DataType_Double, arrow.Field{Type: arrow.ListOfField(arrow.Field{Type: &arrow.Float64Type{}})}) + assert.True(t, ok) + + ok = verifyFieldSchema(schemapb.DataType_Array, schemapb.DataType_VarChar, arrow.Field{Type: arrow.ListOfField(arrow.Field{Type: &arrow.StringType{}})}) + assert.True(t, ok) + + ok = verifyFieldSchema(schemapb.DataType_Array, schemapb.DataType_None, arrow.Field{Type: arrow.ListOfField(arrow.Field{Type: &arrow.Int64Type{}})}) + assert.False(t, ok) +} + +func TestCalcRowCountPerBlock(t *testing.T) { + t.Run("dim not valid", func(t *testing.T) { + schema := &schemapb.CollectionSchema{ + Name: "dim_invalid", + Description: "dim not invalid", + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: "pk", + IsPrimaryKey: true, + Description: "pk", + DataType: schemapb.DataType_Int64, + AutoID: true, + }, + { + FieldID: 101, + Name: "vector", + Description: "vector", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "invalid", + }, + }, + }, + }, + EnableDynamicField: false, + } + + collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1}) + assert.NoError(t, err) + + p := &ParquetParser{ + collectionInfo: collectionInfo, + } + + _, err = p.calcRowCountPerBlock() + assert.Error(t, err) + + err = p.consume() + assert.Error(t, err) + }) + + t.Run("nil schema", func(t *testing.T) { + collectionInfo := &CollectionInfo{ + Schema: &schemapb.CollectionSchema{ + Name: "nil_schema", + Description: "", + AutoID: false, + Fields: nil, + EnableDynamicField: false, + }, + ShardNum: 2, + } + p := &ParquetParser{ + collectionInfo: collectionInfo, + } + + _, err := p.calcRowCountPerBlock() + assert.Error(t, err) + }) + + t.Run("normal case", func(t *testing.T) { + collectionInfo, err := NewCollectionInfo(parquetSampleSchema(), 2, []int64{1}) + assert.NoError(t, err) + + p := &ParquetParser{ + collectionInfo: collectionInfo, + blockSize: 10, + } + + _, err = p.calcRowCountPerBlock() + assert.NoError(t, err) + }) +}