diff --git a/internal/storage/binlog_test.go b/internal/storage/binlog_test.go index b5058ab6fa562..6a7876dd820f9 100644 --- a/internal/storage/binlog_test.go +++ b/internal/storage/binlog_test.go @@ -39,7 +39,7 @@ import ( func TestInsertBinlog(t *testing.T) { w := NewInsertBinlogWriter(schemapb.DataType_Int64, 10, 20, 30, 40, false) - e1, err := w.NextInsertEventWriter(false) + e1, err := w.NextInsertEventWriter() assert.NoError(t, err) err = e1.AddDataToPayload([]int64{1, 2, 3}, nil) assert.NoError(t, err) @@ -49,7 +49,7 @@ func TestInsertBinlog(t *testing.T) { assert.NoError(t, err) e1.SetEventTimestamp(100, 200) - e2, err := w.NextInsertEventWriter(false) + e2, err := w.NextInsertEventWriter() assert.NoError(t, err) err = e2.AddDataToPayload([]int64{7, 8, 9}, nil) assert.NoError(t, err) @@ -1329,7 +1329,7 @@ func TestNewBinlogReaderError(t *testing.T) { w.SetEventTimeStamp(1000, 2000) - e1, err := w.NextInsertEventWriter(false) + e1, err := w.NextInsertEventWriter() assert.NoError(t, err) err = e1.AddDataToPayload([]int64{1, 2, 3}, nil) assert.NoError(t, err) @@ -1393,7 +1393,7 @@ func TestNewBinlogWriterTsError(t *testing.T) { func TestInsertBinlogWriterCloseError(t *testing.T) { insertWriter := NewInsertBinlogWriter(schemapb.DataType_Int64, 10, 20, 30, 40, false) - e1, err := insertWriter.NextInsertEventWriter(false) + e1, err := insertWriter.NextInsertEventWriter() assert.NoError(t, err) sizeTotal := 2000000 @@ -1406,7 +1406,7 @@ func TestInsertBinlogWriterCloseError(t *testing.T) { err = insertWriter.Finish() assert.NoError(t, err) assert.NotNil(t, insertWriter.buffer) - insertEventWriter, err := insertWriter.NextInsertEventWriter(false) + insertEventWriter, err := insertWriter.NextInsertEventWriter() assert.Nil(t, insertEventWriter) assert.Error(t, err) insertWriter.Close() diff --git a/internal/storage/binlog_writer.go b/internal/storage/binlog_writer.go index 2e716f31e852e..dee80c99883a7 100644 --- a/internal/storage/binlog_writer.go +++ b/internal/storage/binlog_writer.go @@ -23,7 +23,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/util/typeutil" ) // BinlogType is to distinguish different files saving different data. @@ -150,21 +149,12 @@ type InsertBinlogWriter struct { } // NextInsertEventWriter returns an event writer to write insert data to an event. -func (writer *InsertBinlogWriter) NextInsertEventWriter(nullable bool, dim ...int) (*insertEventWriter, error) { +func (writer *InsertBinlogWriter) NextInsertEventWriter(options ...PayloadWriterOptions) (*insertEventWriter, error) { if writer.isClosed() { return nil, fmt.Errorf("binlog has closed") } - var event *insertEventWriter - var err error - if typeutil.IsVectorType(writer.PayloadDataType) && !typeutil.IsSparseFloatVectorType(writer.PayloadDataType) { - if len(dim) != 1 { - return nil, fmt.Errorf("incorrect input numbers") - } - event, err = newInsertEventWriter(writer.PayloadDataType, nullable, dim[0]) - } else { - event, err = newInsertEventWriter(writer.PayloadDataType, nullable) - } + event, err := newInsertEventWriter(writer.PayloadDataType, options...) if err != nil { return nil, err } diff --git a/internal/storage/binlog_writer_test.go b/internal/storage/binlog_writer_test.go index 02e25d32f3a00..ff1c928e2440f 100644 --- a/internal/storage/binlog_writer_test.go +++ b/internal/storage/binlog_writer_test.go @@ -32,7 +32,7 @@ func TestBinlogWriterReader(t *testing.T) { binlogWriter.SetEventTimeStamp(1000, 2000) defer binlogWriter.Close() - eventWriter, err := binlogWriter.NextInsertEventWriter(false) + eventWriter, err := binlogWriter.NextInsertEventWriter() assert.NoError(t, err) err = eventWriter.AddInt32ToPayload([]int32{1, 2, 3}, nil) assert.NoError(t, err) diff --git a/internal/storage/data_codec.go b/internal/storage/data_codec.go index c8a1babefce90..7c485bcc476ac 100644 --- a/internal/storage/data_codec.go +++ b/internal/storage/data_codec.go @@ -243,31 +243,18 @@ func (insertCodec *InsertCodec) Serialize(partitionID UniqueID, segmentID Unique for _, field := range insertCodec.Schema.Schema.Fields { // encode fields writer = NewInsertBinlogWriter(field.DataType, insertCodec.Schema.ID, partitionID, segmentID, field.FieldID, field.GetNullable()) - var eventWriter *insertEventWriter - var err error - var dim int64 - if typeutil.IsVectorType(field.DataType) { - if field.GetNullable() { - return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("vectorType not support null, fieldName: %s", field.GetName())) - } - switch field.DataType { - case schemapb.DataType_FloatVector, - schemapb.DataType_BinaryVector, - schemapb.DataType_Float16Vector, - schemapb.DataType_BFloat16Vector: - dim, err = typeutil.GetDim(field) - if err != nil { - return nil, err - } - eventWriter, err = writer.NextInsertEventWriter(field.GetNullable(), int(dim)) - case schemapb.DataType_SparseFloatVector: - eventWriter, err = writer.NextInsertEventWriter(field.GetNullable()) - default: - return nil, fmt.Errorf("undefined data type %d", field.DataType) + + // get payload writing configs, including nullable and fallback encoding method + options := []PayloadWriterOptions{WithNullable(field.GetNullable()), WithWriterProps(GetWriterPropsByDataType(field.DataType))} + + if typeutil.IsVectorType(field.DataType) && !typeutil.IsSparseFloatVectorType(field.DataType) { + dim, err := typeutil.GetDim(field) + if err != nil { + return nil, err } - } else { - eventWriter, err = writer.NextInsertEventWriter(field.GetNullable()) + options = append(options, WithDim(int(dim))) } + eventWriter, err := writer.NextInsertEventWriter(options...) if err != nil { writer.Close() return nil, err diff --git a/internal/storage/data_codec_test.go b/internal/storage/data_codec_test.go index b37886cd20a00..cbdec1414c589 100644 --- a/internal/storage/data_codec_test.go +++ b/internal/storage/data_codec_test.go @@ -977,7 +977,7 @@ func TestDeleteData(t *testing.T) { func TestAddFieldDataToPayload(t *testing.T) { w := NewInsertBinlogWriter(schemapb.DataType_Int64, 10, 20, 30, 40, false) - e, _ := w.NextInsertEventWriter(false) + e, _ := w.NextInsertEventWriter() var err error err = AddFieldDataToPayload(e, schemapb.DataType_Bool, &BoolFieldData{[]bool{}, nil}) assert.Error(t, err) diff --git a/internal/storage/event_test.go b/internal/storage/event_test.go index 3f4ada4076a78..02eecc61624e7 100644 --- a/internal/storage/event_test.go +++ b/internal/storage/event_test.go @@ -195,7 +195,7 @@ func TestInsertEvent(t *testing.T) { } t.Run("insert_bool", func(t *testing.T) { - w, err := newInsertEventWriter(schemapb.DataType_Bool, false) + w, err := newInsertEventWriter(schemapb.DataType_Bool) assert.NoError(t, err) insertT(t, schemapb.DataType_Bool, w, func(w *insertEventWriter) error { @@ -211,7 +211,7 @@ func TestInsertEvent(t *testing.T) { }) t.Run("insert_int8", func(t *testing.T) { - w, err := newInsertEventWriter(schemapb.DataType_Int8, false) + w, err := newInsertEventWriter(schemapb.DataType_Int8) assert.NoError(t, err) insertT(t, schemapb.DataType_Int8, w, func(w *insertEventWriter) error { @@ -227,7 +227,7 @@ func TestInsertEvent(t *testing.T) { }) t.Run("insert_int16", func(t *testing.T) { - w, err := newInsertEventWriter(schemapb.DataType_Int16, false) + w, err := newInsertEventWriter(schemapb.DataType_Int16) assert.NoError(t, err) insertT(t, schemapb.DataType_Int16, w, func(w *insertEventWriter) error { @@ -243,7 +243,7 @@ func TestInsertEvent(t *testing.T) { }) t.Run("insert_int32", func(t *testing.T) { - w, err := newInsertEventWriter(schemapb.DataType_Int32, false) + w, err := newInsertEventWriter(schemapb.DataType_Int32) assert.NoError(t, err) insertT(t, schemapb.DataType_Int32, w, func(w *insertEventWriter) error { @@ -259,7 +259,7 @@ func TestInsertEvent(t *testing.T) { }) t.Run("insert_int64", func(t *testing.T) { - w, err := newInsertEventWriter(schemapb.DataType_Int64, false) + w, err := newInsertEventWriter(schemapb.DataType_Int64) assert.NoError(t, err) insertT(t, schemapb.DataType_Int64, w, func(w *insertEventWriter) error { @@ -275,7 +275,7 @@ func TestInsertEvent(t *testing.T) { }) t.Run("insert_float32", func(t *testing.T) { - w, err := newInsertEventWriter(schemapb.DataType_Float, false) + w, err := newInsertEventWriter(schemapb.DataType_Float) assert.NoError(t, err) insertT(t, schemapb.DataType_Float, w, func(w *insertEventWriter) error { @@ -291,7 +291,7 @@ func TestInsertEvent(t *testing.T) { }) t.Run("insert_float64", func(t *testing.T) { - w, err := newInsertEventWriter(schemapb.DataType_Double, false) + w, err := newInsertEventWriter(schemapb.DataType_Double) assert.NoError(t, err) insertT(t, schemapb.DataType_Double, w, func(w *insertEventWriter) error { @@ -307,7 +307,7 @@ func TestInsertEvent(t *testing.T) { }) t.Run("insert_binary_vector", func(t *testing.T) { - w, err := newInsertEventWriter(schemapb.DataType_BinaryVector, false, 16) + w, err := newInsertEventWriter(schemapb.DataType_BinaryVector, WithDim(16)) assert.NoError(t, err) insertT(t, schemapb.DataType_BinaryVector, w, func(w *insertEventWriter) error { @@ -323,7 +323,7 @@ func TestInsertEvent(t *testing.T) { }) t.Run("insert_float_vector", func(t *testing.T) { - w, err := newInsertEventWriter(schemapb.DataType_FloatVector, false, 2) + w, err := newInsertEventWriter(schemapb.DataType_FloatVector, WithDim(2)) assert.NoError(t, err) insertT(t, schemapb.DataType_FloatVector, w, func(w *insertEventWriter) error { @@ -339,7 +339,7 @@ func TestInsertEvent(t *testing.T) { }) t.Run("insert_string", func(t *testing.T) { - w, err := newInsertEventWriter(schemapb.DataType_String, false) + w, err := newInsertEventWriter(schemapb.DataType_String) assert.NoError(t, err) w.SetEventTimestamp(tsoutil.ComposeTS(10, 0), tsoutil.ComposeTS(100, 0)) err = w.AddDataToPayload("1234", nil) @@ -1101,7 +1101,7 @@ func TestEventReaderError(t *testing.T) { } func TestEventClose(t *testing.T) { - w, err := newInsertEventWriter(schemapb.DataType_String, false) + w, err := newInsertEventWriter(schemapb.DataType_String) assert.NoError(t, err) w.SetEventTimestamp(tsoutil.ComposeTS(10, 0), tsoutil.ComposeTS(100, 0)) err = w.AddDataToPayload("1234", nil) diff --git a/internal/storage/event_writer.go b/internal/storage/event_writer.go index 6b9390da0a387..cc3c4d6ca92bf 100644 --- a/internal/storage/event_writer.go +++ b/internal/storage/event_writer.go @@ -19,14 +19,12 @@ package storage import ( "bytes" "encoding/binary" - "fmt" "io" "github.com/cockroachdb/errors" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/util/typeutil" ) // EventTypeCode represents event type by code @@ -222,17 +220,8 @@ func NewBaseDescriptorEvent(collectionID int64, partitionID int64, segmentID int return de } -func newInsertEventWriter(dataType schemapb.DataType, nullable bool, dim ...int) (*insertEventWriter, error) { - var payloadWriter PayloadWriterInterface - var err error - if typeutil.IsVectorType(dataType) && !typeutil.IsSparseFloatVectorType(dataType) { - if len(dim) != 1 { - return nil, fmt.Errorf("incorrect input numbers") - } - payloadWriter, err = NewPayloadWriter(dataType, nullable, dim[0]) - } else { - payloadWriter, err = NewPayloadWriter(dataType, nullable) - } +func newInsertEventWriter(dataType schemapb.DataType, options ...PayloadWriterOptions) (*insertEventWriter, error) { + payloadWriter, err := NewPayloadWriter(dataType, options...) if err != nil { return nil, err } @@ -254,7 +243,7 @@ func newInsertEventWriter(dataType schemapb.DataType, nullable bool, dim ...int) } func newDeleteEventWriter(dataType schemapb.DataType) (*deleteEventWriter, error) { - payloadWriter, err := NewPayloadWriter(dataType, false) + payloadWriter, err := NewPayloadWriter(dataType) if err != nil { return nil, err } @@ -280,7 +269,7 @@ func newCreateCollectionEventWriter(dataType schemapb.DataType) (*createCollecti return nil, errors.New("incorrect data type") } - payloadWriter, err := NewPayloadWriter(dataType, false) + payloadWriter, err := NewPayloadWriter(dataType) if err != nil { return nil, err } @@ -306,7 +295,7 @@ func newDropCollectionEventWriter(dataType schemapb.DataType) (*dropCollectionEv return nil, errors.New("incorrect data type") } - payloadWriter, err := NewPayloadWriter(dataType, false) + payloadWriter, err := NewPayloadWriter(dataType) if err != nil { return nil, err } @@ -332,7 +321,7 @@ func newCreatePartitionEventWriter(dataType schemapb.DataType) (*createPartition return nil, errors.New("incorrect data type") } - payloadWriter, err := NewPayloadWriter(dataType, false) + payloadWriter, err := NewPayloadWriter(dataType) if err != nil { return nil, err } @@ -358,7 +347,7 @@ func newDropPartitionEventWriter(dataType schemapb.DataType) (*dropPartitionEven return nil, errors.New("incorrect data type") } - payloadWriter, err := NewPayloadWriter(dataType, false) + payloadWriter, err := NewPayloadWriter(dataType) if err != nil { return nil, err } @@ -380,7 +369,7 @@ func newDropPartitionEventWriter(dataType schemapb.DataType) (*dropPartitionEven } func newIndexFileEventWriter(dataType schemapb.DataType) (*indexFileEventWriter, error) { - payloadWriter, err := NewPayloadWriter(dataType, false) + payloadWriter, err := NewPayloadWriter(dataType) if err != nil { return nil, err } diff --git a/internal/storage/event_writer_test.go b/internal/storage/event_writer_test.go index 9b4997edcaaac..160bd78666f69 100644 --- a/internal/storage/event_writer_test.go +++ b/internal/storage/event_writer_test.go @@ -59,11 +59,11 @@ func TestSizeofStruct(t *testing.T) { } func TestEventWriter(t *testing.T) { - insertEvent, err := newInsertEventWriter(schemapb.DataType_Int32, false) + insertEvent, err := newInsertEventWriter(schemapb.DataType_Int32) assert.NoError(t, err) insertEvent.Close() - insertEvent, err = newInsertEventWriter(schemapb.DataType_Int32, false) + insertEvent, err = newInsertEventWriter(schemapb.DataType_Int32) assert.NoError(t, err) defer insertEvent.Close() diff --git a/internal/storage/payload_test.go b/internal/storage/payload_test.go index b477eed5c6152..82dd64498a31f 100644 --- a/internal/storage/payload_test.go +++ b/internal/storage/payload_test.go @@ -32,7 +32,7 @@ import ( func TestPayload_ReaderAndWriter(t *testing.T) { t.Run("TestBool", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -69,7 +69,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestInt8", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int8, false) + w, err := NewPayloadWriter(schemapb.DataType_Int8) require.Nil(t, err) require.NotNil(t, w) @@ -109,7 +109,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestInt16", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int16, false) + w, err := NewPayloadWriter(schemapb.DataType_Int16) require.Nil(t, err) require.NotNil(t, w) @@ -147,7 +147,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestInt32", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int32, false) + w, err := NewPayloadWriter(schemapb.DataType_Int32) require.Nil(t, err) require.NotNil(t, w) @@ -186,7 +186,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestInt64", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int64, false) + w, err := NewPayloadWriter(schemapb.DataType_Int64) require.Nil(t, err) require.NotNil(t, w) @@ -225,7 +225,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestFloat32", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Float, false) + w, err := NewPayloadWriter(schemapb.DataType_Float) require.Nil(t, err) require.NotNil(t, w) @@ -264,7 +264,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestDouble", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Double, false) + w, err := NewPayloadWriter(schemapb.DataType_Double) require.Nil(t, err) require.NotNil(t, w) @@ -303,7 +303,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddString", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_String, false) + w, err := NewPayloadWriter(schemapb.DataType_String) require.Nil(t, err) require.NotNil(t, w) @@ -351,7 +351,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddArray", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Array, false) + w, err := NewPayloadWriter(schemapb.DataType_Array) require.Nil(t, err) require.NotNil(t, w) @@ -423,7 +423,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddJSON", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_JSON, false) + w, err := NewPayloadWriter(schemapb.DataType_JSON) require.Nil(t, err) require.NotNil(t, w) @@ -471,7 +471,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestBinaryVector", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_BinaryVector, false, 8) + w, err := NewPayloadWriter(schemapb.DataType_BinaryVector, WithDim(8)) require.Nil(t, err) require.NotNil(t, w) @@ -520,7 +520,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestFloatVector", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_FloatVector, false, 1) + w, err := NewPayloadWriter(schemapb.DataType_FloatVector, WithDim(1)) require.Nil(t, err) require.NotNil(t, w) @@ -562,7 +562,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestFloat16Vector", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Float16Vector, false, 1) + w, err := NewPayloadWriter(schemapb.DataType_Float16Vector, WithDim(1)) require.Nil(t, err) require.NotNil(t, w) @@ -604,7 +604,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestBFloat16Vector", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_BFloat16Vector, false, 1) + w, err := NewPayloadWriter(schemapb.DataType_BFloat16Vector, WithDim(1)) require.Nil(t, err) require.NotNil(t, w) @@ -646,7 +646,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestSparseFloatVector", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_SparseFloatVector, false) + w, err := NewPayloadWriter(schemapb.DataType_SparseFloatVector) require.Nil(t, err) require.NotNil(t, w) @@ -715,7 +715,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) testSparseOneBatch := func(t *testing.T, rows [][]byte, actualDim int) { - w, err := NewPayloadWriter(schemapb.DataType_SparseFloatVector, false) + w, err := NewPayloadWriter(schemapb.DataType_SparseFloatVector) require.Nil(t, err) require.NotNil(t, w) @@ -811,31 +811,8 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }, int(int32Max)) }) - // t.Run("TestAddDataToPayload", func(t *testing.T) { - // w, err := NewPayloadWriter(schemapb.DataType_Bool) - // w.colType = 999 - // require.Nil(t, err) - // require.NotNil(t, w) - - // err = w.AddDataToPayload([]bool{false, false, false, false}) - // assert.NotNil(t, err) - - // err = w.AddDataToPayload([]bool{false, false, false, false}, 0) - // assert.NotNil(t, err) - - // err = w.AddDataToPayload([]bool{false, false, false, false}, 0, 0) - // assert.NotNil(t, err) - - // err = w.AddBoolToPayload([]bool{}) - // assert.NotNil(t, err) - // err = w.FinishPayloadWriter() - // assert.Nil(t, err) - // err = w.AddBoolToPayload([]bool{false}) - // assert.NotNil(t, err) - // }) - t.Run("TestAddBoolAfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -851,7 +828,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddInt8AfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int8, false) + w, err := NewPayloadWriter(schemapb.DataType_Int8) require.Nil(t, err) require.NotNil(t, w) defer w.Close() @@ -867,7 +844,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestAddInt16AfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int16, false) + w, err := NewPayloadWriter(schemapb.DataType_Int16) require.Nil(t, err) require.NotNil(t, w) defer w.Close() @@ -883,7 +860,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestAddInt32AfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int32, false) + w, err := NewPayloadWriter(schemapb.DataType_Int32) require.Nil(t, err) require.NotNil(t, w) defer w.Close() @@ -899,7 +876,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestAddInt64AfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int64, false) + w, err := NewPayloadWriter(schemapb.DataType_Int64) require.Nil(t, err) require.NotNil(t, w) defer w.Close() @@ -915,7 +892,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestAddFloatAfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Float, false) + w, err := NewPayloadWriter(schemapb.DataType_Float) require.Nil(t, err) require.NotNil(t, w) defer w.Close() @@ -931,7 +908,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestAddDoubleAfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Double, false) + w, err := NewPayloadWriter(schemapb.DataType_Double) require.Nil(t, err) require.NotNil(t, w) defer w.Close() @@ -947,7 +924,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestAddOneStringAfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_String, false) + w, err := NewPayloadWriter(schemapb.DataType_String) require.Nil(t, err) require.NotNil(t, w) defer w.Close() @@ -963,7 +940,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestAddBinVectorAfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_BinaryVector, false, 8) + w, err := NewPayloadWriter(schemapb.DataType_BinaryVector, WithDim(8)) require.Nil(t, err) require.NotNil(t, w) defer w.Close() @@ -987,7 +964,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestAddFloatVectorAfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_FloatVector, false, 8) + w, err := NewPayloadWriter(schemapb.DataType_FloatVector, WithDim(8)) require.Nil(t, err) require.NotNil(t, w) defer w.Close() @@ -1008,7 +985,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestAddFloat16VectorAfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Float16Vector, false, 8) + w, err := NewPayloadWriter(schemapb.DataType_Float16Vector, WithDim(8)) require.Nil(t, err) require.NotNil(t, w) defer w.Close() @@ -1032,7 +1009,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestAddBFloat16VectorAfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_BFloat16Vector, false, 8) + w, err := NewPayloadWriter(schemapb.DataType_BFloat16Vector, WithDim(8)) require.Nil(t, err) require.NotNil(t, w) defer w.Close() @@ -1056,7 +1033,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestAddSparseFloatVectorAfterFinish", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_SparseFloatVector, false) + w, err := NewPayloadWriter(schemapb.DataType_SparseFloatVector) require.Nil(t, err) require.NotNil(t, w) defer w.Close() @@ -1100,7 +1077,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetBoolError", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int8, false) + w, err := NewPayloadWriter(schemapb.DataType_Int8) require.Nil(t, err) require.NotNil(t, w) @@ -1124,7 +1101,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetBoolError2", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -1145,7 +1122,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetInt8Error", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -1169,7 +1146,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetInt8Error2", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int8, false) + w, err := NewPayloadWriter(schemapb.DataType_Int8) require.Nil(t, err) require.NotNil(t, w) @@ -1190,7 +1167,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetInt16Error", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -1214,7 +1191,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetInt16Error2", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int16, false) + w, err := NewPayloadWriter(schemapb.DataType_Int16) require.Nil(t, err) require.NotNil(t, w) @@ -1235,7 +1212,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetInt32Error", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -1259,7 +1236,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetInt32Error2", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int32, false) + w, err := NewPayloadWriter(schemapb.DataType_Int32) require.Nil(t, err) require.NotNil(t, w) @@ -1280,7 +1257,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetInt64Error", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -1304,7 +1281,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetInt64Error2", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int64, false) + w, err := NewPayloadWriter(schemapb.DataType_Int64) require.Nil(t, err) require.NotNil(t, w) @@ -1325,7 +1302,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetFloatError", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -1349,7 +1326,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetFloatError2", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Float, false) + w, err := NewPayloadWriter(schemapb.DataType_Float) require.Nil(t, err) require.NotNil(t, w) @@ -1370,7 +1347,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetDoubleError", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -1394,7 +1371,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetDoubleError2", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Double, false) + w, err := NewPayloadWriter(schemapb.DataType_Double) require.Nil(t, err) require.NotNil(t, w) @@ -1415,7 +1392,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetStringError", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -1439,7 +1416,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetStringError2", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_String, false) + w, err := NewPayloadWriter(schemapb.DataType_String) require.Nil(t, err) require.NotNil(t, w) @@ -1464,7 +1441,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetArrayError", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -1488,7 +1465,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetBinaryVectorError", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -1512,7 +1489,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetBinaryVectorError2", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_BinaryVector, false, 8) + w, err := NewPayloadWriter(schemapb.DataType_BinaryVector, WithDim(8)) require.Nil(t, err) require.NotNil(t, w) @@ -1533,7 +1510,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetFloatVectorError", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -1557,7 +1534,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.Error(t, err) }) t.Run("TestGetFloatVectorError2", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_FloatVector, false, 8) + w, err := NewPayloadWriter(schemapb.DataType_FloatVector, WithDim(8)) require.Nil(t, err) require.NotNil(t, w) @@ -1579,7 +1556,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestByteArrayDatasetError", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_String, false) + w, err := NewPayloadWriter(schemapb.DataType_String) require.Nil(t, err) require.NotNil(t, w) @@ -1619,7 +1596,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { vec = append(vec, 1) } - w, err := NewPayloadWriter(schemapb.DataType_FloatVector, false) + w, err := NewPayloadWriter(schemapb.DataType_FloatVector) assert.NoError(t, err) err = w.AddFloatVectorToPayload(vec, 128) @@ -1635,7 +1612,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddBool with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -1644,7 +1621,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddInt8 with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int8, false) + w, err := NewPayloadWriter(schemapb.DataType_Int8) require.Nil(t, err) require.NotNil(t, w) @@ -1653,7 +1630,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddInt16 with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int16, false) + w, err := NewPayloadWriter(schemapb.DataType_Int16) require.Nil(t, err) require.NotNil(t, w) @@ -1662,7 +1639,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddInt32 with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int32, false) + w, err := NewPayloadWriter(schemapb.DataType_Int32) require.Nil(t, err) require.NotNil(t, w) @@ -1671,7 +1648,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddInt64 with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int64, false) + w, err := NewPayloadWriter(schemapb.DataType_Int64) require.Nil(t, err) require.NotNil(t, w) @@ -1680,7 +1657,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddFloat32 with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Float, false) + w, err := NewPayloadWriter(schemapb.DataType_Float) require.Nil(t, err) require.NotNil(t, w) @@ -1689,7 +1666,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddDouble with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Double, false) + w, err := NewPayloadWriter(schemapb.DataType_Double) require.Nil(t, err) require.NotNil(t, w) @@ -1698,7 +1675,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddAddString with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_String, false) + w, err := NewPayloadWriter(schemapb.DataType_String) require.Nil(t, err) require.NotNil(t, w) @@ -1707,7 +1684,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddArray with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Array, false) + w, err := NewPayloadWriter(schemapb.DataType_Array) require.Nil(t, err) require.NotNil(t, w) @@ -1722,7 +1699,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { }) t.Run("TestAddJSON with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_JSON, false) + w, err := NewPayloadWriter(schemapb.DataType_JSON) require.Nil(t, err) require.NotNil(t, w) @@ -1733,7 +1710,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { func TestPayload_NullableReaderAndWriter(t *testing.T) { t.Run("TestBool", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, true) + w, err := NewPayloadWriter(schemapb.DataType_Bool, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -1770,7 +1747,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestInt8", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int8, true) + w, err := NewPayloadWriter(schemapb.DataType_Int8, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -1810,7 +1787,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestInt16", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int16, true) + w, err := NewPayloadWriter(schemapb.DataType_Int16, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -1848,7 +1825,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestInt32", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int32, true) + w, err := NewPayloadWriter(schemapb.DataType_Int32, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -1887,7 +1864,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestInt64", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int64, true) + w, err := NewPayloadWriter(schemapb.DataType_Int64, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -1926,7 +1903,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestFloat32", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Float, true) + w, err := NewPayloadWriter(schemapb.DataType_Float, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -1965,7 +1942,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestDouble", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Double, true) + w, err := NewPayloadWriter(schemapb.DataType_Double, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -2004,7 +1981,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestAddString", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_String, true) + w, err := NewPayloadWriter(schemapb.DataType_String, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -2052,7 +2029,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestAddArray", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Array, true) + w, err := NewPayloadWriter(schemapb.DataType_Array, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -2124,7 +2101,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestAddJSON", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_JSON, true) + w, err := NewPayloadWriter(schemapb.DataType_JSON, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -2172,22 +2149,22 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestBinaryVector", func(t *testing.T) { - _, err := NewPayloadWriter(schemapb.DataType_BinaryVector, true, 8) + _, err := NewPayloadWriter(schemapb.DataType_BinaryVector, WithNullable(true), WithDim(8)) assert.ErrorIs(t, err, merr.ErrParameterInvalid) }) t.Run("TestFloatVector", func(t *testing.T) { - _, err := NewPayloadWriter(schemapb.DataType_FloatVector, true, 1) + _, err := NewPayloadWriter(schemapb.DataType_FloatVector, WithNullable(true), WithDim(1)) assert.ErrorIs(t, err, merr.ErrParameterInvalid) }) t.Run("TestFloat16Vector", func(t *testing.T) { - _, err := NewPayloadWriter(schemapb.DataType_Float16Vector, true, 1) + _, err := NewPayloadWriter(schemapb.DataType_Float16Vector, WithNullable(true), WithDim(1)) assert.ErrorIs(t, err, merr.ErrParameterInvalid) }) t.Run("TestAddBool with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, true) + w, err := NewPayloadWriter(schemapb.DataType_Bool, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -2196,7 +2173,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestAddInt8 with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int8, true) + w, err := NewPayloadWriter(schemapb.DataType_Int8, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -2205,7 +2182,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestAddInt16 with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int16, true) + w, err := NewPayloadWriter(schemapb.DataType_Int16, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -2214,7 +2191,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestAddInt32 with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int32, true) + w, err := NewPayloadWriter(schemapb.DataType_Int32, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -2223,7 +2200,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestAddInt64 with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int64, true) + w, err := NewPayloadWriter(schemapb.DataType_Int64, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -2232,7 +2209,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestAddFloat32 with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Float, true) + w, err := NewPayloadWriter(schemapb.DataType_Float, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -2241,7 +2218,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestAddDouble with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Double, true) + w, err := NewPayloadWriter(schemapb.DataType_Double, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -2250,25 +2227,25 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestAddAddString with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_String, true) + w, err := NewPayloadWriter(schemapb.DataType_String, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) err = w.AddDataToPayload("hello0", nil) assert.ErrorIs(t, err, merr.ErrParameterInvalid) - w, err = NewPayloadWriter(schemapb.DataType_String, true) + w, err = NewPayloadWriter(schemapb.DataType_String, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) err = w.AddDataToPayload("hello0", []bool{false, false}) assert.ErrorIs(t, err, merr.ErrParameterInvalid) - w, err = NewPayloadWriter(schemapb.DataType_String, false) + w, err = NewPayloadWriter(schemapb.DataType_String) require.Nil(t, err) require.NotNil(t, w) err = w.AddDataToPayload("hello0", []bool{false}) assert.ErrorIs(t, err, merr.ErrParameterInvalid) - w, err = NewPayloadWriter(schemapb.DataType_String, false) + w, err = NewPayloadWriter(schemapb.DataType_String) require.Nil(t, err) require.NotNil(t, w) err = w.AddDataToPayload("hello0", []bool{true}) @@ -2276,7 +2253,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestAddArray with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Array, true) + w, err := NewPayloadWriter(schemapb.DataType_Array, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) err = w.AddDataToPayload(&schemapb.ScalarField{ @@ -2288,7 +2265,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }, nil) assert.ErrorIs(t, err, merr.ErrParameterInvalid) - w, err = NewPayloadWriter(schemapb.DataType_Array, true) + w, err = NewPayloadWriter(schemapb.DataType_Array, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) @@ -2301,7 +2278,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }, []bool{false, false}) assert.ErrorIs(t, err, merr.ErrParameterInvalid) - w, err = NewPayloadWriter(schemapb.DataType_Array, false) + w, err = NewPayloadWriter(schemapb.DataType_Array) require.Nil(t, err) require.NotNil(t, w) err = w.AddDataToPayload(&schemapb.ScalarField{ @@ -2313,7 +2290,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }, []bool{false}) assert.ErrorIs(t, err, merr.ErrParameterInvalid) - w, err = NewPayloadWriter(schemapb.DataType_Array, false) + w, err = NewPayloadWriter(schemapb.DataType_Array) require.Nil(t, err) require.NotNil(t, w) err = w.AddDataToPayload(&schemapb.ScalarField{ @@ -2327,25 +2304,25 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { }) t.Run("TestAddJSON with wrong valids", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_JSON, true) + w, err := NewPayloadWriter(schemapb.DataType_JSON, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) err = w.AddDataToPayload([]byte(`{"1":"1"}`), nil) assert.ErrorIs(t, err, merr.ErrParameterInvalid) - w, err = NewPayloadWriter(schemapb.DataType_JSON, true) + w, err = NewPayloadWriter(schemapb.DataType_JSON, WithNullable(true)) require.Nil(t, err) require.NotNil(t, w) err = w.AddDataToPayload([]byte(`{"1":"1"}`), []bool{false, false}) assert.ErrorIs(t, err, merr.ErrParameterInvalid) - w, err = NewPayloadWriter(schemapb.DataType_JSON, false) + w, err = NewPayloadWriter(schemapb.DataType_JSON) require.Nil(t, err) require.NotNil(t, w) err = w.AddDataToPayload([]byte(`{"1":"1"}`), []bool{false}) assert.ErrorIs(t, err, merr.ErrParameterInvalid) - w, err = NewPayloadWriter(schemapb.DataType_JSON, false) + w, err = NewPayloadWriter(schemapb.DataType_JSON) require.Nil(t, err) require.NotNil(t, w) err = w.AddDataToPayload([]byte(`{"1":"1"}`), []bool{true}) @@ -2355,7 +2332,7 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { func TestArrowRecordReader(t *testing.T) { t.Run("TestArrowRecordReader", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_String, false) + w, err := NewPayloadWriter(schemapb.DataType_String) assert.NoError(t, err) defer w.Close() @@ -2395,7 +2372,7 @@ func TestArrowRecordReader(t *testing.T) { } func dataGen(size int) ([]byte, error) { - w, err := NewPayloadWriter(schemapb.DataType_String, false) + w, err := NewPayloadWriter(schemapb.DataType_String) if err != nil { return nil, err } @@ -2422,7 +2399,7 @@ func dataGen(size int) ([]byte, error) { } func BenchmarkDefaultReader(b *testing.B) { - size := 1000000 + size := 10 buffer, err := dataGen(size) assert.NoError(b, err) @@ -2446,7 +2423,7 @@ func BenchmarkDefaultReader(b *testing.B) { } func BenchmarkDataSetReader(b *testing.B) { - size := 1000000 + size := 10 buffer, err := dataGen(size) assert.NoError(b, err) @@ -2474,7 +2451,7 @@ func BenchmarkDataSetReader(b *testing.B) { } func BenchmarkArrowRecordReader(b *testing.B) { - size := 1000000 + size := 10 buffer, err := dataGen(size) assert.NoError(b, err) diff --git a/internal/storage/payload_writer.go b/internal/storage/payload_writer.go index 8b8b00100564b..e2ac969719a06 100644 --- a/internal/storage/payload_writer.go +++ b/internal/storage/payload_writer.go @@ -39,6 +39,26 @@ import ( var _ PayloadWriterInterface = (*NativePayloadWriter)(nil) +type PayloadWriterOptions func(*NativePayloadWriter) + +func WithNullable(nullable bool) PayloadWriterOptions { + return func(w *NativePayloadWriter) { + w.nullable = nullable + } +} + +func WithWriterProps(writerProps *parquet.WriterProperties) PayloadWriterOptions { + return func(w *NativePayloadWriter) { + w.writerProps = writerProps + } +} + +func WithDim(dim int) PayloadWriterOptions { + return func(w *NativePayloadWriter) { + w.dim = NewNullableInt(dim) + } +} + type NativePayloadWriter struct { dataType schemapb.DataType arrowType arrow.DataType @@ -47,43 +67,42 @@ type NativePayloadWriter struct { flushedRows int output *bytes.Buffer releaseOnce sync.Once - dim int + dim *NullableInt nullable bool + writerProps *parquet.WriterProperties } -func NewPayloadWriter(colType schemapb.DataType, nullable bool, dim ...int) (PayloadWriterInterface, error) { - var arrowType arrow.DataType - var dimension int +func NewPayloadWriter(colType schemapb.DataType, options ...PayloadWriterOptions) (PayloadWriterInterface, error) { + w := &NativePayloadWriter{ + dataType: colType, + finished: false, + flushedRows: 0, + output: new(bytes.Buffer), + nullable: false, + writerProps: parquet.NewWriterProperties( + parquet.WithCompression(compress.Codecs.Zstd), + parquet.WithCompressionLevel(3), + ), + dim: &NullableInt{}, + } + for _, o := range options { + o(w) + } + // writer for sparse float vector doesn't require dim if typeutil.IsVectorType(colType) && !typeutil.IsSparseFloatVectorType(colType) { - if len(dim) != 1 { + if w.dim.IsNull() { return nil, merr.WrapErrParameterInvalidMsg("incorrect input numbers") } - if nullable { - return nil, merr.WrapErrParameterInvalidMsg("vector type not supprot nullable") + if w.nullable { + return nil, merr.WrapErrParameterInvalidMsg("vector type does not support nullable") } - arrowType = milvusDataTypeToArrowType(colType, dim[0]) - dimension = dim[0] } else { - if len(dim) != 0 { - return nil, merr.WrapErrParameterInvalidMsg("incorrect input numbers") - } - arrowType = milvusDataTypeToArrowType(colType, 1) - dimension = 1 + w.dim = NewNullableInt(1) } - - builder := array.NewBuilder(memory.DefaultAllocator, arrowType) - - return &NativePayloadWriter{ - dataType: colType, - arrowType: arrowType, - builder: builder, - finished: false, - flushedRows: 0, - output: new(bytes.Buffer), - dim: dimension, - nullable: nullable, - }, nil + w.arrowType = milvusDataTypeToArrowType(colType, *w.dim.Value) + w.builder = array.NewBuilder(memory.DefaultAllocator, w.arrowType) + return w, nil } func (w *NativePayloadWriter) AddDataToPayload(data interface{}, validData []bool) error { @@ -192,25 +211,25 @@ func (w *NativePayloadWriter) AddDataToPayload(data interface{}, validData []boo if !ok { return merr.WrapErrParameterInvalidMsg("incorrect data type") } - return w.AddBinaryVectorToPayload(val, w.dim) + return w.AddBinaryVectorToPayload(val, w.dim.GetValue()) case schemapb.DataType_FloatVector: val, ok := data.([]float32) if !ok { return merr.WrapErrParameterInvalidMsg("incorrect data type") } - return w.AddFloatVectorToPayload(val, w.dim) + return w.AddFloatVectorToPayload(val, w.dim.GetValue()) case schemapb.DataType_Float16Vector: val, ok := data.([]byte) if !ok { return merr.WrapErrParameterInvalidMsg("incorrect data type") } - return w.AddFloat16VectorToPayload(val, w.dim) + return w.AddFloat16VectorToPayload(val, w.dim.GetValue()) case schemapb.DataType_BFloat16Vector: val, ok := data.([]byte) if !ok { return merr.WrapErrParameterInvalidMsg("incorrect data type") } - return w.AddBFloat16VectorToPayload(val, w.dim) + return w.AddBFloat16VectorToPayload(val, w.dim.GetValue()) case schemapb.DataType_SparseFloatVector: val, ok := data.(*SparseFloatVectorFieldData) if !ok { @@ -674,14 +693,10 @@ func (w *NativePayloadWriter) FinishPayloadWriter() error { table := array.NewTable(schema, []arrow.Column{column}, int64(column.Len())) defer table.Release() - props := parquet.NewWriterProperties( - parquet.WithCompression(compress.Codecs.Zstd), - parquet.WithCompressionLevel(3), - ) return pqarrow.WriteTable(table, w.output, 1024*1024*1024, - props, + w.writerProps, pqarrow.DefaultWriterProps(), ) } diff --git a/internal/storage/payload_writer_test.go b/internal/storage/payload_writer_test.go index 0a8e5abfb4f05..2e62fc049d319 100644 --- a/internal/storage/payload_writer_test.go +++ b/internal/storage/payload_writer_test.go @@ -3,6 +3,9 @@ package storage import ( "testing" + "github.com/apache/arrow/go/v12/parquet" + "github.com/apache/arrow/go/v12/parquet/compress" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -10,14 +13,11 @@ import ( func TestPayloadWriter_Failed(t *testing.T) { t.Run("wrong input", func(t *testing.T) { - _, err := NewPayloadWriter(schemapb.DataType_FloatVector, false) - require.Error(t, err) - - _, err = NewPayloadWriter(schemapb.DataType_Bool, false, 1) + _, err := NewPayloadWriter(schemapb.DataType_FloatVector) require.Error(t, err) }) t.Run("Test Bool", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Bool, false) + w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) require.NotNil(t, w) @@ -30,7 +30,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddBoolToPayload([]bool{false}, nil) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Float, false) + w, err = NewPayloadWriter(schemapb.DataType_Float) require.Nil(t, err) require.NotNil(t, w) @@ -39,7 +39,7 @@ func TestPayloadWriter_Failed(t *testing.T) { }) t.Run("Test Byte", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int8, Params.CommonCfg.MaxBloomFalsePositive.PanicIfEmpty) + w, err := NewPayloadWriter(schemapb.DataType_Int8, WithNullable(Params.CommonCfg.MaxBloomFalsePositive.PanicIfEmpty)) require.Nil(t, err) require.NotNil(t, w) @@ -52,7 +52,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddByteToPayload([]byte{0}, nil) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Float, false) + w, err = NewPayloadWriter(schemapb.DataType_Float) require.Nil(t, err) require.NotNil(t, w) @@ -61,7 +61,7 @@ func TestPayloadWriter_Failed(t *testing.T) { }) t.Run("Test Int8", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int8, false) + w, err := NewPayloadWriter(schemapb.DataType_Int8) require.Nil(t, err) require.NotNil(t, w) @@ -74,7 +74,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddInt8ToPayload([]int8{0}, nil) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Float, false) + w, err = NewPayloadWriter(schemapb.DataType_Float) require.Nil(t, err) require.NotNil(t, w) @@ -83,7 +83,7 @@ func TestPayloadWriter_Failed(t *testing.T) { }) t.Run("Test Int16", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int16, false) + w, err := NewPayloadWriter(schemapb.DataType_Int16) require.Nil(t, err) require.NotNil(t, w) @@ -96,7 +96,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddInt16ToPayload([]int16{0}, nil) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Float, false) + w, err = NewPayloadWriter(schemapb.DataType_Float) require.Nil(t, err) require.NotNil(t, w) @@ -105,7 +105,7 @@ func TestPayloadWriter_Failed(t *testing.T) { }) t.Run("Test Int32", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int32, false) + w, err := NewPayloadWriter(schemapb.DataType_Int32) require.Nil(t, err) require.NotNil(t, w) @@ -118,7 +118,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddInt32ToPayload([]int32{0}, nil) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Float, false) + w, err = NewPayloadWriter(schemapb.DataType_Float) require.Nil(t, err) require.NotNil(t, w) @@ -127,7 +127,7 @@ func TestPayloadWriter_Failed(t *testing.T) { }) t.Run("Test Int64", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Int64, Params.CommonCfg.MaxBloomFalsePositive.PanicIfEmpty) + w, err := NewPayloadWriter(schemapb.DataType_Int64, WithNullable(Params.CommonCfg.MaxBloomFalsePositive.PanicIfEmpty)) require.Nil(t, err) require.NotNil(t, w) @@ -140,7 +140,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddInt64ToPayload([]int64{0}, nil) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Float, false) + w, err = NewPayloadWriter(schemapb.DataType_Float) require.Nil(t, err) require.NotNil(t, w) @@ -149,7 +149,7 @@ func TestPayloadWriter_Failed(t *testing.T) { }) t.Run("Test Float", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Float, false) + w, err := NewPayloadWriter(schemapb.DataType_Float) require.Nil(t, err) require.NotNil(t, w) @@ -162,7 +162,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddFloatToPayload([]float32{0}, nil) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Int64, false) + w, err = NewPayloadWriter(schemapb.DataType_Int64) require.Nil(t, err) require.NotNil(t, w) @@ -171,7 +171,7 @@ func TestPayloadWriter_Failed(t *testing.T) { }) t.Run("Test Double", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Double, false) + w, err := NewPayloadWriter(schemapb.DataType_Double) require.Nil(t, err) require.NotNil(t, w) @@ -184,7 +184,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddDoubleToPayload([]float64{0}, nil) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Int64, false) + w, err = NewPayloadWriter(schemapb.DataType_Int64) require.Nil(t, err) require.NotNil(t, w) @@ -193,7 +193,7 @@ func TestPayloadWriter_Failed(t *testing.T) { }) t.Run("Test String", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_String, false) + w, err := NewPayloadWriter(schemapb.DataType_String) require.Nil(t, err) require.NotNil(t, w) @@ -203,7 +203,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddOneStringToPayload("test", false) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Int64, false) + w, err = NewPayloadWriter(schemapb.DataType_Int64) require.Nil(t, err) require.NotNil(t, w) @@ -212,7 +212,7 @@ func TestPayloadWriter_Failed(t *testing.T) { }) t.Run("Test Array", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_Array, false) + w, err := NewPayloadWriter(schemapb.DataType_Array) require.Nil(t, err) require.NotNil(t, w) @@ -222,7 +222,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddOneArrayToPayload(&schemapb.ScalarField{}, false) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Int64, false) + w, err = NewPayloadWriter(schemapb.DataType_Int64) require.Nil(t, err) require.NotNil(t, w) @@ -231,7 +231,7 @@ func TestPayloadWriter_Failed(t *testing.T) { }) t.Run("Test Json", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_JSON, false) + w, err := NewPayloadWriter(schemapb.DataType_JSON) require.Nil(t, err) require.NotNil(t, w) @@ -241,7 +241,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddOneJSONToPayload([]byte{0, 1}, false) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Int64, false) + w, err = NewPayloadWriter(schemapb.DataType_Int64) require.Nil(t, err) require.NotNil(t, w) @@ -250,7 +250,7 @@ func TestPayloadWriter_Failed(t *testing.T) { }) t.Run("Test BinaryVector", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_BinaryVector, false, 8) + w, err := NewPayloadWriter(schemapb.DataType_BinaryVector, WithDim(8)) require.Nil(t, err) require.NotNil(t, w) @@ -265,7 +265,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddBinaryVectorToPayload(data, 8) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Int64, false) + w, err = NewPayloadWriter(schemapb.DataType_Int64) require.Nil(t, err) require.NotNil(t, w) @@ -274,7 +274,7 @@ func TestPayloadWriter_Failed(t *testing.T) { }) t.Run("Test FloatVector", func(t *testing.T) { - w, err := NewPayloadWriter(schemapb.DataType_FloatVector, false, 8) + w, err := NewPayloadWriter(schemapb.DataType_FloatVector, WithDim(8)) require.Nil(t, err) require.NotNil(t, w) @@ -292,7 +292,7 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.AddFloatToPayload(data, nil) require.Error(t, err) - w, err = NewPayloadWriter(schemapb.DataType_Int64, false) + w, err = NewPayloadWriter(schemapb.DataType_Int64) require.Nil(t, err) require.NotNil(t, w) @@ -300,3 +300,43 @@ func TestPayloadWriter_Failed(t *testing.T) { require.Error(t, err) }) } + +func BenchmarkPayloadWriter(b *testing.B) { + size := 10 + b.Run("Test default encoding ", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + w, err := NewPayloadWriter(schemapb.DataType_Int64) + assert.NoError(b, err) + for j := 0; j < size; j++ { + err = w.AddInt64ToPayload([]int64{int64(j)}, nil) + assert.NoError(b, err) + } + + err = w.FinishPayloadWriter() + assert.NoError(b, err) + } + b.ReportAllocs() + }) + + b.Run("Test RLE", func(b *testing.B) { + b.ResetTimer() + props := parquet.NewWriterProperties( + parquet.WithCompression(compress.Codecs.Zstd), + parquet.WithCompressionLevel(3), + parquet.WithEncoding(parquet.Encodings.RLE), + ) + for i := 0; i < b.N; i++ { + w, err := NewPayloadWriter(schemapb.DataType_Int64, WithWriterProps(props)) + assert.NoError(b, err) + for j := 0; j < size; j++ { + err = w.AddInt64ToPayload([]int64{int64(j)}, nil) + assert.NoError(b, err) + } + + err = w.FinishPayloadWriter() + assert.NoError(b, err) + } + b.ReportAllocs() + }) +} diff --git a/internal/storage/print_binlog_test.go b/internal/storage/print_binlog_test.go index 0409430b32e85..dc0bee9779cdd 100644 --- a/internal/storage/print_binlog_test.go +++ b/internal/storage/print_binlog_test.go @@ -40,7 +40,7 @@ func TestPrintBinlogFilesInt64(t *testing.T) { curTS := time.Now().UnixNano() / int64(time.Millisecond) - e1, err := w.NextInsertEventWriter(false) + e1, err := w.NextInsertEventWriter() assert.NoError(t, err) err = e1.AddDataToPayload([]int64{1, 2, 3}, nil) assert.NoError(t, err) @@ -50,7 +50,7 @@ func TestPrintBinlogFilesInt64(t *testing.T) { assert.NoError(t, err) e1.SetEventTimestamp(tsoutil.ComposeTS(curTS+10*60*1000, 0), tsoutil.ComposeTS(curTS+20*60*1000, 0)) - e2, err := w.NextInsertEventWriter(false) + e2, err := w.NextInsertEventWriter() assert.NoError(t, err) err = e2.AddDataToPayload([]int64{7, 8, 9}, nil) assert.NoError(t, err) diff --git a/internal/storage/serde.go b/internal/storage/serde.go index a20b3aaba6861..3e20eb4fe95ac 100644 --- a/internal/storage/serde.go +++ b/internal/storage/serde.go @@ -103,6 +103,10 @@ type serdeEntry struct { serialize func(array.Builder, any) bool // sizeof returns the size in bytes of the value sizeof func(any) uint64 + // fallbackEncoding returns the fallback encode method if parquet + // dictionary encoding is disabled, or it fallbacks if the dictionary + // grew too large. + fallbackEncoding func() parquet.Encoding } var serdeMap = func() map[schemapb.DataType]serdeEntry { @@ -136,6 +140,9 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry { func(any) uint64 { return 1 }, + func() parquet.Encoding { + return parquet.Encodings.RLE + }, } m[schemapb.DataType_Int8] = serdeEntry{ func(i int) arrow.DataType { @@ -166,6 +173,9 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry { func(any) uint64 { return 1 }, + func() parquet.Encoding { + return parquet.Encodings.DeltaBinaryPacked + }, } m[schemapb.DataType_Int16] = serdeEntry{ func(i int) arrow.DataType { @@ -196,6 +206,9 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry { func(any) uint64 { return 2 }, + func() parquet.Encoding { + return parquet.Encodings.DeltaBinaryPacked + }, } m[schemapb.DataType_Int32] = serdeEntry{ func(i int) arrow.DataType { @@ -226,6 +239,9 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry { func(any) uint64 { return 4 }, + func() parquet.Encoding { + return parquet.Encodings.DeltaBinaryPacked + }, } m[schemapb.DataType_Int64] = serdeEntry{ func(i int) arrow.DataType { @@ -256,6 +272,9 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry { func(any) uint64 { return 8 }, + func() parquet.Encoding { + return parquet.Encodings.DeltaBinaryPacked + }, } m[schemapb.DataType_Float] = serdeEntry{ func(i int) arrow.DataType { @@ -286,6 +305,9 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry { func(any) uint64 { return 4 }, + func() parquet.Encoding { + return parquet.Encodings.Plain + }, } m[schemapb.DataType_Double] = serdeEntry{ func(i int) arrow.DataType { @@ -316,6 +338,9 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry { func(any) uint64 { return 8 }, + func() parquet.Encoding { + return parquet.Encodings.Plain + }, } stringEntry := serdeEntry{ func(i int) arrow.DataType { @@ -349,6 +374,9 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry { } return uint64(len(v.(string))) }, + func() parquet.Encoding { + return parquet.Encodings.Plain + }, } m[schemapb.DataType_VarChar] = stringEntry @@ -390,6 +418,9 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry { } return uint64(v.(*schemapb.ScalarField).XXX_Size()) }, + func() parquet.Encoding { + return parquet.Encodings.Plain + }, } sizeOfBytes := func(v any) uint64 { @@ -426,6 +457,9 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry { return false }, sizeOfBytes, + func() parquet.Encoding { + return parquet.Encodings.Plain + }, } m[schemapb.DataType_JSON] = byteEntry @@ -460,6 +494,9 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry { fixedSizeDeserializer, fixedSizeSerializer, sizeOfBytes, + func() parquet.Encoding { + return parquet.Encodings.Plain + }, } m[schemapb.DataType_Float16Vector] = serdeEntry{ func(i int) arrow.DataType { @@ -468,6 +505,9 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry { fixedSizeDeserializer, fixedSizeSerializer, sizeOfBytes, + func() parquet.Encoding { + return parquet.Encodings.Plain + }, } m[schemapb.DataType_BFloat16Vector] = serdeEntry{ func(i int) arrow.DataType { @@ -476,6 +516,9 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry { fixedSizeDeserializer, fixedSizeSerializer, sizeOfBytes, + func() parquet.Encoding { + return parquet.Encodings.Plain + }, } m[schemapb.DataType_FloatVector] = serdeEntry{ func(i int) arrow.DataType { @@ -516,11 +559,22 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry { } return uint64(len(v.([]float32)) * 4) }, + func() parquet.Encoding { + return parquet.Encodings.Plain + }, } m[schemapb.DataType_SparseFloatVector] = byteEntry return m }() +func GetWriterPropsByDataType(dt schemapb.DataType) *parquet.WriterProperties { + return parquet.NewWriterProperties( + parquet.WithCompression(compress.Codecs.Zstd), + parquet.WithCompressionLevel(3), + parquet.WithEncoding(serdeMap[dt].fallbackEncoding()), + ) +} + type DeserializeReader[T any] struct { rr RecordReader deserializer Deserializer[T] @@ -654,12 +708,21 @@ func newCompositeRecordWriter(writers map[FieldID]RecordWriter) *compositeRecord var _ RecordWriter = (*singleFieldRecordWriter)(nil) +type RecordWriterOptions func(*singleFieldRecordWriter) + +func WithRecordWriterProps(writerProps *parquet.WriterProperties) RecordWriterOptions { + return func(w *singleFieldRecordWriter) { + w.writerProps = writerProps + } +} + type singleFieldRecordWriter struct { fw *pqarrow.FileWriter fieldId FieldID schema *arrow.Schema - numRows int + numRows int + writerProps *parquet.WriterProperties } func (sfw *singleFieldRecordWriter) Write(r Record) error { @@ -674,23 +737,24 @@ func (sfw *singleFieldRecordWriter) Close() { sfw.fw.Close() } -func newSingleFieldRecordWriter(fieldId FieldID, field arrow.Field, writer io.Writer) (*singleFieldRecordWriter, error) { - schema := arrow.NewSchema([]arrow.Field{field}, nil) - - // use writer properties as same as payload writer's for now - fw, err := pqarrow.NewFileWriter(schema, writer, - parquet.NewWriterProperties( +func newSingleFieldRecordWriter(fieldId FieldID, field arrow.Field, writer io.Writer, options ...RecordWriterOptions) (*singleFieldRecordWriter, error) { + w := &singleFieldRecordWriter{ + fieldId: fieldId, + schema: arrow.NewSchema([]arrow.Field{field}, nil), + writerProps: parquet.NewWriterProperties( + parquet.WithMaxRowGroupLength(math.MaxInt64), // No additional grouping for now. parquet.WithCompression(compress.Codecs.Zstd), parquet.WithCompressionLevel(3)), - pqarrow.DefaultWriterProps()) + } + for _, o := range options { + o(w) + } + fw, err := pqarrow.NewFileWriter(w.schema, writer, w.writerProps, pqarrow.DefaultWriterProps()) if err != nil { return nil, err } - return &singleFieldRecordWriter{ - fw: fw, - fieldId: fieldId, - schema: schema, - }, nil + w.fw = fw + return w, nil } var _ RecordWriter = (*multiFieldRecordWriter)(nil) diff --git a/internal/storage/serde_events.go b/internal/storage/serde_events.go index 1b4911f987e25..00be4cac473de 100644 --- a/internal/storage/serde_events.go +++ b/internal/storage/serde_events.go @@ -278,7 +278,8 @@ func (bsw *BinlogStreamWriter) GetRecordWriter() (RecordWriter, error) { Name: strconv.Itoa(int(fid)), Type: serdeMap[bsw.fieldSchema.DataType].arrowType(int(dim)), Nullable: true, // No nullable check here. - }, &bsw.buf) + }, &bsw.buf, + WithRecordWriterProps(GetWriterPropsByDataType(bsw.fieldSchema.DataType))) if err != nil { return nil, err } @@ -430,7 +431,8 @@ func (dsw *DeltalogStreamWriter) GetRecordWriter() (RecordWriter, error) { Name: dsw.fieldSchema.Name, Type: serdeMap[dsw.fieldSchema.DataType].arrowType(int(dim)), Nullable: false, - }, &dsw.buf) + }, &dsw.buf, + WithRecordWriterProps(GetWriterPropsByDataType(schemapb.DataType_String))) if err != nil { return nil, err } diff --git a/internal/storage/serde_events_test.go b/internal/storage/serde_events_test.go index 4e5733b364bc5..4af7627d1aca0 100644 --- a/internal/storage/serde_events_test.go +++ b/internal/storage/serde_events_test.go @@ -20,12 +20,15 @@ import ( "bytes" "context" "io" + "math/rand" "strconv" "testing" "github.com/apache/arrow/go/v12/arrow" "github.com/apache/arrow/go/v12/arrow/array" "github.com/apache/arrow/go/v12/arrow/memory" + "github.com/apache/arrow/go/v12/parquet" + "github.com/apache/arrow/go/v12/parquet/compress" "github.com/apache/arrow/go/v12/parquet/file" "github.com/apache/arrow/go/v12/parquet/pqarrow" "github.com/stretchr/testify/assert" @@ -141,6 +144,11 @@ func TestBinlogSerializeWriter(t *testing.T) { assert.NoError(t, err) } + for _, f := range schema.Fields { + expected := serdeMap[f.DataType].fallbackEncoding() + assert.Equal(t, expected, writers[f.FieldID].rw.writerProps.Encoding()) + } + err = reader.Next() assert.Equal(t, io.EOF, err) err = writer.Close() @@ -386,6 +394,44 @@ func TestDeltalogPkTsSeparateFormat(t *testing.T) { } } +func BenchmarkEncodingCompression(b *testing.B) { + schema := arrow.NewSchema([]arrow.Field{ + {Name: "int32", Type: arrow.PrimitiveTypes.Int32}, + {Name: "int64", Type: arrow.PrimitiveTypes.Int64}, + {Name: "string", Type: arrow.BinaryTypes.String}, + }, nil) + + record := createRecordBatch(schema, 1000000) + + encodings := []parquet.Encoding{ + parquet.Encodings.Plain, + parquet.Encodings.RLE, + parquet.Encodings.DeltaBinaryPacked, + } + + for _, encoding := range encodings { + name := encoding.String() + props := NewWriterProperties(encoding, compress.Codecs.Zstd) + b.Run(name, func(b *testing.B) { + benchmarkWrite(b, props, schema, record) + }) + } +} + +func NewWriterProperties(encoding parquet.Encoding, codec compress.Compression) *parquet.WriterProperties { + if encoding == parquet.Encodings.Plain { + return parquet.NewWriterProperties( + parquet.WithEncoding(encoding), + parquet.WithCompression(codec), + parquet.WithDictionaryDefault(true), + ) + } + return parquet.NewWriterProperties( + parquet.WithEncoding(encoding), + parquet.WithCompression(codec), + ) +} + func BenchmarkDeltalogReader(b *testing.B) { size := 1000000 blob, err := generateTestDeltalogData(size) @@ -483,3 +529,40 @@ func readDeltaLog(size int, blob *Blob) error { } return nil } + +func benchmarkWrite(b *testing.B, writerProps *parquet.WriterProperties, schema *arrow.Schema, record arrow.Record) { + fw, err := pqarrow.NewFileWriter(schema, &bytes.Buffer{}, writerProps, pqarrow. + DefaultWriterProps()) + assert.Nil(b, err) + defer fw.Close() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := fw.WriteBuffered(record); err != nil { + b.Fatal(err) + } + } + b.StopTimer() +} + +func createRecordBatch(schema *arrow.Schema, num int) arrow.Record { + pool := memory.NewGoAllocator() + b := array.NewRecordBuilder(pool, schema) + defer b.Release() + + for i := 0; i < num; i++ { + b.Field(0).(*array.Int32Builder).Append(rand.Int31()) + b.Field(1).(*array.Int64Builder).Append(rand.Int63()) + b.Field(2).(*array.StringBuilder).Append(randString(10)) + } + + return b.NewRecord() +} + +func randString(n int) string { + const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + bytes := make([]byte, n) + for i := range bytes { + bytes[i] = letters[rand.Intn(len(letters))] + } + return string(bytes) +} diff --git a/internal/storage/utils.go b/internal/storage/utils.go index c8b16328f9d11..a9d616f198d51 100644 --- a/internal/storage/utils.go +++ b/internal/storage/utils.go @@ -1303,3 +1303,21 @@ func GetFilesSize(ctx context.Context, paths []string, cm ChunkManager) (int64, } return totalSize, nil } + +type NullableInt struct { + Value *int +} + +// NewNullableInt creates a new NullableInt instance +func NewNullableInt(value int) *NullableInt { + return &NullableInt{Value: &value} +} + +func (ni NullableInt) GetValue() int { + return *ni.Value +} + +// IsNull checks if the NullableInt is null +func (ni NullableInt) IsNull() bool { + return ni.Value == nil +} diff --git a/internal/util/importutilv2/binlog/reader_test.go b/internal/util/importutilv2/binlog/reader_test.go index a179374723732..f734786143927 100644 --- a/internal/util/importutilv2/binlog/reader_test.go +++ b/internal/util/importutilv2/binlog/reader_test.go @@ -81,7 +81,7 @@ func createBinlogBuf(t *testing.T, field *schemapb.FieldSchema, data storage.Fie dim = 1 } - evt, err := w.NextInsertEventWriter(false, int(dim)) + evt, err := w.NextInsertEventWriter(storage.WithDim(int(dim))) assert.NoError(t, err) evt.SetEventTimestamp(1, math.MaxInt64)