diff --git a/internal/storage/stats.go b/internal/storage/stats.go index 9f4a2ea4dba59..3ab2e026d477e 100644 --- a/internal/storage/stats.go +++ b/internal/storage/stats.go @@ -458,9 +458,9 @@ func (m *BM25Stats) Deserialize(bs []byte) error { } func (m *BM25Stats) BuildIDF(tf []byte) (idf []byte) { - dim := typeutil.SparseFloatRowElementCount(tf) + numElements := typeutil.SparseFloatRowElementCount(tf) idf = make([]byte, len(tf)) - for idx := 0; idx < dim; idx++ { + for idx := 0; idx < numElements; idx++ { key := typeutil.SparseFloatRowIndexAt(tf, idx) value := typeutil.SparseFloatRowValueAt(tf, idx) nq := m.rowsWithToken[key] diff --git a/internal/util/function/bm25_function.go b/internal/util/function/bm25_function.go index b7d04987abba8..d4c484922164d 100644 --- a/internal/util/function/bm25_function.go +++ b/internal/util/function/bm25_function.go @@ -107,7 +107,7 @@ func (v *BM25FunctionRunner) run(data []string, dst []map[uint32]float32) error func (v *BM25FunctionRunner) BatchRun(inputs ...any) ([]any, error) { if len(inputs) > 1 { - return nil, fmt.Errorf("BM25 function receieve more than one input") + return nil, fmt.Errorf("BM25 function received more than one input column") } text, ok := inputs[0].([]string) @@ -158,16 +158,18 @@ func (v *BM25FunctionRunner) GetOutputFields() []*schemapb.FieldSchema { } func buildSparseFloatArray(mapdata []map[uint32]float32) *schemapb.SparseFloatArray { - dim := 0 + dim := int64(0) bytes := lo.Map(mapdata, func(sparseMap map[uint32]float32, _ int) []byte { - if len(sparseMap) > dim { - dim = len(sparseMap) + row := typeutil.CreateAndSortSparseFloatRow(sparseMap) + rowDim := typeutil.SparseFloatRowDim(row) + if rowDim > dim { + dim = rowDim } - return typeutil.CreateAndSortSparseFloatRow(sparseMap) + return row }) return &schemapb.SparseFloatArray{ Contents: bytes, - Dim: int64(dim), + Dim: dim, } }