From 942ce208b7d09e5a7b3233824a747033b5516b17 Mon Sep 17 00:00:00 2001 From: MrPresent-Han Date: Sat, 12 Oct 2024 08:17:11 -0400 Subject: [PATCH] feat: support query aggregtion(#36380) Signed-off-by: MrPresent-Han --- internal/agg/aggregate.go | 813 ++++++++++++++++++ internal/agg/aggregate_test.go | 31 + internal/core/CMakeLists.txt | 6 + internal/core/src/CMakeLists.txt | 3 +- internal/core/src/common/BitUtil.h | 204 +++++ internal/core/src/common/ComplexVector.cpp | 49 ++ internal/core/src/common/FieldData.cpp | 67 ++ internal/core/src/common/FieldData.h | 5 + internal/core/src/common/FieldDataInterface.h | 3 +- internal/core/src/common/Schema.h | 22 +- internal/core/src/common/SimdUtil.h | 179 ++++ internal/core/src/common/Types.cpp | 60 ++ internal/core/src/common/Types.h | 138 ++- internal/core/src/common/Utils.h | 87 ++ internal/core/src/common/Vector.h | 90 +- internal/core/src/common/float_util_c.h | 38 + internal/core/src/exec/Driver.cpp | 33 +- internal/core/src/exec/Driver.h | 7 + internal/core/src/exec/HashTable.cpp | 368 ++++++++ internal/core/src/exec/HashTable.h | 340 ++++++++ internal/core/src/exec/VectorHasher.cpp | 99 +++ internal/core/src/exec/VectorHasher.h | 102 +++ internal/core/src/exec/expression/Utils.h | 11 + .../src/exec/operator/AggregationNode.cpp | 83 ++ .../core/src/exec/operator/AggregationNode.h | 89 ++ .../core/src/exec/operator/CallbackSink.h | 4 +- internal/core/src/exec/operator/CountNode.h | 2 +- .../core/src/exec/operator/FilterBitsNode.h | 2 +- .../src/exec/operator/IterativeFilterNode.h | 2 +- internal/core/src/exec/operator/MvccNode.h | 2 +- internal/core/src/exec/operator/Operator.cpp | 7 +- internal/core/src/exec/operator/Operator.h | 25 +- .../core/src/exec/operator/OperatorUtils.cpp | 21 + .../core/src/exec/operator/OperatorUtils.h | 23 + .../core/src/exec/operator/ProjectNode.cpp | 77 ++ internal/core/src/exec/operator/ProjectNode.h | 73 ++ ...{GroupByNode.cpp => SearchGroupByNode.cpp} | 14 +- .../{GroupByNode.h => SearchGroupByNode.h} | 13 +- .../core/src/exec/operator/VectorSearchNode.h | 2 +- internal/core/src/exec/operator/init_c.cpp | 24 + internal/core/src/exec/operator/init_c.h | 27 + .../src/exec/operator/query-agg/Aggregate.cpp | 88 ++ .../src/exec/operator/query-agg/Aggregate.h | 210 +++++ .../exec/operator/query-agg/AggregateInfo.cpp | 60 ++ .../exec/operator/query-agg/AggregateInfo.h | 41 + .../exec/operator/query-agg/AggregateUtil.h | 41 + .../operator/query-agg/CountAggregateBase.cpp | 42 + .../operator/query-agg/CountAggregateBase.h | 121 +++ .../exec/operator/query-agg/GroupingSet.cpp | 263 ++++++ .../src/exec/operator/query-agg/GroupingSet.h | 109 +++ .../query-agg/RegisterAggregateFunctions.cpp | 29 + .../query-agg/RegisterAggregateFunctions.h | 32 + .../exec/operator/query-agg/RowContainer.cpp | 180 ++++ .../exec/operator/query-agg/RowContainer.h | 519 +++++++++++ .../query-agg/SimpleNumericAggregate.h | 128 +++ .../exec/operator/query-agg/SumAggregate.cpp | 89 ++ .../operator/query-agg/SumAggregateBase.h | 107 +++ .../SearchGroupByOperator.cpp | 0 .../SearchGroupByOperator.h | 0 internal/core/src/expr/FunctionSignature.h | 109 +++ internal/core/src/expr/ITypeExpr.h | 28 +- internal/core/src/plan/CMakeLists.txt | 13 + internal/core/src/plan/PlanNode.cpp | 61 ++ internal/core/src/plan/PlanNode.h | 180 +++- .../core/src/query/ExecPlanNodeVisitor.cpp | 203 ++++- internal/core/src/query/ExecPlanNodeVisitor.h | 12 + internal/core/src/query/PlanProto.cpp | 141 ++- .../src/segcore/ChunkedSegmentSealedImpl.cpp | 94 ++ .../src/segcore/ChunkedSegmentSealedImpl.h | 15 + .../core/src/segcore/SegmentGrowingImpl.cpp | 84 ++ .../core/src/segcore/SegmentGrowingImpl.h | 15 + .../core/src/segcore/SegmentInterface.cpp | 35 +- internal/core/src/segcore/SegmentInterface.h | 14 + .../core/src/segcore/SegmentSealedImpl.cpp | 96 +++ internal/core/src/segcore/SegmentSealedImpl.h | 15 + internal/core/src/segcore/Utils.cpp | 119 ++- internal/core/src/segcore/Utils.h | 22 + internal/core/unittest/CMakeLists.txt | 5 +- internal/core/unittest/test_exec.cpp | 2 - .../core/unittest/test_query_group_by.cpp | 579 +++++++++++++ ..._group_by.cpp => test_search_group_by.cpp} | 0 internal/core/unittest/test_utils/DataGen.h | 9 +- internal/core/virtualenv | 164 ++++ internal/proto/internal.proto | 7 +- internal/proto/plan.proto | 16 + internal/proxy/agg_reducer.go | 38 + internal/proxy/agg_reducer_test.go | 538 ++++++++++++ internal/proxy/count_reducer.go | 4 +- internal/proxy/count_reducer_test.go | 3 +- internal/proxy/reducer.go | 16 +- internal/proxy/reducer_test.go | 8 +- internal/proxy/task.go | 3 + internal/proxy/task_query.go | 126 ++- internal/proxy/task_query_test.go | 8 +- internal/proxy/task_search.go | 2 +- internal/proxy/task_test.go | 54 +- internal/proxy/util.go | 45 +- internal/querynodev2/segments/agg_reducer.go | 51 ++ .../querynodev2/segments/agg_reducer_test.go | 535 ++++++++++++ .../querynodev2/segments/count_reducer.go | 4 +- .../segments/count_reducer_test.go | 4 +- internal/querynodev2/segments/reducer.go | 9 +- internal/querynodev2/segments/reducer_test.go | 4 +- internal/querynodev2/segments/segment.go | 2 +- internal/querynodev2/server.go | 3 + internal/querynodev2/tasks/search_task.go | 1 + internal/util/typeutil/hash.go | 38 + 107 files changed, 8637 insertions(+), 201 deletions(-) create mode 100644 internal/agg/aggregate.go create mode 100644 internal/agg/aggregate_test.go create mode 100644 internal/core/src/common/BitUtil.h create mode 100644 internal/core/src/common/ComplexVector.cpp create mode 100644 internal/core/src/common/SimdUtil.h create mode 100644 internal/core/src/common/Types.cpp create mode 100644 internal/core/src/common/float_util_c.h create mode 100644 internal/core/src/exec/HashTable.cpp create mode 100644 internal/core/src/exec/HashTable.h create mode 100644 internal/core/src/exec/VectorHasher.cpp create mode 100644 internal/core/src/exec/VectorHasher.h create mode 100644 internal/core/src/exec/operator/AggregationNode.cpp create mode 100644 internal/core/src/exec/operator/AggregationNode.h create mode 100644 internal/core/src/exec/operator/OperatorUtils.cpp create mode 100644 internal/core/src/exec/operator/OperatorUtils.h create mode 100644 internal/core/src/exec/operator/ProjectNode.cpp create mode 100644 internal/core/src/exec/operator/ProjectNode.h rename internal/core/src/exec/operator/{GroupByNode.cpp => SearchGroupByNode.cpp} (91%) rename internal/core/src/exec/operator/{GroupByNode.h => SearchGroupByNode.h} (86%) create mode 100644 internal/core/src/exec/operator/init_c.cpp create mode 100644 internal/core/src/exec/operator/init_c.h create mode 100644 internal/core/src/exec/operator/query-agg/Aggregate.cpp create mode 100644 internal/core/src/exec/operator/query-agg/Aggregate.h create mode 100644 internal/core/src/exec/operator/query-agg/AggregateInfo.cpp create mode 100644 internal/core/src/exec/operator/query-agg/AggregateInfo.h create mode 100644 internal/core/src/exec/operator/query-agg/AggregateUtil.h create mode 100644 internal/core/src/exec/operator/query-agg/CountAggregateBase.cpp create mode 100644 internal/core/src/exec/operator/query-agg/CountAggregateBase.h create mode 100644 internal/core/src/exec/operator/query-agg/GroupingSet.cpp create mode 100644 internal/core/src/exec/operator/query-agg/GroupingSet.h create mode 100644 internal/core/src/exec/operator/query-agg/RegisterAggregateFunctions.cpp create mode 100644 internal/core/src/exec/operator/query-agg/RegisterAggregateFunctions.h create mode 100644 internal/core/src/exec/operator/query-agg/RowContainer.cpp create mode 100644 internal/core/src/exec/operator/query-agg/RowContainer.h create mode 100644 internal/core/src/exec/operator/query-agg/SimpleNumericAggregate.h create mode 100644 internal/core/src/exec/operator/query-agg/SumAggregate.cpp create mode 100644 internal/core/src/exec/operator/query-agg/SumAggregateBase.h rename internal/core/src/exec/operator/{groupby => search-groupby}/SearchGroupByOperator.cpp (100%) rename internal/core/src/exec/operator/{groupby => search-groupby}/SearchGroupByOperator.h (100%) create mode 100644 internal/core/src/expr/FunctionSignature.h create mode 100644 internal/core/src/plan/CMakeLists.txt create mode 100644 internal/core/src/plan/PlanNode.cpp create mode 100644 internal/core/unittest/test_query_group_by.cpp rename internal/core/unittest/{test_group_by.cpp => test_search_group_by.cpp} (100%) create mode 100644 internal/core/virtualenv create mode 100644 internal/proxy/agg_reducer.go create mode 100644 internal/proxy/agg_reducer_test.go create mode 100644 internal/querynodev2/segments/agg_reducer.go create mode 100644 internal/querynodev2/segments/agg_reducer_test.go diff --git a/internal/agg/aggregate.go b/internal/agg/aggregate.go new file mode 100644 index 0000000000000..638b42a76efe0 --- /dev/null +++ b/internal/agg/aggregate.go @@ -0,0 +1,813 @@ +package agg + +import ( + "context" + "encoding/binary" + "fmt" + "hash" + "hash/fnv" + "math" + "regexp" + "strings" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/planpb" + "github.com/milvus-io/milvus/internal/proto/segcorepb" + typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +const ( + kSum = "sum" + kCount = "count" + kAvg = "avg" + kMin = "min" + kMax = "max" +) + +var ( + // Define the regular expression pattern once to avoid repeated concatenation. + aggregationTypes = kSum + `|` + kCount + `|` + kAvg + `|` + kMin + `|` + kMax + aggregationPattern = regexp.MustCompile(`(?i)^(` + aggregationTypes + `)\s*\(\s*([\w\*]*)\s*\)$`) +) + +// MatchAggregationExpression return isAgg, operator name, operator parameter +func MatchAggregationExpression(expression string) (bool, string, string) { + // FindStringSubmatch returns the full match and submatches. + matches := aggregationPattern.FindStringSubmatch(expression) + if len(matches) > 0 { + // Return true, the operator, and the captured parameter. + return true, strings.ToLower(matches[1]), strings.TrimSpace(matches[2]) + } + return false, "", "" +} + +type AggregateBase interface { + Name() string + Update(target *Entry, new *Entry) error + ToPB() *planpb.Aggregate + FieldID() int64 + OriginalName() string +} + +func NewAggregate(aggregateName string, aggFieldID int64, originalName string) (AggregateBase, error) { + switch aggregateName { + case kCount: + return &CountAggregate{fieldID: aggFieldID, originalName: originalName}, nil + case kSum: + return &SumAggregate{fieldID: aggFieldID, originalName: originalName}, nil + case kMin: + return &MinAggregate{fieldID: aggFieldID, originalName: originalName}, nil + case kMax: + return &MaxAggregate{fieldID: aggFieldID, originalName: originalName}, nil + default: + return nil, fmt.Errorf("invalid Aggregation operator %s", aggregateName) + } +} + +func FromPB(pb *planpb.Aggregate) (AggregateBase, error) { + switch pb.Op { + case planpb.AggregateOp_count: + return &CountAggregate{fieldID: pb.GetFieldId()}, nil + case planpb.AggregateOp_sum: + return &SumAggregate{fieldID: pb.GetFieldId()}, nil + case planpb.AggregateOp_min: + return &MinAggregate{fieldID: pb.GetFieldId()}, nil + case planpb.AggregateOp_max: + return &MaxAggregate{fieldID: pb.GetFieldId()}, nil + default: + return nil, fmt.Errorf("invalid Aggregation operator %d", pb.Op) + } +} + +func AccumulateEntryVal(target *Entry, new *Entry) error { + if target == nil || new == nil { + return fmt.Errorf("target or new entry is nil") + } + + // Handle nil `val` for initialization + if target.val == nil { + target.val = new.val + return nil + } + // ensure the value type outside + switch target.val.(type) { + case int: + target.val = target.val.(int) + new.val.(int) + case int32: + target.val = target.val.(int32) + new.val.(int32) + case int64: + target.val = target.val.(int64) + new.val.(int64) + case float32: + target.val = target.val.(float32) + new.val.(float32) + case float64: + target.val = target.val.(float64) + new.val.(float64) + default: + return fmt.Errorf("unsupported type: %T", target.val) + } + return nil +} + +type SumAggregate struct { + fieldID int64 + originalName string +} + +func (sum *SumAggregate) Name() string { + return kSum +} + +func (sum *SumAggregate) Update(target *Entry, new *Entry) error { + return AccumulateEntryVal(target, new) +} + +func (sum *SumAggregate) ToPB() *planpb.Aggregate { + return &planpb.Aggregate{Op: planpb.AggregateOp_sum, FieldId: sum.FieldID()} +} + +func (sum *SumAggregate) FieldID() int64 { + return sum.fieldID +} + +func (sum *SumAggregate) OriginalName() string { + return sum.originalName +} + +type CountAggregate struct { + fieldID int64 + originalName string +} + +func (count *CountAggregate) Name() string { + return kCount +} + +func (count *CountAggregate) Update(target *Entry, new *Entry) error { + return AccumulateEntryVal(target, new) +} + +func (count *CountAggregate) ToPB() *planpb.Aggregate { + return &planpb.Aggregate{Op: planpb.AggregateOp_count, FieldId: count.FieldID()} +} + +func (count *CountAggregate) FieldID() int64 { + return count.fieldID +} + +func (count *CountAggregate) OriginalName() string { + return count.originalName +} + +type MinAggregate struct { + fieldID int64 + originalName string +} + +func (min *MinAggregate) Name() string { + return kMin +} + +func (min *MinAggregate) Update(target *Entry, new *Entry) error { + return nil +} + +func (min *MinAggregate) ToPB() *planpb.Aggregate { + return &planpb.Aggregate{Op: planpb.AggregateOp_min, FieldId: min.FieldID()} +} + +func (min *MinAggregate) FieldID() int64 { + return min.fieldID +} + +func (min *MinAggregate) OriginalName() string { + return min.originalName +} + +type MaxAggregate struct { + fieldID int64 + originalName string +} + +func (max *MaxAggregate) Name() string { + return kMax +} + +func (max *MaxAggregate) Update(target *Entry, new *Entry) error { + return nil +} + +func (max *MaxAggregate) ToPB() *planpb.Aggregate { + return &planpb.Aggregate{Op: planpb.AggregateOp_max, FieldId: max.FieldID()} +} + +func (max *MaxAggregate) FieldID() int64 { + return max.fieldID +} + +func (max *MaxAggregate) OriginalName() string { + return max.originalName +} + +func AggregatesToPB(aggregates []AggregateBase) []*planpb.Aggregate { + ret := make([]*planpb.Aggregate, len(aggregates)) + for idx, agg := range aggregates { + ret[idx] = agg.ToPB() + } + return ret +} + +type Entry struct { + val interface{} +} + +func NewEntry(v interface{}) *Entry { + return &Entry{val: v} +} + +type Row struct { + entries []*Entry +} + +func (r *Row) Count() int { + return len(r.entries) +} + +func (r *Row) ValAt(col int) interface{} { + return r.entries[col].val +} + +func (r *Row) Equal(other *Row, keyCount int) bool { + // Check if the number of entries is the same + if len(r.entries) != len(other.entries) { + return false + } + // Compare each entry for equality + for i := 0; i < keyCount; i++ { + if r.entries[i].val != other.entries[i].val { + return false + } + } + return true +} + +func (r *Row) UpdateEntry(newRow *Row, col int, agg AggregateBase) { + agg.Update(r.entries[col], newRow.entries[col]) +} + +func (r *Row) ToString() string { + var builder strings.Builder + builder.WriteString("agg-row:") + for _, entry := range r.entries { + builder.WriteString(fmt.Sprintf("%v,", entry.val)) + } + return builder.String() +} + +func NewRow(entries []*Entry) *Row { + return &Row{entries: entries} +} + +type Bucket struct { + rows []*Row +} + +func (bucket *Bucket) AddRow(row *Row) { + bucket.rows = append(bucket.rows, row) +} + +func (bucket *Bucket) RowAt(idx int) *Row { + return bucket.rows[idx] +} + +func (bucket *Bucket) RowCount() int { + return len(bucket.rows) +} + +func (bucket *Bucket) Accumulate(row *Row, idx int, keyCount int, aggs []AggregateBase) error { + if idx >= len(bucket.rows) || idx < 0 { + return fmt.Errorf("wrong idx:%d for bucket", idx) + } + targetRow := bucket.rows[idx] + if targetRow == nil { + return fmt.Errorf("nil row at the target idx:%d, cannot accumulate the row", idx) + } + if row.Count() != targetRow.Count() { + return fmt.Errorf("column count:%d in the row must be equal to the target row:%d", row.Count(), bucket.rows[idx].Count()) + } + if row.Count() != keyCount+len(aggs) { + return fmt.Errorf("column count:%d in the row must be sum of keyCount:%d and the number of aggs:%d", row.Count(), keyCount, len(aggs)) + } + for col := keyCount; col < row.Count(); col++ { + targetRow.UpdateEntry(row, col, aggs[col-keyCount]) + } + return nil +} + +const NONE int = -1 + +func (bucket *Bucket) Find(row *Row, keyCount int) int { + for idx, existingRow := range bucket.rows { + if existingRow.Equal(row, keyCount) { + return idx + } + } + return NONE +} + +func NewBucket() *Bucket { + return &Bucket{rows: make([]*Row, 0, 1)} +} + +func NewFieldAccessor(fieldType schemapb.DataType) (FieldAccessor, error) { + switch fieldType { + case schemapb.DataType_Bool: + return newBoolFieldAccessor(), nil + case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32: + return newInt32FieldAccessor(), nil + case schemapb.DataType_Int64: + return newInt64FieldAccessor(), nil + case schemapb.DataType_VarChar, schemapb.DataType_String: + return newStringFieldAccessor(), nil + case schemapb.DataType_Float: + return newFloat32FieldAccessor(), nil + case schemapb.DataType_Double: + return newFloat64FieldAccessor(), nil + default: + return nil, fmt.Errorf("unsupported data type for hasher") + } +} + +type FieldAccessor interface { + Hash(idx int) uint64 + ValAt(idx int) interface{} + SetVals(fieldData *schemapb.FieldData) + RowCount() int +} + +type Int32FieldAccessor struct { + vals []int32 + hasher hash.Hash64 + buffer []byte +} + +func (i32Field *Int32FieldAccessor) Hash(idx int) uint64 { + i32Field.hasher.Reset() + val := i32Field.vals[idx] + binary.LittleEndian.PutUint32(i32Field.buffer, uint32(val)) + i32Field.hasher.Write(i32Field.buffer) + return i32Field.hasher.Sum64() +} + +func (i32Field *Int32FieldAccessor) SetVals(fieldData *schemapb.FieldData) { + i32Field.vals = fieldData.GetScalars().GetIntData().GetData() +} + +func (i32Field *Int32FieldAccessor) RowCount() int { + return len(i32Field.vals) +} + +func (i32Field *Int32FieldAccessor) ValAt(idx int) interface{} { + return i32Field.vals[idx] +} + +func newInt32FieldAccessor() FieldAccessor { + return &Int32FieldAccessor{hasher: fnv.New64a(), buffer: make([]byte, 4)} +} + +type Int64FieldAccessor struct { + vals []int64 + hasher hash.Hash64 + buffer []byte +} + +func (i64Field *Int64FieldAccessor) Hash(idx int) uint64 { + i64Field.hasher.Reset() + val := i64Field.vals[idx] + binary.LittleEndian.PutUint64(i64Field.buffer, uint64(val)) + i64Field.hasher.Write(i64Field.buffer) + return i64Field.hasher.Sum64() +} + +func (i64Field *Int64FieldAccessor) SetVals(fieldData *schemapb.FieldData) { + i64Field.vals = fieldData.GetScalars().GetLongData().GetData() +} + +func (i64Field *Int64FieldAccessor) RowCount() int { + return len(i64Field.vals) +} + +func (i64Field *Int64FieldAccessor) ValAt(idx int) interface{} { + return i64Field.vals[idx] +} + +func newInt64FieldAccessor() FieldAccessor { + return &Int64FieldAccessor{hasher: fnv.New64a(), buffer: make([]byte, 8)} +} + +// BoolFieldAccessor +type BoolFieldAccessor struct { + vals []bool + hasher hash.Hash64 + buffer []byte +} + +func (boolField *BoolFieldAccessor) Hash(idx int) uint64 { + boolField.hasher.Reset() + val := boolField.vals[idx] + if val { + boolField.buffer[0] = 1 + } else { + boolField.buffer[0] = 0 + } + boolField.hasher.Write(boolField.buffer[:1]) + return boolField.hasher.Sum64() +} + +func (boolField *BoolFieldAccessor) SetVals(fieldData *schemapb.FieldData) { + boolField.vals = fieldData.GetScalars().GetBoolData().GetData() +} + +func (boolField *BoolFieldAccessor) RowCount() int { + return len(boolField.vals) +} + +func (boolField *BoolFieldAccessor) ValAt(idx int) interface{} { + return boolField.vals[idx] +} + +func newBoolFieldAccessor() FieldAccessor { + return &BoolFieldAccessor{hasher: fnv.New64a(), buffer: make([]byte, 1)} +} + +// Float32FieldAccessor +type Float32FieldAccessor struct { + vals []float32 + hasher hash.Hash64 + buffer []byte +} + +func (f32FieldAccessor *Float32FieldAccessor) Hash(idx int) uint64 { + f32FieldAccessor.hasher.Reset() + val := f32FieldAccessor.vals[idx] + binary.LittleEndian.PutUint32(f32FieldAccessor.buffer, math.Float32bits(val)) + f32FieldAccessor.hasher.Write(f32FieldAccessor.buffer[:4]) + return f32FieldAccessor.hasher.Sum64() +} + +func (f32FieldAccessor *Float32FieldAccessor) SetVals(fieldData *schemapb.FieldData) { + f32FieldAccessor.vals = fieldData.GetScalars().GetFloatData().GetData() +} + +func (f32FieldAccessor *Float32FieldAccessor) RowCount() int { + return len(f32FieldAccessor.vals) +} + +func (f32FieldAccessor *Float32FieldAccessor) ValAt(idx int) interface{} { + return f32FieldAccessor.vals[idx] +} + +func newFloat32FieldAccessor() FieldAccessor { + return &Float32FieldAccessor{hasher: fnv.New64a(), buffer: make([]byte, 4)} +} + +// Float64FieldAccessor +type Float64FieldAccessor struct { + vals []float64 + hasher hash.Hash64 + buffer []byte +} + +func (f64Field *Float64FieldAccessor) Hash(idx int) uint64 { + f64Field.hasher.Reset() + val := f64Field.vals[idx] + binary.LittleEndian.PutUint64(f64Field.buffer, math.Float64bits(val)) + f64Field.hasher.Write(f64Field.buffer) + return f64Field.hasher.Sum64() +} + +func (f64Field *Float64FieldAccessor) SetVals(fieldData *schemapb.FieldData) { + f64Field.vals = fieldData.GetScalars().GetDoubleData().GetData() +} + +func (f64Field *Float64FieldAccessor) RowCount() int { + return len(f64Field.vals) +} + +func (f64Field *Float64FieldAccessor) ValAt(idx int) interface{} { + return f64Field.vals[idx] +} + +func newFloat64FieldAccessor() FieldAccessor { + return &Float64FieldAccessor{hasher: fnv.New64a(), buffer: make([]byte, 8)} +} + +// StringFieldAccessor +type StringFieldAccessor struct { + vals []string + hasher hash.Hash64 + buffer []byte +} + +func (stringField *StringFieldAccessor) Hash(idx int) uint64 { + stringField.hasher.Reset() + val := stringField.vals[idx] + if len(val) > len(stringField.buffer) { + newSize := typeutil2.NextPowerOfTwo(len(val)) + stringField.buffer = make([]byte, newSize) + } + copy(stringField.buffer, val) + stringField.hasher.Write(stringField.buffer[0:len(val)]) + return stringField.hasher.Sum64() +} + +func (stringField *StringFieldAccessor) SetVals(fieldData *schemapb.FieldData) { + stringField.vals = fieldData.GetScalars().GetStringData().GetData() +} + +func (stringField *StringFieldAccessor) RowCount() int { + return len(stringField.vals) +} + +func (stringField *StringFieldAccessor) ValAt(idx int) interface{} { + return stringField.vals[idx] +} + +func newStringFieldAccessor() FieldAccessor { + return &StringFieldAccessor{hasher: fnv.New64a(), buffer: make([]byte, 1024)} +} + +func AssembleBucket(bucket *Bucket, fieldDatas []*schemapb.FieldData) error { + colCount := len(fieldDatas) + for r := 0; r < bucket.RowCount(); r++ { + row := bucket.RowAt(r) + AssembleSingleRow(colCount, row, fieldDatas) + } + return nil +} + +func AssembleSingleRow(colCount int, row *Row, fieldDatas []*schemapb.FieldData) error { + for c := 0; c < colCount; c++ { + err := AssembleSingleValue(row.ValAt(c), fieldDatas[c]) + if err != nil { + return err + } + } + return nil +} + +func AssembleSingleValue(val interface{}, fieldData *schemapb.FieldData) error { + switch fieldData.GetType() { + case schemapb.DataType_Bool: + fieldData.GetScalars().GetBoolData().Data = append(fieldData.GetScalars().GetBoolData().GetData(), val.(bool)) + case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32: + fieldData.GetScalars().GetIntData().Data = append(fieldData.GetScalars().GetIntData().GetData(), val.(int32)) + case schemapb.DataType_Int64: + fieldData.GetScalars().GetLongData().Data = append(fieldData.GetScalars().GetLongData().GetData(), val.(int64)) + case schemapb.DataType_Float: + fieldData.GetScalars().GetFloatData().Data = append(fieldData.GetScalars().GetFloatData().GetData(), val.(float32)) + case schemapb.DataType_Double: + fieldData.GetScalars().GetDoubleData().Data = append(fieldData.GetScalars().GetDoubleData().GetData(), val.(float64)) + case schemapb.DataType_VarChar, schemapb.DataType_String: + fieldData.GetScalars().GetStringData().Data = append(fieldData.GetScalars().GetStringData().GetData(), val.(string)) + default: + return fmt.Errorf("unsupported DataType:%d", fieldData.GetType()) + } + return nil +} + +type GroupAggReducer struct { + groupByFieldIds []int64 + aggregates []*planpb.Aggregate + hashValsMap map[uint64]*Bucket +} + +func NewGroupAggReducer(groupByFieldIds []int64, aggregates []*planpb.Aggregate) *GroupAggReducer { + return &GroupAggReducer{ + groupByFieldIds: groupByFieldIds, + aggregates: aggregates, + hashValsMap: make(map[uint64]*Bucket), // Initialize hashValsMap + } +} + +type AggregationResult struct { + fieldDatas []*schemapb.FieldData +} + +func NewAggregationResult(fieldDatas []*schemapb.FieldData) *AggregationResult { + return &AggregationResult{ + fieldDatas: fieldDatas, + } +} + +// GetFieldDatas returns the fieldDatas slice +func (ar *AggregationResult) GetFieldDatas() []*schemapb.FieldData { + return ar.fieldDatas +} + +func (reducer *GroupAggReducer) Reduce(ctx context.Context, results []*AggregationResult) (*AggregationResult, error) { + if len(results) == 0 { + return nil, fmt.Errorf("no input segment's retrieved results can be reduced") + } + // 0. set up aggregates + aggs := make([]AggregateBase, len(reducer.aggregates)) + for idx, aggPb := range reducer.aggregates { + agg, err := FromPB(aggPb) + if err != nil { + return nil, err + } + aggs[idx] = agg + } + + // 1. set up hashers and accumulators + numGroupingKeys := len(reducer.groupByFieldIds) + numAggs := len(reducer.aggregates) + hashers := make([]FieldAccessor, numGroupingKeys) + accumulators := make([]FieldAccessor, numAggs) + firstFieldData := results[0].GetFieldDatas() + outputColumnCount := len(firstFieldData) + for idx, fieldData := range firstFieldData { + if idx < numGroupingKeys { + hasher, err := NewFieldAccessor(fieldData.GetType()) + if err != nil { + return nil, err + } + hashers[idx] = hasher + } + if idx >= numGroupingKeys { + accumulator, err := NewFieldAccessor(fieldData.GetType()) + if err != nil { + return nil, err + } + accumulators[idx-numGroupingKeys] = accumulator + } + } + + isGlobal := numGroupingKeys == 0 + if isGlobal { + reducedResult := NewAggregationResult(nil) + reducedResult.fieldDatas = typeutil.PrepareResultFieldData(firstFieldData, 1) + rows := make([]*Row, len(results)) + for idx, result := range results { + entries := make([]*Entry, outputColumnCount) + for col := 0; col < outputColumnCount; col++ { + fieldData := result.GetFieldDatas()[col] + accumulators[col].SetVals(fieldData) + entries[col] = NewEntry(accumulators[col].ValAt(0)) + } + rows[idx] = NewRow(entries) + } + for r := 1; r < len(rows); r++ { + for c := 0; c < outputColumnCount; c++ { + rows[0].UpdateEntry(rows[r], c, aggs[c]) + } + } + AssembleSingleRow(outputColumnCount, rows[0], reducedResult.fieldDatas) + return reducedResult, nil + } + + // 2. compute hash values for all rows in the result retrieved + totalRowCount := 0 + for _, result := range results { + if result == nil { + return nil, fmt.Errorf("input result from any sources cannot be nil") + } + fieldDatas := result.GetFieldDatas() + if outputColumnCount != len(fieldDatas) { + return nil, fmt.Errorf("retrieved results from different segments have different size of columns") + } + if outputColumnCount == 0 { + return nil, fmt.Errorf("retrieved results have no column data") + } + rowCount := -1 + for i := 0; i < outputColumnCount; i++ { + fieldData := fieldDatas[i] + if i < numGroupingKeys { + hashers[i].SetVals(fieldData) + } else { + accumulators[i-numGroupingKeys].SetVals(fieldData) + } + if rowCount == -1 { + rowCount = hashers[i].RowCount() + } else if i < numGroupingKeys { + if rowCount != hashers[i].RowCount() { + return nil, fmt.Errorf("field data:%d for different columns have different row count, %d vs %d, wrong state", + i, rowCount, hashers[i].RowCount()) + } + } else if rowCount != accumulators[i-numGroupingKeys].RowCount() { + return nil, fmt.Errorf("field data:%d for different columns have different row count, %d vs %d, wrong state", + i, rowCount, accumulators[i-numGroupingKeys].RowCount()) + } + } + for row := 0; row < rowCount; row++ { + rowEntries := make([]*Entry, outputColumnCount) + var hashVal uint64 + for col := 0; col < outputColumnCount; col++ { + if col < numGroupingKeys { + if col > 0 { + hashVal = typeutil2.HashMix(hashVal, hashers[col].Hash(row)) + } else { + hashVal = hashers[col].Hash(row) + } + rowEntries[col] = NewEntry(hashers[col].ValAt(row)) + } else { + rowEntries[col] = NewEntry(accumulators[col-numGroupingKeys].ValAt(row)) + } + } + newRow := NewRow(rowEntries) + if bucket := reducer.hashValsMap[hashVal]; bucket == nil { + newBucket := NewBucket() + newBucket.AddRow(newRow) + totalRowCount++ + reducer.hashValsMap[hashVal] = newBucket + } else { + if rowIdx := bucket.Find(newRow, numGroupingKeys); rowIdx == NONE { + bucket.AddRow(newRow) + totalRowCount++ + } else { + bucket.Accumulate(newRow, rowIdx, numGroupingKeys, aggs) + } + } + } + } + + // 3. assemble reduced buckets into retrievedResult + reducedResult := NewAggregationResult(nil) + reducedResult.fieldDatas = typeutil.PrepareResultFieldData(firstFieldData, int64(totalRowCount)) + for _, bucket := range reducer.hashValsMap { + err := AssembleBucket(bucket, reducedResult.GetFieldDatas()) + if err != nil { + return nil, err + } + } + return reducedResult, nil +} + +func InternalResult2AggResult(results []*internalpb.RetrieveResults) []*AggregationResult { + aggResults := make([]*AggregationResult, len(results)) + for i := 0; i < len(results); i++ { + aggResults[i] = NewAggregationResult(results[i].GetFieldsData()) + } + return aggResults +} + +func AggResult2internalResult(aggRes *AggregationResult) *internalpb.RetrieveResults { + return &internalpb.RetrieveResults{FieldsData: aggRes.GetFieldDatas()} +} + +func SegcoreResults2AggResult(results []*segcorepb.RetrieveResults) ([]*AggregationResult, error) { + aggResults := make([]*AggregationResult, len(results)) + for i := 0; i < len(results); i++ { + if results[i] == nil { + return nil, fmt.Errorf("input segcore query results from any sources cannot be nil") + } + aggResults[i] = NewAggregationResult(results[i].GetFieldsData()) + } + return aggResults, nil +} + +func AggResult2segcoreResult(aggRes *AggregationResult) *segcorepb.RetrieveResults { + return &segcorepb.RetrieveResults{FieldsData: aggRes.GetFieldDatas()} +} + +type AggregationFieldMap struct { + userOriginalOutputFields []string + userOriginalOutputFieldIdxes []int +} + +func (aggMap *AggregationFieldMap) Count() int { + return len(aggMap.userOriginalOutputFields) +} + +func (aggMap *AggregationFieldMap) IndexAt(idx int) int { + return aggMap.userOriginalOutputFieldIdxes[idx] +} + +func (aggMap *AggregationFieldMap) NameAt(idx int) string { + return aggMap.userOriginalOutputFields[idx] +} + +func NewAggregationFieldMap(originalUserOutputFields []string, groupByFields []string, aggs []AggregateBase) *AggregationFieldMap { + numGroupingKeys := len(groupByFields) + + groupByFieldMap := make(map[string]int, len(groupByFields)) + for i, field := range groupByFields { + groupByFieldMap[field] = i + } + aggFieldMap := make(map[string]int, len(aggs)) + for i, agg := range aggs { + aggFieldMap[agg.OriginalName()] = i + numGroupingKeys + } + + userOriginalOutputFieldIdxes := make([]int, len(originalUserOutputFields)) + for i, outputField := range originalUserOutputFields { + if idx, exist := groupByFieldMap[outputField]; exist { + userOriginalOutputFieldIdxes[i] = idx + } + if idx, exist := aggFieldMap[outputField]; exist { + userOriginalOutputFieldIdxes[i] = idx + } + } + + return &AggregationFieldMap{originalUserOutputFields, userOriginalOutputFieldIdxes} +} diff --git a/internal/agg/aggregate_test.go b/internal/agg/aggregate_test.go new file mode 100644 index 0000000000000..e20b7209756bc --- /dev/null +++ b/internal/agg/aggregate_test.go @@ -0,0 +1,31 @@ +package agg + +import "testing" + +func TestMatchAggregationExpression(t *testing.T) { + tests := []struct { + expression string + expectedIsValid bool + expectedOperator string + expectedParam string + }{ + {"count(*)", true, "count", "*"}, + {"count(a)", true, "count", "a"}, + {"sum(b)", true, "sum", "b"}, + {"avg(c)", true, "avg", "c"}, + {"min(d)", true, "min", "d"}, + {"max(e)", true, "max", "e"}, + {"invalidExpression", false, "", ""}, + {"sum ( x )", true, "sum", "x"}, + {"SUM(Z)", true, "sum", "Z"}, + {"AVG( y )", true, "avg", "y"}, + } + + for _, test := range tests { + isValid, operator, param := MatchAggregationExpression(test.expression) + if isValid != test.expectedIsValid || operator != test.expectedOperator || param != test.expectedParam { + t.Errorf("MatchAggregationExpression(%q) = (%v, %q, %q), want (%v, %q, %q)", + test.expression, isValid, operator, param, test.expectedIsValid, test.expectedOperator, test.expectedParam) + } + } +} diff --git a/internal/core/CMakeLists.txt b/internal/core/CMakeLists.txt index a7a835f4627d2..8a64844f76e2c 100644 --- a/internal/core/CMakeLists.txt +++ b/internal/core/CMakeLists.txt @@ -295,6 +295,12 @@ install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/segcore/ FILES_MATCHING PATTERN "*_c.h" ) +# Install exec/operator/ +install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/exec/operator/ + DESTINATION include/exec/operator + FILES_MATCHING PATTERN "*_c.h" + ) + # Install exec/expression/function install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/exec/expression/function/ DESTINATION include/exec/expression/function diff --git a/internal/core/src/CMakeLists.txt b/internal/core/src/CMakeLists.txt index 0c17d074bd224..ac3de9989169b 100644 --- a/internal/core/src/CMakeLists.txt +++ b/internal/core/src/CMakeLists.txt @@ -48,6 +48,7 @@ add_subdirectory( clustering ) add_subdirectory( exec ) add_subdirectory( bitset ) add_subdirectory( futures ) +add_subdirectory( plan ) milvus_add_pkg_config("milvus_core") @@ -66,7 +67,7 @@ add_library(milvus_core SHARED $ $ $ -) + $) set(LINK_TARGETS boost_bitset_ext diff --git a/internal/core/src/common/BitUtil.h b/internal/core/src/common/BitUtil.h new file mode 100644 index 0000000000000..01dd4fc76ae78 --- /dev/null +++ b/internal/core/src/common/BitUtil.h @@ -0,0 +1,204 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +namespace milvus { +namespace bits { + +template +inline bool +isBitSet(const T* bits, uint32_t idx) { + return bits[idx / (sizeof(bits[0]) * 8)] & + (static_cast(1) << (idx & ((sizeof(bits[0]) * 8) - 1))); +} + +template +constexpr T +roundUp(T value, U factor) { + return (value + (factor - 1)) / factor * factor; +} + +constexpr uint64_t +nBytes(int32_t value) { + return roundUp(value, 8) / 8; +} + +constexpr inline uint64_t +lowMask(int32_t bits) { + return (1UL << bits) - 1; +} + +inline int32_t +getAndClearLastSetBit(uint16_t& bits) { + int32_t trailingZeros = __builtin_ctz(bits); + bits &= bits - 1; + return trailingZeros; +} + +constexpr inline uint64_t +highMask(int32_t bits) { + return lowMask(bits) << (64 - bits); +} + +/** + * Invokes a function for each batch of bits (partial or full words) + * in a given range. + * + * @param begin first bit to check (inclusive) + * @param end last bit to check (exclusive) + * @param partialWordFunc function to invoke for a partial word; + * takes index of the word and mask + * @param fullWordFunc function to invoke for a full word; + * takes index of the word + */ +template +inline void +forEachWord(int32_t begin, + int32_t end, + PartialWordFunc partialWordFunc, + FullWordFunc fullWordFunc) { + if (begin >= end) { + return; + } + int32_t firstWord = roundUp(begin, 64); + int32_t lastWord = end & ~63L; + if (lastWord < firstWord) { + partialWordFunc(lastWord / 64, + lowMask(end - lastWord) & highMask(firstWord - begin)); + return; + } + if (begin != firstWord) { + partialWordFunc(begin / 64, highMask(firstWord - begin)); + } + for (int32_t i = firstWord; i + 64 <= lastWord; i += 64) { + fullWordFunc(i / 64); + } + if (end != lastWord) { + partialWordFunc(lastWord / 64, lowMask(end - lastWord)); + } +} + +inline int32_t +countBits(const uint64_t* bits, int32_t begin, int32_t end) { + int32_t count = 0; + forEachWord( + begin, + end, + [&count, bits](int32_t idx, uint64_t mask) { + count += __builtin_popcountll(bits[idx] & mask); + }, + [&count, bits](int32_t idx) { + count += __builtin_popcountll(bits[idx]); + }); + return count; +} + +inline bool +isPowerOfTwo(uint64_t size) { + return bits::countBits(&size, 0, sizeof(uint64_t) * 8) <= 1; +} + +template +inline int32_t +countLeadingZeros(T word) { + static_assert(std::is_same_v || + std::is_same_v); + /// Built-in Function: int __builtin_clz (unsigned int x) returns the number + /// of leading 0-bits in x, starting at the most significant bit position. If + /// x is 0, the result is undefined. + if (word == 0) { + return sizeof(T) * 8; + } + if constexpr (std::is_same_v) { + return __builtin_clzll(word); + } else { + uint64_t hi = word >> 64; + uint64_t lo = static_cast(word); + return (hi == 0) ? 64 + __builtin_clzll(lo) : __builtin_clzll(hi); + } +} + +inline uint64_t +nextPowerOfTwo(uint64_t size) { + if (size == 0) { + return 0; + } + uint32_t bits = 63 - countLeadingZeros(size); + uint64_t lower = 1ULL << bits; + // Size is a power of 2. + if (lower == size) { + return size; + } + return 2 * lower; +} + +// This is the Hash128to64 function from Google's cityhash (available +// under the MIT License). We use it to reduce multiple 64 bit hashes +// into a single hash. +#if defined(FOLLY_DISABLE_UNDEFINED_BEHAVIOR_SANITIZER) +FOLLY_DISABLE_UNDEFINED_BEHAVIOR_SANITIZER("unsigned-integer-overflow") +#endif +inline uint64_t +hashMix(const uint64_t upper, const uint64_t lower) noexcept { + // Murmur-inspired hashing. + const uint64_t kMul = 0x9ddfea08eb382d69ULL; + uint64_t a = (lower ^ upper) * kMul; + a ^= (a >> 47); + uint64_t b = (upper ^ a) * kMul; + b ^= (b >> 47); + b *= kMul; + return b; +} + +/// Extract bits from integer 'a' at the corresponding bit locations specified +/// by 'mask' to contiguous low bits in return value; the remaining upper bits +/// in return value are set to zero. +template +inline T +extractBits(T a, T mask); + +#ifdef __BMI2__ +template <> +inline uint32_t +extractBits(uint32_t a, uint32_t mask) { + return _pext_u32(a, mask); +} +template <> +inline uint64_t +extractBits(uint64_t a, uint64_t mask) { + return _pext_u64(a, mask); +} +#else +template +T +extractBits(T a, T mask) { + constexpr int kBitsCount = 8 * sizeof(T); + T dst = 0; + for (int i = 0, k = 0; i < kBitsCount; ++i) { + if (mask & 1) { + dst |= ((a & 1) << k); + ++k; + } + a >>= 1; + mask >>= 1; + } + return dst; +} +#endif +} // namespace bits +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/common/ComplexVector.cpp b/internal/core/src/common/ComplexVector.cpp new file mode 100644 index 0000000000000..0c8fccbaa1f0b --- /dev/null +++ b/internal/core/src/common/ComplexVector.cpp @@ -0,0 +1,49 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "Vector.h" + +namespace milvus { + +void +BaseVector::prepareForReuse(milvus::VectorPtr& vector, + milvus::vector_size_t size) { + if (!vector.unique()) { + vector = std::make_shared( + vector->type(), size, vector->nullCount()); + } else { + vector->prepareForReuse(); + vector->resize(size); + } +} + +void +BaseVector::prepareForReuse() { + null_count_ = std::nullopt; +} + +void +RowVector::resize(milvus::vector_size_t new_size, bool setNotNull) { + const auto oldSize = size(); + BaseVector::resize(new_size, setNotNull); + for (auto& child : childrens()) { + if (new_size > oldSize) { + child->resize(new_size, setNotNull); + } + } +} + +} // namespace milvus diff --git a/internal/core/src/common/FieldData.cpp b/internal/core/src/common/FieldData.cpp index 69015dd2743c0..0b59b9424471d 100644 --- a/internal/core/src/common/FieldData.cpp +++ b/internal/core/src/common/FieldData.cpp @@ -321,4 +321,71 @@ InitScalarFieldData(const DataType& type, bool nullable, int64_t cap_rows) { } } +void +ResizeScalarFieldData(const DataType& type, + int64_t new_num_rows, + FieldDataPtr& field_data) { + switch (type) { + case DataType::BOOL: { + auto inner_field_data = + std::dynamic_pointer_cast>(field_data); + inner_field_data->resize_field_data(new_num_rows); + return; + } + case DataType::INT8: { + auto inner_field_data = + std::dynamic_pointer_cast>(field_data); + inner_field_data->resize_field_data(new_num_rows); + return; + } + case DataType::INT16: { + auto inner_field_data = + std::dynamic_pointer_cast>(field_data); + inner_field_data->resize_field_data(new_num_rows); + return; + } + case DataType::INT32: { + auto inner_field_data = + std::dynamic_pointer_cast>(field_data); + inner_field_data->resize_field_data(new_num_rows); + return; + } + case DataType::INT64: { + auto inner_field_data = + std::dynamic_pointer_cast>(field_data); + inner_field_data->resize_field_data(new_num_rows); + return; + } + case DataType::FLOAT: { + auto inner_field_data = + std::dynamic_pointer_cast>(field_data); + inner_field_data->resize_field_data(new_num_rows); + return; + } + case DataType::DOUBLE: { + auto inner_field_data = + std::dynamic_pointer_cast>(field_data); + inner_field_data->resize_field_data(new_num_rows); + return; + } + case DataType::STRING: + case DataType::VARCHAR: { + auto inner_field_data = + std::dynamic_pointer_cast>(field_data); + inner_field_data->resize_field_data(new_num_rows); + return; + } + case DataType::JSON: { + auto inner_field_data = + std::dynamic_pointer_cast>(field_data); + inner_field_data->resize_field_data(new_num_rows); + return; + } + default: + PanicInfo(DataTypeInvalid, + "ResizeScalarFieldData not support data type " + + GetDataTypeName(type)); + } +} + } // namespace milvus diff --git a/internal/core/src/common/FieldData.h b/internal/core/src/common/FieldData.h index cdd2b735464a2..0c0ea571220d4 100644 --- a/internal/core/src/common/FieldData.h +++ b/internal/core/src/common/FieldData.h @@ -165,4 +165,9 @@ using ArrowReaderChannel = Channel>; FieldDataPtr InitScalarFieldData(const DataType& type, bool nullable, int64_t cap_rows); +void +ResizeScalarFieldData(const DataType& type, + int64_t new_size, + FieldDataPtr& field_data); + } // namespace milvus \ No newline at end of file diff --git a/internal/core/src/common/FieldDataInterface.h b/internal/core/src/common/FieldDataInterface.h index 5703acce16106..5b8ec94ceccdc 100644 --- a/internal/core/src/common/FieldDataInterface.h +++ b/internal/core/src/common/FieldDataInterface.h @@ -25,7 +25,6 @@ #include #include -#include "Types.h" #include "arrow/api.h" #include "arrow/array/array_binary.h" #include "common/FieldMeta.h" @@ -330,6 +329,7 @@ class FieldDataImpl : public FieldDataBase { data_ = std::move(data); Assert(data_.size() % dim == 0); num_rows_ = data_.size() / dim; + length_ = num_rows_; } explicit FieldDataImpl(size_t dim, @@ -344,6 +344,7 @@ class FieldDataImpl : public FieldDataBase { valid_data_ = std::move(valid_data); Assert(data_.size() % dim == 0); num_rows_ = data_.size() / dim; + length_ = num_rows_; } void diff --git a/internal/core/src/common/Schema.h b/internal/core/src/common/Schema.h index b839df0d6a86f..d783d0c13f1b8 100644 --- a/internal/core/src/common/Schema.h +++ b/internal/core/src/common/Schema.h @@ -183,7 +183,9 @@ class Schema { FieldId get_field_id(const FieldName& field_name) const { - AssertInfo(name_ids_.count(field_name), "Cannot find field_name"); + AssertInfo(name_ids_.count(field_name), + "Cannot find field_name:{}", + field_name.get()); return name_ids_.at(field_name); } @@ -232,6 +234,24 @@ class Schema { field_ids_.emplace_back(field_id); } + DataType + GetFieldType(const FieldId& field_id) const { + AssertInfo(fields_.count(field_id), + "field_id:{} is not existed in the schema", + field_id.get()); + auto& meta = fields_.at(field_id); + return meta.get_data_type(); + } + + const std::string& + GetFieldName(const FieldId& field_id) const { + AssertInfo(fields_.count(field_id), + "field_id:{} is not existed in the schema", + field_id.get()); + auto& meta = fields_.at(field_id); + return meta.get_name().get(); + } + private: int64_t debug_id = START_USER_FIELDID; std::vector field_ids_; diff --git a/internal/core/src/common/SimdUtil.h b/internal/core/src/common/SimdUtil.h new file mode 100644 index 0000000000000..13e8264ba9522 --- /dev/null +++ b/internal/core/src/common/SimdUtil.h @@ -0,0 +1,179 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "common/BitUtil.h" + +namespace milvus { +template +struct BitMask; + +template +struct BitMask { + static constexpr int kAllSet = + milvus::bits::lowMask(xsimd::batch_bool::size); + +#if XSIMD_WITH_AVX2 + static int + toBitMask(xsimd::batch_bool mask, const xsimd::avx2&) { + return _mm256_movemask_epi8(mask); + } +#endif + +#if XSIMD_WITH_SSE2 + static int + toBitMask(xsimd::batch_bool mask, const xsimd::sse2&) { + return _mm_movemask_epi8(mask); + } +#endif + +#if XSIMD_WITH_NEON + static int + toBitMask(xsimd::batch_bool mask, const xsimd::neon&) { + alignas(A::alignment()) static const int8_t kShift[] = { + -7, -6, -5, -4, -3, -2, -1, 0, -7, -6, -5, -4, -3, -2, -1, 0}; + int8x16_t vshift = vld1q_s8(kShift); + uint8x16_t vmask = vshlq_u8(vandq_u8(mask, vdupq_n_u8(0x80)), vshift); + return (vaddv_u8(vget_high_u8(vmask)) << 8) | + vaddv_u8(vget_low_u8(vmask)); + } +#endif +}; + +template +struct BitMask { + static constexpr int kAllSet = + milvus::bits::lowMask(xsimd::batch_bool::size); + +#if XSIMD_WITH_AVX2 + static int + toBitMask(xsimd::batch_bool mask, const xsimd::avx2&) { + // There is no intrinsic for extracting high bits of a 16x16 + // vector. Hence take every second bit of the high bits of a 32x1 + // vector. + // + // NOTE: TVL might have a more efficient implementation for this. + return bits::extractBits(_mm256_movemask_epi8(mask), + 0xAAAAAAAA); + } +#endif + +#if XSIMD_WITH_SSE2 + static int + toBitMask(xsimd::batch_bool mask, const xsimd::sse2&) { + return milvus::bits::extractBits(_mm_movemask_epi8(mask), + 0xAAAA); + } +#endif + + static int + toBitMask(xsimd::batch_bool mask, const xsimd::generic&) { + return genericToBitMask(mask); + } +}; + +template +struct FromBitMask { + FromBitMask() { + static_assert(N <= 8); + for (int i = 0; i < (1 << N); ++i) { + bool tmp[N]; + for (int bit = 0; bit < N; ++bit) { + tmp[bit] = (i & (1 << bit)) ? true : false; + } + memo_[i] = xsimd::batch_bool::load_unaligned(tmp); + } + } + + xsimd::batch_bool + operator[](size_t i) const { + return memo_[i]; + } + + private: + static constexpr int N = xsimd::batch_bool::size; + xsimd::batch_bool memo_[1 << N]; +}; + +extern const FromBitMask fromBitMask32; +extern const FromBitMask fromBitMask64; + +template +struct BitMask { + static constexpr int kAllSet = + milvus::bits::lowMask(xsimd::batch_bool::size); + +#if XSIMD_WITH_AVX + static int + toBitMask(xsimd::batch_bool mask, const xsimd::avx&) { + return _mm256_movemask_ps(reinterpret_cast<__m256>(mask.data)); + } +#endif + +#if XSIMD_WITH_SSE2 + static int + toBitMask(xsimd::batch_bool mask, const xsimd::sse2&) { + return _mm_movemask_ps(reinterpret_cast<__m128>(mask.data)); + } +#endif + + static int + toBitMask(xsimd::batch_bool mask, const xsimd::generic&) { + return genericToBitMask(mask); + } + + static xsimd::batch_bool + fromBitMask(int mask, const xsimd::default_arch&) { + return fromBitMask32[mask]; + } +}; + +template +struct BitMask { + static constexpr int kAllSet = + milvus::bits::lowMask(xsimd::batch_bool::size); + +#if XSIMD_WITH_AVX + static int + toBitMask(xsimd::batch_bool mask, const xsimd::avx&) { + return _mm256_movemask_pd(reinterpret_cast<__m256d>(mask.data)); + } +#endif + +#if XSIMD_WITH_SSE2 + static int + toBitMask(xsimd::batch_bool mask, const xsimd::sse2&) { + return _mm_movemask_pd(reinterpret_cast<__m128d>(mask.data)); + } +#endif + + static int + toBitMask(xsimd::batch_bool mask, const xsimd::generic&) { + return genericToBitMask(mask); + } + + static xsimd::batch_bool + fromBitMask(int mask, const xsimd::default_arch&) { + return fromBitMask64[mask]; + } +}; + +template +auto +toBitMask(xsimd::batch_bool mask, const A& arch = {}) { + return BitMask::toBitMask(mask, arch); +} +} // namespace milvus diff --git a/internal/core/src/common/Types.cpp b/internal/core/src/common/Types.cpp new file mode 100644 index 0000000000000..5779cf20abcd5 --- /dev/null +++ b/internal/core/src/common/Types.cpp @@ -0,0 +1,60 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "Types.h" + +const RowTypePtr RowType::None = std::make_shared( + std::vector{}, std::vector{}); +namespace milvus { +bool +IsFixedSizeType(DataType type) { + switch (type) { + case DataType::NONE: + return false; + case DataType::BOOL: + return TypeTraits::IsFixedWidth; + case DataType::INT8: + return TypeTraits::IsFixedWidth; + case DataType::INT16: + return TypeTraits::IsFixedWidth; + case DataType::INT32: + return TypeTraits::IsFixedWidth; + case DataType::INT64: + return TypeTraits::IsFixedWidth; + case DataType::FLOAT: + return TypeTraits::IsFixedWidth; + case DataType::DOUBLE: + return TypeTraits::IsFixedWidth; + case DataType::STRING: + return TypeTraits::IsFixedWidth; + case DataType::VARCHAR: + return TypeTraits::IsFixedWidth; + case DataType::ARRAY: + return TypeTraits::IsFixedWidth; + case DataType::JSON: + return TypeTraits::IsFixedWidth; + case DataType::ROW: + return TypeTraits::IsFixedWidth; + case DataType::VECTOR_BINARY: + return TypeTraits::IsFixedWidth; + case DataType::VECTOR_FLOAT: + return TypeTraits::IsFixedWidth; + default: + PanicInfo(DataTypeInvalid, "unknown data type: {}", type); + } +} +} // namespace milvus diff --git a/internal/core/src/common/Types.h b/internal/core/src/common/Types.h index d26d2ee2ed9dc..b6491ab9671f9 100644 --- a/internal/core/src/common/Types.h +++ b/internal/core/src/common/Types.h @@ -47,6 +47,7 @@ #include "Json.h" #include "CustomBitset.h" +#include "log/Log.h" namespace milvus { @@ -60,6 +61,8 @@ using bfloat16 = knowhere::bf16; using bin1 = knowhere::bin1; // See also: https://github.com/milvus-io/milvus-proto/blob/master/proto/schema.proto +using vector_size_t = int32_t; + enum class DataType { NONE = 0, BOOL = 1, @@ -100,6 +103,54 @@ using IdArray = proto::schema::IDs; using InsertRecordProto = proto::segcore::InsertRecord; using PkType = std::variant; +inline milvus::proto::schema::DataType +GetProtoDataType(DataType internal_data_type) { + switch (internal_data_type) { + case DataType::BOOL: + return milvus::proto::schema::Bool; + case DataType::INT8: + return milvus::proto::schema::Int8; + case DataType::INT16: + return milvus::proto::schema::Int16; + case DataType::INT32: + return milvus::proto::schema::Int32; + case DataType::INT64: + return milvus::proto::schema::Int64; + case DataType::FLOAT: + return milvus::proto::schema::Float; + case DataType::DOUBLE: + return milvus::proto::schema::Double; + case DataType::STRING: + return milvus::proto::schema::String; + case DataType::VARCHAR: + return milvus::proto::schema::VarChar; + case DataType::ARRAY: + return milvus::proto::schema::Array; + case DataType::JSON: + return milvus::proto::schema::JSON; + case DataType::VECTOR_FLOAT: + return milvus::proto::schema::FloatVector; + case DataType::VECTOR_BINARY: { + return milvus::proto::schema::BinaryVector; + } + case DataType::VECTOR_FLOAT16: { + return milvus::proto::schema::Float16Vector; + } + case DataType::VECTOR_BFLOAT16: { + return milvus::proto::schema::BFloat16Vector; + } + case DataType::VECTOR_SPARSE_FLOAT: { + return milvus::proto::schema::SparseFloatVector; + } + default: { + PanicInfo( + DataTypeInvalid, + fmt::format("failed to get data type size, invalid type {}", + internal_data_type)); + } + } +} + inline size_t GetDataTypeSize(DataType data_type, int dim = 1) { switch (data_type) { @@ -478,7 +529,7 @@ struct TypeTraits { template <> struct TypeTraits { - using NativeType = int32_t; + using NativeType = int64_t; static constexpr DataType TypeKind = DataType::INT64; static constexpr bool IsPrimitiveType = true; static constexpr bool IsFixedWidth = true; @@ -563,6 +614,9 @@ struct TypeTraits { static constexpr const char* Name = "VECTOR_FLOAT"; }; +bool +IsFixedSizeType(DataType type); + } // namespace milvus template <> struct fmt::formatter : formatter { @@ -684,3 +738,85 @@ struct fmt::formatter : formatter { return formatter::format(name, ctx); } }; + +using column_index_t = uint32_t; +class RowType final { + public: + RowType(std::vector&& names, + std::vector&& types) + : names_(std::move(names)), columns_types_(std::move(types)) { + AssertInfo(names_.size() == columns_types_.size(), + "Name count:{} and column count:{} must be the same", + names_.size(), + columns_types_.size()); + }; + + static const std::shared_ptr None; + + column_index_t + GetChildIndex(std::string name) const { + std::optional idx; + for (auto i = 0; i < names_.size(); i++) { + if (names_[i] == name) { + idx = i; + break; + } + } + AssertInfo(idx.has_value(), + "Cannot find target column in the rowType list"); + return idx.value(); + } + + milvus::DataType + column_type(uint32_t idx) const { + return columns_types_.at(idx); + } + + size_t + column_count() const { + return names_.size(); + } + + private: + const std::vector names_; + const std::vector columns_types_; +}; + +using RowTypePtr = std::shared_ptr; + +#define MILVUS_DYNAMIC_TYPE_DISPATCH(TEMPLATE_FUNC, DATETYPE, ...) \ + MILVUS_DYNAMIC_TYPE_DISPATCH_IMPL(TEMPLATE_FUNC, , DATETYPE, __VA_ARGS__) + +#define MILVUS_DYNAMIC_TYPE_DISPATCH_IMPL(PREFIX, SUFFIX, DATATYPE, ...) \ + [&]() { \ + switch (DATATYPE) { \ + case milvus::DataType::BOOL: \ + return PREFIX SUFFIX(__VA_ARGS__); \ + case milvus::DataType::INT8: \ + return PREFIX SUFFIX(__VA_ARGS__); \ + case milvus::DataType::INT16: \ + return PREFIX SUFFIX(__VA_ARGS__); \ + case milvus::DataType::INT32: \ + return PREFIX SUFFIX(__VA_ARGS__); \ + case milvus::DataType::INT64: \ + return PREFIX SUFFIX(__VA_ARGS__); \ + case milvus::DataType::FLOAT: \ + return PREFIX SUFFIX(__VA_ARGS__); \ + case milvus::DataType::DOUBLE: \ + return PREFIX SUFFIX(__VA_ARGS__); \ + case milvus::DataType::VARCHAR: \ + return PREFIX SUFFIX(__VA_ARGS__); \ + case milvus::DataType::STRING: \ + return PREFIX SUFFIX(__VA_ARGS__); \ + case milvus::DataType::JSON: \ + return PREFIX SUFFIX(__VA_ARGS__); \ + case milvus::DataType::ARRAY: \ + return PREFIX SUFFIX(__VA_ARGS__); \ + case milvus::DataType::ROW: \ + return PREFIX SUFFIX(__VA_ARGS__); \ + default: \ + PanicInfo(milvus::DataTypeInvalid, \ + "UnsupportedDataType for " \ + "MILVUS_DYNAMIC_TYPE_DISPATCH_IMPL"); \ + } \ + }() diff --git a/internal/core/src/common/Utils.h b/internal/core/src/common/Utils.h index 0e52db367ad52..2f7fa4cd11fab 100644 --- a/internal/core/src/common/Utils.h +++ b/internal/core/src/common/Utils.h @@ -311,4 +311,91 @@ class Defer { #define DeferLambda(fn) Defer Defer_##__COUNTER__(fn); +template +FOLLY_ALWAYS_INLINE int +comparePrimitiveAsc(const T& left, const T& right) { + if constexpr (std::is_floating_point::value) { + bool leftNan = std::isnan(left); + bool rightNan = std::isnan(right); + if (leftNan) { + return rightNan ? 0 : 1; + } + if (rightNan) { + return -1; + } + } + return left < right ? -1 : left == right ? 0 : 1; +} + +inline std::string +lowerString(const std::string& str) { + std::string ret; + ret.resize(str.size()); + std::transform(str.begin(), str.end(), ret.begin(), [](unsigned char c) { + return std::tolower(c); + }); + return ret; +} + +template +T +checkPlus(const T& a, const T& b, const char* typeName = "integer") { + T result; + bool overflow = __builtin_add_overflow(a, b, &result); + if (UNLIKELY(overflow)) { + PanicInfo(DataTypeInvalid, "{} overflow: {} + {}", typeName, a, b); + } + return result; +} + +template +T +checkedMultiply(const T& a, const T& b, const char* typeName = "integer") { + T result; + bool overflow = __builtin_mul_overflow(a, b, &result); + if (UNLIKELY(overflow)) { + PanicInfo(DataTypeInvalid, "{} overflow: {} * {}", typeName, a, b); + } + return result; +} + +const char* const kSum = "sum"; +const char* const KMin = "min"; +const char* const KMax = "max"; +const char* const KCount = "count"; + +inline DataType +GetAggResultType(std::string func_name, DataType input_type) { + if (func_name == kSum) { + switch (input_type) { + case DataType::INT8: + case DataType::INT16: + case DataType::INT32: + case DataType::INT64: { + return DataType::INT64; + } + case DataType::FLOAT: { + return DataType::DOUBLE; + } + case DataType::DOUBLE: { + return DataType::DOUBLE; + } + default: { + PanicInfo(DataTypeInvalid, + "Unsupported data type for type:{}", + input_type); + } + } + } + if (func_name == KCount) { + return DataType::INT64; + } + PanicInfo(OpTypeInvalid, "Unsupported func type:{}", func_name); +} + +inline int32_t +Align(int32_t number, int32_t alignment) { + return (number + alignment - 1) & ~(alignment - 1); +} + } // namespace milvus diff --git a/internal/core/src/common/Vector.h b/internal/core/src/common/Vector.h index 6fa073e1d714c..a1b50e7739003 100644 --- a/internal/core/src/common/Vector.h +++ b/internal/core/src/common/Vector.h @@ -26,6 +26,8 @@ #include "common/Types.h" namespace milvus { +class BaseVector; +using VectorPtr = std::shared_ptr; /** * @brief base class for different type vector @@ -41,15 +43,40 @@ class BaseVector { virtual ~BaseVector() = default; int64_t - size() { + size() const { return length_; } DataType - type() { + type() const { return type_kind_; } + int32_t + elementSize() const { + return GetDataTypeSize(type_kind_); + }; + + size_t + nullCount() const { + return null_count_.has_value() ? null_count_.value() : 0; + } + + virtual void + resize(vector_size_t newSize, bool setNotNull = true) { + length_ = newSize; + } + + static void + prepareForReuse(VectorPtr& vector, vector_size_t size); + + /// Resets non-reusable buffers and updates child vectors by calling + /// BaseVector::prepareForReuse. + /// Base implementation checks and resets nulls buffer if needed. Keeps the + /// nulls buffer if singly-referenced, mutable and has at least one null bit set. + virtual void + prepareForReuse(); + protected: DataType type_kind_; size_t length_; @@ -57,8 +84,6 @@ class BaseVector { std::optional null_count_; }; -using VectorPtr = std::shared_ptr; - /** * SimpleVector abstracts over various Columnar Storage Formats, * it is used in custom functions. @@ -110,6 +135,13 @@ class ColumnVector final : public SimpleVector { std::move(bitmap)); } + ColumnVector(FieldDataPtr&& value, TargetBitmap&& valid_bitmap) + : SimpleVector(value->get_data_type(), value->Length()), + is_bitmap_(false), + valid_values_(std::move(valid_bitmap)) { + values_ = std::move(value); + } + virtual ~ColumnVector() override { values_.reset(); valid_values_.reset(); @@ -120,6 +152,28 @@ class ColumnVector final : public SimpleVector { return reinterpret_cast(GetRawData()) + index * size_of_element; } + template + T + ValueAt(size_t index) const { + return *(reinterpret_cast(GetRawData()) + index); + } + + template + void + SetValueAt(size_t index, const T& value) { + *(reinterpret_cast(values_->Data()) + index) = value; + } + + void + nullAt(size_t index) { + valid_values_.set(index, false); + } + + void + clearNullAt(size_t index) { + valid_values_.set(index, true); + } + bool ValidAt(size_t index) override { return valid_values_[index]; @@ -146,6 +200,19 @@ class ColumnVector final : public SimpleVector { return is_bitmap_; } + void + resize(vector_size_t new_size, bool setNotNull = true) override { + AssertInfo(!is_bitmap_, "Cannot resize bitmap column vector"); + BaseVector::resize(new_size, setNotNull); + ResizeScalarFieldData(type(), new_size, values_); + valid_values_.resize(new_size); + } + + void + append(const ColumnVector& other) { + values_->FillFieldData(other.GetRawData(), other.size()); + } + private: bool is_bitmap_; // TODO: remove the field after implementing BitmapVector FieldDataPtr values_; @@ -226,6 +293,18 @@ class RowVector : public BaseVector { } } + RowVector(const RowTypePtr& rowType, + size_t size, + std::optional nullCount = std::nullopt) + : BaseVector(DataType::ROW, size, nullCount) { + auto column_count = rowType->column_count(); + for (auto i = 0; i < column_count; i++) { + auto column_type = rowType->column_type(i); + children_values_.emplace_back( + std::make_shared(column_type, size)); + } + } + const std::vector& childrens() const { return children_values_; @@ -237,6 +316,9 @@ class RowVector : public BaseVector { return children_values_[index]; } + void + resize(vector_size_t new_size, bool setNotNull = true) override; + private: std::vector children_values_; }; diff --git a/internal/core/src/common/float_util_c.h b/internal/core/src/common/float_util_c.h new file mode 100644 index 0000000000000..88eabee242c9d --- /dev/null +++ b/internal/core/src/common/float_util_c.h @@ -0,0 +1,38 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace milvus { + +template ::value, bool> = true> +struct NaNAwareHash { + std::size_t + operator()(const FLOAT& val) const noexcept { + static const std::size_t kNanHash = + folly::hasher{}(std::numeric_limits::quiet_NaN()); + if (std::isnan(val)) { + return kNanHash; + } + return folly::hasher{}(val); + } +}; + +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/Driver.cpp b/internal/core/src/exec/Driver.cpp index 39ef70d14dc38..c77ff85abea2a 100644 --- a/internal/core/src/exec/Driver.cpp +++ b/internal/core/src/exec/Driver.cpp @@ -27,7 +27,9 @@ #include "exec/operator/MvccNode.h" #include "exec/operator/Operator.h" #include "exec/operator/VectorSearchNode.h" -#include "exec/operator/GroupByNode.h" +#include "exec/operator/SearchGroupByNode.h" +#include "exec/operator/AggregationNode.h" +#include "exec/operator/ProjectNode.h" #include "exec/Task.h" #include "common/EasyAssert.h" @@ -78,11 +80,21 @@ DriverFactory::CreateDriver(std::unique_ptr ctx, plannode)) { operators.push_back(std::make_unique( id, ctx.get(), vectorsearchnode)); - } else if (auto groupbynode = - std::dynamic_pointer_cast( + } else if (auto vectorGroupByNode = + std::dynamic_pointer_cast( + plannode)) { + operators.push_back(std::make_unique( + id, ctx.get(), vectorGroupByNode)); + } else if (auto queryGroupByNode = + std::dynamic_pointer_cast( + plannode)) { + operators.push_back(std::make_unique( + id, ctx.get(), queryGroupByNode)); + } else if (auto projectNode = + std::dynamic_pointer_cast( plannode)) { operators.push_back( - std::make_unique(id, ctx.get(), groupbynode)); + std::make_unique(id, ctx.get(), projectNode)); } // TODO: add more operators } @@ -135,6 +147,17 @@ Driver::Run(std::shared_ptr self) { } } +void +Driver::initializeOperators() { + if (operatorsInitialized_) { + return; + } + operatorsInitialized_ = true; + for (auto& op : operators_) { + op->initialize(); + } +} + void Driver::Init(std::unique_ptr ctx, std::vector> operators) { @@ -200,13 +223,13 @@ Driver::RunInternal(std::shared_ptr& self, std::shared_ptr& blocking_state, RowVectorPtr& result) { try { + initializeOperators(); int num_operators = operators_.size(); ContinueFuture future; for (;;) { for (int32_t i = num_operators - 1; i >= 0; --i) { auto op = operators_[i].get(); - current_operator_index_ = i; CALL_OPERATOR( blocking_reason_ = op->IsBlocked(&future), op, "IsBlocked"); diff --git a/internal/core/src/exec/Driver.h b/internal/core/src/exec/Driver.h index ef513b88dee47..ccd37d9233de4 100644 --- a/internal/core/src/exec/Driver.h +++ b/internal/core/src/exec/Driver.h @@ -219,6 +219,11 @@ class Driver : public std::enable_shared_from_this { EnqueueInternal() { } + /// Invoked to initialize the operators from this driver once on its first + /// execution. + void + initializeOperators(); + static void Run(std::shared_ptr self); @@ -238,6 +243,8 @@ class Driver : public std::enable_shared_from_this { size_t current_operator_index_{0}; + bool operatorsInitialized_{false}; + BlockingReason blocking_reason_{BlockingReason::kNotBlocked}; friend struct DriverFactory; diff --git a/internal/core/src/exec/HashTable.cpp b/internal/core/src/exec/HashTable.cpp new file mode 100644 index 0000000000000..cc3e429823df3 --- /dev/null +++ b/internal/core/src/exec/HashTable.cpp @@ -0,0 +1,368 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "HashTable.h" +#include +#include "common/SimdUtil.h" +#include "exec/operator/OperatorUtils.h" + +namespace milvus { +namespace exec { + +void +populateLookupRows(const TargetBitmapView& activeRows, + std::vector& lookupRows) { + if (activeRows.all()) { + std::iota(lookupRows.begin(), lookupRows.end(), 0); + } else { + auto start = -1; + lookupRows.clear(); + lookupRows.reserve(activeRows.count()); + do { + auto next_active = activeRows.find_next(start); + if (!next_active.has_value()) + break; + auto next_active_row = next_active.value(); + lookupRows.emplace_back(next_active_row); + start = next_active_row; + } while (true); + } +} + +void +BaseHashTable::prepareForGroupProbe(HashLookup& lookup, + const RowVectorPtr& input, + TargetBitmap& activeRows, + bool ignoreNullKeys) { + auto& hashers = lookup.hashers_; + int numKeys = hashers.size(); + // set up column vector to each column + for (auto i = 0; i < numKeys; i++) { + auto& hasher = hashers[i]; + auto column_idx = hasher->ChannelIndex(); + ColumnVectorPtr column_ptr = + std::dynamic_pointer_cast(input->child(column_idx)); + AssertInfo(column_ptr != nullptr, + "Failed to get column vector from row vector input"); + hashers[i]->setColumnData(column_ptr); + // deselect null values + if (ignoreNullKeys) { + int64_t length = column_ptr->size(); + TargetBitmapView valid_bits_view(column_ptr->GetValidRawData(), + length); + activeRows &= valid_bits_view; + } + } + lookup.reset(activeRows.size()); + + const auto mode = hashMode(); + for (auto i = 0; i < hashers.size(); i++) { + if (mode == BaseHashTable::HashMode::kHash) { + hashers[i]->hash(i > 0, activeRows, lookup.hashes_); + } else { + PanicInfo( + milvus::OpTypeInvalid, + "Not support target hashMode, only support kHash for now"); + } + } + populateLookupRows(activeRows, lookup.rows_); +} + +class ProbeState { + public: + enum class Operation { kProbe, kInsert, kErase }; + // Special tag for an erased entry. This counts as occupied for probe and as + // empty for insert. If a tag word with empties gets an erase, we make the + // erased tag empty. If the tag word getting the erase has no empties, the + // erase is marked with a tombstone. A probe always stops with a tag word with + // empties. Adding an empty to a tag word with no empties would break probes + // that needed to skip this tag word. This is standard practice for open + // addressing hash tables. F14 has more sophistication in this but we do not + // need it here since erase is very rare except spilling and is not expected + // to change the load factor by much in the expected uses. + //static constexpr uint8_t kTombstoneTag = 0x7f; + static constexpr uint8_t kEmptyTag = 0x00; + static constexpr int32_t kFullMask = 0xffff; + + int32_t + row() const { + return row_; + } + + template + inline void + preProbe(const Table& table, uint64_t hash, int32_t row) { + row_ = row; + bucketOffset_ = table.bucketOffset(hash); + const auto tag = BaseHashTable::hashTag(hash); + wantedTags_ = BaseHashTable::TagVector::broadcast(tag); + group_ = nullptr; + __builtin_prefetch(reinterpret_cast(table.table_) + + bucketOffset_); + } + + template + inline void + firstProbe(const Table& table, int32_t firstKey) { + tagsInTable_ = BaseHashTable::loadTags( + reinterpret_cast(table.table_), bucketOffset_); + hits_ = milvus::toBitMask(tagsInTable_ == wantedTags_); + if (hits_) { + loadNextHit(table, firstKey); + } + } + + template + inline char* + fullProbe(Table& table, + int32_t firstKey, + Compare compare, + Insert insert, + bool extraCheck) { + AssertInfo(op == Operation::kInsert, + "Only support insert operation for group cases"); + if (group_ && compare(group_, row_)) { + return group_; + } + + if (extraCheck) { + tagsInTable_ = table.loadTags(bucketOffset_); + hits_ = milvus::toBitMask(tagsInTable_ == wantedTags_); + } + + const auto kEmptyGroup = BaseHashTable::TagVector::broadcast(0); + for (int64_t numProbedBuckets = 0; + numProbedBuckets < table.numBuckets(); + ++numProbedBuckets) { + while (hits_ > 0) { + loadNextHit(table, firstKey); + if (compare(group_, row_)) { + return group_; + } + } + + uint16_t empty = + milvus::toBitMask(tagsInTable_ == kEmptyGroup) & kFullMask; + // if there are still empty slot available, try to insert into existing empty slot or tombstone slot + if (empty > 0) { + auto pos = milvus::bits::getAndClearLastSetBit(empty); + return insert(row_, bucketOffset_ + pos); + } + bucketOffset_ = table.nextBucketOffset(bucketOffset_); + tagsInTable_ = table.loadTags(bucketOffset_); + hits_ = milvus::toBitMask(tagsInTable_ == wantedTags_); + } + PanicInfo(UnexpectedError, + "Slots in hash table is not enough for hash operation, fail " + "the request"); + } + + private: + static constexpr uint8_t kNotSet = 0xff; + template + inline void + loadNextHit(Table& table, int32_t firstKey) { + const int32_t hit = milvus::bits::getAndClearLastSetBit(hits_); + group_ = table.row(bucketOffset_, hit); + __builtin_prefetch(group_ + firstKey); + } + + char* group_; + BaseHashTable::TagVector wantedTags_; + BaseHashTable::TagVector tagsInTable_; + int32_t row_; + int64_t bucketOffset_; + BaseHashTable::MaskType hits_; + //uint8_t indexInTags_ = kNotSet; +}; + +template +void +HashTable::allocateTables(uint64_t size) { + AssertInfo(milvus::bits::isPowerOfTwo(size), + "Size:{} for allocating tables must be a power of two", + size); + AssertInfo(size > 0, + "Size:{} for allocating tables must be larger than zero", + size); + capacity_ = size; + const uint64_t byteSize = capacity_ * tableSlotSize(); + AssertInfo(byteSize % kBucketSize == 0, + "byteSize:{} for hashTable must be a multiple of kBucketSize:{}", + byteSize, + kBucketSize); + numBuckets_ = byteSize / kBucketSize; + sizeMask_ = byteSize - 1; + sizeBits_ = __builtin_popcountll(sizeMask_); + bucketOffsetMask_ = sizeMask_ & ~(kBucketSize - 1); + // The total size is 8 bytes per slot, in groups of 16 slots with 16 bytes of + // tags and 16 * 6 bytes of pointers and a padding of 16 bytes to round up the + // cache line. + // TODO support memory pool here to avoid OOM + table_ = new char*[capacity_]; + memset(table_, 0, capacity_ * sizeof(char*)); +} + +template +void +HashTable::checkSize(int32_t numNew) { + AssertInfo(capacity_ == 0 || capacity_ > numDistinct_, + "capacity_ {}, numDistinct {}", + capacity_, + numDistinct_); + const int64_t newNumDistinct = numNew + numDistinct_; + if (table_ == nullptr || capacity_ == 0) { + const auto newSize = newHashTableEntriesNumber(numDistinct_, numNew); + allocateTables(newSize); + } else if (newNumDistinct > rehashSize()) { + const auto newCapacity = milvus::bits::nextPowerOfTwo( + std::max(newNumDistinct, capacity_) + 1); + allocateTables(newCapacity); + } +} + +template +bool +HashTable::compareKeys(const char* group, + milvus::exec::HashLookup& lookup, + milvus::vector_size_t row) { + int32_t numKeys = lookup.hashers_.size(); + int32_t i = 0; + do { + auto& hasher = lookup.hashers_[i]; + if (!rows_->equals( + group, rows()->columnAt(i), hasher->columnData(), row)) { + return false; + } + } while (++i < numKeys); + return true; +} + +template +void +HashTable::storeKeys(milvus::exec::HashLookup& lookup, + milvus::vector_size_t row) { + for (int32_t i = 0; i < hashers_.size(); i++) { + auto& hasher = hashers_[i]; + rows_->store(hasher->columnData(), row, lookup.hits_[row], i); + } +} + +template +void +HashTable::storeRowPointer(uint64_t index, + uint64_t hash, + char* row) { + const int64_t bktOffset = bucketOffset(index); + auto* bucket = bucketAt(bktOffset); + const auto slotIndex = index & (sizeof(TagVector) - 1); + bucket->setTag(slotIndex, hashTag(hash)); + bucket->setPointer(slotIndex, row); +} + +template +char* +HashTable::insertEntry(milvus::exec::HashLookup& lookup, + uint64_t index, + milvus::vector_size_t row) { + char* group = rows_->newRow(); + lookup.hits_[row] = group; + storeKeys(lookup, row); + storeRowPointer(index, lookup.hashes_[row], group); + numDistinct_++; + lookup.newGroups_.push_back(row); + return group; +} + +template +FOLLY_ALWAYS_INLINE void +HashTable::fullProbe(HashLookup& lookup, + ProbeState& state, + bool extraCheck) { + constexpr ProbeState::Operation op = ProbeState::Operation::kInsert; + lookup.hits_[state.row()] = state.fullProbe( + *this, + 0, + [&](char* group, int32_t row) { + return compareKeys(group, lookup, row); + }, + [&](int32_t row, uint64_t index) { + return insertEntry(lookup, index, row); + }, + extraCheck); +} + +template +void +HashTable::groupProbe(milvus::exec::HashLookup& lookup) { + AssertInfo(hashMode_ == HashMode::kHash, "Only support kHash mode for now"); + checkSize(lookup.rows_.size()); + ProbeState state1; + ProbeState state2; + ProbeState state3; + ProbeState state4; + int32_t probeIdx = 0; + int32_t numProbes = lookup.rows_.size(); + auto rows = lookup.rows_.data(); + for (; probeIdx + 4 <= numProbes; probeIdx += 4) { + int32_t row = rows[probeIdx]; + state1.preProbe(*this, lookup.hashes_[row], row); + row = rows[probeIdx + 1]; + state2.preProbe(*this, lookup.hashes_[row], row); + row = rows[probeIdx + 2]; + state3.preProbe(*this, lookup.hashes_[row], row); + row = rows[probeIdx + 3]; + state4.preProbe(*this, lookup.hashes_[row], row); + + state1.firstProbe(*this, 0); + state2.firstProbe(*this, 0); + state3.firstProbe(*this, 0); + state4.firstProbe(*this, 0); + + fullProbe(lookup, state1, false); + fullProbe(lookup, state2, true); + fullProbe(lookup, state3, true); + fullProbe(lookup, state4, true); + } + for (; probeIdx < numProbes; probeIdx++) { + int32_t row = rows[probeIdx]; + state1.preProbe(*this, lookup.hashes_[row], row); + state1.firstProbe(*this, 0); + fullProbe(lookup, state1, false); + } +} + +template +void +HashTable::setHashMode(HashMode mode, int32_t numNew) { + // TODO set hash mode kArray/kHash/kNormalizedKey +} + +template +void +HashTable::clear(bool freeTable) { + if (table_) { + delete[] table_; + table_ = nullptr; + } + rows_->clear(); + numDistinct_ = 0; +} + +template class HashTable; +template class HashTable; +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/HashTable.h b/internal/core/src/exec/HashTable.h new file mode 100644 index 0000000000000..27eedfe7a6f8b --- /dev/null +++ b/internal/core/src/exec/HashTable.h @@ -0,0 +1,340 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include + +#include "VectorHasher.h" +#include "exec/operator/query-agg/RowContainer.h" +#include "xsimd/xsimd.hpp" +#include "common/BitUtil.h" + +namespace milvus { +namespace exec { + +struct HashLookup { + explicit HashLookup( + const std::vector>& hashers) + : hashers_(hashers) { + } + + void + reset(vector_size_t size) { + rows_.resize(size); + hashes_.resize(size); + hits_.resize(size); + newGroups_.clear(); + } + + /// One entry per group-by + const std::vector>& hashers_; + + /// Set of row numbers of row to probe. + std::vector rows_; + + /// Hashes or value IDs for rows in 'rows'. Not aligned with 'rows'. Index is + /// the row number. + std::vector hashes_; + + /// Contains one entry for each row in 'rows'. Index is the row number. + /// For groupProbe, a pointer to an existing or new row with matching grouping + /// keys. + std::vector hits_; + + /// For groupProbe, row numbers for which a new entry was inserted (didn't + /// exist before the groupProbe). Empty for joinProbe. + std::vector newGroups_; +}; + +class BaseHashTable { + public: +#if XSIMD_WITH_SSE2 + using TagVector = xsimd::batch; +#elif XSIMD_WITH_NEON + using TagVector = xsimd::batch; +#endif + using MaskType = uint16_t; + + enum class HashMode { kHash, kArray, kNormalizedKey }; + + explicit BaseHashTable(std::vector>&& hashers) + : hashers_(std::move(hashers)) { + } + + virtual ~BaseHashTable() = default; + + RowContainer* + rows() const { + return rows_.get(); + } + + /// Extracts a 7 bit tag from a hash number. The high bit is always set. + static uint8_t + hashTag(uint64_t hash) { + // This is likely all 0 for small key types (<= 32 bits). Not an issue + // because small types have a range that makes them normalized key cases. + // If there are multiple small type keys, they are mixed which makes them a + // 64 bit hash. Normalized keys are mixed before being used as hash + // numbers. + return static_cast(hash >> 38) | 0x80; + } + + static FOLLY_ALWAYS_INLINE size_t + tableSlotSize() { + // Each slot is 8 bytes. + return sizeof(void*); + } + + static TagVector + loadTags(uint8_t* tags, int64_t tagIndex) { + auto src = tags + tagIndex; +#if XSIMD_WITH_SSE2 + return TagVector( + _mm_loadu_si128(reinterpret_cast<__m128i const*>(src))); +#elif XSIMD_WITH_NEON + return TagVector(vld1q_u8(src)); +#endif + } + + const std::vector>& + hashers() const { + return hashers_; + } + + /// Returns the hash mode. This is needed for the caller to calculate + /// the hash numbers using the appropriate method of the + /// VectorHashers of 'this'. + virtual HashMode + hashMode() const = 0; + + virtual void + setHashMode(HashMode mode, int32_t numNew) = 0; + + /// Disables use of array or normalized key hash modes. + void + forceGenericHashMode() { + setHashMode(HashMode::kHash, 0); + } + + /// Populates 'hashes' and 'rows' fields in 'lookup' in preparation for + /// 'groupProbe' call. Rehashes the table if necessary. Uses lookup.hashes to + /// decode grouping keys from 'input'. If 'ignoreNullKeys_' is true, updates + /// 'rows' to remove entries with null grouping keys. After this call, 'rows' + /// may have no entries selected. + void + prepareForGroupProbe(HashLookup& lookup, + const RowVectorPtr& input, + TargetBitmap& activeRows, + bool nullableKeys); + + /// Finds or creates a group for each key in 'lookup'. The keys are + /// returned in 'lookup.hits'. + virtual void + groupProbe(HashLookup& lookup) = 0; + + virtual void + clear(bool freeTable = false) = 0; + + protected: + std::vector> hashers_; + std::unique_ptr rows_; +}; + +class ProbeState; + +template +class HashTable : public BaseHashTable { + public: + HashTable(std::vector>&& hashers, + const std::vector& accumulators) + : BaseHashTable(std::move(hashers)) { + std::vector keyTypes; + for (auto& hasher : hashers_) { + keyTypes.push_back(hasher->ChannelDataType()); + } + hashMode_ = HashMode::kHash; + rows_ = std::make_unique( + keyTypes, accumulators, ignoreNullKeys); + }; + + ~HashTable() override { + } + + void + setHashMode(HashMode mode, int32_t numNew) override; + + void + groupProbe(HashLookup& lookup) override; + + // The table in non-kArray mode has a power of two number of buckets each with + // 16 slots. Each slot has a 1 byte tag (a field of hash number) and a 48 bit + // pointer. All the tags are in a 16 byte SIMD word followed by the 6 byte + // pointers. There are 16 bytes of padding at the end to make the bucket + // occupy exactly two (64 bytes) cache lines. + class Bucket { + public: + uint8_t + tagAt(int32_t slotIndex) { + return reinterpret_cast(&tags_)[slotIndex]; + } + + char* + pointerAt(int32_t slotIndex) { + return reinterpret_cast( + *reinterpret_cast( + &pointers_[kPointerSize * slotIndex]) & + kPointerMask); + } + + void + setTag(int32_t slotIndex, uint8_t tag) { + reinterpret_cast(&tags_)[slotIndex] = tag; + } + + void + setPointer(int32_t slotIndex, void* pointer) { + auto* const slot = reinterpret_cast( + &pointers_[slotIndex * kPointerSize]); + *slot = + (*slot & ~kPointerMask) | reinterpret_cast(pointer); + } + + private: + static constexpr uint8_t kPointerSignificantBits = 48; + static constexpr uint64_t kPointerMask = + milvus::bits::lowMask(kPointerSignificantBits); + static constexpr int32_t kPointerSize = kPointerSignificantBits / 8; + + TagVector tags_; + char pointers_[sizeof(TagVector) * kPointerSize]; + char padding_[16]; + }; + static_assert(sizeof(Bucket) == 128); + static constexpr uint64_t kBucketSize = sizeof(Bucket); + + Bucket* + bucketAt(int64_t offset) const { + //AssertInfo(offset&(kBucketSize-1)==0, "Invalid offset:{} and kBucketSize:{}", offset, kBucketSize); + return reinterpret_cast(reinterpret_cast(table_) + + offset); + } + + int64_t + bucketOffset(uint64_t hash) const { + return hash & bucketOffsetMask_; + } + + int64_t + nextBucketOffset(int64_t bucketOffset) const { + //AssertInfo(bucketOffset&(kBucketSize - 1) == 0, "Invalid bucketOffset:{} for nextBucketOffset", bucketOffset); + AssertInfo(bucketOffset < sizeMask_, + "BucketOffset:{} must be less than sizeMask_:{} for " + "nextBucketOffset", + bucketOffset, + sizeMask_); + return sizeMask_ & (bucketOffset + kBucketSize); + } + + bool + compareKeys(const char* group, HashLookup& lookup, vector_size_t row); + + char* + row(int64_t bucketOffset, int32_t slotIndex) const { + return bucketAt(bucketOffset)->pointerAt(slotIndex); + } + + int64_t + numBuckets() const { + return numBuckets_; + } + + TagVector + loadTags(int64_t bucketOffset) const { + return BaseHashTable::loadTags(reinterpret_cast(table_), + bucketOffset); + } + + char* + insertEntry(HashLookup& lookup, uint64_t index, vector_size_t row); + + void + storeKeys(HashLookup& lookup, vector_size_t row); + + void + storeRowPointer(uint64_t index, uint64_t hash, char* row); + + // Allocates new tables for tags and payload pointers. The size must + // a power of 2. + void + allocateTables(uint64_t size); + + void + fullProbe(HashLookup& lookup, ProbeState& state, bool extraCheck); + + void + clear(bool freeTable = false) override; + + void + checkSize(int32_t numNew); + + // Returns the number of entries after which the table gets rehashed. + static uint64_t + rehashSize(int64_t size) { + // This implements the F14 load factor: Resize if less than 1/8 unoccupied. + return size - (size / 8); + } + + uint64_t + rehashSize() const { + return rehashSize(capacity_); + } + + static uint64_t + newHashTableEntriesNumber(uint64_t numDistinct, uint64_t numNew) { + auto numNewEntries = + std::max((uint64_t)2048, + milvus::bits::nextPowerOfTwo(numNew * 2 + numDistinct)); + const auto newNumDistinct = numDistinct + numNew; + if (newNumDistinct > rehashSize(numNewEntries)) { + numNewEntries *= 2; + } + return numNewEntries; + } + + private: + HashMode hashMode_ = HashMode::kHash; + int64_t bucketOffsetMask_{0}; + int64_t numBuckets_{0}; + int64_t numDistinct_{0}; + + // Number of slots across all buckets. + int64_t capacity_{0}; + // Mask for extracting low bits of hash number for use as byte offsets into + // the table. This is set to 'capacity_ * sizeof(void*) - 1'. + int64_t sizeMask_{0}; + int8_t sizeBits_; + + int64_t numRehashes_{0}; + char** table_ = nullptr; + + HashMode + hashMode() const override { + return hashMode_; + } + friend class ProbeState; +}; + +} // namespace exec +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/VectorHasher.cpp b/internal/core/src/exec/VectorHasher.cpp new file mode 100644 index 0000000000000..aaa81b65befb1 --- /dev/null +++ b/internal/core/src/exec/VectorHasher.cpp @@ -0,0 +1,99 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "VectorHasher.h" +#include "common/float_util_c.h" +#include +#include "common/BitUtil.h" + +namespace milvus { +namespace exec { +std::vector> +createVectorHashers(const RowTypePtr& rowType, + const std::vector& exprs) { + std::vector> hashers; + hashers.reserve(exprs.size()); + for (const auto& expr : exprs) { + auto column_idx = rowType->GetChildIndex(expr->name()); + hashers.emplace_back(VectorHasher::create(expr->type(), column_idx)); + } + return hashers; +} + +template +void +VectorHasher::hashValues(const ColumnVectorPtr& column_data, + const TargetBitmapView& activeRows, + bool mix, + uint64_t* result) { + if constexpr (Type == DataType::ROW || Type == DataType::ARRAY || + Type == DataType::JSON) { + PanicInfo(milvus::DataTypeInvalid, + "NotSupport hash for complext type row/array/json:{}", + Type); + } else { + using T = typename TypeTraits::NativeType; + auto start = -1; + do { + auto next_valid_op = activeRows.find_next(start); + if (!next_valid_op.has_value()) { + break; + } + auto next_valid_row = next_valid_op.value(); + if (!column_data->ValidAt(next_valid_row)) { + result[next_valid_row] = + mix ? milvus::bits::hashMix(result[next_valid_row], + kNullHash) + : kNullHash; + } else { + T raw_value = column_data->ValueAt(next_valid_row); + uint64_t hash_value = kNullHash; + if constexpr (std::is_floating_point_v) { + hash_value = milvus::NaNAwareHash()(raw_value); + } else { + hash_value = folly::hasher()(raw_value); + } + result[next_valid_row] = + mix ? milvus::bits::hashMix(result[next_valid_row], + hash_value) + : hash_value; + } + start = next_valid_row; + } while (true); + } +} + +void +VectorHasher::hash(bool mix, + const TargetBitmapView& activeRows, + std::vector& result) { + // auto element_size = GetDataTypeSize(element_data_type); + // auto element_count = column_data->size(); + + // for(auto i = 0; i < element_count; i++) { + // void* raw_value = column_data->RawValueAt(i, element_size); + // } + auto element_data_type = ChannelDataType(); + MILVUS_DYNAMIC_TYPE_DISPATCH(hashValues, + element_data_type, + columnData(), + activeRows, + mix, + result.data()); +} + +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/VectorHasher.h b/internal/core/src/exec/VectorHasher.h new file mode 100644 index 0000000000000..e06638c845616 --- /dev/null +++ b/internal/core/src/exec/VectorHasher.h @@ -0,0 +1,102 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "common/Vector.h" +#include "common/Types.h" +#include "expr/ITypeExpr.h" + +namespace milvus { +namespace exec { +class VectorHasher { + public: + VectorHasher(DataType data_type, column_index_t column_idx) + : channel_type_(data_type), channel_idx_(column_idx) { + } + + static std::unique_ptr + create(DataType data_type, column_index_t col_idx) { + return std::make_unique(data_type, col_idx); + } + + column_index_t + ChannelIndex() const { + return channel_idx_; + } + + DataType + ChannelDataType() const { + return channel_type_; + } + + void + hash(bool mix, + const TargetBitmapView& activeRows, + std::vector& result); + + static constexpr uint64_t kNullHash = 1; + + static bool + typeSupportValueIds(DataType type) { + switch (type) { + case DataType::BOOL: + case DataType::INT8: + case DataType::INT16: + case DataType::INT32: + case DataType::INT64: + case DataType::VARCHAR: + case DataType::STRING: + return true; + default: + return false; + } + } + + template + void + hashValues(const ColumnVectorPtr& column_data, + const TargetBitmapView& activeRows, + bool mix, + uint64_t* result); + + void + setColumnData(const ColumnVectorPtr& column_data) { + column_data_ = column_data; + } + + const ColumnVectorPtr& + columnData() const { + return column_data_; + } + + private: + const column_index_t channel_idx_; + const DataType channel_type_; + ColumnVectorPtr column_data_; +}; + +std::vector> +createVectorHashers(const RowTypePtr& rowType, + const std::vector& exprs); + +static std::unique_ptr +create(DataType dataType, column_index_t column_idx) { + return std::make_unique(dataType, column_idx); +} + +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/expression/Utils.h b/internal/core/src/exec/expression/Utils.h index 5b6549250cb5c..a6b67c168fe29 100644 --- a/internal/core/src/exec/expression/Utils.h +++ b/internal/core/src/exec/expression/Utils.h @@ -162,5 +162,16 @@ GetValueFromProtoWithOverflow( return GetValueFromProtoInternal(value_proto, overflowed); } +inline std::string +sanitizeName(const std::string& name) { + std::string sanitizedName; + sanitizedName.resize(name.size()); + std::transform( + name.begin(), name.end(), sanitizedName.begin(), [](unsigned char c) { + return std::tolower(c); + }); + return sanitizedName; +} + } // namespace exec } // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/operator/AggregationNode.cpp b/internal/core/src/exec/operator/AggregationNode.cpp new file mode 100644 index 0000000000000..99618852cf01e --- /dev/null +++ b/internal/core/src/exec/operator/AggregationNode.cpp @@ -0,0 +1,83 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +// +// Created by hanchun on 24-10-18. +// + +#include "AggregationNode.h" +#include "common/Utils.h" + +namespace milvus { +namespace exec { + +PhyAggregationNode::PhyAggregationNode( + int32_t operator_id, + milvus::exec::DriverContext* ctx, + const std::shared_ptr& node) + : Operator(ctx, node->output_type(), operator_id, node->id()), + aggregationNode_(node), + isGlobal_(node->GroupingKeys().empty()) { +} + +void +PhyAggregationNode::prepareOutput(vector_size_t size) { + if (output_) { + VectorPtr new_output = std::move(output_); + BaseVector::prepareForReuse(new_output, size); + output_ = std::static_pointer_cast(new_output); + } else { + output_ = std::make_shared(output_type_, size); + } +} + +RowVectorPtr +PhyAggregationNode::GetOutput() { + if (finished_ || (!no_more_input_ && !grouping_set_->hasOutput())) { + input_ = nullptr; + return nullptr; + } + DeferLambda([&]() { finished_ = true; }); + const auto outputRowCount = isGlobal_ ? 1 : grouping_set_->outputRowCount(); + prepareOutput(outputRowCount); + const bool hasData = grouping_set_->getOutput(output_); + if (!hasData) { + return nullptr; + } + numOutputRows_ += output_->size(); + return output_; +} + +void +PhyAggregationNode::initialize() { + Operator::initialize(); + const auto& input_type = aggregationNode_->sources()[0]->output_type(); + auto hashers = + createVectorHashers(input_type, aggregationNode_->GroupingKeys()); + auto numHashers = hashers.size(); + std::vector aggregateInfos = + toAggregateInfo(*aggregationNode_, *operator_context_, numHashers); + grouping_set_ = + std::make_unique(input_type, + std::move(hashers), + std::move(aggregateInfos), + aggregationNode_->ignoreNullKeys()); + aggregationNode_.reset(); +} + +void +PhyAggregationNode::AddInput(milvus::RowVectorPtr& input) { + grouping_set_->addInput(input); + numInputRows_ += input->size(); +} + +}; // namespace exec +}; // namespace milvus diff --git a/internal/core/src/exec/operator/AggregationNode.h b/internal/core/src/exec/operator/AggregationNode.h new file mode 100644 index 0000000000000..36607ada3df80 --- /dev/null +++ b/internal/core/src/exec/operator/AggregationNode.h @@ -0,0 +1,89 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "exec/operator/Operator.h" +#include "exec/operator/query-agg/GroupingSet.h" +#include "common/Types.h" + +namespace milvus { +namespace exec { +class PhyAggregationNode : public Operator { + public: + PhyAggregationNode( + int32_t operator_id, + DriverContext* ctx, + const std::shared_ptr& node); + + bool + NeedInput() const override { + return true; + } + + void + AddInput(RowVectorPtr& input) override; + + RowVectorPtr + GetOutput() override; + + bool + IsFinished() override { + return finished_; + } + + bool + IsFilter() const override { + return false; + } + + BlockingReason + IsBlocked(ContinueFuture* future) { + return BlockingReason::kNotBlocked; + } + + void + Close() override { + input_ = nullptr; + results_.clear(); + } + + void + initialize() override; + + std::string + ToString() const override { + return "PhyAggregationNode"; + } + + private: + void + prepareOutput(vector_size_t size); + + RowVectorPtr output_; + std::unique_ptr grouping_set_; + std::shared_ptr aggregationNode_; + const bool isGlobal_; + + // Count the number of input rows. It is reset on partial aggregation output + // flush. + int64_t numInputRows_ = 0; + // Count the number of output rows. It is reset on partial aggregation output + // flush. + int64_t numOutputRows_ = 0; + bool finished_ = false; +}; +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/operator/CallbackSink.h b/internal/core/src/exec/operator/CallbackSink.h index d0f5e2d37afc1..a069b68a1fbdb 100644 --- a/internal/core/src/exec/operator/CallbackSink.h +++ b/internal/core/src/exec/operator/CallbackSink.h @@ -26,7 +26,7 @@ class CallbackSink : public Operator { int32_t operator_id, DriverContext* ctx, std::function callback) - : Operator(ctx, DataType::NONE, operator_id, "N/A", "CallbackSink"), + : Operator(ctx, RowType::None, operator_id, "N/A", "CallbackSink"), callback_(callback) { } @@ -52,7 +52,7 @@ class CallbackSink : public Operator { } bool - IsFilter() override { + IsFilter() const override { return false; } diff --git a/internal/core/src/exec/operator/CountNode.h b/internal/core/src/exec/operator/CountNode.h index cfb9512a555b2..4692ef25f53d6 100644 --- a/internal/core/src/exec/operator/CountNode.h +++ b/internal/core/src/exec/operator/CountNode.h @@ -34,7 +34,7 @@ class PhyCountNode : public Operator { const std::shared_ptr& node); bool - IsFilter() override { + IsFilter() const override { return false; } diff --git a/internal/core/src/exec/operator/FilterBitsNode.h b/internal/core/src/exec/operator/FilterBitsNode.h index de6d472a508e6..9eea755287f1a 100644 --- a/internal/core/src/exec/operator/FilterBitsNode.h +++ b/internal/core/src/exec/operator/FilterBitsNode.h @@ -34,7 +34,7 @@ class PhyFilterBitsNode : public Operator { const std::shared_ptr& filter); bool - IsFilter() override { + IsFilter() const override { return true; } diff --git a/internal/core/src/exec/operator/IterativeFilterNode.h b/internal/core/src/exec/operator/IterativeFilterNode.h index 07404d974b7ca..23eb6052aed54 100644 --- a/internal/core/src/exec/operator/IterativeFilterNode.h +++ b/internal/core/src/exec/operator/IterativeFilterNode.h @@ -37,7 +37,7 @@ class PhyIterativeFilterNode : public Operator { const std::shared_ptr& filter); bool - IsFilter() override { + IsFilter() const override { return true; } diff --git a/internal/core/src/exec/operator/MvccNode.h b/internal/core/src/exec/operator/MvccNode.h index 332dc71c5333a..5a9f998aa0a1a 100644 --- a/internal/core/src/exec/operator/MvccNode.h +++ b/internal/core/src/exec/operator/MvccNode.h @@ -34,7 +34,7 @@ class PhyMvccNode : public Operator { const std::shared_ptr& mvcc_node); bool - IsFilter() override { + IsFilter() const override { return false; } diff --git a/internal/core/src/exec/operator/Operator.cpp b/internal/core/src/exec/operator/Operator.cpp index 972482c797d0a..6af1b467211e3 100644 --- a/internal/core/src/exec/operator/Operator.cpp +++ b/internal/core/src/exec/operator/Operator.cpp @@ -17,5 +17,10 @@ #include "Operator.h" namespace milvus { -namespace exec {} +namespace exec { +void +Operator::initialize() { + // TODO check memory and set up memory pool in the future +} +} // namespace exec } // namespace milvus diff --git a/internal/core/src/exec/operator/Operator.h b/internal/core/src/exec/operator/Operator.h index 1115ee263ac50..8b39402952fe4 100644 --- a/internal/core/src/exec/operator/Operator.h +++ b/internal/core/src/exec/operator/Operator.h @@ -94,16 +94,26 @@ class OperatorContext { class Operator { public: Operator(DriverContext* ctx, - DataType output_type, + RowTypePtr output_type, int32_t operator_id, const std::string& plannode_id, const std::string& operator_type = "") : operator_context_(std::make_unique( - ctx, plannode_id, operator_id, operator_type)) { + ctx, plannode_id, operator_id, operator_type)), + output_type_(output_type) { } virtual ~Operator() = default; + /// Does initialization work for this operator which requires memory + /// allocation from memory pool that can't be done under operator constructor. + /// + /// NOTE: the default implementation set 'initialized_' to true to ensure we + /// never call this more than once. The overload initialize() implementation + /// must call this base implementation first. + virtual void + initialize(); + virtual bool NeedInput() const = 0; @@ -122,7 +132,7 @@ class Operator { IsFinished() = 0; virtual bool - IsFilter() = 0; + IsFilter() const = 0; virtual BlockingReason IsBlocked(ContinueFuture* future) = 0; @@ -158,10 +168,15 @@ class Operator { return "Base Operator"; } + virtual const RowTypePtr& + OutputType() const { + return output_type_; + } + protected: std::unique_ptr operator_context_; - DataType output_type_; + RowTypePtr output_type_; RowVectorPtr input_; @@ -173,7 +188,7 @@ class Operator { class SourceOperator : public Operator { public: SourceOperator(DriverContext* driver_ctx, - DataType out_type, + RowTypePtr out_type, int32_t operator_id, const std::string& plannode_id, const std::string& operator_type) diff --git a/internal/core/src/exec/operator/OperatorUtils.cpp b/internal/core/src/exec/operator/OperatorUtils.cpp new file mode 100644 index 0000000000000..2c5704453ccf6 --- /dev/null +++ b/internal/core/src/exec/operator/OperatorUtils.cpp @@ -0,0 +1,21 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "OperatorUtils.h" + +namespace milvus { +namespace exec {} +} // namespace milvus diff --git a/internal/core/src/exec/operator/OperatorUtils.h b/internal/core/src/exec/operator/OperatorUtils.h new file mode 100644 index 0000000000000..31cd51c4e6885 --- /dev/null +++ b/internal/core/src/exec/operator/OperatorUtils.h @@ -0,0 +1,23 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "exec/VectorHasher.h" +#include "common/Types.h" + +namespace milvus { +namespace exec {} +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/operator/ProjectNode.cpp b/internal/core/src/exec/operator/ProjectNode.cpp new file mode 100644 index 0000000000000..f2a54a1b784a3 --- /dev/null +++ b/internal/core/src/exec/operator/ProjectNode.cpp @@ -0,0 +1,77 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ProjectNode.h" +#include "exec/expression/Utils.h" +#include "segcore/Utils.h" + +namespace milvus { +namespace exec { +PhyProjectNode::PhyProjectNode( + int32_t operator_id, + milvus::exec::DriverContext* ctx, + const std::shared_ptr& projectNode) + : Operator(ctx, + projectNode->output_type(), + operator_id, + projectNode->id(), + "Project"), + fields_to_project_(projectNode->FieldsToProject()) { + auto exec_context = operator_context_->get_exec_context(); + segment_ = exec_context->get_query_context()->get_segment(); +} + +void +PhyProjectNode::AddInput(milvus::RowVectorPtr& input) { + input_ = std::move(input); +} + +RowVectorPtr +PhyProjectNode::GetOutput() { + if (is_finished_ || input_ == nullptr) { + return nullptr; + } + auto col_input = GetColumnVector(input_); + // raw data view + TargetBitmapView raw_data_view(col_input->GetRawData(), col_input->size()); + auto result_pair = segment_->find_first(-1, raw_data_view); + auto selected_offsets = result_pair.first; + auto selected_count = selected_offsets.size(); + auto row_type = OutputType(); + std::vector column_vectors; + for (int i = 0; i < fields_to_project_.size(); i++) { + auto column_type = row_type->column_type(i); + auto field_id = fields_to_project_.at(i); + + TargetBitmap valid_map(selected_count); + TargetBitmapView valid_view(valid_map.data(), selected_count); + auto field_data = bulk_script_field_data(field_id, + column_type, + selected_offsets.data(), + selected_count, + segment_, + valid_view); + auto column_vector = std::make_shared( + std::move(field_data), std::move(valid_view)); + column_vectors.emplace_back(column_vector); + } + is_finished_ = true; + auto row_vector = std::make_shared(std::move(column_vectors)); + return row_vector; +} + +}; // namespace exec +}; // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/operator/ProjectNode.h b/internal/core/src/exec/operator/ProjectNode.h new file mode 100644 index 0000000000000..971a504fa6a45 --- /dev/null +++ b/internal/core/src/exec/operator/ProjectNode.h @@ -0,0 +1,73 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "Operator.h" +#include "plan/PlanNode.h" + +namespace milvus { +namespace exec { +class PhyProjectNode : public Operator { + public: + PhyProjectNode(int32_t operator_id, + DriverContext* ctx, + const std::shared_ptr& projectNode); + + bool + IsFilter() const override { + return false; + } + + bool + NeedInput() const override { + return true; + } + + void + AddInput(RowVectorPtr& input) override; + + RowVectorPtr + GetOutput() override; + + bool + IsFinished() override { + return is_finished_; + } + + BlockingReason + IsBlocked(ContinueFuture* /* unused */) override { + return BlockingReason::kNotBlocked; + } + + std::string + ToString() const override { + return "Project Operator"; + } + + private: + FieldDataPtr + projectFieldData(FieldId fieldId, + milvus::DataType dataType, + const int64_t* seg_offsets, + int64_t count) const; + + private: + const segcore::SegmentInternalInterface* segment_; + bool is_finished_{false}; + const std::vector fields_to_project_; +}; +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/operator/GroupByNode.cpp b/internal/core/src/exec/operator/SearchGroupByNode.cpp similarity index 91% rename from internal/core/src/exec/operator/GroupByNode.cpp rename to internal/core/src/exec/operator/SearchGroupByNode.cpp index bd13eab202096..ff4657df08287 100644 --- a/internal/core/src/exec/operator/GroupByNode.cpp +++ b/internal/core/src/exec/operator/SearchGroupByNode.cpp @@ -14,18 +14,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "GroupByNode.h" +#include "SearchGroupByNode.h" -#include "exec/operator/groupby/SearchGroupByOperator.h" +#include "exec/operator/search-groupby/SearchGroupByOperator.h" #include "monitor/prometheus_client.h" namespace milvus { namespace exec { -PhyGroupByNode::PhyGroupByNode( +PhySearchGroupByNode::PhySearchGroupByNode( int32_t operator_id, DriverContext* driverctx, - const std::shared_ptr& node) + const std::shared_ptr& node) : Operator(driverctx, node->output_type(), operator_id, node->id()) { ExecContext* exec_context = operator_context_->get_exec_context(); query_context_ = exec_context->get_query_context(); @@ -34,12 +34,12 @@ PhyGroupByNode::PhyGroupByNode( } void -PhyGroupByNode::AddInput(RowVectorPtr& input) { +PhySearchGroupByNode::AddInput(RowVectorPtr& input) { input_ = std::move(input); } RowVectorPtr -PhyGroupByNode::GetOutput() { +PhySearchGroupByNode::GetOutput() { if (is_finished_ || !no_more_input_) { return nullptr; } @@ -86,7 +86,7 @@ PhyGroupByNode::GetOutput() { } bool -PhyGroupByNode::IsFinished() { +PhySearchGroupByNode::IsFinished() { return is_finished_; } diff --git a/internal/core/src/exec/operator/GroupByNode.h b/internal/core/src/exec/operator/SearchGroupByNode.h similarity index 86% rename from internal/core/src/exec/operator/GroupByNode.h rename to internal/core/src/exec/operator/SearchGroupByNode.h index a5d05898f999a..ee0b64c122ec2 100644 --- a/internal/core/src/exec/operator/GroupByNode.h +++ b/internal/core/src/exec/operator/SearchGroupByNode.h @@ -27,14 +27,15 @@ namespace milvus { namespace exec { -class PhyGroupByNode : public Operator { +class PhySearchGroupByNode : public Operator { public: - PhyGroupByNode(int32_t operator_id, - DriverContext* ctx, - const std::shared_ptr& node); + PhySearchGroupByNode( + int32_t operator_id, + DriverContext* ctx, + const std::shared_ptr& node); bool - IsFilter() override { + IsFilter() const override { return false; } @@ -63,7 +64,7 @@ class PhyGroupByNode : public Operator { virtual std::string ToString() const override { - return "PhyGroupByNode"; + return "PhySearchGroupByNode"; } private: diff --git a/internal/core/src/exec/operator/VectorSearchNode.h b/internal/core/src/exec/operator/VectorSearchNode.h index e6ec630eed9c9..7dd7b32dfb638 100644 --- a/internal/core/src/exec/operator/VectorSearchNode.h +++ b/internal/core/src/exec/operator/VectorSearchNode.h @@ -35,7 +35,7 @@ class PhyVectorSearchNode : public Operator { const std::shared_ptr& search_node); bool - IsFilter() override { + IsFilter() const override { return false; } diff --git a/internal/core/src/exec/operator/init_c.cpp b/internal/core/src/exec/operator/init_c.cpp new file mode 100644 index 0000000000000..417ab8d44b317 --- /dev/null +++ b/internal/core/src/exec/operator/init_c.cpp @@ -0,0 +1,24 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "exec/operator/init_c.h" +#include "exec/operator/query-agg/RegisterAggregateFunctions.h" +#include "log/Log.h" + +void +RegisterAggregationFunctions() { + milvus::exec::registerAllAggregateFunctions(); +} \ No newline at end of file diff --git a/internal/core/src/exec/operator/init_c.h b/internal/core/src/exec/operator/init_c.h new file mode 100644 index 0000000000000..924e667ef721e --- /dev/null +++ b/internal/core/src/exec/operator/init_c.h @@ -0,0 +1,27 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +void +RegisterAggregationFunctions(); + +#ifdef __cplusplus +}; +#endif diff --git a/internal/core/src/exec/operator/query-agg/Aggregate.cpp b/internal/core/src/exec/operator/query-agg/Aggregate.cpp new file mode 100644 index 0000000000000..d5f27a4bedfdc --- /dev/null +++ b/internal/core/src/exec/operator/query-agg/Aggregate.cpp @@ -0,0 +1,88 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License +#include "Aggregate.h" +#include "exec/expression/Utils.h" + +namespace milvus { +namespace exec { + +void +Aggregate::setOffsetsInternal(int32_t offset, + int32_t nullByte, + uint8_t nullMask, + int32_t initializedByte, + uint8_t initializedMask, + int32_t rowSizeOffset) { + offset_ = offset; + nullByte_ = nullByte; + nullMask_ = nullMask; + initializedByte_ = initializedByte; + initializedMask_ = initializedMask; + rowSizeOffset_ = rowSizeOffset; +} + +const AggregateFunctionEntry* +getAggregateFunctionEntry(const std::string& name) { + auto sanitizedName = milvus::exec::sanitizeName(name); + + return aggregateFunctions().withRLock( + [&](const auto& functionsMap) -> const AggregateFunctionEntry* { + auto it = functionsMap.find(sanitizedName); + if (it != functionsMap.end()) { + return &it->second; + } + return nullptr; + }); +} + +std::unique_ptr +Aggregate::create(const std::string& name, + plan::AggregationNode::Step step, + const std::vector& argTypes, + const QueryConfig& query_config) { + if (auto func = getAggregateFunctionEntry(name)) { + return func->factory(step, argTypes, query_config); + } + PanicInfo(UnexpectedError, "Aggregate function not registered: {}", name); +} + +bool +isPartialOutput(milvus::plan::AggregationNode::Step step) { + return step == milvus::plan::AggregationNode::Step::kPartial || + step == milvus::plan::AggregationNode::Step::kIntermediate; +} + +void +registerAggregateFunction( + const std::string& name, + const std::vector>& + signatures, + const AggregateFunctionFactory& factory) { + auto realName = lowerString(name); + LOG_INFO("hc=== try to register agg function, name:{}, realName:{}", + name, + realName); + aggregateFunctions().withWLock([&](auto& aggFunctionMap) { + aggFunctionMap[realName] = {signatures, factory}; + LOG_INFO("hc=== registered agg function, name:{}, realName:{}", + name, + realName); + }); +} + +AggregateFunctionMap& +aggregateFunctions() { + static AggregateFunctionMap aggFunctionMap; + return aggFunctionMap; +} + +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/operator/query-agg/Aggregate.h b/internal/core/src/exec/operator/query-agg/Aggregate.h new file mode 100644 index 0000000000000..1e4c17335ec04 --- /dev/null +++ b/internal/core/src/exec/operator/query-agg/Aggregate.h @@ -0,0 +1,210 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License +#pragma once + +#include "common/Types.h" +#include "plan/PlanNode.h" +#include "expr/FunctionSignature.h" +#include "plan/PlanNode.h" +#include "exec/QueryContext.h" +#include + +namespace milvus { +namespace exec { +class Aggregate { + protected: + explicit Aggregate(DataType result_type) : result_type_(result_type) { + } + + private: + const DataType result_type_; + + // Byte position of null flag in group row. + int32_t nullByte_; + uint8_t nullMask_; + // Byte position of the initialized flag in group row. + int32_t initializedByte_; + uint8_t initializedMask_; + // Offset of fixed length accumulator state in group row. + int32_t offset_; + // Offset of uint32_t row byte size of row. 0 if there are no + // variable width fields or accumulators on the row. The size is + // capped at 4G and will stay at 4G and not wrap around if growing + // past this. This serves to track the batch size when extracting + // rows. A size in excess of 4G would finish the batch in any case, + // so larger values need not be represented. + int32_t rowSizeOffset_ = 0; + + public: + DataType + resultType() const { + return result_type_; + } + + static std::unique_ptr + create(const std::string& name, + plan::AggregationNode::Step step, + const std::vector& argTypes, + const QueryConfig& query_config); + + void + setOffsets(int32_t offset, + int32_t nullByte, + uint8_t nullMask, + int32_t initializedByte, + int8_t initializedMask, + int32_t rowSizeOffset) { + setOffsetsInternal(offset, + nullByte, + nullMask, + initializedByte, + initializedMask, + rowSizeOffset); + } + + virtual void + initializeNewGroups(char** groups, + folly::Range indices) { + initializeNewGroupsInternal(groups, indices); + for (auto index : indices) { + groups[index][initializedByte_] |= initializedMask_; + } + } + + virtual void + addSingleGroupRawInput(char* group, + const TargetBitmapView& activeRows, + const std::vector& input) = 0; + + virtual void + addRawInput(char** groups, + const TargetBitmapView& activeRows, + const std::vector& input) = 0; + + virtual void + extractValues(char** groups, int32_t numGroups, VectorPtr* result) = 0; + + template + T* + value(char* group) const { + AssertInfo(reinterpret_cast(group + offset_) % + accumulatorAlignmentSize() == + 0, + "aggregation value in the groups is not aligned"); + return reinterpret_cast(group + offset_); + } + + bool + isNull(char* group) const { + return numNulls_ && (group[nullByte_] & nullMask_); + } + + // Returns true if the accumulator never takes more than + // accumulatorFixedWidthSize() bytes. If this is false, the + // accumulator needs to track its changing variable length footprint + // using RowSizeTracker (Aggregate::trackRowSize), see ArrayAggAggregate for + // sample usage. A group row with at least one variable length key or + // aggregate will have a 32-bit slot at offset RowContainer::rowSize_ for + // keeping track of per-row size. The size is relevant for keeping caps on + // result set and spilling batch sizes with skewed data. + virtual bool + isFixedSize() const { + return true; + } + + // Returns the fixed number of bytes the accumulator takes on a group + // row. Variable width accumulators will reference the variable + // width part of the state from the fixed part. + virtual int32_t + accumulatorFixedWidthSize() const = 0; + + /// Returns the alignment size of the accumulator. Some types such as + /// int128_t require aligned access. This value must be a power of 2. + virtual int32_t + accumulatorAlignmentSize() const { + return 1; + } + + protected: + virtual void + setOffsetsInternal(int32_t offset, + int32_t nullByte, + uint8_t nullMask, + int32_t initializedByte, + uint8_t initializedMask, + int32_t rowSizeOffset); + + virtual void + initializeNewGroupsInternal(char** groups, + folly::Range indices) = 0; + // Number of null accumulators in the current state of the aggregation + // operator for this aggregate. If 0, clearing the null as part of update + // is not needed. + uint64_t numNulls_ = 0; + + inline bool + clearNull(char* group) { + if (numNulls_) { + uint8_t mask = group[nullByte_]; + if (mask & nullMask_) { + group[nullByte_] = mask & ~nullMask_; + numNulls_--; + return true; + } + } + return false; + } + + void + setAllNulls(char** groups, folly::Range indices) { + for (auto i : indices) { + groups[i][nullByte_] = nullMask_; + } + numNulls_ += indices.size(); + } +}; + +using AggregateFunctionFactory = std::function( + plan::AggregationNode::Step step, + const std::vector& argTypes, + const QueryConfig& config)>; + +struct AggregateFunctionEntry { + std::vector signatures; + AggregateFunctionFactory factory; +}; + +const AggregateFunctionEntry* +getAggregateFunctionEntry(const std::string& name); + +using AggregateFunctionMap = folly::Synchronized< + std::unordered_map>; + +AggregateFunctionMap& +aggregateFunctions(); + +/// Register an aggregate function with the specified name and signatures. If +/// registerCompanionFunctions is true, also register companion aggregate and +/// scalar functions with it. When functions with `name` already exist, if +/// overwrite is true, existing registration will be replaced. Otherwise, return +/// false without overwriting the registry. +void +registerAggregateFunction( + const std::string& name, + const std::vector>& + signatures, + const AggregateFunctionFactory& factory); + +bool +isPartialOutput(milvus::plan::AggregationNode::Step step); + +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/operator/query-agg/AggregateInfo.cpp b/internal/core/src/exec/operator/query-agg/AggregateInfo.cpp new file mode 100644 index 0000000000000..9280cfd5d2219 --- /dev/null +++ b/internal/core/src/exec/operator/query-agg/AggregateInfo.cpp @@ -0,0 +1,60 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +// +// Created by hanchun on 24-10-22. +// +#include "AggregateInfo.h" +#include "common/Types.h" + +namespace milvus { +namespace exec { + +std::vector +toAggregateInfo(const plan::AggregationNode& aggregationNode, + const milvus::exec::OperatorContext& operatorCtx, + uint32_t numKeys) { + const auto numAggregates = aggregationNode.aggregates().size(); + std::vector aggregates; + aggregates.reserve(numAggregates); + const auto& inputType = aggregationNode.sources()[0]->output_type(); + const auto& outputType = aggregationNode.output_type(); + const auto step = aggregationNode.step(); + + for (auto i = 0; i < numAggregates; i++) { + const auto& aggregate = aggregationNode.aggregates()[i]; + AggregateInfo info; + auto& inputColumnIdxes = info.input_column_idxes_; + for (const auto& inputExpr : aggregate.call_->inputs()) { + if (auto fieldExpr = dynamic_cast( + inputExpr.get())) { + inputColumnIdxes.emplace_back( + inputType->GetChildIndex(fieldExpr->name())); + } else if (inputExpr != nullptr) { + PanicInfo(ExprInvalid, + "Only support aggregation towards column for now"); + } + } + auto index = numKeys + i; + info.function_ = Aggregate::create( + aggregate.call_->fun_name(), + isPartialOutput(step) ? plan::AggregationNode::Step::kPartial + : plan::AggregationNode::Step::kSingle, + aggregate.rawInputTypes_, + *(operatorCtx.get_exec_context()->get_query_config())); + info.output_ = index; + aggregates.emplace_back(std::move(info)); + } + return aggregates; +} + +} // namespace exec +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/operator/query-agg/AggregateInfo.h b/internal/core/src/exec/operator/query-agg/AggregateInfo.h new file mode 100644 index 0000000000000..af4c004e8695a --- /dev/null +++ b/internal/core/src/exec/operator/query-agg/AggregateInfo.h @@ -0,0 +1,41 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include + +#include "common/Types.h" +#include "Aggregate.h" +#include "plan/PlanNode.h" +#include "exec/operator/Operator.h" + +namespace milvus { +namespace exec { + +/// Information needed to evaluate an aggregate function. +struct AggregateInfo { + /// Instance of the Aggregate class. + std::unique_ptr function_; + + /// Indices of the input columns in the input RowVector. + std::vector input_column_idxes_; + + /// Index of the result column in the output RowVector. + column_index_t output_; +}; + +std::vector +toAggregateInfo(const plan::AggregationNode& aggregationNode, + const milvus::exec::OperatorContext& operatorCtx, + uint32_t numKeys); + +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/operator/query-agg/AggregateUtil.h b/internal/core/src/exec/operator/query-agg/AggregateUtil.h new file mode 100644 index 0000000000000..a70d38e29e249 --- /dev/null +++ b/internal/core/src/exec/operator/query-agg/AggregateUtil.h @@ -0,0 +1,41 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "common/Types.h" + +namespace milvus { +namespace exec { +// The result of aggregation function registration. +struct AggregateRegistrationResult { + bool mainFunction{false}; + bool partialFunction{false}; + bool mergeFunction{false}; + bool extractFunction{false}; + bool mergeExtractFunction{false}; + + bool + operator==(const AggregateRegistrationResult& other) const { + return mainFunction == other.mainFunction && + partialFunction == other.partialFunction && + mergeFunction == other.mergeFunction && + extractFunction == other.extractFunction && + mergeExtractFunction == other.mergeExtractFunction; + } +}; + +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/operator/query-agg/CountAggregateBase.cpp b/internal/core/src/exec/operator/query-agg/CountAggregateBase.cpp new file mode 100644 index 0000000000000..703cd4c2a96b8 --- /dev/null +++ b/internal/core/src/exec/operator/query-agg/CountAggregateBase.cpp @@ -0,0 +1,42 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License +#pragma once + +#include "CountAggregateBase.h" + +namespace milvus { +namespace exec { + +void +registerCount(const std::string name) { + std::vector> signatures{ + expr::AggregateFunctionSignatureBuilder() + .argumentType(DataType::INT64) + .intermediateType(DataType::INT64) + .returnType(DataType::INT64) + .build()}; + + exec::registerAggregateFunction( + name, + signatures, + [name](plan::AggregationNode::Step /*step*/, + const std::vector& /*argumentTypes*/, + const QueryConfig& /*config*/) -> std::unique_ptr { + return std::make_unique(); + }); +} + +void +registerCountAggregate(const std::string& prefix) { + registerCount(prefix + KCount); +} +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/operator/query-agg/CountAggregateBase.h b/internal/core/src/exec/operator/query-agg/CountAggregateBase.h new file mode 100644 index 0000000000000..42eb3a494956e --- /dev/null +++ b/internal/core/src/exec/operator/query-agg/CountAggregateBase.h @@ -0,0 +1,121 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License +#pragma once + +#include "SimpleNumericAggregate.h" + +namespace milvus { +namespace exec { +class CountAggregate : public SimpleNumericAggregate { + using BaseAggregate = SimpleNumericAggregate; + + public: + explicit CountAggregate() : BaseAggregate(DataType::INT64) { + } + + int32_t + accumulatorFixedWidthSize() const override { + return sizeof(int64_t); + } + + void + extractValues(char** groups, + int32_t numGroups, + VectorPtr* result) override { + LOG_INFO("hc===extractValues for count"); + BaseAggregate::doExtractValues( + groups, numGroups, result, [&](char* group) { + LOG_INFO("hc===extractValues for count, value:{}", + *value(group)); + return *value(group); + }); + } + + void + addRawInput(char** groups, + const TargetBitmapView& activeRows, + const std::vector& input) override { + LOG_INFO("hc===addRawInput for count, active_rows:{}, active_size:{}", + activeRows.count(), + activeRows.size()); + ColumnVectorPtr input_column = nullptr; + AssertInfo(input.empty() || input.size() == 1, + fmt::format("input column count for count aggregation " + "must be one or zero for now, but got:{}", + input.size())); + if (input.size() == 1) { + input_column = std::dynamic_pointer_cast(input[0]); + } + auto start = -1; + do { + auto next_active_idx = activeRows.find_next(start); + if (!next_active_idx.has_value()) { + break; + } + auto active_idx = next_active_idx.value(); + if ((input_column && input_column->ValidAt(active_idx)) || + !input_column) { + LOG_INFO("hc===addToGroup, active_idx:{}", active_idx); + addToGroup(groups[active_idx], 1); + } else { + LOG_INFO("hc===addRawInput failed to add count"); + } + start = active_idx; + } while (true); + } + + void + addSingleGroupRawInput(char* group, + const TargetBitmapView& activeRows, + const std::vector& input) override { + if (input.empty()) { + addToGroup(group, activeRows.count()); + } else { + AssertInfo(input.size() == 1, + fmt::format("input column count for count aggregation " + "must be exactly one for now, but got:{}", + input.size())); + const auto& column = + std::dynamic_pointer_cast(input[0]); + auto start = -1; + do { + auto next_active_idx = activeRows.find_next(start); + if (!next_active_idx.has_value()) { + break; + } + auto active_idx = next_active_idx.value(); + if (column->ValidAt(active_idx)) { + addToGroup(group, 1); + } + start = active_idx; + } while (true); + } + } + + void + initializeNewGroupsInternal( + char** groups, folly::Range indices) override { + for (auto i : indices) { + LOG_INFO("hc===initializeNewGroupsInternal for count, i:{}", i); + // initialized result of count is always zero + *value(groups[i]) = static_cast(0); + } + } + + private: + inline void + addToGroup(char* group, int64_t count) { + *value(group) += count; + } +}; + +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/operator/query-agg/GroupingSet.cpp b/internal/core/src/exec/operator/query-agg/GroupingSet.cpp new file mode 100644 index 0000000000000..90a904ae3c759 --- /dev/null +++ b/internal/core/src/exec/operator/query-agg/GroupingSet.cpp @@ -0,0 +1,263 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "GroupingSet.h" +#include "common/Utils.h" +#include "SumAggregateBase.h" + +namespace milvus{ +namespace exec{ +GroupingSet::~GroupingSet(){ + if(isGlobal_) { + AssertInfo(lookup_->hits_.size()==1, "GlobalAggregation should have exactly one output line"); + char* global_line = lookup_->hits_[0]; + delete[] global_line; + lookup_->hits_[0] = nullptr; + } +} + +void +GroupingSet::addInput(const RowVectorPtr& input) { + if (isGlobal_) { + addGlobalAggregationInput(input); + return; + } + auto numRows = input->size(); + numInputRows_ += numRows; + active_rows_.resize(numRows); + active_rows_.set(); + addInputForActiveRows(input); +} + +void +GroupingSet::initializeGlobalAggregation() { + if (globalAggregationInitialized_) { + return; + } + lookup_ = std::make_unique(hashers_); + lookup_->reset(1); + + // Row layout is: + // - alternating null flag, intialized flag - one bit per flag, one pair per + // aggregation, + // - uint32_t row size, + // - fixed-width accumulators - one per aggregate + // + // Here we always make space for a row size since we only have one row and no + // RowContainer. The whole row is allocated to guarantee that alignment + // requirements of all aggregate functions are satisfied. + + // Allocate space for the null and initialized flags. + size_t numAggregates = aggregates_.size(); + int32_t rowSizeOffset = milvus::bits::nBytes( + numAggregates * RowContainer::kNumAccumulatorFlags); + int32_t offset = rowSizeOffset + sizeof(int32_t); + int32_t accumulatorFlagsOffset = 0; + int32_t alignment = 1; + + for (auto& aggregate : aggregates_) { + auto& function = aggregate.function_; + Accumulator accumulator(function.get()); + // Accumulator offset must be aligned by their alignment size. + offset = milvus::bits::roundUp(offset, accumulator.alignment()); + function->setOffsets( + offset, + RowContainer::nullByte(accumulatorFlagsOffset), + RowContainer::nullMask(accumulatorFlagsOffset), + RowContainer::initializedByte(accumulatorFlagsOffset), + RowContainer::initializedMask(accumulatorFlagsOffset), + rowSizeOffset); + offset += accumulator.fixedWidthSize(); + accumulatorFlagsOffset += RowContainer::kNumAccumulatorFlags; + alignment = + RowContainer::combineAlignments(accumulator.alignment(), alignment); + } + AssertInfo(__builtin_popcount(alignment) == 1, + "alignment of aggregations must be power of two"); + offset = milvus::Align(offset, alignment); + lookup_->hits_[0] = new char[offset]; //TODO memory allocation control + const auto singleGroup = std::vector{0}; + for (auto& aggregate : aggregates_) { + aggregate.function_->initializeNewGroups(lookup_->hits_.data(), + singleGroup); + } + globalAggregationInitialized_ = true; +} + +void +GroupingSet::addGlobalAggregationInput(const milvus::RowVectorPtr& input) { + initializeGlobalAggregation(); + auto numRows = input->size(); + active_rows_.resize(numRows); + active_rows_.set(); + auto* group = lookup_->hits_[0]; + for (auto i = 0; i < aggregates_.size(); i++) { + auto& function = aggregates_[i].function_; + populateTempVectors(i, input); + function->addSingleGroupRawInput(group, active_rows_, tempVectors_); + } + tempVectors_.clear(); +} + +bool +GroupingSet::getGlobalAggregationOutput(milvus::RowVectorPtr& result) { + initializeGlobalAggregation(); + AssertInfo(lookup_->hits_.size()==1, "GlobalAggregation should have exactly one output line"); + auto groups = lookup_->hits_.data(); + for (auto i = 0; i < aggregates_.size(); i++) { + auto& function = aggregates_[i].function_; + auto resultVector = result->child(aggregates_[i].output_); + function->extractValues(groups, 1, &resultVector); + } + return true; +} + +bool +GroupingSet::getOutput(milvus::RowVectorPtr& result) { + if (isGlobal_) { + return getGlobalAggregationOutput(result); + } + if (hash_table_ == nullptr) { + return false; + } + const auto& all_rows = hash_table_->rows()->allRows(); + DeferLambda([&]() { hash_table_->clear(); }); + if (!all_rows.empty()) { + extractGroups(folly::Range(const_cast(all_rows.data()), + all_rows.size()), + result); + return true; + } + return false; +} + +std::vector +GroupingSet::accumulators() { + std::vector accumulators; + accumulators.reserve(aggregates_.size()); + for (auto& aggregate : aggregates_) { + accumulators.emplace_back(Accumulator{aggregate.function_.get()}); + } + return accumulators; +} + +void +GroupingSet::ensureInputFits(const RowVectorPtr& input) { + //TODO memory check +} + +void +GroupingSet::extractGroups(folly::Range groups, + const milvus::RowVectorPtr& result) { + if (groups.empty()) { + return; + } + result->resize(groups.size()); + RowContainer* rows = hash_table_->rows(); + auto totalKeys = rows->KeyTypes().size(); + for (auto i = 0; i < totalKeys; i++) { + auto keyVector = result->child(i); + rows->extractColumn(groups.data(), groups.size(), i, keyVector); + } + for (auto i = 0; i < aggregates_.size(); i++) { + auto& function = aggregates_[i].function_; + auto aggregateVector = result->child(totalKeys + i); + function->extractValues(groups.data(), groups.size(), &aggregateVector); + } +} + +void +GroupingSet::addInputForActiveRows(const RowVectorPtr& input) { + AssertInfo( + !isGlobal_, + "Global aggregations should not reach add input for active rows"); + if (!hash_table_) { + createHashTable(); + } + ensureInputFits(input); + hash_table_->prepareForGroupProbe( + *lookup_, input, active_rows_, ignoreNullKeys_); + if (lookup_->rows_.empty()) { + // No rows to probe. Can happen when ignoreNullKeys_ is true and all rows + // have null keys. + return; + } + hash_table_->groupProbe(*lookup_); + auto* groups = lookup_->hits_.data(); + const auto& newGroups = lookup_->newGroups_; + for (auto i = 0; i < aggregates_.size(); i++) { + auto& function = aggregates_[i].function_; + if (!newGroups.empty()) { + function->initializeNewGroups(groups, newGroups); + } + if (!active_rows_.any()) { + continue; + } + populateTempVectors(i, input); + function->addRawInput(groups, active_rows_, tempVectors_); + } + tempVectors_.clear(); +} + +void +GroupingSet::populateTempVectors(int32_t aggregateIndex, + const milvus::RowVectorPtr& input) { + const auto& channel_idxes = aggregates_[aggregateIndex].input_column_idxes_; + tempVectors_.resize(channel_idxes.size()); + for (auto i = 0; i < channel_idxes.size(); i++) { + tempVectors_[i] = input->child(channel_idxes[i]); + } +} + +int32_t +GroupingSet::outputRowCount() const { + return lookup_->newGroups_.size(); +} + +void +initializeAggregates(const std::vector& aggregates, + RowContainer& rows) { + const auto numKeys = rows.KeyTypes().size(); + int i = 0; + for (auto& aggregate : aggregates) { + auto& function = aggregate.function_; + const auto& rowColumn = rows.columnAt(numKeys + i); + function->setOffsets(rowColumn.offset(), + rowColumn.nullByte(), + rowColumn.nullMask(), + rowColumn.initializedByte(), + rowColumn.initializedMask(), + rows.rowSizeOffset()); + i++; + } +} + +void +GroupingSet::createHashTable() { + if (ignoreNullKeys_) { + hash_table_ = std::make_unique>(std::move(hashers_), + accumulators()); + } else { + hash_table_ = std::make_unique>(std::move(hashers_), + accumulators()); + } + auto& rows = *(hash_table_->rows()); + initializeAggregates(aggregates_, rows); + lookup_ = std::make_unique(hash_table_->hashers()); +} + +} // namespace exec +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/operator/query-agg/GroupingSet.h b/internal/core/src/exec/operator/query-agg/GroupingSet.h new file mode 100644 index 0000000000000..4da06b76adbe4 --- /dev/null +++ b/internal/core/src/exec/operator/query-agg/GroupingSet.h @@ -0,0 +1,109 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "common/Types.h" +#include "exec/VectorHasher.h" +#include "AggregateInfo.h" +#include "exec/HashTable.h" +#include "plan/PlanNode.h" +#include "RowContainer.h" + +namespace milvus { +namespace exec { + +class GroupingSet { + public: + GroupingSet(const RowTypePtr& input_type, + std::vector>&& hashers, + std::vector&& aggregates, + bool ignoreNullKeys) + : hashers_(std::move(hashers)), + aggregates_(std::move(aggregates)), + ignoreNullKeys_(ignoreNullKeys) { + isGlobal_ = hashers_.empty(); + } + + ~GroupingSet(); + + void + addInput(const RowVectorPtr& input); + + void + initializeGlobalAggregation(); + + void + addGlobalAggregationInput(const RowVectorPtr& input); + + void + addInputForActiveRows(const RowVectorPtr& input); + + void + createHashTable(); + + std::vector + accumulators(); + + // Checks if input will fit in the existing memory and increases reservation + // if not. If reservation cannot be increased, spills enough to make 'input' + // fit. + void + ensureInputFits(const RowVectorPtr& input); + + bool + getOutput(RowVectorPtr& result); + + bool + hasOutput() { + return noMoreInput_; + } + + void + extractGroups(folly::Range groups, const RowVectorPtr& result); + + void + populateTempVectors(int32_t aggregateIndex, const RowVectorPtr& input); + + bool + getGlobalAggregationOutput(RowVectorPtr& result); + + int32_t + outputRowCount() const; + + private: + bool isGlobal_; + const bool ignoreNullKeys_; + + std::vector> hashers_; + std::vector aggregates_; + + // Place for the arguments of the aggregate being updated. + std::vector tempVectors_; + std::unique_ptr hash_table_; + std::unique_ptr lookup_; + TargetBitmap active_rows_; + + uint64_t numInputRows_ = 0; + + bool noMoreInput_{false}; + + // Boolean indicating whether accumulators for a global aggregation (i.e. + // aggregation with no grouping keys) have been initialized. + bool globalAggregationInitialized_{false}; +}; + +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/operator/query-agg/RegisterAggregateFunctions.cpp b/internal/core/src/exec/operator/query-agg/RegisterAggregateFunctions.cpp new file mode 100644 index 0000000000000..e7f7779f79250 --- /dev/null +++ b/internal/core/src/exec/operator/query-agg/RegisterAggregateFunctions.cpp @@ -0,0 +1,29 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "RegisterAggregateFunctions.h" +#include "log/Log.h" + +namespace milvus { +namespace exec { +void +registerAllAggregateFunctions(const std::string& prefix) { + registerSumAggregate(prefix); + registerCountAggregate(prefix); +} + +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/operator/query-agg/RegisterAggregateFunctions.h b/internal/core/src/exec/operator/query-agg/RegisterAggregateFunctions.h new file mode 100644 index 0000000000000..0d1f323f89df4 --- /dev/null +++ b/internal/core/src/exec/operator/query-agg/RegisterAggregateFunctions.h @@ -0,0 +1,32 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +namespace milvus { +namespace exec { +void +registerAllAggregateFunctions(const std::string& prefix = ""); + +extern void +registerSumAggregate(const std::string& prefix); + +extern void +registerCountAggregate(const std::string& prefix); +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/operator/query-agg/RowContainer.cpp b/internal/core/src/exec/operator/query-agg/RowContainer.cpp new file mode 100644 index 0000000000000..508f5ebf961b2 --- /dev/null +++ b/internal/core/src/exec/operator/query-agg/RowContainer.cpp @@ -0,0 +1,180 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "RowContainer.h" +#include "common/BitUtil.h" +#include "common/Vector.h" + +namespace milvus { +namespace exec { + +RowContainer::RowContainer(const std::vector& keyTypes, + const std::vector& accumulators, + bool ignoreNullKeys) + : keyTypes_(keyTypes), + accumulators_(accumulators), + ignoreNullKeys_(ignoreNullKeys) { + int32_t offset = 0; + int32_t nullOffset = 0; + bool isVariableWidth = false; + int idx = 0; + for (auto& type : keyTypes_) { + bool varLength = !IsFixedSizeType(type); + isVariableWidth |= varLength; + if (varLength) { + variable_offsets.emplace_back(offset); + variable_idxes.emplace_back(idx); + } + offsets_.push_back(offset); + if (type == DataType::VARCHAR || type == DataType::STRING) { + offset += 8; //use a pointer to store string + } else { + offset += GetDataTypeSize(type, 1); + } + nullOffsets_.push_back(nullOffset); + if (!ignoreNullKeys_) { + ++nullOffset; + } + idx++; + } + // Make offset at least sizeof pointer so that there is space for a + // free list next pointer below the bit at 'freeFlagOffset_'. + offset = std::max(offset, sizeof(void*)); + const int32_t firstAggregateOffset = offset; + if (!accumulators.empty()) { + // This moves nullOffset to the start of the next byte. + // This is to guarantee the null and initialized bits for an aggregate + // always appear in the same byte. + nullOffset = (nullOffset + 7) & -8; + } + for (const auto& accumulator : accumulators) { + // Initialized bit. Set when the accumulator is initialized. + nullOffsets_.push_back(nullOffset); + ++nullOffset; + // Null bit. + nullOffsets_.push_back(nullOffset); + ++nullOffset; + isVariableWidth |= !accumulator.isFixedSize(); + alignment_ = combineAlignments(accumulator.alignment(), alignment_); + } + + // Free flag. + nullOffsets_.push_back(nullOffset); + freeFlagOffset_ = nullOffset + firstAggregateOffset * 8; + ++nullOffset; + // Add 1 to the last null offset to get the number of bits. + flagBytes_ = milvus::bits::nBytes(nullOffsets_.back() + 1); + for (auto i = 0; i < nullOffsets_.size(); i++) { + nullOffsets_[i] += firstAggregateOffset * 8; + } + offset += flagBytes_; + + for (const auto& accumulator : accumulators) { + offset = milvus::bits::roundUp(offset, accumulator.alignment()); + offsets_.push_back(offset); + offset += accumulator.fixedWidthSize(); + } + if (isVariableWidth) { + rowSizeOffset_ = offset; + offset += sizeof(uint32_t); + } + fixedRowSize_ = milvus::bits::roundUp(offset, alignment_); + + // A distinct hash table has no aggregates and if the hash table has + // no nulls, it may be that there are no null flags. + if (!nullOffsets_.empty()) { + // All flags like free and null flags for keys and non-keys + // start as 0. This is also used to mark aggregates as uninitialized on row + // creation. + initialNulls_.resize(flagBytes_, 0x0); + } + size_t nullOffsetsPos = 0; + uint16_t column_sum = keyTypes_.size() + accumulators.size(); + for (auto i = 0; i < offsets_.size(); i++) { + rowColumns_.emplace_back(offsets_[i], + (!ignoreNullKeys_ || i >= keyTypes_.size()) + ? nullOffsets_[nullOffsetsPos] + : RowColumn::kNotNullOffset); + // offsets_ contains the offsets for keys, then accumulators + // This captures the case where i is the index of an accumulator. + if (!accumulators.empty() && i >= keyTypes_.size() && i < column_sum) { + nullOffsetsPos += kNumAccumulatorFlags; + } else { + ++nullOffsetsPos; + } + } +} + +char* +RowContainer::initializeRow(char* row) { + std::memset(row, 0, fixedRowSize_); + return row; +} + +char* +RowContainer::newRow() { + char* row = new char[fixedRowSize_]; + if (rows_.size() < numRows_ + 1) { + rows_.reserve(numRows_ + 1024); + } + rows_.emplace_back(row); + ++numRows_; + return initializeRow(row); +} + +void +RowContainer::store(const milvus::ColumnVectorPtr& column_data, + milvus::vector_size_t index, + char* row, + int32_t column_index) { + auto numKeys = keyTypes_.size(); + bool isKey = column_index < numKeys; + if (isKey && ignoreNullKeys_) { + MILVUS_DYNAMIC_TYPE_DISPATCH(storeNoNulls, + keyTypes_[column_index], + column_data, + index, + row, + offsets_[column_index]); + } else { + AssertInfo(isKey || accumulators_.empty(), + "Should only store into rows for key"); + auto rowColumn = rowColumns_[column_index]; + MILVUS_DYNAMIC_TYPE_DISPATCH(storeWithNull, + keyTypes_[column_index], + column_data, + index, + row, + rowColumn.offset(), + rowColumn.nullByte(), + rowColumn.nullMask()); + } +} + +Accumulator::Accumulator(bool isFixedSize, int32_t fixedSize, int32_t alignment) + : isFixedSize_{isFixedSize}, fixedSize_{fixedSize}, alignment_{alignment} { +} + +Accumulator::Accumulator(milvus::exec::Aggregate* aggregate) + : isFixedSize_(aggregate->isFixedSize()), + fixedSize_{aggregate->accumulatorFixedWidthSize()}, + alignment_(aggregate->accumulatorAlignmentSize()) { + AssertInfo(aggregate != nullptr, + "Input aggregate for accumulator cannot be nullptr!"); +} + +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/operator/query-agg/RowContainer.h b/internal/core/src/exec/operator/query-agg/RowContainer.h new file mode 100644 index 0000000000000..89824e1c416c6 --- /dev/null +++ b/internal/core/src/exec/operator/query-agg/RowContainer.h @@ -0,0 +1,519 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "common/Types.h" +#include "common/Vector.h" +#include "common/Utils.h" +#include "Aggregate.h" +#include "storage/Util.h" + +namespace milvus { +namespace exec { + +class Accumulator { + public: + Accumulator(bool isFixedSize, int32_t fixedSize, int32_t alignment); + + explicit Accumulator(Aggregate* aggregate); + + bool + isFixedSize() const { + return isFixedSize_; + } + + int32_t + alignment() const { + return alignment_; + } + + int32_t + fixedWidthSize() const { + return fixedSize_; + } + + private: + const bool isFixedSize_; + const int32_t fixedSize_; + const int32_t alignment_; +}; + +/// Packed representation of offset, null byte offset and null mask for +/// a column inside a RowContainer. +class RowColumn { + public: + /// Used as null offset for a non-null column. + static constexpr int32_t kNotNullOffset = -1; + + RowColumn(int32_t offset, int32_t nullOffset) + : packedOffsets_(PackOffsets(offset, nullOffset)) { + } + + int32_t + offset() const { + return packedOffsets_ >> 32; + } + + int32_t + nullByte() const { + return static_cast(packedOffsets_) >> 8; + } + + uint8_t + nullMask() const { + return packedOffsets_ & 0xff; + } + + int32_t + initializedByte() const { + return nullByte(); + } + + int32_t + initializedMask() const { + return nullMask() << 1; + } + + private: + static uint64_t + PackOffsets(int32_t offset, int32_t nullOffset) { + if (nullOffset == kNotNullOffset) { + // If the column is not nullable, The low word is 0, meaning + // that a null check will AND 0 to the 0th byte of the row, + // which is always false and always safe to do. + return static_cast(offset) << 32; + } + return (1UL << (nullOffset & 7)) | ((nullOffset & ~7UL) << 5) | + static_cast(offset) << 32; + } + + const uint64_t packedOffsets_; +}; + +class RowContainer { + public: + RowContainer(const std::vector& keyTypes, + const std::vector& accumulators, + bool ignoreNullKeys); + + // The number of flags (bits) per accumulator, one for null and one for + // initialized. + static constexpr size_t kNumAccumulatorFlags = 2; + + /// Allocates a new row and initializes possible aggregates to null. + char* + newRow(); + + const std::vector& + KeyTypes() const { + return keyTypes_; + } + + const RowColumn& + columnAt(int32_t column_idx) const { + return rowColumns_[column_idx]; + } + + static int32_t + combineAlignments(int32_t a, int32_t b) { + AssertInfo(__builtin_popcount(a) == 1, + "Alignment can only be power of 2, but got{}", + a); + AssertInfo(__builtin_popcount(b) == 1, + "Alignment can only be power of 2, but got{}", + b); + return std::max(a, b); + } + + int32_t + rowSizeOffset() const { + return rowSizeOffset_; + } + + static inline bool + isNullAt(const char* row, int32_t nullByte, uint8_t nullMask) { + return (row[nullByte] & nullMask) != 0; + } + + static inline const std::string*& + strAt(const char* group, int32_t offset) { + return *reinterpret_cast( + const_cast(group + offset)); + } + + template + static inline T + valueAt(const char* group, int32_t offset) { + return *reinterpret_cast(group + offset); + } + + template + inline bool + equalsNoNulls(const char* row, + int32_t offset, + const ColumnVectorPtr& column, + vector_size_t index) { + if constexpr (Type == DataType::NONE || Type == DataType::ROW || + Type == DataType::JSON || Type == DataType::ARRAY) { + PanicInfo(DataTypeInvalid, + "Cannot support complex data type:[ROW/JSON/ARRAY] in " + "rows container for now"); + } else { + using T = typename TypeTraits::NativeType; + T raw_value = column->ValueAt(index); + bool equal = false; + if constexpr (std::is_same_v) { + equal = (raw_value == *(strAt(row, offset))); + } else { + equal = (milvus::comparePrimitiveAsc( + raw_value, valueAt(row, offset)) == 0); + } + return equal; + } + } + + template + inline bool + equalsWithNulls(const char* row, + int32_t offset, + int32_t nullByte, + uint8_t nullMask, + const ColumnVectorPtr& column, + vector_size_t index) { + bool rowIsNull = isNullAt(row, nullByte, nullMask); + bool columnIsNull = !column->ValidAt(index); + if (rowIsNull || columnIsNull) { + return rowIsNull == columnIsNull; + } + return equalsNoNulls(row, offset, column, index); + } + + template + inline bool + equals(const char* row, + RowColumn column, + const ColumnVectorPtr& column_data, + vector_size_t index) { + auto type = column_data->type(); + if constexpr (mayHaveNulls) { + return MILVUS_DYNAMIC_TYPE_DISPATCH(equalsWithNulls, + type, + row, + column.offset(), + column.nullByte(), + column.nullMask(), + column_data, + index); + } else { + return MILVUS_DYNAMIC_TYPE_DISPATCH( + equalsNoNulls, type, row, column.offset(), column_data, index); + } + } + + /// Stores the 'index'th value in 'columnVector' into 'row' at 'columnIndex'. + void + store(const ColumnVectorPtr& column_data, + vector_size_t index, + char* row, + int32_t column_index); + + template + inline void + storeWithNull(const ColumnVectorPtr& column, + vector_size_t index, + char* row, + int32_t offset, + int32_t nullByte, + uint8_t nullMask) { + static std::string null_string_val = ""; + static std::string* null_string_val_ptr = &null_string_val; + if constexpr (Type == DataType::NONE || Type == DataType::ROW || + Type == DataType::JSON || Type == DataType::ARRAY) { + PanicInfo(DataTypeInvalid, + "Cannot support complex data type:[ROW/JSON/ARRAY] in " + "rows container for now"); + } else { + using T = typename milvus::TypeTraits::NativeType; + if (!column->ValidAt(index)) { + row[nullByte] |= nullMask; + if constexpr (std::is_same_v) { + *reinterpret_cast(row + offset) = + null_string_val_ptr; + } else { + *reinterpret_cast(row + offset) = T(); + } + return; + } + storeNoNulls(column, index, row, offset); + } + } + + template + inline void + storeNoNulls(const ColumnVectorPtr& column, + vector_size_t index, + char* group, + int32_t offset) { + using T = typename milvus::TypeTraits::NativeType; + if constexpr (Type == DataType::NONE || Type == DataType::ROW || + Type == DataType::JSON || Type == DataType::ARRAY) { + PanicInfo(DataTypeInvalid, + "Cannot support complex data type:[ROW/JSON/ARRAY] in " + "rows container for now"); + } else { + auto raw_val_ptr = column->RawValueAt(index, sizeof(T)); + if constexpr (std::is_same_v) { + // the string object and also the underlying char array are both allocated on the heap + // must call clear method to deallocate these memory allocated for varchar type to avoid memory leak + *reinterpret_cast(group + offset) = + new std::string(*static_cast(raw_val_ptr)); + } else { + *reinterpret_cast(group + offset) = + *(static_cast(raw_val_ptr)); + } + } + } + + template + static void + extractValuesWithNulls(const char* const* rows, + int32_t numRows, + int32_t offset, + int32_t nullByte, + uint8_t nullMask, + int32_t resultOffset, + const VectorPtr& result) { + auto maxRows = numRows + resultOffset; + AssertInfo(maxRows == result->size(), + "extracted rows number should be equal to the size of " + "result vector"); + auto result_column_vec = + std::dynamic_pointer_cast(result); + AssertInfo( + result_column_vec != nullptr, + "Input column to extract result must be of ColumnVector type"); + for (auto i = 0; i < numRows; i++) { + const char* row = rows[i]; + auto resultIndex = resultOffset + i; + if (row == nullptr || isNullAt(row, nullByte, nullMask)) { + result_column_vec->nullAt(resultIndex); + } else { + if constexpr (std::is_same_v || + std::is_same_v) { + auto* str_ptr = strAt(row, offset); + result_column_vec->SetValueAt(resultIndex, *str_ptr); + } else { + result_column_vec->SetValueAt(resultIndex, + valueAt(row, offset)); + } + } + } + } + + template + static void + extractValuesNoNulls(const char* const* rows, + int32_t numRows, + int32_t offset, + int32_t resultOffset, + const VectorPtr& result) { + auto maxRows = numRows + resultOffset; + AssertInfo(maxRows == result->size(), + "extracted rows number should be equal to the size of " + "result vector"); + auto result_column_vec = + std::dynamic_pointer_cast(result); + AssertInfo( + result_column_vec != nullptr, + "Input column to extract result must be of ColumnVector type"); + for (auto i = 0; i < numRows; i++) { + const char* row = rows[i]; + auto resultIndex = resultOffset + i; + if (row == nullptr) { + result_column_vec->nullAt(resultIndex); + } else { + if constexpr (std::is_same_v || + std::is_same_v) { + auto* str_ptr = strAt(row, offset); + result_column_vec->SetValueAt(resultIndex, *str_ptr); + } else { + result_column_vec->SetValueAt(resultIndex, + valueAt(row, offset)); + } + } + } + } + + template + static void + extractColumnTypedInternal(const char* const* rows, + int32_t numRows, + RowColumn column, + int32_t resultOffset, + const VectorPtr& result) { + result->resize(numRows + resultOffset); + if constexpr (Type == DataType::ROW || Type == DataType::JSON || + Type == DataType::ARRAY || Type == DataType::NONE) { + PanicInfo(DataTypeInvalid, + "Not Support Extract types:[ROW/JSON/ARRAY/NONE]"); + } else { + using T = typename milvus::TypeTraits::NativeType; + auto nullMask = column.nullMask(); + auto offset = column.offset(); + if (nullMask) { + extractValuesWithNulls(rows, + numRows, + offset, + column.nullByte(), + nullMask, + resultOffset, + result); + } else { + extractValuesNoNulls( + rows, numRows, offset, resultOffset, result); + } + } + } + + template + static void + extractColumnTyped(const char* const* rows, + int32_t numRows, + RowColumn column, + int32_t resultOffset, + const VectorPtr& result) { + extractColumnTypedInternal( + rows, numRows, column, resultOffset, result); + } + + static void + extractColumn(const char* const* rows, + int32_t num_rows, + RowColumn column, + vector_size_t result_offset, + const VectorPtr& result); + + void + extractColumn(const char* const* rows, + int32_t numRows, + int32_t column_idx, + const VectorPtr& result) { + extractColumn(rows, numRows, columnAt(column_idx), 0, result); + } + + const std::vector& + allRows() const { + return rows_; + } + + static inline int32_t + nullByte(int32_t nullOffset) { + return nullOffset / 8; + } + + static inline uint8_t + nullMask(int32_t nullOffset) { + return 1 << (nullOffset & 7); + } + // Only accumulators have initialized flags. accumulatorFlagsOffset is the + // offset at which the flags for an accumulator begin. Currently this is the + // null flag, followed by the initialized flag. So it's equivalent to the + // nullOffset. + + // It's guaranteed that the flags for an accumulator appear in the same byte. + static inline int32_t + initializedByte(int32_t accumulatorFlagsOffset) { + return nullByte(accumulatorFlagsOffset); + } + + // accumulatorFlagsOffset is the offset at which the flags for an accumulator + // begin. + static inline int32_t + initializedMask(int32_t accumulatorFlagsOffset) { + return nullMask(accumulatorFlagsOffset) << 1; + } + + void + clear() { + for (auto row : rows_) { + for (auto i = 0; i < variable_offsets.size(); i++) { + auto& off = variable_offsets[i]; + auto& row_col = columnAt(variable_idxes[i]); + bool isStrNull = + isNullAt(row, row_col.nullByte(), row_col.nullMask()); + auto str = *reinterpret_cast(row + off); + if (!isStrNull && str) { + delete str; + str = nullptr; + *reinterpret_cast(row + off) = nullptr; + } + } + delete[] row; + } + numRows_ = 0; + } + + char* + initializeRow(char* row); + + private: + const std::vector keyTypes_; + std::vector variable_offsets{}; + std::vector variable_idxes{}; + const bool ignoreNullKeys_; + std::vector offsets_; + std::vector nullOffsets_; + + std::vector rowColumns_; + + // How many bytes do the flags (null, free) occupy. + uint32_t fixedRowSize_; + uint32_t flagBytes_; + + // Bit position of free bit. + uint32_t freeFlagOffset_ = 0; + uint32_t rowSizeOffset_ = 0; + + int alignment_ = 1; + + // Copied over the null bits of each row on initialization. Keys are + // not null, aggregates are null. + std::vector initialNulls_; + + std::vector accumulators_; + + uint64_t numRows_ = 0; + std::vector rows_{}; +}; + +inline void +RowContainer::extractColumn(const char* const* rows, + int32_t num_rows, + milvus::exec::RowColumn column, + milvus::vector_size_t result_offset, + const milvus::VectorPtr& result) { + MILVUS_DYNAMIC_TYPE_DISPATCH(extractColumnTyped, + result->type(), + rows, + num_rows, + column, + result_offset, + result); +} +} // namespace exec +} // namespace milvus diff --git a/internal/core/src/exec/operator/query-agg/SimpleNumericAggregate.h b/internal/core/src/exec/operator/query-agg/SimpleNumericAggregate.h new file mode 100644 index 0000000000000..47b36c304abd5 --- /dev/null +++ b/internal/core/src/exec/operator/query-agg/SimpleNumericAggregate.h @@ -0,0 +1,128 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License +#pragma once + +#include "Aggregate.h" +namespace milvus { +namespace exec { +template +class SimpleNumericAggregate : public exec::Aggregate { + protected: + explicit SimpleNumericAggregate(DataType resultType) + : Aggregate(resultType) { + } + + // TData is either TAccumulator or TResult, which in most cases are the same, + // but for sum(real) can differ. + template + void + doExtractValues(char** groups, + int32_t numGroups, + VectorPtr* result, + ExtractOneValue extractOneValue) { + AssertInfo((*result)->elementSize() == sizeof(TData), + "Incorrect type size of input result vector"); + ColumnVectorPtr result_column = + std::dynamic_pointer_cast(*result); + AssertInfo(result_column != nullptr, + "input vector for extracting aggregation must be of Type " + "ColumnVector"); + result_column->resize(numGroups); + TData* rawValues = static_cast(result_column->GetRawData()); + for (auto i = 0; i < numGroups; i++) { + char* group = groups[i]; + if (isNull(group)) { + result_column->nullAt(i); + LOG_INFO("hc===extractValues set null at i:{}", i); + } else { + result_column->clearNullAt(i); + rawValues[i] = extractOneValue(group); + LOG_INFO("hc===try to extractValues, i:{}, rawValues[i]:{}", + i, + rawValues[i]); + } + } + } + + template + void + updateGroups(char** groups, + const TargetBitmapView& rows, + const VectorPtr& vector, + UpdateSingleValue updateSingleValue) { + auto start = -1; + auto column_data = std::dynamic_pointer_cast(vector); + AssertInfo( + column_data != nullptr, + "input column data for upgrading groups should not be nullptr"); + while (true) { + auto next_selected = rows.find_next(start); + if (!next_selected.has_value()) { + return; + } + auto selected_idx = next_selected.value(); + if (column_data->ValidAt(selected_idx)) { + updateNonNullValue( + groups[selected_idx], + TData(column_data->ValueAt(selected_idx)), + updateSingleValue); + } else { + } + start = selected_idx; + } + } + + template + void + updateOneGroup(char* group, + const TargetBitmapView& rows, + const VectorPtr& vector, + UpdateSingle updateSingleValue) { + auto start = -1; + auto column_data = std::dynamic_pointer_cast(vector); + AssertInfo( + column_data != nullptr, + "input column data for upgrading groups should not be nullptr"); + while (true) { + auto next_selected = rows.find_next(start); + if (!next_selected.has_value()) { + return; + } + auto selected_idx = next_selected.value(); + if (column_data->ValidAt(selected_idx)) { + updateNonNullValue( + group, + TData(column_data->ValueAt(selected_idx)), + updateSingleValue); + } + start = selected_idx; + } + } + + template + inline void + updateNonNullValue(char* group, TDataType value, Update updateValue) { + if constexpr (tableHasNulls) { + Aggregate::clearNull(group); + } + updateValue(*Aggregate::value(group), value); + } +}; + +} // namespace exec +} // namespace milvus \ No newline at end of file diff --git a/internal/core/src/exec/operator/query-agg/SumAggregate.cpp b/internal/core/src/exec/operator/query-agg/SumAggregate.cpp new file mode 100644 index 0000000000000..bf4c70ff0f0cc --- /dev/null +++ b/internal/core/src/exec/operator/query-agg/SumAggregate.cpp @@ -0,0 +1,89 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "SumAggregateBase.h" +#include "RegisterAggregateFunctions.h" +#include "expr/FunctionSignature.h" + +namespace milvus { +namespace exec { + +template +using SumAggregate = SumAggregateBase; + +template