From 88b373b0246547a579e9b7dc3bb8acbdbfd5e6a1 Mon Sep 17 00:00:00 2001 From: shaoting-huang <167743503+shaoting-huang@users.noreply.github.com> Date: Wed, 17 Jul 2024 17:47:44 +0800 Subject: [PATCH] enhance: binlog primary key turn off dict encoding (#34358) issue: #34357 Go Parquet uses dictionary encoding by default, and it will fall back to plain encoding if the dictionary size exceeds the dictionary size page limit. Users can specify custom fallback encoding by using `parquet.WithEncoding(ENCODING_METHOD)` in writer properties. However, Go Parquet [fallbacks to plain encoding](https://github.com/apache/arrow/blob/e65c1e295d82c7076df484089a63fa3ba2bd55d1/go/parquet/file/column_writer_types.gen.go.tmpl#L238) rather than custom encoding method users provide. Therefore, this patch only turns off dictionary encoding for the primary key. With a 5 million auto ID primary key benchmark, the parquet file size improves from 13.93 MB to 8.36 MB when dictionary encoding is turned off, reducing primary key storage space by 40%. Signed-off-by: shaoting-huang --- internal/storage/binlog_iterator_test.go | 2 +- internal/storage/binlog_test.go | 10 +- internal/storage/binlog_writer.go | 18 +- internal/storage/binlog_writer_test.go | 2 +- internal/storage/data_codec.go | 37 ++- internal/storage/data_codec_test.go | 2 +- internal/storage/event_test.go | 22 +- internal/storage/event_writer.go | 29 +-- internal/storage/event_writer_test.go | 4 +- internal/storage/payload_test.go | 221 ++++++++---------- internal/storage/payload_writer.go | 87 ++++--- internal/storage/payload_writer_test.go | 88 ++++--- internal/storage/print_binlog_test.go | 4 +- internal/storage/serde.go | 53 +++-- internal/storage/serde_events.go | 4 +- internal/storage/serde_events_test.go | 12 +- internal/storage/utils.go | 18 ++ .../util/importutilv2/binlog/reader_test.go | 2 +- 18 files changed, 329 insertions(+), 286 deletions(-) diff --git a/internal/storage/binlog_iterator_test.go b/internal/storage/binlog_iterator_test.go index d387e0fedc2e7..0a9546a495b0e 100644 --- a/internal/storage/binlog_iterator_test.go +++ b/internal/storage/binlog_iterator_test.go @@ -39,7 +39,7 @@ func generateTestSchema() *schemapb.CollectionSchema { {FieldID: 13, Name: "int64", DataType: schemapb.DataType_Int64}, {FieldID: 14, Name: "float", DataType: schemapb.DataType_Float}, {FieldID: 15, Name: "double", DataType: schemapb.DataType_Double}, - {FieldID: 16, Name: "varchar", DataType: schemapb.DataType_VarChar}, + {FieldID: 16, Name: "varchar", DataType: schemapb.DataType_VarChar, IsPrimaryKey: true}, {FieldID: 17, Name: "string", DataType: schemapb.DataType_String}, {FieldID: 18, Name: "array", DataType: schemapb.DataType_Array}, {FieldID: 19, Name: "string", DataType: schemapb.DataType_JSON}, diff --git a/internal/storage/binlog_test.go b/internal/storage/binlog_test.go index b5058ab6fa562..6a7876dd820f9 100644 --- a/internal/storage/binlog_test.go +++ b/internal/storage/binlog_test.go @@ -39,7 +39,7 @@ import ( func TestInsertBinlog(t *testing.T) { w := NewInsertBinlogWriter(schemapb.DataType_Int64, 10, 20, 30, 40, false) - e1, err := w.NextInsertEventWriter(false) + e1, err := w.NextInsertEventWriter() assert.NoError(t, err) err = e1.AddDataToPayload([]int64{1, 2, 3}, nil) assert.NoError(t, err) @@ -49,7 +49,7 @@ func TestInsertBinlog(t *testing.T) { assert.NoError(t, err) e1.SetEventTimestamp(100, 200) - e2, err := w.NextInsertEventWriter(false) + e2, err := w.NextInsertEventWriter() assert.NoError(t, err) err = e2.AddDataToPayload([]int64{7, 8, 9}, nil) assert.NoError(t, err) @@ -1329,7 +1329,7 @@ func TestNewBinlogReaderError(t *testing.T) { w.SetEventTimeStamp(1000, 2000) - e1, err := w.NextInsertEventWriter(false) + e1, err := w.NextInsertEventWriter() assert.NoError(t, err) err = e1.AddDataToPayload([]int64{1, 2, 3}, nil) assert.NoError(t, err) @@ -1393,7 +1393,7 @@ func TestNewBinlogWriterTsError(t *testing.T) { func TestInsertBinlogWriterCloseError(t *testing.T) { insertWriter := NewInsertBinlogWriter(schemapb.DataType_Int64, 10, 20, 30, 40, false) - e1, err := insertWriter.NextInsertEventWriter(false) + e1, err := insertWriter.NextInsertEventWriter() assert.NoError(t, err) sizeTotal := 2000000 @@ -1406,7 +1406,7 @@ func TestInsertBinlogWriterCloseError(t *testing.T) { err = insertWriter.Finish() assert.NoError(t, err) assert.NotNil(t, insertWriter.buffer) - insertEventWriter, err := insertWriter.NextInsertEventWriter(false) + insertEventWriter, err := insertWriter.NextInsertEventWriter() assert.Nil(t, insertEventWriter) assert.Error(t, err) insertWriter.Close() diff --git a/internal/storage/binlog_writer.go b/internal/storage/binlog_writer.go index 2e716f31e852e..173aae219e7b8 100644 --- a/internal/storage/binlog_writer.go +++ b/internal/storage/binlog_writer.go @@ -23,7 +23,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/util/typeutil" ) // BinlogType is to distinguish different files saving different data. @@ -150,21 +149,12 @@ type InsertBinlogWriter struct { } // NextInsertEventWriter returns an event writer to write insert data to an event. -func (writer *InsertBinlogWriter) NextInsertEventWriter(nullable bool, dim ...int) (*insertEventWriter, error) { +func (writer *InsertBinlogWriter) NextInsertEventWriter(opts ...PayloadWriterOptions) (*insertEventWriter, error) { if writer.isClosed() { return nil, fmt.Errorf("binlog has closed") } - var event *insertEventWriter - var err error - if typeutil.IsVectorType(writer.PayloadDataType) && !typeutil.IsSparseFloatVectorType(writer.PayloadDataType) { - if len(dim) != 1 { - return nil, fmt.Errorf("incorrect input numbers") - } - event, err = newInsertEventWriter(writer.PayloadDataType, nullable, dim[0]) - } else { - event, err = newInsertEventWriter(writer.PayloadDataType, nullable) - } + event, err := newInsertEventWriter(writer.PayloadDataType, opts...) if err != nil { return nil, err } @@ -179,11 +169,11 @@ type DeleteBinlogWriter struct { } // NextDeleteEventWriter returns an event writer to write delete data to an event. -func (writer *DeleteBinlogWriter) NextDeleteEventWriter() (*deleteEventWriter, error) { +func (writer *DeleteBinlogWriter) NextDeleteEventWriter(opts ...PayloadWriterOptions) (*deleteEventWriter, error) { if writer.isClosed() { return nil, fmt.Errorf("binlog has closed") } - event, err := newDeleteEventWriter(writer.PayloadDataType) + event, err := newDeleteEventWriter(writer.PayloadDataType, opts...) if err != nil { return nil, err } diff --git a/internal/storage/binlog_writer_test.go b/internal/storage/binlog_writer_test.go index 02e25d32f3a00..ff1c928e2440f 100644 --- a/internal/storage/binlog_writer_test.go +++ b/internal/storage/binlog_writer_test.go @@ -32,7 +32,7 @@ func TestBinlogWriterReader(t *testing.T) { binlogWriter.SetEventTimeStamp(1000, 2000) defer binlogWriter.Close() - eventWriter, err := binlogWriter.NextInsertEventWriter(false) + eventWriter, err := binlogWriter.NextInsertEventWriter() assert.NoError(t, err) err = eventWriter.AddInt32ToPayload([]int32{1, 2, 3}, nil) assert.NoError(t, err) diff --git a/internal/storage/data_codec.go b/internal/storage/data_codec.go index 549fbe932d12f..efd4d1195ab11 100644 --- a/internal/storage/data_codec.go +++ b/internal/storage/data_codec.go @@ -243,31 +243,18 @@ func (insertCodec *InsertCodec) Serialize(partitionID UniqueID, segmentID Unique for _, field := range insertCodec.Schema.Schema.Fields { // encode fields writer = NewInsertBinlogWriter(field.DataType, insertCodec.Schema.ID, partitionID, segmentID, field.FieldID, field.GetNullable()) - var eventWriter *insertEventWriter - var err error - var dim int64 - if typeutil.IsVectorType(field.DataType) { - if field.GetNullable() { - return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("vectorType not support null, fieldName: %s", field.GetName())) - } - switch field.DataType { - case schemapb.DataType_FloatVector, - schemapb.DataType_BinaryVector, - schemapb.DataType_Float16Vector, - schemapb.DataType_BFloat16Vector: - dim, err = typeutil.GetDim(field) - if err != nil { - return nil, err - } - eventWriter, err = writer.NextInsertEventWriter(field.GetNullable(), int(dim)) - case schemapb.DataType_SparseFloatVector: - eventWriter, err = writer.NextInsertEventWriter(field.GetNullable()) - default: - return nil, fmt.Errorf("undefined data type %d", field.DataType) + + // get payload writing configs, including nullable and fallback encoding method + opts := []PayloadWriterOptions{WithNullable(field.GetNullable()), WithWriterProps(getFieldWriterProps(field))} + + if typeutil.IsVectorType(field.DataType) && !typeutil.IsSparseFloatVectorType(field.DataType) { + dim, err := typeutil.GetDim(field) + if err != nil { + return nil, err } - } else { - eventWriter, err = writer.NextInsertEventWriter(field.GetNullable()) + opts = append(opts, WithDim(int(dim))) } + eventWriter, err := writer.NextInsertEventWriter(opts...) if err != nil { writer.Close() return nil, err @@ -711,7 +698,9 @@ func NewDeleteCodec() *DeleteCodec { // For each delete message, it will save "pk,ts" string to binlog. func (deleteCodec *DeleteCodec) Serialize(collectionID UniqueID, partitionID UniqueID, segmentID UniqueID, data *DeleteData) (*Blob, error) { binlogWriter := NewDeleteBinlogWriter(schemapb.DataType_String, collectionID, partitionID, segmentID) - eventWriter, err := binlogWriter.NextDeleteEventWriter() + field := &schemapb.FieldSchema{IsPrimaryKey: true, DataType: schemapb.DataType_String} + opts := []PayloadWriterOptions{WithWriterProps(getFieldWriterProps(field))} + eventWriter, err := binlogWriter.NextDeleteEventWriter(opts...) if err != nil { binlogWriter.Close() return nil, err diff --git a/internal/storage/data_codec_test.go b/internal/storage/data_codec_test.go index b37886cd20a00..cbdec1414c589 100644 --- a/internal/storage/data_codec_test.go +++ b/internal/storage/data_codec_test.go @@ -977,7 +977,7 @@ func TestDeleteData(t *testing.T) { func TestAddFieldDataToPayload(t *testing.T) { w := NewInsertBinlogWriter(schemapb.DataType_Int64, 10, 20, 30, 40, false) - e, _ := w.NextInsertEventWriter(false) + e, _ := w.NextInsertEventWriter() var err error err = AddFieldDataToPayload(e, schemapb.DataType_Bool, &BoolFieldData{[]bool{}, nil}) assert.Error(t, err) diff --git a/internal/storage/event_test.go b/internal/storage/event_test.go index 3f4ada4076a78..02eecc61624e7 100644 --- a/internal/storage/event_test.go +++ b/internal/storage/event_test.go @@ -195,7 +195,7 @@ func TestInsertEvent(t *testing.T) { } t.Run("insert_bool", func(t *testing.T) { - w, err := newInsertEventWriter(schemapb.DataType_Bool, false) + w, err := newInsertEventWriter(schemapb.DataType_Bool) assert.NoError(t, err) insertT(t, schemapb.DataType_Bool, w, func(w *insertEventWriter) error { @@ -211,7 +211,7 @@ func TestInsertEvent(t *testing.T) { }) t.Run("insert_int8", func(t *testing.T) { - w, err := newInsertEventWriter(schemapb.DataType_Int8, false) + w, err := newInsertEventWriter(schemapb.DataType_Int8) assert.NoError(t, err) insertT(t, schemapb.DataType_Int8, w, func(w *insertEventWriter) error { @@ -227,7 +227,7 @@ func TestInsertEvent(t *testing.T) { }) t.Run("insert_int16", func(t *testing.T) { - w, err := newInsertEventWriter(schemapb.DataType_Int16, false) + w, err := newInsertEventWriter(schemapb.DataType_Int16) assert.NoError(t, err) insertT(t, schemapb.DataType_Int16, w, func(w *insertEventWriter) error { @@ -243,7 +243,7 @@ func TestInsertEvent(t *testing.T) { }) t.Run("insert_int32", func(t *testing.T) { - w, err := newInsertEventWriter(schemapb.DataType_Int32, false) + w, err := newInsertEventWriter(schemapb.DataType_Int32) assert.NoError(t, err) insertT(t, schemapb.DataType_Int32, w, func(w *insertEventWriter) error { @@ -259,7 +259,7 @@ func TestInsertEvent(t *testing.T) { }) t.Run("insert_int64", func(t *testing.T) { - w, err := newInsertEventWriter(schemapb.DataType_Int64, false) + w, err := newInsertEventWriter(schemapb.DataType_Int64) assert.NoError(t, err) insertT(t, schemapb.DataType_Int64, w, func(w *insertEventWriter) error { @@ -275,7 +275,7 @@ func TestInsertEvent(t *testing.T) { }) t.Run("insert_float32", func(t *testing.T) { - w, err := newInsertEventWriter(schemapb.DataType_Float, false) + w, err := newInsertEventWriter(schemapb.DataType_Float) assert.NoError(t, err) insertT(t, schemapb.DataType_Float, w, func(w *insertEventWriter) error { @@ -291,7 +291,7 @@ func TestInsertEvent(t *testing.T) { }) t.Run("insert_float64", func(t *testing.T) { - w, err := newInsertEventWriter(schemapb.DataType_Double, false) + w, err := newInsertEventWriter(schemapb.DataType_Double) assert.NoError(t, err) insertT(t, schemapb.DataType_Double, w, func(w *insertEventWriter) error { @@ -307,7 +307,7 @@ func TestInsertEvent(t *testing.T) { }) t.Run("insert_binary_vector", func(t *testing.T) { - w, err := newInsertEventWriter(schemapb.DataType_BinaryVector, false, 16) + w, err := newInsertEventWriter(schemapb.DataType_BinaryVector, WithDim(16)) assert.NoError(t, err) insertT(t, schemapb.DataType_BinaryVector, w, func(w *insertEventWriter) error { @@ -323,7 +323,7 @@ func TestInsertEvent(t *testing.T) { }) t.Run("insert_float_vector", func(t *testing.T) { - w, err := newInsertEventWriter(schemapb.DataType_FloatVector, false, 2) + w, err := newInsertEventWriter(schemapb.DataType_FloatVector, WithDim(2)) assert.NoError(t, err) insertT(t, schemapb.DataType_FloatVector, w, func(w *insertEventWriter) error { @@ -339,7 +339,7 @@ func TestInsertEvent(t *testing.T) { }) t.Run("insert_string", func(t *testing.T) { - w, err := newInsertEventWriter(schemapb.DataType_String, false) + w, err := newInsertEventWriter(schemapb.DataType_String) assert.NoError(t, err) w.SetEventTimestamp(tsoutil.ComposeTS(10, 0), tsoutil.ComposeTS(100, 0)) err = w.AddDataToPayload("1234", nil) @@ -1101,7 +1101,7 @@ func TestEventReaderError(t *testing.T) { } func TestEventClose(t *testing.T) { - w, err := newInsertEventWriter(schemapb.DataType_String, false) + w, err := newInsertEventWriter(schemapb.DataType_String) assert.NoError(t, err) w.SetEventTimestamp(tsoutil.ComposeTS(10, 0), tsoutil.ComposeTS(100, 0)) err = w.AddDataToPayload("1234", nil) diff --git a/internal/storage/event_writer.go b/internal/storage/event_writer.go index 6b9390da0a387..495bab0c76c30 100644 --- a/internal/storage/event_writer.go +++ b/internal/storage/event_writer.go @@ -19,14 +19,12 @@ package storage import ( "bytes" "encoding/binary" - "fmt" "io" "github.com/cockroachdb/errors" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/util/typeutil" ) // EventTypeCode represents event type by code @@ -222,17 +220,8 @@ func NewBaseDescriptorEvent(collectionID int64, partitionID int64, segmentID int return de } -func newInsertEventWriter(dataType schemapb.DataType, nullable bool, dim ...int) (*insertEventWriter, error) { - var payloadWriter PayloadWriterInterface - var err error - if typeutil.IsVectorType(dataType) && !typeutil.IsSparseFloatVectorType(dataType) { - if len(dim) != 1 { - return nil, fmt.Errorf("incorrect input numbers") - } - payloadWriter, err = NewPayloadWriter(dataType, nullable, dim[0]) - } else { - payloadWriter, err = NewPayloadWriter(dataType, nullable) - } +func newInsertEventWriter(dataType schemapb.DataType, opts ...PayloadWriterOptions) (*insertEventWriter, error) { + payloadWriter, err := NewPayloadWriter(dataType, opts...) if err != nil { return nil, err } @@ -253,8 +242,8 @@ func newInsertEventWriter(dataType schemapb.DataType, nullable bool, dim ...int) return writer, nil } -func newDeleteEventWriter(dataType schemapb.DataType) (*deleteEventWriter, error) { - payloadWriter, err := NewPayloadWriter(dataType, false) +func newDeleteEventWriter(dataType schemapb.DataType, opts ...PayloadWriterOptions) (*deleteEventWriter, error) { + payloadWriter, err := NewPayloadWriter(dataType, opts...) if err != nil { return nil, err } @@ -280,7 +269,7 @@ func newCreateCollectionEventWriter(dataType schemapb.DataType) (*createCollecti return nil, errors.New("incorrect data type") } - payloadWriter, err := NewPayloadWriter(dataType, false) + payloadWriter, err := NewPayloadWriter(dataType) if err != nil { return nil, err } @@ -306,7 +295,7 @@ func newDropCollectionEventWriter(dataType schemapb.DataType) (*dropCollectionEv return nil, errors.New("incorrect data type") } - payloadWriter, err := NewPayloadWriter(dataType, false) + payloadWriter, err := NewPayloadWriter(dataType) if err != nil { return nil, err } @@ -332,7 +321,7 @@ func newCreatePartitionEventWriter(dataType schemapb.DataType) (*createPartition return nil, errors.New("incorrect data type") } - payloadWriter, err := NewPayloadWriter(dataType, false) + payloadWriter, err := NewPayloadWriter(dataType) if err != nil { return nil, err } @@ -358,7 +347,7 @@ func newDropPartitionEventWriter(dataType schemapb.DataType) (*dropPartitionEven return nil, errors.New("incorrect data type") } - payloadWriter, err := NewPayloadWriter(dataType, false) + payloadWriter, err := NewPayloadWriter(dataType) if err != nil { return nil, err } @@ -380,7 +369,7 @@ func newDropPartitionEventWriter(dataType schemapb.DataType) (*dropPartitionEven } func newIndexFileEventWriter(dataType schemapb.DataType) (*indexFileEventWriter, error) { - payloadWriter, err := NewPayloadWriter(dataType, false) + payloadWriter, err := NewPayloadWriter(dataType) if err != nil { return nil, err } diff --git a/internal/storage/event_writer_test.go b/internal/storage/event_writer_test.go index 9b4997edcaaac..160bd78666f69 100644 --- a/internal/storage/event_writer_test.go +++ b/internal/storage/event_writer_test.go @@ -59,11 +59,11 @@ func TestSizeofStruct(t *testing.T) { } func TestEventWriter(t *testing.T) { - insertEvent, err := newInsertEventWriter(schemapb.DataType_Int32, false) + insertEvent, err := newInsertEventWriter(schemapb.DataType_Int32) assert.NoError(t, err) insertEvent.Close() - insertEvent, err = newInsertEventWriter(schemapb.DataType_Int32, false) + insertEvent, err = newInsertEventWriter(schemapb.DataType_Int32) assert.NoError(t, err) defer insertEvent.Close() diff --git a/internal/storage/payload_test.go b/internal/storage/payload_test.go index b477eed5c6152..82dd64498a31f 100644 --- a/internal/storage/payload_test.go +++ b/internal/storage/payload_test.go @@ -32,7 +32,7 @@ import ( func TestPayload_ReaderAndWriter(t *testing.T) { t.Run("TestBool", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -69,7 +69,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestInt8", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int8, false) + w, err := NewPayloadWriter(schemapb.DataType_Int8) require.Nil(t, err) require.NotNil(t, w) @@ -109,7 +109,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestInt16", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int16, false) + w, err := NewPayloadWriter(schemapb.DataType_Int16) require.Nil(t, err) require.NotNil(t, w) @@ -147,7 +147,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestInt32", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int32, false) + w, err := NewPayloadWriter(schemapb.DataType_Int32) require.Nil(t, err) require.NotNil(t, w) @@ -186,7 +186,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestInt64", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int64, false) + w, err := NewPayloadWriter(schemapb.DataType_Int64) require.Nil(t, err) require.NotNil(t, w) @@ -225,7 +225,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestFloat32", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Float, false) + w, err := NewPayloadWriter(schemapb.DataType_Float) require.Nil(t, err) require.NotNil(t, w) @@ -264,7 +264,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestDouble", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Double, false) + w, err := NewPayloadWriter(schemapb.DataType_Double) require.Nil(t, err) require.NotNil(t, w) @@ -303,7 +303,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddString", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_String, false) + w, err := NewPayloadWriter(schemapb.DataType_String) require.Nil(t, err) require.NotNil(t, w) @@ -351,7 +351,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddArray", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Array, false) + w, err := NewPayloadWriter(schemapb.DataType_Array) require.Nil(t, err) require.NotNil(t, w) @@ -423,7 +423,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddJSON", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_JSON, false) + w, err := NewPayloadWriter(schemapb.DataType_JSON) require.Nil(t, err) require.NotNil(t, w) @@ -471,7 +471,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestBinaryVector", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_BinaryVector, false, 8) + w, err := NewPayloadWriter(schemapb.DataType_BinaryVector, WithDim(8)) require.Nil(t, err) require.NotNil(t, w) @@ -520,7 +520,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestFloatVector", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_FloatVector, false, 1) + w, err := NewPayloadWriter(schemapb.DataType_FloatVector, WithDim(1)) require.Nil(t, err) require.NotNil(t, w) @@ -562,7 +562,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestFloat16Vector", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Float16Vector, false, 1) + w, err := NewPayloadWriter(schemapb.DataType_Float16Vector, WithDim(1)) require.Nil(t, err) require.NotNil(t, w) @@ -604,7 +604,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestBFloat16Vector", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_BFloat16Vector, false, 1) + w, err := NewPayloadWriter(schemapb.DataType_BFloat16Vector, WithDim(1)) require.Nil(t, err) require.NotNil(t, w) @@ -646,7 +646,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestSparseFloatVector", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_SparseFloatVector, false) + w, err := NewPayloadWriter(schemapb.DataType_SparseFloatVector) require.Nil(t, err) require.NotNil(t, w) @@ -715,7 +715,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) testSparseOneBatch := func(t *testing.T, rows [][]byte, actualDim int) { - w, err := NewPayloadWriter(schemapb.DataType_SparseFloatVector, false) + w, err := NewPayloadWriter(schemapb.DataType_SparseFloatVector) require.Nil(t, err) require.NotNil(t, w) @@ -811,31 +811,8 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }, int(int32Max)) }) - // t.Run("TestAddDataToPayload", func(t *testing.T) { - // w, err := NewPayloadWriter(schemapb.DataType_Bool) - // w.colType = 999 - // require.Nil(t, err) - // require.NotNil(t, w) - - // err = w.AddDataToPayload([]bool{false, false, false, false}) - // assert.NotNil(t, err) - - // err = w.AddDataToPayload([]bool{false, false, false, false}, 0) - // assert.NotNil(t, err) - - // err = w.AddDataToPayload([]bool{false, false, false, false}, 0, 0) - // assert.NotNil(t, err) - - // err = w.AddBoolToPayload([]bool{}) - // assert.NotNil(t, err) - // err = w.FinishPayloadWriter() - // assert.Nil(t, err) - // err = w.AddBoolToPayload([]bool{false}) - // assert.NotNil(t, err) - // }) - t.Run("TestAddBoolAfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -851,7 +828,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddInt8AfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int8, false) + w, err := NewPayloadWriter(schemapb.DataType_Int8) require.Nil(t, err) require.NotNil(t, w) defer w.Close() @@ -867,7 +844,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestAddInt16AfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int16, false) + w, err := NewPayloadWriter(schemapb.DataType_Int16) require.Nil(t, err) require.NotNil(t, w) defer w.Close() @@ -883,7 +860,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestAddInt32AfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int32, false) + w, err := NewPayloadWriter(schemapb.DataType_Int32) require.Nil(t, err) require.NotNil(t, w) defer w.Close() @@ -899,7 +876,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestAddInt64AfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int64, false) + w, err := NewPayloadWriter(schemapb.DataType_Int64) require.Nil(t, err) require.NotNil(t, w) defer w.Close() @@ -915,7 +892,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestAddFloatAfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Float, false) + w, err := NewPayloadWriter(schemapb.DataType_Float) require.Nil(t, err) require.NotNil(t, w) defer w.Close() @@ -931,7 +908,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestAddDoubleAfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Double, false) + w, err := NewPayloadWriter(schemapb.DataType_Double) require.Nil(t, err) require.NotNil(t, w) defer w.Close() @@ -947,7 +924,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestAddOneStringAfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_String, false) + w, err := NewPayloadWriter(schemapb.DataType_String) require.Nil(t, err) require.NotNil(t, w) defer w.Close() @@ -963,7 +940,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestAddBinVectorAfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_BinaryVector, false, 8) + w, err := NewPayloadWriter(schemapb.DataType_BinaryVector, WithDim(8)) require.Nil(t, err) require.NotNil(t, w) defer w.Close() @@ -987,7 +964,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestAddFloatVectorAfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_FloatVector, false, 8) + w, err := NewPayloadWriter(schemapb.DataType_FloatVector, WithDim(8)) require.Nil(t, err) require.NotNil(t, w) defer w.Close() @@ -1008,7 +985,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestAddFloat16VectorAfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Float16Vector, false, 8) + w, err := NewPayloadWriter(schemapb.DataType_Float16Vector, WithDim(8)) require.Nil(t, err) require.NotNil(t, w) defer w.Close() @@ -1032,7 +1009,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestAddBFloat16VectorAfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_BFloat16Vector, false, 8) + w, err := NewPayloadWriter(schemapb.DataType_BFloat16Vector, WithDim(8)) require.Nil(t, err) require.NotNil(t, w) defer w.Close() @@ -1056,7 +1033,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestAddSparseFloatVectorAfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_SparseFloatVector, false) + w, err := NewPayloadWriter(schemapb.DataType_SparseFloatVector) require.Nil(t, err) require.NotNil(t, w) defer w.Close() @@ -1100,7 +1077,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetBoolError", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int8, false) + w, err := NewPayloadWriter(schemapb.DataType_Int8) require.Nil(t, err) require.NotNil(t, w) @@ -1124,7 +1101,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetBoolError2", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -1145,7 +1122,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetInt8Error", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -1169,7 +1146,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetInt8Error2", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int8, false) + w, err := NewPayloadWriter(schemapb.DataType_Int8) require.Nil(t, err) require.NotNil(t, w) @@ -1190,7 +1167,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetInt16Error", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -1214,7 +1191,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetInt16Error2", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int16, false) + w, err := NewPayloadWriter(schemapb.DataType_Int16) require.Nil(t, err) require.NotNil(t, w) @@ -1235,7 +1212,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetInt32Error", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -1259,7 +1236,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetInt32Error2", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int32, false) + w, err := NewPayloadWriter(schemapb.DataType_Int32) require.Nil(t, err) require.NotNil(t, w) @@ -1280,7 +1257,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetInt64Error", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -1304,7 +1281,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetInt64Error2", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int64, false) + w, err := NewPayloadWriter(schemapb.DataType_Int64) require.Nil(t, err) require.NotNil(t, w) @@ -1325,7 +1302,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetFloatError", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -1349,7 +1326,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetFloatError2", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Float, false) + w, err := NewPayloadWriter(schemapb.DataType_Float) require.Nil(t, err) require.NotNil(t, w) @@ -1370,7 +1347,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetDoubleError", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -1394,7 +1371,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetDoubleError2", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Double, false) + w, err := NewPayloadWriter(schemapb.DataType_Double) require.Nil(t, err) require.NotNil(t, w) @@ -1415,7 +1392,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetStringError", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -1439,7 +1416,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetStringError2", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_String, false) + w, err := NewPayloadWriter(schemapb.DataType_String) require.Nil(t, err) require.NotNil(t, w) @@ -1464,7 +1441,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetArrayError", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -1488,7 +1465,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetBinaryVectorError", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -1512,7 +1489,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetBinaryVectorError2", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_BinaryVector, false, 8) + w, err := NewPayloadWriter(schemapb.DataType_BinaryVector, WithDim(8)) require.Nil(t, err) require.NotNil(t, w) @@ -1533,7 +1510,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetFloatVectorError", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -1557,7 +1534,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetFloatVectorError2", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_FloatVector, false, 8) + w, err := NewPayloadWriter(schemapb.DataType_FloatVector, WithDim(8)) require.Nil(t, err) require.NotNil(t, w) @@ -1579,7 +1556,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestByteArrayDatasetError", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_String, false) + w, err := NewPayloadWriter(schemapb.DataType_String) require.Nil(t, err) require.NotNil(t, w) @@ -1619,7 +1596,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { vec = append(vec, 1) } - w, err := NewPayloadWriter(schemapb.DataType_FloatVector, false) + w, err := NewPayloadWriter(schemapb.DataType_FloatVector) assert.NoError(t, err) err = w.AddFloatVectorToPayload(vec, 128) @@ -1635,7 +1612,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddBool with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -1644,7 +1621,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddInt8 with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int8, false) + w, err := NewPayloadWriter(schemapb.DataType_Int8) require.Nil(t, err) require.NotNil(t, w) @@ -1653,7 +1630,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddInt16 with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int16, false) + w, err := NewPayloadWriter(schemapb.DataType_Int16) require.Nil(t, err) require.NotNil(t, w) @@ -1662,7 +1639,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddInt32 with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int32, false) + w, err := NewPayloadWriter(schemapb.DataType_Int32) require.Nil(t, err) require.NotNil(t, w) @@ -1671,7 +1648,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddInt64 with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int64, false) + w, err := NewPayloadWriter(schemapb.DataType_Int64) require.Nil(t, err) require.NotNil(t, w) @@ -1680,7 +1657,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddFloat32 with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Float, false) + w, err := NewPayloadWriter(schemapb.DataType_Float) require.Nil(t, err) require.NotNil(t, w) @@ -1689,7 +1666,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddDouble with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Double, false) + w, err := NewPayloadWriter(schemapb.DataType_Double) require.Nil(t, err) require.NotNil(t, w) @@ -1698,7 +1675,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddAddString with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_String, false) + w, err := NewPayloadWriter(schemapb.DataType_String) require.Nil(t, err) require.NotNil(t, w) @@ -1707,7 +1684,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddArray with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Array, false) + w, err := NewPayloadWriter(schemapb.DataType_Array) require.Nil(t, err) require.NotNil(t, w) @@ -1722,7 +1699,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddJSON with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_JSON, false) + w, err := NewPayloadWriter(schemapb.DataType_JSON) require.Nil(t, err) require.NotNil(t, w) @@ -1733,7 +1710,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { func TestPayload_NullableReaderAndWriter(t *testing.T) { t.Run("TestBool", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, true) + w, err := NewPayloadWriter(schemapb.DataType_Bool, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -1770,7 +1747,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestInt8", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int8, true) + w, err := NewPayloadWriter(schemapb.DataType_Int8, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -1810,7 +1787,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestInt16", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int16, true) + w, err := NewPayloadWriter(schemapb.DataType_Int16, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -1848,7 +1825,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestInt32", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int32, true) + w, err := NewPayloadWriter(schemapb.DataType_Int32, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -1887,7 +1864,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestInt64", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int64, true) + w, err := NewPayloadWriter(schemapb.DataType_Int64, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -1926,7 +1903,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestFloat32", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Float, true) + w, err := NewPayloadWriter(schemapb.DataType_Float, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -1965,7 +1942,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestDouble", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Double, true) + w, err := NewPayloadWriter(schemapb.DataType_Double, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -2004,7 +1981,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestAddString", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_String, true) + w, err := NewPayloadWriter(schemapb.DataType_String, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -2052,7 +2029,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestAddArray", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Array, true) + w, err := NewPayloadWriter(schemapb.DataType_Array, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -2124,7 +2101,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestAddJSON", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_JSON, true) + w, err := NewPayloadWriter(schemapb.DataType_JSON, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -2172,22 +2149,22 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestBinaryVector", func(t *testing.T) { - _, err := NewPayloadWriter(schemapb.DataType_BinaryVector, true, 8) + _, err := NewPayloadWriter(schemapb.DataType_BinaryVector, WithNullable(true), WithDim(8)) assert.ErrorIs(t, err, merr.ErrParameterInvalid) }) t.Run("TestFloatVector", func(t *testing.T) { - _, err := NewPayloadWriter(schemapb.DataType_FloatVector, true, 1) + _, err := NewPayloadWriter(schemapb.DataType_FloatVector, WithNullable(true), WithDim(1)) assert.ErrorIs(t, err, merr.ErrParameterInvalid) }) t.Run("TestFloat16Vector", func(t *testing.T) { - _, err := NewPayloadWriter(schemapb.DataType_Float16Vector, true, 1) + _, err := NewPayloadWriter(schemapb.DataType_Float16Vector, WithNullable(true), WithDim(1)) assert.ErrorIs(t, err, merr.ErrParameterInvalid) }) t.Run("TestAddBool with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, true) + w, err := NewPayloadWriter(schemapb.DataType_Bool, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -2196,7 +2173,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestAddInt8 with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int8, true) + w, err := NewPayloadWriter(schemapb.DataType_Int8, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -2205,7 +2182,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestAddInt16 with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int16, true) + w, err := NewPayloadWriter(schemapb.DataType_Int16, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -2214,7 +2191,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestAddInt32 with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int32, true) + w, err := NewPayloadWriter(schemapb.DataType_Int32, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -2223,7 +2200,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestAddInt64 with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int64, true) + w, err := NewPayloadWriter(schemapb.DataType_Int64, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -2232,7 +2209,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestAddFloat32 with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Float, true) + w, err := NewPayloadWriter(schemapb.DataType_Float, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -2241,7 +2218,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestAddDouble with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Double, true) + w, err := NewPayloadWriter(schemapb.DataType_Double, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -2250,25 +2227,25 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestAddAddString with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_String, true) + w, err := NewPayloadWriter(schemapb.DataType_String, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) err = w.AddDataToPayload("hello0", nil) assert.ErrorIs(t, err, merr.ErrParameterInvalid) - w, err = NewPayloadWriter(schemapb.DataType_String, true) + w, err = NewPayloadWriter(schemapb.DataType_String, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) err = w.AddDataToPayload("hello0", []bool{false, false}) assert.ErrorIs(t, err, merr.ErrParameterInvalid) - w, err = NewPayloadWriter(schemapb.DataType_String, false) + w, err = NewPayloadWriter(schemapb.DataType_String) require.Nil(t, err) require.NotNil(t, w) err = w.AddDataToPayload("hello0", []bool{false}) assert.ErrorIs(t, err, merr.ErrParameterInvalid) - w, err = NewPayloadWriter(schemapb.DataType_String, false) + w, err = NewPayloadWriter(schemapb.DataType_String) require.Nil(t, err) require.NotNil(t, w) err = w.AddDataToPayload("hello0", []bool{true}) @@ -2276,7 +2253,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestAddArray with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Array, true) + w, err := NewPayloadWriter(schemapb.DataType_Array, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) err = w.AddDataToPayload(&schemapb.ScalarField{ @@ -2288,7 +2265,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }, nil) assert.ErrorIs(t, err, merr.ErrParameterInvalid) - w, err = NewPayloadWriter(schemapb.DataType_Array, true) + w, err = NewPayloadWriter(schemapb.DataType_Array, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -2301,7 +2278,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }, []bool{false, false}) assert.ErrorIs(t, err, merr.ErrParameterInvalid) - w, err = NewPayloadWriter(schemapb.DataType_Array, false) + w, err = NewPayloadWriter(schemapb.DataType_Array) require.Nil(t, err) require.NotNil(t, w) err = w.AddDataToPayload(&schemapb.ScalarField{ @@ -2313,7 +2290,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }, []bool{false}) assert.ErrorIs(t, err, merr.ErrParameterInvalid) - w, err = NewPayloadWriter(schemapb.DataType_Array, false) + w, err = NewPayloadWriter(schemapb.DataType_Array) require.Nil(t, err) require.NotNil(t, w) err = w.AddDataToPayload(&schemapb.ScalarField{ @@ -2327,25 +2304,25 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestAddJSON with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_JSON, true) + w, err := NewPayloadWriter(schemapb.DataType_JSON, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) err = w.AddDataToPayload([]byte(`{"1":"1"}`), nil) assert.ErrorIs(t, err, merr.ErrParameterInvalid) - w, err = NewPayloadWriter(schemapb.DataType_JSON, true) + w, err = NewPayloadWriter(schemapb.DataType_JSON, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) err = w.AddDataToPayload([]byte(`{"1":"1"}`), []bool{false, false}) assert.ErrorIs(t, err, merr.ErrParameterInvalid) - w, err = NewPayloadWriter(schemapb.DataType_JSON, false) + w, err = NewPayloadWriter(schemapb.DataType_JSON) require.Nil(t, err) require.NotNil(t, w) err = w.AddDataToPayload([]byte(`{"1":"1"}`), []bool{false}) assert.ErrorIs(t, err, merr.ErrParameterInvalid) - w, err = NewPayloadWriter(schemapb.DataType_JSON, false) + w, err = NewPayloadWriter(schemapb.DataType_JSON) require.Nil(t, err) require.NotNil(t, w) err = w.AddDataToPayload([]byte(`{"1":"1"}`), []bool{true}) @@ -2355,7 +2332,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { func TestArrowRecordReader(t *testing.T) { t.Run("TestArrowRecordReader", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_String, false) + w, err := NewPayloadWriter(schemapb.DataType_String) assert.NoError(t, err) defer w.Close() @@ -2395,7 +2372,7 @@ func TestArrowRecordReader(t *testing.T) { } func dataGen(size int) ([]byte, error) { - w, err := NewPayloadWriter(schemapb.DataType_String, false) + w, err := NewPayloadWriter(schemapb.DataType_String) if err != nil { return nil, err } @@ -2422,7 +2399,7 @@ func dataGen(size int) ([]byte, error) { } func BenchmarkDefaultReader(b *testing.B) { - size := 1000000 + size := 10 buffer, err := dataGen(size) assert.NoError(b, err) @@ -2446,7 +2423,7 @@ func BenchmarkDefaultReader(b *testing.B) { } func BenchmarkDataSetReader(b *testing.B) { - size := 1000000 + size := 10 buffer, err := dataGen(size) assert.NoError(b, err) @@ -2474,7 +2451,7 @@ func BenchmarkDataSetReader(b *testing.B) { } func BenchmarkArrowRecordReader(b *testing.B) { - size := 1000000 + size := 10 buffer, err := dataGen(size) assert.NoError(b, err) diff --git a/internal/storage/payload_writer.go b/internal/storage/payload_writer.go index 8b8b00100564b..e2ac969719a06 100644 --- a/internal/storage/payload_writer.go +++ b/internal/storage/payload_writer.go @@ -39,6 +39,26 @@ import ( var _ PayloadWriterInterface = (*NativePayloadWriter)(nil) +type PayloadWriterOptions func(*NativePayloadWriter) + +func WithNullable(nullable bool) PayloadWriterOptions { + return func(w *NativePayloadWriter) { + w.nullable = nullable + } +} + +func WithWriterProps(writerProps *parquet.WriterProperties) PayloadWriterOptions { + return func(w *NativePayloadWriter) { + w.writerProps = writerProps + } +} + +func WithDim(dim int) PayloadWriterOptions { + return func(w *NativePayloadWriter) { + w.dim = NewNullableInt(dim) + } +} + type NativePayloadWriter struct { dataType schemapb.DataType arrowType arrow.DataType @@ -47,43 +67,42 @@ type NativePayloadWriter struct { flushedRows int output *bytes.Buffer releaseOnce sync.Once - dim int + dim *NullableInt nullable bool + writerProps *parquet.WriterProperties } -func NewPayloadWriter(colType schemapb.DataType, nullable bool, dim ...int) (PayloadWriterInterface, error) { - var arrowType arrow.DataType - var dimension int +func NewPayloadWriter(colType schemapb.DataType, options ...PayloadWriterOptions) (PayloadWriterInterface, error) { + w := &NativePayloadWriter{ + dataType: colType, + finished: false, + flushedRows: 0, + output: new(bytes.Buffer), + nullable: false, + writerProps: parquet.NewWriterProperties( + parquet.WithCompression(compress.Codecs.Zstd), + parquet.WithCompressionLevel(3), + ), + dim: &NullableInt{}, + } + for _, o := range options { + o(w) + } + // writer for sparse float vector doesn't require dim if typeutil.IsVectorType(colType) && !typeutil.IsSparseFloatVectorType(colType) { - if len(dim) != 1 { + if w.dim.IsNull() { return nil, merr.WrapErrParameterInvalidMsg("incorrect input numbers") } - if nullable { - return nil, merr.WrapErrParameterInvalidMsg("vector type not supprot nullable") + if w.nullable { + return nil, merr.WrapErrParameterInvalidMsg("vector type does not support nullable") } - arrowType = milvusDataTypeToArrowType(colType, dim[0]) - dimension = dim[0] } else { - if len(dim) != 0 { - return nil, merr.WrapErrParameterInvalidMsg("incorrect input numbers") - } - arrowType = milvusDataTypeToArrowType(colType, 1) - dimension = 1 + w.dim = NewNullableInt(1) } - - builder := array.NewBuilder(memory.DefaultAllocator, arrowType) - - return &NativePayloadWriter{ - dataType: colType, - arrowType: arrowType, - builder: builder, - finished: false, - flushedRows: 0, - output: new(bytes.Buffer), - dim: dimension, - nullable: nullable, - }, nil + w.arrowType = milvusDataTypeToArrowType(colType, *w.dim.Value) + w.builder = array.NewBuilder(memory.DefaultAllocator, w.arrowType) + return w, nil } func (w *NativePayloadWriter) AddDataToPayload(data interface{}, validData []bool) error { @@ -192,25 +211,25 @@ func (w *NativePayloadWriter) AddDataToPayload(data interface{}, validData []boo if !ok { return merr.WrapErrParameterInvalidMsg("incorrect data type") } - return w.AddBinaryVectorToPayload(val, w.dim) + return w.AddBinaryVectorToPayload(val, w.dim.GetValue()) case schemapb.DataType_FloatVector: val, ok := data.([]float32) if !ok { return merr.WrapErrParameterInvalidMsg("incorrect data type") } - return w.AddFloatVectorToPayload(val, w.dim) + return w.AddFloatVectorToPayload(val, w.dim.GetValue()) case schemapb.DataType_Float16Vector: val, ok := data.([]byte) if !ok { return merr.WrapErrParameterInvalidMsg("incorrect data type") } - return w.AddFloat16VectorToPayload(val, w.dim) + return w.AddFloat16VectorToPayload(val, w.dim.GetValue()) case schemapb.DataType_BFloat16Vector: val, ok := data.([]byte) if !ok { return merr.WrapErrParameterInvalidMsg("incorrect data type") } - return w.AddBFloat16VectorToPayload(val, w.dim) + return w.AddBFloat16VectorToPayload(val, w.dim.GetValue()) case schemapb.DataType_SparseFloatVector: val, ok := data.(*SparseFloatVectorFieldData) if !ok { @@ -674,14 +693,10 @@ func (w *NativePayloadWriter) FinishPayloadWriter() error { table := array.NewTable(schema, []arrow.Column{column}, int64(column.Len())) defer table.Release() - props := parquet.NewWriterProperties( - parquet.WithCompression(compress.Codecs.Zstd), - parquet.WithCompressionLevel(3), - ) return pqarrow.WriteTable(table, w.output, 1024*1024*1024, - props, + w.writerProps, pqarrow.DefaultWriterProps(), ) } diff --git a/internal/storage/payload_writer_test.go b/internal/storage/payload_writer_test.go index 0a8e5abfb4f05..18ceb7361dd61 100644 --- a/internal/storage/payload_writer_test.go +++ b/internal/storage/payload_writer_test.go @@ -3,6 +3,7 @@ package storage import ( "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -10,14 +11,11 @@ import ( func TestPayloadWriter_Failed(t *testing.T) { t.Run("wrong input", func(t *testing.T) { - _, err := NewPayloadWriter(schemapb.DataType_FloatVector, false) - require.Error(t, err) - - _, err = NewPayloadWriter(schemapb.DataType_Bool, false, 1) + _, err := NewPayloadWriter(schemapb.DataType_FloatVector) require.Error(t, err) }) t.Run("Test Bool", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -30,7 +28,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddBoolToPayload([]bool{false}, nil) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Float, false) + w, err = NewPayloadWriter(schemapb.DataType_Float) require.Nil(t, err) require.NotNil(t, w) @@ -39,7 +37,7 @@ func TestPayloadWriter_Failed(t *testing.T) { }) t.Run("Test Byte", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int8, Params.CommonCfg.MaxBloomFalsePositive.PanicIfEmpty) + w, err := NewPayloadWriter(schemapb.DataType_Int8, WithNullable(Params.CommonCfg.MaxBloomFalsePositive.PanicIfEmpty)) require.Nil(t, err) require.NotNil(t, w) @@ -52,7 +50,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddByteToPayload([]byte{0}, nil) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Float, false) + w, err = NewPayloadWriter(schemapb.DataType_Float) require.Nil(t, err) require.NotNil(t, w) @@ -61,7 +59,7 @@ func TestPayloadWriter_Failed(t *testing.T) { }) t.Run("Test Int8", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int8, false) + w, err := NewPayloadWriter(schemapb.DataType_Int8) require.Nil(t, err) require.NotNil(t, w) @@ -74,7 +72,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddInt8ToPayload([]int8{0}, nil) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Float, false) + w, err = NewPayloadWriter(schemapb.DataType_Float) require.Nil(t, err) require.NotNil(t, w) @@ -83,7 +81,7 @@ func TestPayloadWriter_Failed(t *testing.T) { }) t.Run("Test Int16", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int16, false) + w, err := NewPayloadWriter(schemapb.DataType_Int16) require.Nil(t, err) require.NotNil(t, w) @@ -96,7 +94,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddInt16ToPayload([]int16{0}, nil) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Float, false) + w, err = NewPayloadWriter(schemapb.DataType_Float) require.Nil(t, err) require.NotNil(t, w) @@ -105,7 +103,7 @@ func TestPayloadWriter_Failed(t *testing.T) { }) t.Run("Test Int32", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int32, false) + w, err := NewPayloadWriter(schemapb.DataType_Int32) require.Nil(t, err) require.NotNil(t, w) @@ -118,7 +116,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddInt32ToPayload([]int32{0}, nil) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Float, false) + w, err = NewPayloadWriter(schemapb.DataType_Float) require.Nil(t, err) require.NotNil(t, w) @@ -127,7 +125,7 @@ func TestPayloadWriter_Failed(t *testing.T) { }) t.Run("Test Int64", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int64, Params.CommonCfg.MaxBloomFalsePositive.PanicIfEmpty) + w, err := NewPayloadWriter(schemapb.DataType_Int64, WithNullable(Params.CommonCfg.MaxBloomFalsePositive.PanicIfEmpty)) require.Nil(t, err) require.NotNil(t, w) @@ -140,7 +138,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddInt64ToPayload([]int64{0}, nil) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Float, false) + w, err = NewPayloadWriter(schemapb.DataType_Float) require.Nil(t, err) require.NotNil(t, w) @@ -149,7 +147,7 @@ func TestPayloadWriter_Failed(t *testing.T) { }) t.Run("Test Float", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Float, false) + w, err := NewPayloadWriter(schemapb.DataType_Float) require.Nil(t, err) require.NotNil(t, w) @@ -162,7 +160,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddFloatToPayload([]float32{0}, nil) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Int64, false) + w, err = NewPayloadWriter(schemapb.DataType_Int64) require.Nil(t, err) require.NotNil(t, w) @@ -171,7 +169,7 @@ func TestPayloadWriter_Failed(t *testing.T) { }) t.Run("Test Double", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Double, false) + w, err := NewPayloadWriter(schemapb.DataType_Double) require.Nil(t, err) require.NotNil(t, w) @@ -184,7 +182,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddDoubleToPayload([]float64{0}, nil) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Int64, false) + w, err = NewPayloadWriter(schemapb.DataType_Int64) require.Nil(t, err) require.NotNil(t, w) @@ -193,7 +191,7 @@ func TestPayloadWriter_Failed(t *testing.T) { }) t.Run("Test String", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_String, false) + w, err := NewPayloadWriter(schemapb.DataType_String) require.Nil(t, err) require.NotNil(t, w) @@ -203,7 +201,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddOneStringToPayload("test", false) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Int64, false) + w, err = NewPayloadWriter(schemapb.DataType_Int64) require.Nil(t, err) require.NotNil(t, w) @@ -212,7 +210,7 @@ func TestPayloadWriter_Failed(t *testing.T) { }) t.Run("Test Array", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Array, false) + w, err := NewPayloadWriter(schemapb.DataType_Array) require.Nil(t, err) require.NotNil(t, w) @@ -222,7 +220,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddOneArrayToPayload(&schemapb.ScalarField{}, false) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Int64, false) + w, err = NewPayloadWriter(schemapb.DataType_Int64) require.Nil(t, err) require.NotNil(t, w) @@ -231,7 +229,7 @@ func TestPayloadWriter_Failed(t *testing.T) { }) t.Run("Test Json", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_JSON, false) + w, err := NewPayloadWriter(schemapb.DataType_JSON) require.Nil(t, err) require.NotNil(t, w) @@ -241,7 +239,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddOneJSONToPayload([]byte{0, 1}, false) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Int64, false) + w, err = NewPayloadWriter(schemapb.DataType_Int64) require.Nil(t, err) require.NotNil(t, w) @@ -250,7 +248,7 @@ func TestPayloadWriter_Failed(t *testing.T) { }) t.Run("Test BinaryVector", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_BinaryVector, false, 8) + w, err := NewPayloadWriter(schemapb.DataType_BinaryVector, WithDim(8)) require.Nil(t, err) require.NotNil(t, w) @@ -265,7 +263,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddBinaryVectorToPayload(data, 8) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Int64, false) + w, err = NewPayloadWriter(schemapb.DataType_Int64) require.Nil(t, err) require.NotNil(t, w) @@ -274,7 +272,7 @@ func TestPayloadWriter_Failed(t *testing.T) { }) t.Run("Test FloatVector", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_FloatVector, false, 8) + w, err := NewPayloadWriter(schemapb.DataType_FloatVector, WithDim(8)) require.Nil(t, err) require.NotNil(t, w) @@ -292,7 +290,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddFloatToPayload(data, nil) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Int64, false) + w, err = NewPayloadWriter(schemapb.DataType_Int64) require.Nil(t, err) require.NotNil(t, w) @@ -300,3 +298,33 @@ func TestPayloadWriter_Failed(t *testing.T) { require.Error(t, err) }) } + +func TestParquetEncoding(t *testing.T) { + t.Run("test int64 pk", func(t *testing.T) { + field := &schemapb.FieldSchema{IsPrimaryKey: true, DataType: schemapb.DataType_Int64} + + w, err := NewPayloadWriter(schemapb.DataType_Int64, WithWriterProps(getFieldWriterProps(field))) + + assert.NoError(t, err) + err = w.AddDataToPayload([]int64{1, 2, 3}, nil) + assert.NoError(t, err) + + err = w.FinishPayloadWriter() + assert.True(t, !w.(*NativePayloadWriter).writerProps.DictionaryEnabled()) + assert.NoError(t, err) + }) + + t.Run("test string pk", func(t *testing.T) { + field := &schemapb.FieldSchema{IsPrimaryKey: true, DataType: schemapb.DataType_String} + + w, err := NewPayloadWriter(schemapb.DataType_String, WithWriterProps(getFieldWriterProps(field))) + + assert.NoError(t, err) + err = w.AddOneStringToPayload("1", true) + assert.NoError(t, err) + + err = w.FinishPayloadWriter() + assert.True(t, !w.(*NativePayloadWriter).writerProps.DictionaryEnabled()) + assert.NoError(t, err) + }) +} diff --git a/internal/storage/print_binlog_test.go b/internal/storage/print_binlog_test.go index 0409430b32e85..dc0bee9779cdd 100644 --- a/internal/storage/print_binlog_test.go +++ b/internal/storage/print_binlog_test.go @@ -40,7 +40,7 @@ func TestPrintBinlogFilesInt64(t *testing.T) { curTS := time.Now().UnixNano() / int64(time.Millisecond) - e1, err := w.NextInsertEventWriter(false) + e1, err := w.NextInsertEventWriter() assert.NoError(t, err) err = e1.AddDataToPayload([]int64{1, 2, 3}, nil) assert.NoError(t, err) @@ -50,7 +50,7 @@ func TestPrintBinlogFilesInt64(t *testing.T) { assert.NoError(t, err) e1.SetEventTimestamp(tsoutil.ComposeTS(curTS+10*60*1000, 0), tsoutil.ComposeTS(curTS+20*60*1000, 0)) - e2, err := w.NextInsertEventWriter(false) + e2, err := w.NextInsertEventWriter() assert.NoError(t, err) err = e2.AddDataToPayload([]int64{7, 8, 9}, nil) assert.NoError(t, err) diff --git a/internal/storage/serde.go b/internal/storage/serde.go index 34f070519e710..263f912a6e097 100644 --- a/internal/storage/serde.go +++ b/internal/storage/serde.go @@ -521,6 +521,23 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry { return m }() +// Since parquet does not support custom fallback encoding for now, +// we disable dict encoding for primary key. +// It can be scale to all fields once parquet fallback encoding is available. +func getFieldWriterProps(field *schemapb.FieldSchema) *parquet.WriterProperties { + if field.GetIsPrimaryKey() { + return parquet.NewWriterProperties( + parquet.WithCompression(compress.Codecs.Zstd), + parquet.WithCompressionLevel(3), + parquet.WithDictionaryDefault(false), + ) + } + return parquet.NewWriterProperties( + parquet.WithCompression(compress.Codecs.Zstd), + parquet.WithCompressionLevel(3), + ) +} + type DeserializeReader[T any] struct { rr RecordReader deserializer Deserializer[T] @@ -654,12 +671,21 @@ func newCompositeRecordWriter(writers map[FieldID]RecordWriter) *compositeRecord var _ RecordWriter = (*singleFieldRecordWriter)(nil) +type RecordWriterOptions func(*singleFieldRecordWriter) + +func WithRecordWriterProps(writerProps *parquet.WriterProperties) RecordWriterOptions { + return func(w *singleFieldRecordWriter) { + w.writerProps = writerProps + } +} + type singleFieldRecordWriter struct { fw *pqarrow.FileWriter fieldId FieldID schema *arrow.Schema - numRows int + numRows int + writerProps *parquet.WriterProperties } func (sfw *singleFieldRecordWriter) Write(r Record) error { @@ -674,23 +700,24 @@ func (sfw *singleFieldRecordWriter) Close() { sfw.fw.Close() } -func newSingleFieldRecordWriter(fieldId FieldID, field arrow.Field, writer io.Writer) (*singleFieldRecordWriter, error) { - schema := arrow.NewSchema([]arrow.Field{field}, nil) - - // use writer properties as same as payload writer's for now - fw, err := pqarrow.NewFileWriter(schema, writer, - parquet.NewWriterProperties( +func newSingleFieldRecordWriter(fieldId FieldID, field arrow.Field, writer io.Writer, opts ...RecordWriterOptions) (*singleFieldRecordWriter, error) { + w := &singleFieldRecordWriter{ + fieldId: fieldId, + schema: arrow.NewSchema([]arrow.Field{field}, nil), + writerProps: parquet.NewWriterProperties( + parquet.WithMaxRowGroupLength(math.MaxInt64), // No additional grouping for now. parquet.WithCompression(compress.Codecs.Zstd), parquet.WithCompressionLevel(3)), - pqarrow.DefaultWriterProps()) + } + for _, o := range opts { + o(w) + } + fw, err := pqarrow.NewFileWriter(w.schema, writer, w.writerProps, pqarrow.DefaultWriterProps()) if err != nil { return nil, err } - return &singleFieldRecordWriter{ - fw: fw, - fieldId: fieldId, - schema: schema, - }, nil + w.fw = fw + return w, nil } var _ RecordWriter = (*multiFieldRecordWriter)(nil) diff --git a/internal/storage/serde_events.go b/internal/storage/serde_events.go index 609f9e5c26d8e..3a30771399f56 100644 --- a/internal/storage/serde_events.go +++ b/internal/storage/serde_events.go @@ -279,7 +279,7 @@ func (bsw *BinlogStreamWriter) GetRecordWriter() (RecordWriter, error) { Name: strconv.Itoa(int(fid)), Type: serdeMap[bsw.fieldSchema.DataType].arrowType(int(dim)), Nullable: true, // No nullable check here. - }, &bsw.buf) + }, &bsw.buf, WithRecordWriterProps(getFieldWriterProps(bsw.fieldSchema))) if err != nil { return nil, err } @@ -431,7 +431,7 @@ func (dsw *DeltalogStreamWriter) GetRecordWriter() (RecordWriter, error) { Name: dsw.fieldSchema.Name, Type: serdeMap[dsw.fieldSchema.DataType].arrowType(int(dim)), Nullable: false, - }, &dsw.buf) + }, &dsw.buf, WithRecordWriterProps(getFieldWriterProps(dsw.fieldSchema))) if err != nil { return nil, err } diff --git a/internal/storage/serde_events_test.go b/internal/storage/serde_events_test.go index 4e5733b364bc5..83953de999453 100644 --- a/internal/storage/serde_events_test.go +++ b/internal/storage/serde_events_test.go @@ -141,6 +141,11 @@ func TestBinlogSerializeWriter(t *testing.T) { assert.NoError(t, err) } + for _, f := range schema.Fields { + props := writers[f.FieldID].rw.writerProps + assert.Equal(t, !f.IsPrimaryKey, props.DictionaryEnabled()) + } + err = reader.Next() assert.Equal(t, io.EOF, err) err = writer.Close() @@ -158,8 +163,13 @@ func TestBinlogSerializeWriter(t *testing.T) { newblobs[i] = blob i++ } + // Both field pk and field 17 are with datatype string and auto id + // in test data. Field pk uses delta byte array encoding, while + // field 17 uses dict encoding. + assert.Less(t, writers[16].buf.Len(), writers[17].buf.Len()) + // assert.Equal(t, blobs[0].Value, newblobs[0].Value) - reader, err = NewBinlogDeserializeReader(blobs, common.RowIDField) + reader, err = NewBinlogDeserializeReader(newblobs, common.RowIDField) assert.NoError(t, err) defer reader.Close() for i := 1; i <= size; i++ { diff --git a/internal/storage/utils.go b/internal/storage/utils.go index c8b16328f9d11..a9d616f198d51 100644 --- a/internal/storage/utils.go +++ b/internal/storage/utils.go @@ -1303,3 +1303,21 @@ func GetFilesSize(ctx context.Context, paths []string, cm ChunkManager) (int64, } return totalSize, nil } + +type NullableInt struct { + Value *int +} + +// NewNullableInt creates a new NullableInt instance +func NewNullableInt(value int) *NullableInt { + return &NullableInt{Value: &value} +} + +func (ni NullableInt) GetValue() int { + return *ni.Value +} + +// IsNull checks if the NullableInt is null +func (ni NullableInt) IsNull() bool { + return ni.Value == nil +} diff --git a/internal/util/importutilv2/binlog/reader_test.go b/internal/util/importutilv2/binlog/reader_test.go index a179374723732..f734786143927 100644 --- a/internal/util/importutilv2/binlog/reader_test.go +++ b/internal/util/importutilv2/binlog/reader_test.go @@ -81,7 +81,7 @@ func createBinlogBuf(t *testing.T, field *schemapb.FieldSchema, data storage.Fie dim = 1 } - evt, err := w.NextInsertEventWriter(false, int(dim)) + evt, err := w.NextInsertEventWriter(storage.WithDim(int(dim))) assert.NoError(t, err) evt.SetEventTimestamp(1, math.MaxInt64)