Skip to content

Commit

Permalink
enhance: Add Sparse Index type enum (#723)
Browse files Browse the repository at this point in the history
See also #708
Support SparseInverted & SparseWAND

Signed-off-by: Congqi Xia <[email protected]>
  • Loading branch information
congqixia authored Apr 16, 2024
1 parent 150a59f commit a963bd4
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 0 deletions.
1 change: 1 addition & 0 deletions entity/columns_sparse.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type SparseEmbedding interface {
Len() int // the actual items in this vector
Get(idx int) (pos uint32, value float32, ok bool)
Serialize() []byte
FieldType() FieldType
}

var _ SparseEmbedding = sliceSparseEmbedding{}
Expand Down
5 changes: 5 additions & 0 deletions entity/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ const (
GPUCagra IndexType = "GPU_CAGRA"
GPUBruteForce IndexType = "GPU_BRUTE_FORCE"

// Sparse
SparseInverted IndexType = "SPARSE_INVERTED_INDEX"
SparseWAND IndexType = "SPARSE_WAND"

// DEPRECATED
Scalar IndexType = ""

Expand All @@ -66,6 +70,7 @@ const (

// index param field tag
const (
tParams = `params`
tIndexType = `index_type`
tMetricType = `metric_type`
)
Expand Down
115 changes: 115 additions & 0 deletions entity/index_sparse.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package entity

import (
"encoding/json"
"fmt"

"github.com/cockroachdb/errors"
)

var _ Index = (*IndexSparseInverted)(nil)

// IndexSparseInverted index type for SPARSE_INVERTED_INDEX
type IndexSparseInverted struct {
metricType MetricType
dropRatio float64
}

func (i *IndexSparseInverted) Name() string {
return "SparseInverted"
}

func (i *IndexSparseInverted) IndexType() IndexType {
return SparseInverted
}

func (i *IndexSparseInverted) Params() map[string]string {
params := map[string]string{
"drop_ratio_build": fmt.Sprintf("%v", i.dropRatio),
}
bs, _ := json.Marshal(params)
return map[string]string{
tParams: string(bs),
tIndexType: string(i.IndexType()),
tMetricType: string(i.metricType),
}
}

type IndexSparseInvertedSearchParam struct {
baseSearchParams
}

func NewIndexSparseInvertedSearchParam(dropRatio float64) (*IndexSparseInvertedSearchParam, error) {
if dropRatio < 0 || dropRatio >= 1 {
return nil, errors.Newf("invalid dropRatio for search: %v, must be in range [0, 1)", dropRatio)
}
sp := &IndexSparseInvertedSearchParam{
baseSearchParams: newBaseSearchParams(),
}

sp.params["drop_ratio_search"] = dropRatio
return sp, nil
}

// IndexSparseInverted index type for SPARSE_INVERTED_INDEX
func NewIndexSparseInverted(metricType MetricType, dropRatio float64) (*IndexSparseInverted, error) {
if dropRatio < 0 || dropRatio >= 1.0 {
return nil, errors.Newf("invalid dropRatio for build: %v, must be in range [0, 1)", dropRatio)
}
return &IndexSparseInverted{
metricType: metricType,
dropRatio: dropRatio,
}, nil
}

type IndexSparseWAND struct {
metricType MetricType
dropRatio float64
}

func (i *IndexSparseWAND) Name() string {
return "SparseWAND"
}

func (i *IndexSparseWAND) IndexType() IndexType {
return SparseWAND
}

func (i *IndexSparseWAND) Params() map[string]string {
params := map[string]string{
"drop_ratio_build": fmt.Sprintf("%v", i.dropRatio),
}
bs, _ := json.Marshal(params)
return map[string]string{
tParams: string(bs),
tIndexType: string(i.IndexType()),
tMetricType: string(i.metricType),
}
}

// IndexSparseWAND index type for SPARSE_WAND, weak-and
func NewIndexSparseWAND(metricType MetricType, dropRatio float64) (*IndexSparseWAND, error) {
if dropRatio < 0 || dropRatio >= 1.0 {
return nil, errors.Newf("invalid dropRatio for build: %v, must be in range [0, 1)", dropRatio)
}
return &IndexSparseWAND{
metricType: metricType,
dropRatio: dropRatio,
}, nil
}

type IndexSparseWANDSearchParam struct {
baseSearchParams
}

func NewIndexSparseWANDSearchParam(dropRatio float64) (*IndexSparseWANDSearchParam, error) {
if dropRatio < 0 || dropRatio >= 1 {
return nil, errors.Newf("invalid dropRatio for search: %v, must be in range [0, 1)", dropRatio)
}
sp := &IndexSparseWANDSearchParam{
baseSearchParams: newBaseSearchParams(),
}

sp.params["drop_ratio_search"] = dropRatio
return sp, nil
}
98 changes: 98 additions & 0 deletions entity/index_sparse_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package entity

import (
"encoding/json"
"testing"

"github.com/stretchr/testify/suite"
)

type SparseIndexSuite struct {
suite.Suite
}

func (s *SparseIndexSuite) TestSparseInverted() {
s.Run("bad_drop_ratio", func() {
_, err := NewIndexSparseInverted(IP, -1)
s.Error(err)

_, err = NewIndexSparseInverted(IP, 1.0)
s.Error(err)
})

s.Run("normal_case", func() {
idx, err := NewIndexSparseInverted(IP, 0.2)
s.Require().NoError(err)

s.Equal("SparseInverted", idx.Name())
s.Equal(SparseInverted, idx.IndexType())
params := idx.Params()

s.Equal("SPARSE_INVERTED_INDEX", params[tIndexType])
s.Equal("IP", params[tMetricType])
paramsVal, has := params[tParams]
s.True(has)
m := make(map[string]string)
err = json.Unmarshal([]byte(paramsVal), &m)
s.Require().NoError(err)
dropRatio, ok := m["drop_ratio_build"]
s.True(ok)
s.Equal("0.2", dropRatio)
})

s.Run("search_param", func() {
_, err := NewIndexSparseInvertedSearchParam(-1)
s.Error(err)
_, err = NewIndexSparseInvertedSearchParam(1.0)
s.Error(err)

sp, err := NewIndexSparseInvertedSearchParam(0.2)
s.Require().NoError(err)
s.EqualValues(0.2, sp.Params()["drop_ratio_search"])
})
}

func (s *SparseIndexSuite) TestSparseWAND() {
s.Run("bad_drop_ratio", func() {
_, err := NewIndexSparseWAND(IP, -1)
s.Error(err)

_, err = NewIndexSparseWAND(IP, 1.0)
s.Error(err)
})

s.Run("normal_case", func() {
idx, err := NewIndexSparseWAND(IP, 0.2)
s.Require().NoError(err)

s.Equal("SparseWAND", idx.Name())
s.Equal(SparseWAND, idx.IndexType())
params := idx.Params()

s.Equal("SPARSE_WAND", params[tIndexType])
s.Equal("IP", params[tMetricType])
paramsVal, has := params[tParams]
s.True(has)
m := make(map[string]string)
err = json.Unmarshal([]byte(paramsVal), &m)
s.Require().NoError(err)
dropRatio, ok := m["drop_ratio_build"]
s.True(ok)
s.Equal("0.2", dropRatio)
})

s.Run("search_param", func() {
_, err := NewIndexSparseWANDSearchParam(-1)
s.Error(err)
_, err = NewIndexSparseWANDSearchParam(1.0)
s.Error(err)

sp, err := NewIndexSparseWANDSearchParam(0.2)
s.Require().NoError(err)
s.EqualValues(0.2, sp.Params()["drop_ratio_search"])
})
}

func TestSparseIndex(t *testing.T) {
suite.Run(t, new(SparseIndexSuite))
}

0 comments on commit a963bd4

Please sign in to comment.