diff --git a/internal/querynodev2/segments/manager.go b/internal/querynodev2/segments/manager.go index ce251198b6f6e..c5fe7f72905b7 100644 --- a/internal/querynodev2/segments/manager.go +++ b/internal/querynodev2/segments/manager.go @@ -344,28 +344,25 @@ func (mgr *segmentManager) GetAndPin(segments []int64, filters ...SegmentFilter) growing, growingExist := mgr.growingSegments[id] sealed, sealedExist := mgr.sealedSegments[id] - growingExist = growingExist && filter(growing, filters...) - sealedExist = sealedExist && filter(sealed, filters...) + if !growingExist && !sealedExist { + err = merr.WrapErrSegmentNotLoaded(id, "segment not found") + return nil, err + } - if growingExist { + if growingExist && filter(growing, filters...) { err = growing.RLock() if err != nil { return nil, err } lockedSegments = append(lockedSegments, growing) } - if sealedExist { + if sealedExist && filter(sealed, filters...) { err = sealed.RLock() if err != nil { return nil, err } lockedSegments = append(lockedSegments, sealed) } - - if !growingExist && !sealedExist { - err = merr.WrapErrSegmentNotLoaded(id, "segment not found") - return nil, err - } } return lockedSegments, nil } diff --git a/internal/querynodev2/segments/manager_test.go b/internal/querynodev2/segments/manager_test.go index 5a8a374355213..070de4ead9c5e 100644 --- a/internal/querynodev2/segments/manager_test.go +++ b/internal/querynodev2/segments/manager_test.go @@ -8,6 +8,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -32,11 +33,11 @@ func (s *ManagerSuite) SetupSuite() { s.partitionIDs = []int64{10, 11, 12} s.channels = []string{"dml1", "dml2", "dml3"} s.types = []SegmentType{SegmentTypeSealed, SegmentTypeGrowing, SegmentTypeSealed} -} -func (s *ManagerSuite) SetupTest() { s.mgr = NewSegmentManager() +} +func (s *ManagerSuite) SetupTest() { for i, id := range s.segmentIDs { schema := GenTestCollectionSchema("manager-suite", schemapb.DataType_Int64) segment, err := NewSegment( @@ -145,6 +146,20 @@ func (s *ManagerSuite) TestIncreaseVersion() { segment.AssertExpectations(s.T()) } +func (s *ManagerSuite) TestGetAndPin() { + // test pin loaded empty segment + segments, err := s.mgr.GetAndPin(s.segmentIDs, WithType(SegmentTypeGrowing)) + s.NoError(err) + s.Len(segments, 0) + segments, err = s.mgr.GetAndPin(s.segmentIDs, WithType(SegmentTypeSealed)) + s.NoError(err) + s.Len(segments, 0) + + // test pin not loaded segment + _, err = s.mgr.GetAndPin([]int64{11, 22, 33}) + s.ErrorIs(err, merr.ErrSegmentNotLoaded) +} + func TestManager(t *testing.T) { suite.Run(t, new(ManagerSuite)) } diff --git a/internal/querynodev2/segments/retrieve_test.go b/internal/querynodev2/segments/retrieve_test.go index c356948b52739..5a7df04371726 100644 --- a/internal/querynodev2/segments/retrieve_test.go +++ b/internal/querynodev2/segments/retrieve_test.go @@ -235,7 +235,7 @@ func (suite *RetrieveSuite) TestRetrieveNonExistSegment() { } res, segments, err := Retrieve(context.TODO(), suite.manager, plan, req) - suite.Error(err) + suite.ErrorIs(err, merr.ErrSegmentNotLoaded) suite.Len(res, 0) suite.manager.Segment.Unpin(segments) } @@ -255,7 +255,7 @@ func (suite *RetrieveSuite) TestRetrieveNilSegment() { } res, segments, err := Retrieve(context.TODO(), suite.manager, plan, req) - suite.ErrorIs(err, merr.ErrSegmentNotLoaded) + suite.NoError(err) suite.Len(res, 0) suite.manager.Segment.Unpin(segments) }