From 4424c9e5e7f4c3c758c8d043281d875119801a29 Mon Sep 17 00:00:00 2001 From: congqixia Date: Thu, 20 Jun 2024 14:42:00 +0800 Subject: [PATCH 01/21] fix: [2.4] Remove loopclosure issue in ChannelManagerImplV2 (#33989) (#34004) Cherry-pick from master pr: #33989 See also #33987 Signed-off-by: Congqi Xia --- internal/datacoord/channel_manager_v2.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/internal/datacoord/channel_manager_v2.go b/internal/datacoord/channel_manager_v2.go index 53504e16b020c..cee393b1b7ddc 100644 --- a/internal/datacoord/channel_manager_v2.go +++ b/internal/datacoord/channel_manager_v2.go @@ -529,6 +529,7 @@ func (m *ChannelManagerImplV2) advanceToNotifies(ctx context.Context, toNotifies if channelCount == 0 { continue } + nodeID := nodeAssign.NodeID var ( succeededChannels = make([]RWChannel, 0, channelCount) @@ -548,7 +549,7 @@ func (m *ChannelManagerImplV2) advanceToNotifies(ctx context.Context, toNotifies tmpWatchInfo.Vchan = m.h.GetDataVChanPositions(innerCh, allPartitionID) future := getOrCreateIOPool().Submit(func() (any, error) { - err := m.Notify(ctx, nodeAssign.NodeID, tmpWatchInfo) + err := m.Notify(ctx, nodeID, tmpWatchInfo) return innerCh, err }) futures = append(futures, future) @@ -591,6 +592,7 @@ func (m *ChannelManagerImplV2) advanceToChecks(ctx context.Context, toChecks []* continue } + nodeID := nodeAssign.NodeID futures := make([]*conc.Future[any], 0, len(nodeAssign.Channels)) chNames := lo.Keys(nodeAssign.Channels) @@ -603,7 +605,7 @@ func (m *ChannelManagerImplV2) advanceToChecks(ctx context.Context, toChecks []* innerCh := ch future := getOrCreateIOPool().Submit(func() (any, error) { - successful, got := m.Check(ctx, nodeAssign.NodeID, innerCh.GetWatchInfo()) + successful, got := m.Check(ctx, nodeID, innerCh.GetWatchInfo()) if got { return poolResult{ successful: successful, From b3d425f50a4fcf119c676835a1a09e6c9949a1c5 Mon Sep 17 00:00:00 2001 From: shaoting-huang <167743503+shaoting-huang@users.noreply.github.com> Date: Thu, 20 Jun 2024 14:52:00 +0800 Subject: [PATCH 02/21] enhance: Upgrade go version from 1.20 to 1.21 (#33940) issue #32982 related pr in master: pr: #33047 #33150 #33176 #33351 #33202 #33192 Signed-off-by: shaoting-huang --- .env | 8 ++++---- .github/workflows/mac.yaml | 2 +- .golangci.yml | 2 +- DEVELOPMENT.md | 12 ++++++------ Makefile | 15 ++++++++------- README.md | 6 +++--- README_CN.md | 2 +- .../docker/builder/cpu/amazonlinux2023/Dockerfile | 15 ++++++++++++--- build/docker/builder/cpu/rockylinux8/Dockerfile | 2 +- build/docker/builder/cpu/ubuntu20.04/Dockerfile | 2 +- build/docker/builder/gpu/ubuntu20.04/Dockerfile | 2 +- build/docker/builder/gpu/ubuntu22.04/Dockerfile | 2 +- build/docker/meta-migration/builder/Dockerfile | 2 +- client/go.mod | 2 +- configs/pgo/default.pgo | 0 go.mod | 2 +- pkg/go.mod | 2 +- scripts/README.md | 2 +- 18 files changed, 45 insertions(+), 35 deletions(-) create mode 100644 configs/pgo/default.pgo diff --git a/.env b/.env index 96cd6e27ed183..6beb24525c5e1 100644 --- a/.env +++ b/.env @@ -5,12 +5,12 @@ IMAGE_ARCH=amd64 OS_NAME=ubuntu20.04 # for services.builder.image in docker-compose.yml -DATE_VERSION=20240429-6289f3a -LATEST_DATE_VERSION=20240429-6289f3a +DATE_VERSION=20240520-d27db99 +LATEST_DATE_VERSION=20240520-d27db99 # for services.gpubuilder.image in docker-compose.yml -GPU_DATE_VERSION=20240409-08bfb43 -LATEST_GPU_DATE_VERSION=20240409-08bfb43 +GPU_DATE_VERSION=20240520-c35eaaa +LATEST_GPU_DATE_VERSION=20240520-c35eaaa # for other services in docker-compose.yml MINIO_ADDRESS=minio:9000 diff --git a/.github/workflows/mac.yaml b/.github/workflows/mac.yaml index d17125b9d7c86..ccb21ebaab5af 100644 --- a/.github/workflows/mac.yaml +++ b/.github/workflows/mac.yaml @@ -56,7 +56,7 @@ jobs: - name: Setup Go environment uses: actions/setup-go@v2.2.0 with: - go-version: '~1.20.7' + go-version: '~1.21.10' - name: Mac Cache Go Mod Volumes uses: actions/cache@v3 with: diff --git a/.golangci.yml b/.golangci.yml index 09779daf2548c..91895ce0cc115 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,5 +1,5 @@ run: - go: "1.20" + go: "1.21" skip-dirs: - build - configs diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 12d92600e311e..88de174aa695e 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -104,7 +104,7 @@ You can use Vscode to integrate C++ and Go together. Please replace user.setting Linux systems (Recommend Ubuntu 20.04 or later): ```bash -go: >= 1.20 +go: >= 1.21 cmake: >= 3.18 gcc: 7.5 conan: 1.61 @@ -113,7 +113,7 @@ conan: 1.61 MacOS systems with x86_64 (Big Sur 11.5 or later recommended): ```bash -go: >= 1.20 +go: >= 1.21 cmake: >= 3.18 llvm: >= 15 conan: 1.61 @@ -122,7 +122,7 @@ conan: 1.61 MacOS systems with Apple Silicon (Monterey 12.0.1 or later recommended): ```bash -go: >= 1.20 (Arch=ARM64) +go: >= 1.21 (Arch=ARM64) cmake: >= 3.18 llvm: >= 15 conan: 1.61 @@ -178,7 +178,7 @@ Confirm that your `GOPATH` and `GOBIN` environment variables are correctly set a ```shell $ go version ``` -Note: go >= 1.20 is required to build Milvus. +Note: go >= 1.21 is required to build Milvus. #### Docker & Docker Compose @@ -239,8 +239,8 @@ pip3 install conan==1.61.0 #### Install GO 1.80 ```bash -wget https://go.dev/dl/go1.18.10.linux-arm64.tar.gz -tar zxf go1.18.10.linux-arm64.tar.gz +wget https://go.dev/dl/go1.21.10.linux-arm64.tar.gz +tar zxf go1.21.10.linux-arm64.tar.gz mv ./go /usr/local vi /etc/profile export PATH=$PATH:/usr/local/go/bin diff --git a/Makefile b/Makefile index 268f7f994f251..c1bfc9e901e75 100644 --- a/Makefile +++ b/Makefile @@ -17,6 +17,7 @@ OBJPREFIX := "github.com/milvus-io/milvus/cmd/milvus" INSTALL_PATH := $(PWD)/bin LIBRARY_PATH := $(PWD)/lib +PGO_PATH := $(PWD)/configs/pgo OS := $(shell uname -s) mode = Release @@ -72,14 +73,14 @@ milvus: build-cpp print-build-info @echo "Building Milvus ..." @source $(PWD)/scripts/setenv.sh && \ mkdir -p $(INSTALL_PATH) && go env -w CGO_ENABLED="1" && \ - GO111MODULE=on $(GO) build -ldflags="-r $${RPATH} -X '$(OBJPREFIX).BuildTags=$(BUILD_TAGS)' -X '$(OBJPREFIX).BuildTime=$(BUILD_TIME)' -X '$(OBJPREFIX).GitCommit=$(GIT_COMMIT)' -X '$(OBJPREFIX).GoVersion=$(GO_VERSION)'" \ + GO111MODULE=on $(GO) build -pgo=$(PGO_PATH)/default.pgo -ldflags="-r $${RPATH} -X '$(OBJPREFIX).BuildTags=$(BUILD_TAGS)' -X '$(OBJPREFIX).BuildTime=$(BUILD_TIME)' -X '$(OBJPREFIX).GitCommit=$(GIT_COMMIT)' -X '$(OBJPREFIX).GoVersion=$(GO_VERSION)'" \ -tags dynamic -o $(INSTALL_PATH)/milvus $(PWD)/cmd/main.go 1>/dev/null milvus-gpu: build-cpp-gpu print-gpu-build-info @echo "Building Milvus-gpu ..." @source $(PWD)/scripts/setenv.sh && \ mkdir -p $(INSTALL_PATH) && go env -w CGO_ENABLED="1" && \ - GO111MODULE=on $(GO) build -ldflags="-r $${RPATH} -X '$(OBJPREFIX).BuildTags=$(BUILD_TAGS_GPU)' -X '$(OBJPREFIX).BuildTime=$(BUILD_TIME)' -X '$(OBJPREFIX).GitCommit=$(GIT_COMMIT)' -X '$(OBJPREFIX).GoVersion=$(GO_VERSION)'" \ + GO111MODULE=on $(GO) build -pgo=$(PGO_PATH)/default.pgo -ldflags="-r $${RPATH} -X '$(OBJPREFIX).BuildTags=$(BUILD_TAGS_GPU)' -X '$(OBJPREFIX).BuildTime=$(BUILD_TIME)' -X '$(OBJPREFIX).GitCommit=$(GIT_COMMIT)' -X '$(OBJPREFIX).GoVersion=$(GO_VERSION)'" \ -tags dynamic -o $(INSTALL_PATH)/milvus $(PWD)/cmd/main.go 1>/dev/null get-build-deps: @@ -106,7 +107,7 @@ getdeps: tools/bin/revive: tools/check/go.mod cd tools/check; \ - $(GO) build -o ../bin/revive github.com/mgechev/revive + $(GO) build -pgo=$(PGO_PATH)/default.pgo -o ../bin/revive github.com/mgechev/revive cppcheck: @#(env bash ${PWD}/scripts/core_build.sh -l) @@ -164,14 +165,14 @@ binlog: @echo "Building binlog ..." @source $(PWD)/scripts/setenv.sh && \ mkdir -p $(INSTALL_PATH) && go env -w CGO_ENABLED="1" && \ - GO111MODULE=on $(GO) build -ldflags="-r $${RPATH}" -o $(INSTALL_PATH)/binlog $(PWD)/cmd/tools/binlog/main.go 1>/dev/null + GO111MODULE=on $(GO) build -pgo=$(PGO_PATH)/default.pgo -ldflags="-r $${RPATH}" -o $(INSTALL_PATH)/binlog $(PWD)/cmd/tools/binlog/main.go 1>/dev/null MIGRATION_PATH = $(PWD)/cmd/tools/migration meta-migration: @echo "Building migration tool ..." @source $(PWD)/scripts/setenv.sh && \ mkdir -p $(INSTALL_PATH) && go env -w CGO_ENABLED="1" && \ - GO111MODULE=on $(GO) build -ldflags="-r $${RPATH} -X '$(OBJPREFIX).BuildTags=$(BUILD_TAGS)' -X '$(OBJPREFIX).BuildTime=$(BUILD_TIME)' -X '$(OBJPREFIX).GitCommit=$(GIT_COMMIT)' -X '$(OBJPREFIX).GoVersion=$(GO_VERSION)'" \ + GO111MODULE=on $(GO) build -pgo=$(PGO_PATH)/default.pgo -ldflags="-r $${RPATH} -X '$(OBJPREFIX).BuildTags=$(BUILD_TAGS)' -X '$(OBJPREFIX).BuildTime=$(BUILD_TIME)' -X '$(OBJPREFIX).GitCommit=$(GIT_COMMIT)' -X '$(OBJPREFIX).GoVersion=$(GO_VERSION)'" \ -tags dynamic -o $(INSTALL_PATH)/meta-migration $(MIGRATION_PATH)/main.go 1>/dev/null INTERATION_PATH = $(PWD)/tests/integration @@ -366,7 +367,7 @@ clean: milvus-tools: print-build-info @echo "Building tools ..." @mkdir -p $(INSTALL_PATH)/tools && go env -w CGO_ENABLED="1" && GO111MODULE=on $(GO) build \ - -ldflags="-X 'main.BuildTags=$(BUILD_TAGS)' -X 'main.BuildTime=$(BUILD_TIME)' -X 'main.GitCommit=$(GIT_COMMIT)' -X 'main.GoVersion=$(GO_VERSION)'" \ + -pgo=$(PGO_PATH)/default.pgo -ldflags="-X 'main.BuildTags=$(BUILD_TAGS)' -X 'main.BuildTime=$(BUILD_TIME)' -X 'main.GitCommit=$(GIT_COMMIT)' -X 'main.GoVersion=$(GO_VERSION)'" \ -o $(INSTALL_PATH)/tools $(PWD)/cmd/tools/* 1>/dev/null rpm-setup: @@ -514,5 +515,5 @@ mmap-migration: @echo "Building migration tool ..." @source $(PWD)/scripts/setenv.sh && \ mkdir -p $(INSTALL_PATH) && go env -w CGO_ENABLED="1" && \ - GO111MODULE=on $(GO) build -ldflags="-r $${RPATH} -X '$(OBJPREFIX).BuildTags=$(BUILD_TAGS)' -X '$(OBJPREFIX).BuildTime=$(BUILD_TIME)' -X '$(OBJPREFIX).GitCommit=$(GIT_COMMIT)' -X '$(OBJPREFIX).GoVersion=$(GO_VERSION)'" \ + GO111MODULE=on $(GO) build -pgo=$(PGO_PATH)/default.pgo -ldflags="-r $${RPATH} -X '$(OBJPREFIX).BuildTags=$(BUILD_TAGS)' -X '$(OBJPREFIX).BuildTime=$(BUILD_TIME)' -X '$(OBJPREFIX).GitCommit=$(GIT_COMMIT)' -X '$(OBJPREFIX).GoVersion=$(GO_VERSION)'" \ -tags dynamic -o $(INSTALL_PATH)/mmap-migration $(MMAP_MIGRATION_PATH)/main.go 1>/dev/null \ No newline at end of file diff --git a/README.md b/README.md index d4de9cb07a564..f352fce1fbda9 100644 --- a/README.md +++ b/README.md @@ -72,21 +72,21 @@ Check the requirements first. Linux systems (Ubuntu 20.04 or later recommended): ```bash -go: >= 1.20 +go: >= 1.21 cmake: >= 3.26.4 gcc: 7.5 ``` MacOS systems with x86_64 (Big Sur 11.5 or later recommended): ```bash -go: >= 1.20 +go: >= 1.21 cmake: >= 3.26.4 llvm: >= 15 ``` MacOS systems with Apple Silicon (Monterey 12.0.1 or later recommended): ```bash -go: >= 1.20 (Arch=ARM64) +go: >= 1.21 (Arch=ARM64) cmake: >= 3.26.4 llvm: >= 15 ``` diff --git a/README_CN.md b/README_CN.md index 2a4e149a72d54..c7fe1f4e7eb13 100644 --- a/README_CN.md +++ b/README_CN.md @@ -68,7 +68,7 @@ Milvus 基于 [Apache 2.0 License](https://github.com/milvus-io/milvus/blob/mast 请先安装相关依赖。 ``` -go: 1.20 +go: 1.21 cmake: >=3.18 gcc: 7.5 protobuf: >=3.7 diff --git a/build/docker/builder/cpu/amazonlinux2023/Dockerfile b/build/docker/builder/cpu/amazonlinux2023/Dockerfile index d5516fd46ab0f..d052c37755b73 100644 --- a/build/docker/builder/cpu/amazonlinux2023/Dockerfile +++ b/build/docker/builder/cpu/amazonlinux2023/Dockerfile @@ -14,10 +14,19 @@ FROM amazonlinux:2023 ARG TARGETARCH RUN dnf install -y wget g++ gcc gdb libatomic libstdc++-static ninja-build git make zip unzip tar which \ - autoconf automake golang python3 python3-pip perl-FindBin texinfo \ + autoconf automake python3 python3-pip perl-FindBin texinfo \ pkg-config libuuid-devel libaio perl-IPC-Cmd libasan openblas-devel && \ rm -rf /var/cache/yum/* +ENV GOPATH /go +ENV GOROOT /usr/local/go +ENV GO111MODULE on +ENV PATH $GOPATH/bin:$GOROOT/bin:$PATH +RUN mkdir -p /usr/local/go && wget -qO- "https://go.dev/dl/go1.21.10.linux-$TARGETARCH.tar.gz" | tar --strip-components=1 -xz -C /usr/local/go && \ + mkdir -p "$GOPATH/src" "$GOPATH/bin" && \ + go clean --modcache && \ + chmod -R 777 "$GOPATH" && chmod -R a+w $(go env GOTOOLDIR) + RUN pip3 install conan==1.61.0 RUN echo "target arch $TARGETARCH" @@ -35,9 +44,9 @@ RUN /opt/vcpkg/bootstrap-vcpkg.sh -disableMetrics && ln -s /opt/vcpkg/vcpkg /usr RUN vcpkg install azure-identity-cpp azure-storage-blobs-cpp gtest --only-downloads RUN mkdir /tmp/ccache && cd /tmp/ccache &&\ - wget https://dl.fedoraproject.org/pub/epel/9/Everything/`uname -m`/Packages/h/hiredis-1.0.2-1.el9.`uname -m`.rpm &&\ + wget https://dl.fedoraproject.org/pub/epel/9/Everything/`uname -m`/Packages/h/hiredis-1.0.2-2.el9.`uname -m`.rpm &&\ wget https://dl.fedoraproject.org/pub/epel/9/Everything/`uname -m`/Packages/c/ccache-4.5.1-2.el9.`uname -m`.rpm &&\ - rpm -i hiredis-1.0.2-1.el9.`uname -m`.rpm ccache-4.5.1-2.el9.`uname -m`.rpm &&\ + rpm -i hiredis-1.0.2-2.el9.`uname -m`.rpm ccache-4.5.1-2.el9.`uname -m`.rpm &&\ rm -rf /tmp/ccache diff --git a/build/docker/builder/cpu/rockylinux8/Dockerfile b/build/docker/builder/cpu/rockylinux8/Dockerfile index 74e625b57f9a9..ec1ea089035c3 100644 --- a/build/docker/builder/cpu/rockylinux8/Dockerfile +++ b/build/docker/builder/cpu/rockylinux8/Dockerfile @@ -43,7 +43,7 @@ RUN dnf -y update && \ RUN pip3 install conan==1.61.0 -RUN mkdir -p /usr/local/go && wget -qO- "https://go.dev/dl/go1.20.7.linux-$TARGETARCH.tar.gz" | tar --strip-components=1 -xz -C /usr/local/go +RUN mkdir -p /usr/local/go && wget -qO- "https://go.dev/dl/go1.21.10.linux-$TARGETARCH.tar.gz" | tar --strip-components=1 -xz -C /usr/local/go RUN curl https://sh.rustup.rs -sSf | \ sh -s -- --default-toolchain=1.73 -y diff --git a/build/docker/builder/cpu/ubuntu20.04/Dockerfile b/build/docker/builder/cpu/ubuntu20.04/Dockerfile index 77dd0ba101908..beae59281a370 100644 --- a/build/docker/builder/cpu/ubuntu20.04/Dockerfile +++ b/build/docker/builder/cpu/ubuntu20.04/Dockerfile @@ -40,7 +40,7 @@ ENV GOPATH /go ENV GOROOT /usr/local/go ENV GO111MODULE on ENV PATH $GOPATH/bin:$GOROOT/bin:$PATH -RUN mkdir -p /usr/local/go && wget -qO- "https://go.dev/dl/go1.20.7.linux-$TARGETARCH.tar.gz" | tar --strip-components=1 -xz -C /usr/local/go && \ +RUN mkdir -p /usr/local/go && wget -qO- "https://go.dev/dl/go1.21.10.linux-$TARGETARCH.tar.gz" | tar --strip-components=1 -xz -C /usr/local/go && \ mkdir -p "$GOPATH/src" "$GOPATH/bin" && \ go clean --modcache && \ chmod -R 777 "$GOPATH" && chmod -R a+w $(go env GOTOOLDIR) diff --git a/build/docker/builder/gpu/ubuntu20.04/Dockerfile b/build/docker/builder/gpu/ubuntu20.04/Dockerfile index a9fc65f3a895f..ba86136227817 100644 --- a/build/docker/builder/gpu/ubuntu20.04/Dockerfile +++ b/build/docker/builder/gpu/ubuntu20.04/Dockerfile @@ -51,7 +51,7 @@ ENV GOPATH /go ENV GOROOT /usr/local/go ENV GO111MODULE on ENV PATH $GOPATH/bin:$GOROOT/bin:$PATH -RUN mkdir -p /usr/local/go && wget -qO- "https://go.dev/dl/go1.20.7.linux-$TARGETARCH.tar.gz" | tar --strip-components=1 -xz -C /usr/local/go && \ +RUN mkdir -p /usr/local/go && wget -qO- "https://go.dev/dl/go1.21.10.linux-$TARGETARCH.tar.gz" | tar --strip-components=1 -xz -C /usr/local/go && \ mkdir -p "$GOPATH/src" "$GOPATH/bin" && \ curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b ${GOROOT}/bin v1.46.2 && \ # export GO111MODULE=on && go get github.com/quasilyte/go-ruleguard/cmd/ruleguard@v0.2.1 && \ diff --git a/build/docker/builder/gpu/ubuntu22.04/Dockerfile b/build/docker/builder/gpu/ubuntu22.04/Dockerfile index df5b979eae630..3f487b008561c 100644 --- a/build/docker/builder/gpu/ubuntu22.04/Dockerfile +++ b/build/docker/builder/gpu/ubuntu22.04/Dockerfile @@ -13,7 +13,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends wget curl ca-ce # Install go -RUN mkdir -p /usr/local/go && wget -qO- "https://go.dev/dl/go1.20.7.linux-$TARGETARCH.tar.gz" | tar --strip-components=1 -xz -C /usr/local/go +RUN mkdir -p /usr/local/go && wget -qO- "https://go.dev/dl/go1.21.10.linux-$TARGETARCH.tar.gz" | tar --strip-components=1 -xz -C /usr/local/go # Install conan RUN pip3 install conan==1.61.0 # Install rust diff --git a/build/docker/meta-migration/builder/Dockerfile b/build/docker/meta-migration/builder/Dockerfile index cf7832dcdea68..f102266fcfc44 100644 --- a/build/docker/meta-migration/builder/Dockerfile +++ b/build/docker/meta-migration/builder/Dockerfile @@ -1,2 +1,2 @@ -FROM golang:1.20.4-alpine3.17 +FROM golang:1.21.10-alpine3.19 RUN apk add --no-cache make bash \ No newline at end of file diff --git a/client/go.mod b/client/go.mod index e74a11debf035..e542661b64726 100644 --- a/client/go.mod +++ b/client/go.mod @@ -1,6 +1,6 @@ module github.com/milvus-io/milvus/client/v2 -go 1.20 +go 1.21 require ( github.com/blang/semver/v4 v4.0.0 diff --git a/configs/pgo/default.pgo b/configs/pgo/default.pgo new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/go.mod b/go.mod index c1c2177f6c256..63b4f1036e417 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/milvus-io/milvus -go 1.20 +go 1.21 require ( github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.0 diff --git a/pkg/go.mod b/pkg/go.mod index ed5583366a5e5..6f967a0c22605 100644 --- a/pkg/go.mod +++ b/pkg/go.mod @@ -1,6 +1,6 @@ module github.com/milvus-io/milvus/pkg -go 1.20 +go 1.21 require ( github.com/apache/pulsar-client-go v0.6.1-0.20210728062540-29414db801a7 diff --git a/scripts/README.md b/scripts/README.md index f8c1e787f991f..6b702620fe483 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -4,7 +4,7 @@ ``` OS: Ubuntu 20.04 -go:1.20 +go:1.21 cmake: >=3.18 gcc: 7.5 ``` From 8f6f6dc29f74e3c4757dcbc33f14cffa2cc9842a Mon Sep 17 00:00:00 2001 From: sre-ci-robot <56469371+sre-ci-robot@users.noreply.github.com> Date: Thu, 20 Jun 2024 17:38:01 +0800 Subject: [PATCH 03/21] [automated] Bump milvus version to v2.4.5 (#34028) Bump milvus version to v2.4.5 Signed-off-by: sre-ci-robot sre-ci-robot@users.noreply.github.com Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- deployments/binary/README.md | 2 +- deployments/docker/cluster-distributed-deployment/inventory.ini | 2 +- deployments/docker/gpu/standalone/docker-compose.yml | 2 +- deployments/docker/standalone/docker-compose.yml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/deployments/binary/README.md b/deployments/binary/README.md index 827d68cdd906d..642264bb1391f 100644 --- a/deployments/binary/README.md +++ b/deployments/binary/README.md @@ -30,7 +30,7 @@ $ ./minio server /minio To start Milvus standalone, you need a Milvus binary file. Currently you can get the latest version of Milvus binary file through the Milvus docker image. (We will upload Milvus binary files in the future) ```shell -$ docker run -d --name milvus milvusdb/milvus:v2.4.4 /bin/bash +$ docker run -d --name milvus milvusdb/milvus:v2.4.5 /bin/bash $ docker cp milvus:/milvus . ``` diff --git a/deployments/docker/cluster-distributed-deployment/inventory.ini b/deployments/docker/cluster-distributed-deployment/inventory.ini index 533c16062c350..ea63ea3a0b3e1 100644 --- a/deployments/docker/cluster-distributed-deployment/inventory.ini +++ b/deployments/docker/cluster-distributed-deployment/inventory.ini @@ -33,7 +33,7 @@ dependencies_network= host nodes_network= host ; Setup varibale to controll what image version of Milvus to use. -image= milvusdb/milvus:v2.4.4 +image= milvusdb/milvus:v2.4.5 ; Setup static IP addresses of the docker hosts as variable for container environment variable config. ; Before running the playbook, below 4 IP addresses need to be replaced with the IP of your host VM diff --git a/deployments/docker/gpu/standalone/docker-compose.yml b/deployments/docker/gpu/standalone/docker-compose.yml index c35f82951478e..065d1ceed973c 100644 --- a/deployments/docker/gpu/standalone/docker-compose.yml +++ b/deployments/docker/gpu/standalone/docker-compose.yml @@ -38,7 +38,7 @@ services: standalone: container_name: milvus-standalone - image: milvusdb/milvus:v2.4.4-gpu + image: milvusdb/milvus:v2.4.5-gpu command: ["milvus", "run", "standalone"] security_opt: - seccomp:unconfined diff --git a/deployments/docker/standalone/docker-compose.yml b/deployments/docker/standalone/docker-compose.yml index c2268a22654da..ba56d9c039c50 100644 --- a/deployments/docker/standalone/docker-compose.yml +++ b/deployments/docker/standalone/docker-compose.yml @@ -38,7 +38,7 @@ services: standalone: container_name: milvus-standalone - image: milvusdb/milvus:v2.4.4 + image: milvusdb/milvus:v2.4.5 command: ["milvus", "run", "standalone"] security_opt: - seccomp:unconfined From e55fee6b04b5d0c77934e0f0ecbbe77b87c184ab Mon Sep 17 00:00:00 2001 From: XuanYang-cn Date: Thu, 20 Jun 2024 19:32:02 +0800 Subject: [PATCH 04/21] enhance: Add deltaRowCount in l0 compaction (#33843) See also: #33998 pr: #33997 Signed-off-by: yangxuan --- internal/datacoord/compaction_l0_view.go | 44 +++++---- internal/datacoord/compaction_l0_view_test.go | 1 + internal/datacoord/compaction_view.go | 27 ++++-- .../datacoord/compaction_view_manager_test.go | 1 + internal/datanode/l0_compactor.go | 96 +++++++++---------- internal/datanode/l0_compactor_test.go | 12 +-- 6 files changed, 95 insertions(+), 86 deletions(-) diff --git a/internal/datacoord/compaction_l0_view.go b/internal/datacoord/compaction_l0_view.go index d59df36c4b369..5f70ef7102ce9 100644 --- a/internal/datacoord/compaction_l0_view.go +++ b/internal/datacoord/compaction_l0_view.go @@ -22,7 +22,13 @@ func (v *LevelZeroSegmentsView) String() string { l0strings := lo.Map(v.segments, func(v *SegmentView, _ int) string { return v.LevelZeroString() }) - return fmt.Sprintf("label=<%s>, posT=<%v>, l0 segments=%v", + + count := lo.SumBy(v.segments, func(v *SegmentView) int { + return v.DeltaRowCount + }) + return fmt.Sprintf("L0SegCount=%d, DeltaRowCount=%d, label=<%s>, posT=<%v>, L0 segments=%v", + len(v.segments), + count, v.label.String(), v.earliestGrowingSegmentPos.GetTimestamp(), l0strings) @@ -116,19 +122,20 @@ func (v *LevelZeroSegmentsView) minCountSizeTrigger(segments []*SegmentView) (pi maxDeltaCount = paramtable.Get().DataCoordCfg.LevelZeroCompactionTriggerDeltalogMaxNum.GetAsInt() ) - curSize := float64(0) + pickedSize := float64(0) + pickedCount := 0 // count >= minDeltaCount if lo.SumBy(segments, func(view *SegmentView) int { return view.DeltalogCount }) >= minDeltaCount { - picked, curSize = pickByMaxCountSize(segments, maxDeltaSize, maxDeltaCount) - reason = fmt.Sprintf("level zero segments count reaches minForceTriggerCountLimit=%d, curDeltaSize=%.2f, curDeltaCount=%d", minDeltaCount, curSize, len(segments)) + picked, pickedSize, pickedCount = pickByMaxCountSize(segments, maxDeltaSize, maxDeltaCount) + reason = fmt.Sprintf("level zero segments count reaches minForceTriggerCountLimit=%d, pickedSize=%.2fB, pickedCount=%d", minDeltaCount, pickedSize, pickedCount) return } // size >= minDeltaSize if lo.SumBy(segments, func(view *SegmentView) float64 { return view.DeltaSize }) >= minDeltaSize { - picked, curSize = pickByMaxCountSize(segments, maxDeltaSize, maxDeltaCount) - reason = fmt.Sprintf("level zero segments size reaches minForceTriggerSizeLimit=%.2f, curDeltaSize=%.2f, curDeltaCount=%d", minDeltaSize, curSize, len(segments)) + picked, pickedSize, pickedCount = pickByMaxCountSize(segments, maxDeltaSize, maxDeltaCount) + reason = fmt.Sprintf("level zero segments size reaches minForceTriggerSizeLimit=%.2fB, pickedSize=%.2fB, pickedCount=%d", minDeltaSize, pickedSize, pickedCount) return } @@ -143,30 +150,25 @@ func (v *LevelZeroSegmentsView) forceTrigger(segments []*SegmentView) (picked [] maxDeltaCount = paramtable.Get().DataCoordCfg.LevelZeroCompactionTriggerDeltalogMaxNum.GetAsInt() ) - curSize := float64(0) - picked, curSize = pickByMaxCountSize(segments, maxDeltaSize, maxDeltaCount) - reason = fmt.Sprintf("level zero views force to trigger, curDeltaSize=%.2f, curDeltaCount=%d", curSize, len(segments)) - return + picked, pickedSize, pickedCount := pickByMaxCountSize(segments, maxDeltaSize, maxDeltaCount) + reason = fmt.Sprintf("level zero views force to trigger, pickedSize=%.2fB, pickedCount=%d", pickedSize, pickedCount) + return picked, reason } // pickByMaxCountSize picks segments that count <= maxCount or size <= maxSize -func pickByMaxCountSize(segments []*SegmentView, maxSize float64, maxCount int) ([]*SegmentView, float64) { - var ( - curDeltaCount = 0 - curDeltaSize = float64(0) - ) +func pickByMaxCountSize(segments []*SegmentView, maxSize float64, maxCount int) (picked []*SegmentView, pickedSize float64, pickedCount int) { idx := 0 for _, view := range segments { - targetCount := view.DeltalogCount + curDeltaCount - targetSize := view.DeltaSize + curDeltaSize + targetCount := view.DeltalogCount + pickedCount + targetSize := view.DeltaSize + pickedSize - if (curDeltaCount != 0 && curDeltaSize != float64(0)) && (targetSize > maxSize || targetCount > maxCount) { + if (pickedCount != 0 && pickedSize != float64(0)) && (targetSize > maxSize || targetCount > maxCount) { break } - curDeltaCount = targetCount - curDeltaSize = targetSize + pickedCount = targetCount + pickedSize = targetSize idx += 1 } - return segments[:idx], curDeltaSize + return segments[:idx], pickedSize, pickedCount } diff --git a/internal/datacoord/compaction_l0_view_test.go b/internal/datacoord/compaction_l0_view_test.go index 863dbbe5678be..5fa941397b483 100644 --- a/internal/datacoord/compaction_l0_view_test.go +++ b/internal/datacoord/compaction_l0_view_test.go @@ -150,6 +150,7 @@ func (s *LevelZeroSegmentsViewSuite) TestTrigger() { if view.dmlPos.Timestamp < test.prepEarliestT { view.DeltalogCount = test.prepCountEach view.DeltaSize = test.prepSizeEach + view.DeltaRowCount = 1 } } log.Info("LevelZeroSegmentsView", zap.String("view", s.v.String())) diff --git a/internal/datacoord/compaction_view.go b/internal/datacoord/compaction_view.go index b82106213c77b..0e7a25d334c7f 100644 --- a/internal/datacoord/compaction_view.go +++ b/internal/datacoord/compaction_view.go @@ -88,6 +88,9 @@ type SegmentView struct { BinlogCount int StatslogCount int DeltalogCount int + + // row count + DeltaRowCount int } func (s *SegmentView) Clone() *SegmentView { @@ -104,6 +107,7 @@ func (s *SegmentView) Clone() *SegmentView { BinlogCount: s.BinlogCount, StatslogCount: s.StatslogCount, DeltalogCount: s.DeltalogCount, + DeltaRowCount: s.DeltaRowCount, } } @@ -126,6 +130,7 @@ func GetViewsByInfo(segments ...*SegmentInfo) []*SegmentView { DeltaSize: GetBinlogSizeAsBytes(segment.GetDeltalogs()), DeltalogCount: GetBinlogCount(segment.GetDeltalogs()), + DeltaRowCount: GetBinlogEntriesNum(segment.GetDeltalogs()), Size: GetBinlogSizeAsBytes(segment.GetBinlogs()), BinlogCount: GetBinlogCount(segment.GetBinlogs()), @@ -147,13 +152,13 @@ func (v *SegmentView) Equal(other *SegmentView) bool { } func (v *SegmentView) String() string { - return fmt.Sprintf("ID=%d, label=<%s>, state=%s, level=%s, binlogSize=%.2f, binlogCount=%d, deltaSize=%.2f, deltaCount=%d, expireSize=%.2f", - v.ID, v.label, v.State.String(), v.Level.String(), v.Size, v.BinlogCount, v.DeltaSize, v.DeltalogCount, v.ExpireSize) + return fmt.Sprintf("ID=%d, label=<%s>, state=%s, level=%s, binlogSize=%.2f, binlogCount=%d, deltaSize=%.2f, deltalogCount=%d, deltaRowCount=%d, expireSize=%.2f", + v.ID, v.label, v.State.String(), v.Level.String(), v.Size, v.BinlogCount, v.DeltaSize, v.DeltalogCount, v.DeltaRowCount, v.ExpireSize) } func (v *SegmentView) LevelZeroString() string { - return fmt.Sprintf("", - v.ID, v.Level.String(), v.DeltaSize, v.DeltalogCount) + return fmt.Sprintf("", + v.ID, v.Level.String(), v.DeltaSize, v.DeltalogCount, v.DeltaRowCount) } func GetBinlogCount(fieldBinlogs []*datapb.FieldBinlog) int { @@ -164,9 +169,19 @@ func GetBinlogCount(fieldBinlogs []*datapb.FieldBinlog) int { return num } -func GetBinlogSizeAsBytes(deltaBinlogs []*datapb.FieldBinlog) float64 { +func GetBinlogEntriesNum(fieldBinlogs []*datapb.FieldBinlog) int { + var num int + for _, fbinlog := range fieldBinlogs { + for _, binlog := range fbinlog.GetBinlogs() { + num += int(binlog.GetEntriesNum()) + } + } + return num +} + +func GetBinlogSizeAsBytes(fieldBinlogs []*datapb.FieldBinlog) float64 { var deltaSize float64 - for _, deltaLogs := range deltaBinlogs { + for _, deltaLogs := range fieldBinlogs { for _, l := range deltaLogs.GetBinlogs() { deltaSize += float64(l.GetMemorySize()) } diff --git a/internal/datacoord/compaction_view_manager_test.go b/internal/datacoord/compaction_view_manager_test.go index 4567e80aa1567..7a14d7b3b01fb 100644 --- a/internal/datacoord/compaction_view_manager_test.go +++ b/internal/datacoord/compaction_view_manager_test.go @@ -329,6 +329,7 @@ func genTestDeltalogs(logCount int, logSize int64) []*datapb.FieldBinlog { for i := 0; i < logCount; i++ { binlog := &datapb.Binlog{ + EntriesNum: int64(i), LogSize: logSize, MemorySize: logSize, } diff --git a/internal/datanode/l0_compactor.go b/internal/datanode/l0_compactor.go index f00a0768d1b31..a9840415cb0ce 100644 --- a/internal/datanode/l0_compactor.go +++ b/internal/datanode/l0_compactor.go @@ -239,7 +239,7 @@ func (t *levelZeroCompactionTask) serializeUpload(ctx context.Context, segmentWr func (t *levelZeroCompactionTask) splitDelta( ctx context.Context, - allDelta []*storage.DeleteData, + allDelta *storage.DeleteData, targetSegIDs []int64, ) map[int64]*SegmentDeltaWriter { traceCtx, span := otel.Tracer(typeutil.DataNodeRole).Start(ctx, "L0Compact splitDelta") @@ -259,9 +259,6 @@ func (t *levelZeroCompactionTask) splitDelta( startIdx := value.StartIdx pk2SegmentIDs := value.Segment2Hits - pks := allDelta[value.DeleteDataIdx].Pks - tss := allDelta[value.DeleteDataIdx].Tss - for segmentID, hits := range pk2SegmentIDs { for i, hit := range hits { if hit { @@ -271,23 +268,21 @@ func (t *levelZeroCompactionTask) splitDelta( writer = NewSegmentDeltaWriter(segmentID, segment.GetPartitionID(), t.getCollection()) targetSegBuffer[segmentID] = writer } - writer.Write(pks[startIdx+i], tss[startIdx+i]) + writer.Write(allDelta.Pks[startIdx+i], allDelta.Tss[startIdx+i]) } } } return true }) - return targetSegBuffer } type BatchApplyRet = struct { - DeleteDataIdx int - StartIdx int - Segment2Hits map[int64][]bool + StartIdx int + Segment2Hits map[int64][]bool } -func (t *levelZeroCompactionTask) applyBFInParallel(ctx context.Context, deleteDatas []*storage.DeleteData, pool *conc.Pool[any], segmentBfs []*metacache.SegmentInfo) *typeutil.ConcurrentMap[int, *BatchApplyRet] { +func (t *levelZeroCompactionTask) applyBFInParallel(ctx context.Context, deltaData *storage.DeleteData, pool *conc.Pool[any], segmentBfs []*metacache.SegmentInfo) *typeutil.ConcurrentMap[int, *BatchApplyRet] { _, span := otel.Tracer(typeutil.DataNodeRole).Start(ctx, "L0Compact applyBFInParallel") defer span.End() batchSize := paramtable.Get().CommonCfg.BloomFilterApplyBatchSize.GetAsInt() @@ -306,42 +301,37 @@ func (t *levelZeroCompactionTask) applyBFInParallel(ctx context.Context, deleteD retIdx := 0 retMap := typeutil.NewConcurrentMap[int, *BatchApplyRet]() var futures []*conc.Future[any] - for didx, data := range deleteDatas { - pks := data.Pks - for idx := 0; idx < len(pks); idx += batchSize { - startIdx := idx - endIdx := startIdx + batchSize - if endIdx > len(pks) { - endIdx = len(pks) - } + pks := deltaData.Pks + for idx := 0; idx < len(pks); idx += batchSize { + startIdx := idx + endIdx := startIdx + batchSize + if endIdx > len(pks) { + endIdx = len(pks) + } - retIdx += 1 - tmpRetIndex := retIdx - deleteDataId := didx - future := pool.Submit(func() (any, error) { - ret := batchPredict(pks[startIdx:endIdx]) - retMap.Insert(tmpRetIndex, &BatchApplyRet{ - DeleteDataIdx: deleteDataId, - StartIdx: startIdx, - Segment2Hits: ret, - }) - return nil, nil + retIdx += 1 + tmpRetIndex := retIdx + future := pool.Submit(func() (any, error) { + ret := batchPredict(pks[startIdx:endIdx]) + retMap.Insert(tmpRetIndex, &BatchApplyRet{ + StartIdx: startIdx, + Segment2Hits: ret, }) - futures = append(futures, future) - } + return nil, nil + }) + futures = append(futures, future) } conc.AwaitAll(futures...) - return retMap } func (t *levelZeroCompactionTask) process(ctx context.Context, batchSize int, targetSegments []int64, deltaLogs ...[]string) ([]*datapb.CompactionSegment, error) { - _, span := otel.Tracer(typeutil.DataNodeRole).Start(ctx, "L0Compact process") + ctx, span := otel.Tracer(typeutil.DataNodeRole).Start(ctx, "L0Compact process") defer span.End() results := make([]*datapb.CompactionSegment, 0) batch := int(math.Ceil(float64(len(targetSegments)) / float64(batchSize))) - log := log.Ctx(t.ctx).With( + log := log.Ctx(ctx).With( zap.Int64("planID", t.plan.GetPlanID()), zap.Int("max conc segment counts", batchSize), zap.Int("total segment counts", len(targetSegments)), @@ -369,7 +359,10 @@ func (t *levelZeroCompactionTask) process(ctx context.Context, batchSize int, ta return nil, err } - log.Info("L0 compaction finished one batch", zap.Int("batch no.", i), zap.Int("batch segment count", len(batchResults))) + log.Info("L0 compaction finished one batch", + zap.Int("batch no.", i), + zap.Int("total deltaRowCount", int(allDelta.RowCount)), + zap.Int("batch segment count", len(batchResults))) results = append(results, batchResults...) } @@ -377,25 +370,22 @@ func (t *levelZeroCompactionTask) process(ctx context.Context, batchSize int, ta return results, nil } -func (t *levelZeroCompactionTask) loadDelta(ctx context.Context, deltaLogs ...[]string) ([]*storage.DeleteData, error) { - _, span := otel.Tracer(typeutil.DataNodeRole).Start(ctx, "L0Compact loadDelta") +func (t *levelZeroCompactionTask) loadDelta(ctx context.Context, deltaLogs []string) (*storage.DeleteData, error) { + ctx, span := otel.Tracer(typeutil.DataNodeRole).Start(ctx, "L0Compact loadDelta") defer span.End() - allData := make([]*storage.DeleteData, 0, len(deltaLogs)) - for _, paths := range deltaLogs { - blobBytes, err := t.Download(ctx, paths) - if err != nil { - return nil, err - } - blobs := make([]*storage.Blob, 0, len(blobBytes)) - for _, blob := range blobBytes { - blobs = append(blobs, &storage.Blob{Value: blob}) - } - _, _, dData, err := storage.NewDeleteCodec().Deserialize(blobs) - if err != nil { - return nil, err - } - allData = append(allData, dData) + blobBytes, err := t.Download(ctx, deltaLogs) + if err != nil { + return nil, err } - return allData, nil + blobs := make([]*storage.Blob, 0, len(blobBytes)) + for _, blob := range blobBytes { + blobs = append(blobs, &storage.Blob{Value: blob}) + } + _, _, dData, err := storage.NewDeleteCodec().Deserialize(blobs) + if err != nil { + return nil, err + } + + return dData, nil } diff --git a/internal/datanode/l0_compactor_test.go b/internal/datanode/l0_compactor_test.go index 80d7db06db305..e9dffdd7305eb 100644 --- a/internal/datanode/l0_compactor_test.go +++ b/internal/datanode/l0_compactor_test.go @@ -406,7 +406,7 @@ func (s *LevelZeroCompactionTaskSuite) TestSplitDelta() { s.mockMeta.EXPECT().Collection().Return(1) targetSegIDs := predicted - deltaWriters := s.task.splitDelta(context.TODO(), []*storage.DeleteData{s.dData}, targetSegIDs) + deltaWriters := s.task.splitDelta(context.TODO(), s.dData, targetSegIDs) s.NotEmpty(deltaWriters) s.ElementsMatch(predicted, lo.Keys(deltaWriters)) @@ -449,16 +449,16 @@ func (s *LevelZeroCompactionTaskSuite) TestLoadDelta() { } for _, test := range tests { - dDatas, err := s.task.loadDelta(ctx, test.paths) + dData, err := s.task.loadDelta(ctx, test.paths) if test.expectError { s.Error(err) } else { s.NoError(err) - s.NotEmpty(dDatas) - s.EqualValues(1, len(dDatas)) - s.ElementsMatch(s.dData.Pks, dDatas[0].Pks) - s.Equal(s.dData.RowCount, dDatas[0].RowCount) + s.NotEmpty(dData) + s.NotNil(dData) + s.ElementsMatch(s.dData.Pks, dData.Pks) + s.Equal(s.dData.RowCount, dData.RowCount) } } } From 5952c09925d098f4306920da010849dec5dd009b Mon Sep 17 00:00:00 2001 From: elstic Date: Fri, 21 Jun 2024 10:12:01 +0800 Subject: [PATCH 05/21] test: [cherry-pick] optimizing variable names (#34036) pr: https://github.com/milvus-io/milvus/pull/34035 Signed-off-by: elstic --- tests/python_client/common/common_func.py | 14 +++++++------- tests/python_client/common/common_type.py | 3 ++- tests/python_client/testcases/test_search.py | 2 +- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/python_client/common/common_func.py b/tests/python_client/common/common_func.py index df0deeb2c3ca8..ec0c591a23c8f 100644 --- a/tests/python_client/common/common_func.py +++ b/tests/python_client/common/common_func.py @@ -365,10 +365,10 @@ def gen_collection_schema_all_datatype(description=ct.default_desc, else: multiple_dim_array.insert(0, dim) for i in range(len(multiple_dim_array)): - if ct.all_float_vector_types[i%3] != ct.sparse_vector: - fields.append(gen_float_vec_field(name=f"multiple_vector_{ct.all_float_vector_types[i%3]}", + if ct.append_vector_type[i%3] != ct.sparse_vector: + fields.append(gen_float_vec_field(name=f"multiple_vector_{ct.append_vector_type[i%3]}", dim=multiple_dim_array[i], - vector_data_type=ct.all_float_vector_types[i%3])) + vector_data_type=ct.append_vector_type[i%3])) else: # The field of a sparse vector cannot be dimensioned fields.append(gen_float_vec_field(name=f"multiple_vector_{ct.sparse_vector}", @@ -720,7 +720,7 @@ def gen_dataframe_all_data_type(nb=ct.default_nb, dim=ct.default_dim, start=0, w df[ct.default_float_vec_field_name] = float_vec_values else: for i in range(len(multiple_dim_array)): - df[multiple_vector_field_name[i]] = gen_vectors(nb, multiple_dim_array[i], ct.all_float_vector_types[i%3]) + df[multiple_vector_field_name[i]] = gen_vectors(nb, multiple_dim_array[i], ct.append_vector_type[i%3]) if with_json is False: df.drop(ct.default_json_field_name, axis=1, inplace=True) @@ -758,7 +758,7 @@ def gen_general_list_all_data_type(nb=ct.default_nb, dim=ct.default_dim, start=0 insert_list.append(float_vec_values) else: for i in range(len(multiple_dim_array)): - insert_list.append(gen_vectors(nb, multiple_dim_array[i], ct.all_float_vector_types[i%3])) + insert_list.append(gen_vectors(nb, multiple_dim_array[i], ct.append_vector_type[i%3])) if with_json is False: # index = insert_list.index(json_values) @@ -803,7 +803,7 @@ def gen_default_rows_data_all_data_type(nb=ct.default_nb, dim=ct.default_dim, st else: for i in range(len(multiple_dim_array)): dict[multiple_vector_field_name[i]] = gen_vectors(nb, multiple_dim_array[i], - ct.all_float_vector_types[i])[0] + ct.append_vector_type[i])[0] if len(multiple_dim_array) != 0: with open(ct.rows_all_data_type_file_path + f'_{partition_id}' + f'_dim{dim}.txt', 'wb') as json_file: pickle.dump(array, json_file) @@ -1795,7 +1795,7 @@ def insert_data(collection_w, nb=ct.default_nb, is_binary=False, is_all_data_typ multiple_vector_field_name=vector_name_list, vector_data_type=vector_data_type, auto_id=auto_id, primary_field=primary_field) - elif vector_data_type in ct.all_float_vector_types: + elif vector_data_type in ct.append_vector_type: default_data = gen_general_default_list_data(nb // num, dim=dim, start=start, with_json=with_json, random_primary_key=random_primary_key, multiple_dim_array=multiple_dim_array, diff --git a/tests/python_client/common/common_type.py b/tests/python_client/common/common_type.py index a2c3c3af2dfd2..b8ca8a265970d 100644 --- a/tests/python_client/common/common_type.py +++ b/tests/python_client/common/common_type.py @@ -45,7 +45,8 @@ float16_type = "FLOAT16_VECTOR" bfloat16_type = "BFLOAT16_VECTOR" sparse_vector = "SPARSE_FLOAT_VECTOR" -all_float_vector_types = [float16_type, bfloat16_type, sparse_vector] +append_vector_type = [float16_type, bfloat16_type, sparse_vector] +all_dense_vector_types = [float_type, float16_type, bfloat16_type] default_sparse_vec_field_name = "sparse_vector" default_partition_name = "_default" default_resource_group_name = '__default_resource_group' diff --git a/tests/python_client/testcases/test_search.py b/tests/python_client/testcases/test_search.py index dd584c00a8010..e18915a57504b 100644 --- a/tests/python_client/testcases/test_search.py +++ b/tests/python_client/testcases/test_search.py @@ -6925,7 +6925,7 @@ def enable_dynamic_field(self, request): ****************************************************************** """ @pytest.mark.tags(CaseLabel.L0) - @pytest.mark.parametrize("vector_data_type", ["FLOAT_VECTOR", "FLOAT16_VECTOR", "BFLOAT16_VECTOR"]) + @pytest.mark.parametrize("vector_data_type", ct.all_dense_vector_types) def test_range_search_default(self, index_type, metric, vector_data_type): """ target: verify the range search returns correct results From fbc8fb3cb2967de6e460dbc87b33792447f33f3a Mon Sep 17 00:00:00 2001 From: wei liu Date: Fri, 21 Jun 2024 10:24:12 +0800 Subject: [PATCH 06/21] enhance: Skip return data distribution if no change happen (#32814) (#33985) issue: #32813 pr: #32814 --------- Signed-off-by: Wei Liu --- Makefile | 1 + internal/proto/query_coord.proto | 2 + .../querycoordv2/dist/dist_controller_test.go | 25 +- internal/querycoordv2/dist/dist_handler.go | 66 +- .../querycoordv2/dist/dist_handler_test.go | 127 +++ .../querycoordv2/meta/mock_target_manager.go | 975 ++++++++++++++++++ internal/querycoordv2/meta/target_manager.go | 23 + internal/querynodev2/server.go | 4 + internal/querynodev2/services.go | 49 +- 9 files changed, 1231 insertions(+), 41 deletions(-) create mode 100644 internal/querycoordv2/dist/dist_handler_test.go create mode 100644 internal/querycoordv2/meta/mock_target_manager.go diff --git a/Makefile b/Makefile index c1bfc9e901e75..2bb8ca7728002 100644 --- a/Makefile +++ b/Makefile @@ -429,6 +429,7 @@ generate-mockery-proxy: getdeps generate-mockery-querycoord: getdeps $(INSTALL_PATH)/mockery --name=QueryNodeServer --dir=$(PWD)/internal/proto/querypb/ --output=$(PWD)/internal/querycoordv2/mocks --filename=mock_querynode.go --with-expecter --structname=MockQueryNodeServer $(INSTALL_PATH)/mockery --name=Broker --dir=$(PWD)/internal/querycoordv2/meta --output=$(PWD)/internal/querycoordv2/meta --filename=mock_broker.go --with-expecter --structname=MockBroker --outpkg=meta + $(INSTALL_PATH)/mockery --name=TargetManagerInterface --dir=$(PWD)/internal/querycoordv2/meta --output=$(PWD)/internal/querycoordv2/meta --filename=mock_target_manager.go --with-expecter --structname=MockTargetManager --inpackage $(INSTALL_PATH)/mockery --name=Scheduler --dir=$(PWD)/internal/querycoordv2/task --output=$(PWD)/internal/querycoordv2/task --filename=mock_scheduler.go --with-expecter --structname=MockScheduler --outpkg=task --inpackage $(INSTALL_PATH)/mockery --name=Cluster --dir=$(PWD)/internal/querycoordv2/session --output=$(PWD)/internal/querycoordv2/session --filename=mock_cluster.go --with-expecter --structname=MockCluster --outpkg=session --inpackage $(INSTALL_PATH)/mockery --name=Balance --dir=$(PWD)/internal/querycoordv2/balance --output=$(PWD)/internal/querycoordv2/balance --filename=mock_balancer.go --with-expecter --structname=MockBalancer --outpkg=balance --inpackage diff --git a/internal/proto/query_coord.proto b/internal/proto/query_coord.proto index 2f5facab9d86f..808c94d912111 100644 --- a/internal/proto/query_coord.proto +++ b/internal/proto/query_coord.proto @@ -589,6 +589,7 @@ message SealedSegmentsChangeInfo { message GetDataDistributionRequest { common.MsgBase base = 1; map checkpoints = 2; + int64 lastUpdateTs = 3; } message GetDataDistributionResponse { @@ -597,6 +598,7 @@ message GetDataDistributionResponse { repeated SegmentVersionInfo segments = 3; repeated ChannelVersionInfo channels = 4; repeated LeaderView leader_views = 5; + int64 lastModifyTs = 6; } message LeaderView { diff --git a/internal/querycoordv2/dist/dist_controller_test.go b/internal/querycoordv2/dist/dist_controller_test.go index d0ee50fad52bf..8ecaa0e410ba4 100644 --- a/internal/querycoordv2/dist/dist_controller_test.go +++ b/internal/querycoordv2/dist/dist_controller_test.go @@ -48,6 +48,8 @@ type DistControllerTestSuite struct { kv kv.MetaKv meta *meta.Meta broker *meta.MockBroker + + nodeMgr *session.NodeManager } func (suite *DistControllerTestSuite) SetupTest() { @@ -69,16 +71,17 @@ func (suite *DistControllerTestSuite) SetupTest() { // meta store := querycoord.NewCatalog(suite.kv) idAllocator := RandomIncrementIDAllocator() - suite.meta = meta.NewMeta(idAllocator, store, session.NewNodeManager()) + + suite.nodeMgr = session.NewNodeManager() + suite.meta = meta.NewMeta(idAllocator, store, suite.nodeMgr) suite.mockCluster = session.NewMockCluster(suite.T()) - nodeManager := session.NewNodeManager() distManager := meta.NewDistributionManager() suite.broker = meta.NewMockBroker(suite.T()) targetManager := meta.NewTargetManager(suite.broker, suite.meta) suite.mockScheduler = task.NewMockScheduler(suite.T()) suite.mockScheduler.EXPECT().GetExecutedFlag(mock.Anything).Return(nil).Maybe() - suite.controller = NewDistController(suite.mockCluster, nodeManager, distManager, targetManager, suite.mockScheduler) + suite.controller = NewDistController(suite.mockCluster, suite.nodeMgr, distManager, targetManager, suite.mockScheduler) } func (suite *DistControllerTestSuite) TearDownSuite() { @@ -86,6 +89,11 @@ func (suite *DistControllerTestSuite) TearDownSuite() { } func (suite *DistControllerTestSuite) TestStart() { + suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 1, + Address: "localhost", + Hostname: "localhost", + })) dispatchCalled := atomic.NewBool(false) suite.mockCluster.EXPECT().GetDataDistribution(mock.Anything, mock.Anything, mock.Anything).Return( &querypb.GetDataDistributionResponse{Status: merr.Success(), NodeID: 1}, @@ -134,6 +142,17 @@ func (suite *DistControllerTestSuite) TestStop() { } func (suite *DistControllerTestSuite) TestSyncAll() { + suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 1, + Address: "localhost", + Hostname: "localhost", + })) + + suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 2, + Address: "localhost", + Hostname: "localhost", + })) suite.controller.StartDistInstance(context.TODO(), 1) suite.controller.StartDistInstance(context.TODO(), 2) diff --git a/internal/querycoordv2/dist/dist_handler.go b/internal/querycoordv2/dist/dist_handler.go index 1729186718da2..4a5e8f92b8f00 100644 --- a/internal/querycoordv2/dist/dist_handler.go +++ b/internal/querycoordv2/dist/dist_handler.go @@ -26,7 +26,6 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querycoordv2/meta" @@ -40,16 +39,17 @@ import ( ) type distHandler struct { - nodeID int64 - c chan struct{} - wg sync.WaitGroup - client session.Cluster - nodeManager *session.NodeManager - scheduler task.Scheduler - dist *meta.DistributionManager - target *meta.TargetManager - mu sync.Mutex - stopOnce sync.Once + nodeID int64 + c chan struct{} + wg sync.WaitGroup + client session.Cluster + nodeManager *session.NodeManager + scheduler task.Scheduler + dist *meta.DistributionManager + target meta.TargetManagerInterface + mu sync.Mutex + stopOnce sync.Once + lastUpdateTs int64 } func (dh *distHandler) start(ctx context.Context) { @@ -103,21 +103,31 @@ func (dh *distHandler) pullDist(ctx context.Context, failures *int, dispatchTask func (dh *distHandler) handleDistResp(resp *querypb.GetDataDistributionResponse, dispatchTask bool) { node := dh.nodeManager.Get(resp.GetNodeID()) - if node != nil { + if node == nil { + return + } + + if time.Since(node.LastHeartbeat()) > paramtable.Get().QueryCoordCfg.HeartBeatWarningLag.GetAsDuration(time.Millisecond) { + log.Warn("node last heart beat time lag too behind", zap.Time("now", time.Now()), + zap.Time("lastHeartBeatTime", node.LastHeartbeat()), zap.Int64("nodeID", node.ID())) + } + node.SetLastHeartbeat(time.Now()) + + // skip update dist if no distribution change happens in query node + if resp.GetLastModifyTs() != 0 && resp.GetLastModifyTs() <= dh.lastUpdateTs { + log.RatedInfo(30, "skip update dist due to no distribution change", zap.Int64("lastModifyTs", resp.GetLastModifyTs()), zap.Int64("lastUpdateTs", dh.lastUpdateTs)) + } else { + dh.lastUpdateTs = resp.GetLastModifyTs() + node.UpdateStats( session.WithSegmentCnt(len(resp.GetSegments())), session.WithChannelCnt(len(resp.GetChannels())), ) - if time.Since(node.LastHeartbeat()) > paramtable.Get().QueryCoordCfg.HeartBeatWarningLag.GetAsDuration(time.Millisecond) { - log.Warn("node last heart beat time lag too behind", zap.Time("now", time.Now()), - zap.Time("lastHeartBeatTime", node.LastHeartbeat()), zap.Int64("nodeID", node.ID())) - } - node.SetLastHeartbeat(time.Now()) - } - dh.updateSegmentsDistribution(resp) - dh.updateChannelsDistribution(resp) - dh.updateLeaderView(resp) + dh.updateSegmentsDistribution(resp) + dh.updateChannelsDistribution(resp) + dh.updateLeaderView(resp) + } if dispatchTask { dh.scheduler.Dispatch(dh.nodeID) @@ -232,23 +242,13 @@ func (dh *distHandler) getDistribution(ctx context.Context) (*querypb.GetDataDis dh.mu.Lock() defer dh.mu.Unlock() - channels := make(map[string]*msgpb.MsgPosition) - for _, channel := range dh.dist.ChannelDistManager.GetByFilter(meta.WithNodeID2Channel(dh.nodeID)) { - targetChannel := dh.target.GetDmChannel(channel.GetCollectionID(), channel.GetChannelName(), meta.CurrentTarget) - if targetChannel == nil { - continue - } - - channels[channel.GetChannelName()] = targetChannel.GetSeekPosition() - } - ctx, cancel := context.WithTimeout(ctx, paramtable.Get().QueryCoordCfg.DistributionRequestTimeout.GetAsDuration(time.Millisecond)) defer cancel() resp, err := dh.client.GetDataDistribution(ctx, dh.nodeID, &querypb.GetDataDistributionRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_GetDistribution), ), - Checkpoints: channels, + LastUpdateTs: dh.lastUpdateTs, }) if err != nil { return nil, err @@ -277,7 +277,7 @@ func newDistHandler( nodeManager *session.NodeManager, scheduler task.Scheduler, dist *meta.DistributionManager, - targetMgr *meta.TargetManager, + targetMgr meta.TargetManagerInterface, ) *distHandler { h := &distHandler{ nodeID: nodeID, diff --git a/internal/querycoordv2/dist/dist_handler_test.go b/internal/querycoordv2/dist/dist_handler_test.go new file mode 100644 index 0000000000000..99b2ad47b6c25 --- /dev/null +++ b/internal/querycoordv2/dist/dist_handler_test.go @@ -0,0 +1,127 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dist + +import ( + "context" + "testing" + "time" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/querycoordv2/meta" + "github.com/milvus-io/milvus/internal/querycoordv2/session" + "github.com/milvus-io/milvus/internal/querycoordv2/task" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +type DistHandlerSuite struct { + suite.Suite + + ctx context.Context + meta *meta.Meta + broker *meta.MockBroker + + nodeID int64 + client *session.MockCluster + nodeManager *session.NodeManager + scheduler *task.MockScheduler + dist *meta.DistributionManager + target *meta.MockTargetManager + + handler *distHandler +} + +func (suite *DistHandlerSuite) SetupSuite() { + paramtable.Init() + suite.nodeID = 1 + suite.client = session.NewMockCluster(suite.T()) + suite.nodeManager = session.NewNodeManager() + suite.scheduler = task.NewMockScheduler(suite.T()) + suite.dist = meta.NewDistributionManager() + + suite.target = meta.NewMockTargetManager(suite.T()) + suite.ctx = context.Background() + + suite.scheduler.EXPECT().Dispatch(mock.Anything).Maybe() + suite.scheduler.EXPECT().GetExecutedFlag(mock.Anything).Return(nil).Maybe() + suite.target.EXPECT().GetSealedSegment(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + suite.target.EXPECT().GetDmChannel(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() +} + +func (suite *DistHandlerSuite) TestBasic() { + suite.nodeManager.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 1, + Address: "localhost", + Hostname: "localhost", + })) + suite.client.EXPECT().GetDataDistribution(mock.Anything, mock.Anything, mock.Anything).Return(&querypb.GetDataDistributionResponse{ + Status: merr.Success(), + NodeID: 1, + Channels: []*querypb.ChannelVersionInfo{ + { + Channel: "test-channel-1", + Collection: 1, + Version: 1, + }, + }, + Segments: []*querypb.SegmentVersionInfo{ + { + ID: 1, + Collection: 1, + Partition: 1, + Channel: "test-channel-1", + Version: 1, + }, + }, + + LeaderViews: []*querypb.LeaderView{ + { + Collection: 1, + Channel: "test-channel-1", + }, + }, + LastModifyTs: 1, + }, nil) + + suite.handler = newDistHandler(suite.ctx, suite.nodeID, suite.client, suite.nodeManager, suite.scheduler, suite.dist, suite.target) + defer suite.handler.stop() + + time.Sleep(10 * time.Second) +} + +func (suite *DistHandlerSuite) TestGetDistributionFailed() { + suite.nodeManager.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 1, + Address: "localhost", + Hostname: "localhost", + })) + suite.client.EXPECT().GetDataDistribution(mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("fake error")) + + suite.handler = newDistHandler(suite.ctx, suite.nodeID, suite.client, suite.nodeManager, suite.scheduler, suite.dist, suite.target) + defer suite.handler.stop() + + time.Sleep(10 * time.Second) +} + +func TestDistHandlerSuite(t *testing.T) { + suite.Run(t, new(DistHandlerSuite)) +} diff --git a/internal/querycoordv2/meta/mock_target_manager.go b/internal/querycoordv2/meta/mock_target_manager.go new file mode 100644 index 0000000000000..5728fd2903f32 --- /dev/null +++ b/internal/querycoordv2/meta/mock_target_manager.go @@ -0,0 +1,975 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package meta + +import ( + metastore "github.com/milvus-io/milvus/internal/metastore" + datapb "github.com/milvus-io/milvus/internal/proto/datapb" + + mock "github.com/stretchr/testify/mock" + + typeutil "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +// MockTargetManager is an autogenerated mock type for the TargetManagerInterface type +type MockTargetManager struct { + mock.Mock +} + +type MockTargetManager_Expecter struct { + mock *mock.Mock +} + +func (_m *MockTargetManager) EXPECT() *MockTargetManager_Expecter { + return &MockTargetManager_Expecter{mock: &_m.Mock} +} + +// GetCollectionTargetVersion provides a mock function with given fields: collectionID, scope +func (_m *MockTargetManager) GetCollectionTargetVersion(collectionID int64, scope int32) int64 { + ret := _m.Called(collectionID, scope) + + var r0 int64 + if rf, ok := ret.Get(0).(func(int64, int32) int64); ok { + r0 = rf(collectionID, scope) + } else { + r0 = ret.Get(0).(int64) + } + + return r0 +} + +// MockTargetManager_GetCollectionTargetVersion_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionTargetVersion' +type MockTargetManager_GetCollectionTargetVersion_Call struct { + *mock.Call +} + +// GetCollectionTargetVersion is a helper method to define mock.On call +// - collectionID int64 +// - scope int32 +func (_e *MockTargetManager_Expecter) GetCollectionTargetVersion(collectionID interface{}, scope interface{}) *MockTargetManager_GetCollectionTargetVersion_Call { + return &MockTargetManager_GetCollectionTargetVersion_Call{Call: _e.mock.On("GetCollectionTargetVersion", collectionID, scope)} +} + +func (_c *MockTargetManager_GetCollectionTargetVersion_Call) Run(run func(collectionID int64, scope int32)) *MockTargetManager_GetCollectionTargetVersion_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64), args[1].(int32)) + }) + return _c +} + +func (_c *MockTargetManager_GetCollectionTargetVersion_Call) Return(_a0 int64) *MockTargetManager_GetCollectionTargetVersion_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockTargetManager_GetCollectionTargetVersion_Call) RunAndReturn(run func(int64, int32) int64) *MockTargetManager_GetCollectionTargetVersion_Call { + _c.Call.Return(run) + return _c +} + +// GetDmChannel provides a mock function with given fields: collectionID, channel, scope +func (_m *MockTargetManager) GetDmChannel(collectionID int64, channel string, scope int32) *DmChannel { + ret := _m.Called(collectionID, channel, scope) + + var r0 *DmChannel + if rf, ok := ret.Get(0).(func(int64, string, int32) *DmChannel); ok { + r0 = rf(collectionID, channel, scope) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*DmChannel) + } + } + + return r0 +} + +// MockTargetManager_GetDmChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDmChannel' +type MockTargetManager_GetDmChannel_Call struct { + *mock.Call +} + +// GetDmChannel is a helper method to define mock.On call +// - collectionID int64 +// - channel string +// - scope int32 +func (_e *MockTargetManager_Expecter) GetDmChannel(collectionID interface{}, channel interface{}, scope interface{}) *MockTargetManager_GetDmChannel_Call { + return &MockTargetManager_GetDmChannel_Call{Call: _e.mock.On("GetDmChannel", collectionID, channel, scope)} +} + +func (_c *MockTargetManager_GetDmChannel_Call) Run(run func(collectionID int64, channel string, scope int32)) *MockTargetManager_GetDmChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64), args[1].(string), args[2].(int32)) + }) + return _c +} + +func (_c *MockTargetManager_GetDmChannel_Call) Return(_a0 *DmChannel) *MockTargetManager_GetDmChannel_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockTargetManager_GetDmChannel_Call) RunAndReturn(run func(int64, string, int32) *DmChannel) *MockTargetManager_GetDmChannel_Call { + _c.Call.Return(run) + return _c +} + +// GetDmChannelsByCollection provides a mock function with given fields: collectionID, scope +func (_m *MockTargetManager) GetDmChannelsByCollection(collectionID int64, scope int32) map[string]*DmChannel { + ret := _m.Called(collectionID, scope) + + var r0 map[string]*DmChannel + if rf, ok := ret.Get(0).(func(int64, int32) map[string]*DmChannel); ok { + r0 = rf(collectionID, scope) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]*DmChannel) + } + } + + return r0 +} + +// MockTargetManager_GetDmChannelsByCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDmChannelsByCollection' +type MockTargetManager_GetDmChannelsByCollection_Call struct { + *mock.Call +} + +// GetDmChannelsByCollection is a helper method to define mock.On call +// - collectionID int64 +// - scope int32 +func (_e *MockTargetManager_Expecter) GetDmChannelsByCollection(collectionID interface{}, scope interface{}) *MockTargetManager_GetDmChannelsByCollection_Call { + return &MockTargetManager_GetDmChannelsByCollection_Call{Call: _e.mock.On("GetDmChannelsByCollection", collectionID, scope)} +} + +func (_c *MockTargetManager_GetDmChannelsByCollection_Call) Run(run func(collectionID int64, scope int32)) *MockTargetManager_GetDmChannelsByCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64), args[1].(int32)) + }) + return _c +} + +func (_c *MockTargetManager_GetDmChannelsByCollection_Call) Return(_a0 map[string]*DmChannel) *MockTargetManager_GetDmChannelsByCollection_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockTargetManager_GetDmChannelsByCollection_Call) RunAndReturn(run func(int64, int32) map[string]*DmChannel) *MockTargetManager_GetDmChannelsByCollection_Call { + _c.Call.Return(run) + return _c +} + +// GetDroppedSegmentsByChannel provides a mock function with given fields: collectionID, channelName, scope +func (_m *MockTargetManager) GetDroppedSegmentsByChannel(collectionID int64, channelName string, scope int32) []int64 { + ret := _m.Called(collectionID, channelName, scope) + + var r0 []int64 + if rf, ok := ret.Get(0).(func(int64, string, int32) []int64); ok { + r0 = rf(collectionID, channelName, scope) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]int64) + } + } + + return r0 +} + +// MockTargetManager_GetDroppedSegmentsByChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDroppedSegmentsByChannel' +type MockTargetManager_GetDroppedSegmentsByChannel_Call struct { + *mock.Call +} + +// GetDroppedSegmentsByChannel is a helper method to define mock.On call +// - collectionID int64 +// - channelName string +// - scope int32 +func (_e *MockTargetManager_Expecter) GetDroppedSegmentsByChannel(collectionID interface{}, channelName interface{}, scope interface{}) *MockTargetManager_GetDroppedSegmentsByChannel_Call { + return &MockTargetManager_GetDroppedSegmentsByChannel_Call{Call: _e.mock.On("GetDroppedSegmentsByChannel", collectionID, channelName, scope)} +} + +func (_c *MockTargetManager_GetDroppedSegmentsByChannel_Call) Run(run func(collectionID int64, channelName string, scope int32)) *MockTargetManager_GetDroppedSegmentsByChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64), args[1].(string), args[2].(int32)) + }) + return _c +} + +func (_c *MockTargetManager_GetDroppedSegmentsByChannel_Call) Return(_a0 []int64) *MockTargetManager_GetDroppedSegmentsByChannel_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockTargetManager_GetDroppedSegmentsByChannel_Call) RunAndReturn(run func(int64, string, int32) []int64) *MockTargetManager_GetDroppedSegmentsByChannel_Call { + _c.Call.Return(run) + return _c +} + +// GetGrowingSegmentsByChannel provides a mock function with given fields: collectionID, channelName, scope +func (_m *MockTargetManager) GetGrowingSegmentsByChannel(collectionID int64, channelName string, scope int32) typeutil.Set[int64] { + ret := _m.Called(collectionID, channelName, scope) + + var r0 typeutil.Set[int64] + if rf, ok := ret.Get(0).(func(int64, string, int32) typeutil.Set[int64]); ok { + r0 = rf(collectionID, channelName, scope) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(typeutil.Set[int64]) + } + } + + return r0 +} + +// MockTargetManager_GetGrowingSegmentsByChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetGrowingSegmentsByChannel' +type MockTargetManager_GetGrowingSegmentsByChannel_Call struct { + *mock.Call +} + +// GetGrowingSegmentsByChannel is a helper method to define mock.On call +// - collectionID int64 +// - channelName string +// - scope int32 +func (_e *MockTargetManager_Expecter) GetGrowingSegmentsByChannel(collectionID interface{}, channelName interface{}, scope interface{}) *MockTargetManager_GetGrowingSegmentsByChannel_Call { + return &MockTargetManager_GetGrowingSegmentsByChannel_Call{Call: _e.mock.On("GetGrowingSegmentsByChannel", collectionID, channelName, scope)} +} + +func (_c *MockTargetManager_GetGrowingSegmentsByChannel_Call) Run(run func(collectionID int64, channelName string, scope int32)) *MockTargetManager_GetGrowingSegmentsByChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64), args[1].(string), args[2].(int32)) + }) + return _c +} + +func (_c *MockTargetManager_GetGrowingSegmentsByChannel_Call) Return(_a0 typeutil.Set[int64]) *MockTargetManager_GetGrowingSegmentsByChannel_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockTargetManager_GetGrowingSegmentsByChannel_Call) RunAndReturn(run func(int64, string, int32) typeutil.Set[int64]) *MockTargetManager_GetGrowingSegmentsByChannel_Call { + _c.Call.Return(run) + return _c +} + +// GetGrowingSegmentsByCollection provides a mock function with given fields: collectionID, scope +func (_m *MockTargetManager) GetGrowingSegmentsByCollection(collectionID int64, scope int32) typeutil.Set[int64] { + ret := _m.Called(collectionID, scope) + + var r0 typeutil.Set[int64] + if rf, ok := ret.Get(0).(func(int64, int32) typeutil.Set[int64]); ok { + r0 = rf(collectionID, scope) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(typeutil.Set[int64]) + } + } + + return r0 +} + +// MockTargetManager_GetGrowingSegmentsByCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetGrowingSegmentsByCollection' +type MockTargetManager_GetGrowingSegmentsByCollection_Call struct { + *mock.Call +} + +// GetGrowingSegmentsByCollection is a helper method to define mock.On call +// - collectionID int64 +// - scope int32 +func (_e *MockTargetManager_Expecter) GetGrowingSegmentsByCollection(collectionID interface{}, scope interface{}) *MockTargetManager_GetGrowingSegmentsByCollection_Call { + return &MockTargetManager_GetGrowingSegmentsByCollection_Call{Call: _e.mock.On("GetGrowingSegmentsByCollection", collectionID, scope)} +} + +func (_c *MockTargetManager_GetGrowingSegmentsByCollection_Call) Run(run func(collectionID int64, scope int32)) *MockTargetManager_GetGrowingSegmentsByCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64), args[1].(int32)) + }) + return _c +} + +func (_c *MockTargetManager_GetGrowingSegmentsByCollection_Call) Return(_a0 typeutil.Set[int64]) *MockTargetManager_GetGrowingSegmentsByCollection_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockTargetManager_GetGrowingSegmentsByCollection_Call) RunAndReturn(run func(int64, int32) typeutil.Set[int64]) *MockTargetManager_GetGrowingSegmentsByCollection_Call { + _c.Call.Return(run) + return _c +} + +// GetSealedSegment provides a mock function with given fields: collectionID, id, scope +func (_m *MockTargetManager) GetSealedSegment(collectionID int64, id int64, scope int32) *datapb.SegmentInfo { + ret := _m.Called(collectionID, id, scope) + + var r0 *datapb.SegmentInfo + if rf, ok := ret.Get(0).(func(int64, int64, int32) *datapb.SegmentInfo); ok { + r0 = rf(collectionID, id, scope) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*datapb.SegmentInfo) + } + } + + return r0 +} + +// MockTargetManager_GetSealedSegment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSealedSegment' +type MockTargetManager_GetSealedSegment_Call struct { + *mock.Call +} + +// GetSealedSegment is a helper method to define mock.On call +// - collectionID int64 +// - id int64 +// - scope int32 +func (_e *MockTargetManager_Expecter) GetSealedSegment(collectionID interface{}, id interface{}, scope interface{}) *MockTargetManager_GetSealedSegment_Call { + return &MockTargetManager_GetSealedSegment_Call{Call: _e.mock.On("GetSealedSegment", collectionID, id, scope)} +} + +func (_c *MockTargetManager_GetSealedSegment_Call) Run(run func(collectionID int64, id int64, scope int32)) *MockTargetManager_GetSealedSegment_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64), args[1].(int64), args[2].(int32)) + }) + return _c +} + +func (_c *MockTargetManager_GetSealedSegment_Call) Return(_a0 *datapb.SegmentInfo) *MockTargetManager_GetSealedSegment_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockTargetManager_GetSealedSegment_Call) RunAndReturn(run func(int64, int64, int32) *datapb.SegmentInfo) *MockTargetManager_GetSealedSegment_Call { + _c.Call.Return(run) + return _c +} + +// GetSealedSegmentsByChannel provides a mock function with given fields: collectionID, channelName, scope +func (_m *MockTargetManager) GetSealedSegmentsByChannel(collectionID int64, channelName string, scope int32) map[int64]*datapb.SegmentInfo { + ret := _m.Called(collectionID, channelName, scope) + + var r0 map[int64]*datapb.SegmentInfo + if rf, ok := ret.Get(0).(func(int64, string, int32) map[int64]*datapb.SegmentInfo); ok { + r0 = rf(collectionID, channelName, scope) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[int64]*datapb.SegmentInfo) + } + } + + return r0 +} + +// MockTargetManager_GetSealedSegmentsByChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSealedSegmentsByChannel' +type MockTargetManager_GetSealedSegmentsByChannel_Call struct { + *mock.Call +} + +// GetSealedSegmentsByChannel is a helper method to define mock.On call +// - collectionID int64 +// - channelName string +// - scope int32 +func (_e *MockTargetManager_Expecter) GetSealedSegmentsByChannel(collectionID interface{}, channelName interface{}, scope interface{}) *MockTargetManager_GetSealedSegmentsByChannel_Call { + return &MockTargetManager_GetSealedSegmentsByChannel_Call{Call: _e.mock.On("GetSealedSegmentsByChannel", collectionID, channelName, scope)} +} + +func (_c *MockTargetManager_GetSealedSegmentsByChannel_Call) Run(run func(collectionID int64, channelName string, scope int32)) *MockTargetManager_GetSealedSegmentsByChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64), args[1].(string), args[2].(int32)) + }) + return _c +} + +func (_c *MockTargetManager_GetSealedSegmentsByChannel_Call) Return(_a0 map[int64]*datapb.SegmentInfo) *MockTargetManager_GetSealedSegmentsByChannel_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockTargetManager_GetSealedSegmentsByChannel_Call) RunAndReturn(run func(int64, string, int32) map[int64]*datapb.SegmentInfo) *MockTargetManager_GetSealedSegmentsByChannel_Call { + _c.Call.Return(run) + return _c +} + +// GetSealedSegmentsByCollection provides a mock function with given fields: collectionID, scope +func (_m *MockTargetManager) GetSealedSegmentsByCollection(collectionID int64, scope int32) map[int64]*datapb.SegmentInfo { + ret := _m.Called(collectionID, scope) + + var r0 map[int64]*datapb.SegmentInfo + if rf, ok := ret.Get(0).(func(int64, int32) map[int64]*datapb.SegmentInfo); ok { + r0 = rf(collectionID, scope) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[int64]*datapb.SegmentInfo) + } + } + + return r0 +} + +// MockTargetManager_GetSealedSegmentsByCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSealedSegmentsByCollection' +type MockTargetManager_GetSealedSegmentsByCollection_Call struct { + *mock.Call +} + +// GetSealedSegmentsByCollection is a helper method to define mock.On call +// - collectionID int64 +// - scope int32 +func (_e *MockTargetManager_Expecter) GetSealedSegmentsByCollection(collectionID interface{}, scope interface{}) *MockTargetManager_GetSealedSegmentsByCollection_Call { + return &MockTargetManager_GetSealedSegmentsByCollection_Call{Call: _e.mock.On("GetSealedSegmentsByCollection", collectionID, scope)} +} + +func (_c *MockTargetManager_GetSealedSegmentsByCollection_Call) Run(run func(collectionID int64, scope int32)) *MockTargetManager_GetSealedSegmentsByCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64), args[1].(int32)) + }) + return _c +} + +func (_c *MockTargetManager_GetSealedSegmentsByCollection_Call) Return(_a0 map[int64]*datapb.SegmentInfo) *MockTargetManager_GetSealedSegmentsByCollection_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockTargetManager_GetSealedSegmentsByCollection_Call) RunAndReturn(run func(int64, int32) map[int64]*datapb.SegmentInfo) *MockTargetManager_GetSealedSegmentsByCollection_Call { + _c.Call.Return(run) + return _c +} + +// GetSealedSegmentsByPartition provides a mock function with given fields: collectionID, partitionID, scope +func (_m *MockTargetManager) GetSealedSegmentsByPartition(collectionID int64, partitionID int64, scope int32) map[int64]*datapb.SegmentInfo { + ret := _m.Called(collectionID, partitionID, scope) + + var r0 map[int64]*datapb.SegmentInfo + if rf, ok := ret.Get(0).(func(int64, int64, int32) map[int64]*datapb.SegmentInfo); ok { + r0 = rf(collectionID, partitionID, scope) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[int64]*datapb.SegmentInfo) + } + } + + return r0 +} + +// MockTargetManager_GetSealedSegmentsByPartition_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSealedSegmentsByPartition' +type MockTargetManager_GetSealedSegmentsByPartition_Call struct { + *mock.Call +} + +// GetSealedSegmentsByPartition is a helper method to define mock.On call +// - collectionID int64 +// - partitionID int64 +// - scope int32 +func (_e *MockTargetManager_Expecter) GetSealedSegmentsByPartition(collectionID interface{}, partitionID interface{}, scope interface{}) *MockTargetManager_GetSealedSegmentsByPartition_Call { + return &MockTargetManager_GetSealedSegmentsByPartition_Call{Call: _e.mock.On("GetSealedSegmentsByPartition", collectionID, partitionID, scope)} +} + +func (_c *MockTargetManager_GetSealedSegmentsByPartition_Call) Run(run func(collectionID int64, partitionID int64, scope int32)) *MockTargetManager_GetSealedSegmentsByPartition_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64), args[1].(int64), args[2].(int32)) + }) + return _c +} + +func (_c *MockTargetManager_GetSealedSegmentsByPartition_Call) Return(_a0 map[int64]*datapb.SegmentInfo) *MockTargetManager_GetSealedSegmentsByPartition_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockTargetManager_GetSealedSegmentsByPartition_Call) RunAndReturn(run func(int64, int64, int32) map[int64]*datapb.SegmentInfo) *MockTargetManager_GetSealedSegmentsByPartition_Call { + _c.Call.Return(run) + return _c +} + +// IsCurrentTargetExist provides a mock function with given fields: collectionID +func (_m *MockTargetManager) IsCurrentTargetExist(collectionID int64) bool { + ret := _m.Called(collectionID) + + var r0 bool + if rf, ok := ret.Get(0).(func(int64) bool); ok { + r0 = rf(collectionID) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// MockTargetManager_IsCurrentTargetExist_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsCurrentTargetExist' +type MockTargetManager_IsCurrentTargetExist_Call struct { + *mock.Call +} + +// IsCurrentTargetExist is a helper method to define mock.On call +// - collectionID int64 +func (_e *MockTargetManager_Expecter) IsCurrentTargetExist(collectionID interface{}) *MockTargetManager_IsCurrentTargetExist_Call { + return &MockTargetManager_IsCurrentTargetExist_Call{Call: _e.mock.On("IsCurrentTargetExist", collectionID)} +} + +func (_c *MockTargetManager_IsCurrentTargetExist_Call) Run(run func(collectionID int64)) *MockTargetManager_IsCurrentTargetExist_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64)) + }) + return _c +} + +func (_c *MockTargetManager_IsCurrentTargetExist_Call) Return(_a0 bool) *MockTargetManager_IsCurrentTargetExist_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockTargetManager_IsCurrentTargetExist_Call) RunAndReturn(run func(int64) bool) *MockTargetManager_IsCurrentTargetExist_Call { + _c.Call.Return(run) + return _c +} + +// IsNextTargetExist provides a mock function with given fields: collectionID +func (_m *MockTargetManager) IsNextTargetExist(collectionID int64) bool { + ret := _m.Called(collectionID) + + var r0 bool + if rf, ok := ret.Get(0).(func(int64) bool); ok { + r0 = rf(collectionID) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// MockTargetManager_IsNextTargetExist_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsNextTargetExist' +type MockTargetManager_IsNextTargetExist_Call struct { + *mock.Call +} + +// IsNextTargetExist is a helper method to define mock.On call +// - collectionID int64 +func (_e *MockTargetManager_Expecter) IsNextTargetExist(collectionID interface{}) *MockTargetManager_IsNextTargetExist_Call { + return &MockTargetManager_IsNextTargetExist_Call{Call: _e.mock.On("IsNextTargetExist", collectionID)} +} + +func (_c *MockTargetManager_IsNextTargetExist_Call) Run(run func(collectionID int64)) *MockTargetManager_IsNextTargetExist_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64)) + }) + return _c +} + +func (_c *MockTargetManager_IsNextTargetExist_Call) Return(_a0 bool) *MockTargetManager_IsNextTargetExist_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockTargetManager_IsNextTargetExist_Call) RunAndReturn(run func(int64) bool) *MockTargetManager_IsNextTargetExist_Call { + _c.Call.Return(run) + return _c +} + +// PullNextTargetV1 provides a mock function with given fields: broker, collectionID, chosenPartitionIDs +func (_m *MockTargetManager) PullNextTargetV1(broker Broker, collectionID int64, chosenPartitionIDs ...int64) (map[int64]*datapb.SegmentInfo, map[string]*DmChannel, error) { + _va := make([]interface{}, len(chosenPartitionIDs)) + for _i := range chosenPartitionIDs { + _va[_i] = chosenPartitionIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, broker, collectionID) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 map[int64]*datapb.SegmentInfo + var r1 map[string]*DmChannel + var r2 error + if rf, ok := ret.Get(0).(func(Broker, int64, ...int64) (map[int64]*datapb.SegmentInfo, map[string]*DmChannel, error)); ok { + return rf(broker, collectionID, chosenPartitionIDs...) + } + if rf, ok := ret.Get(0).(func(Broker, int64, ...int64) map[int64]*datapb.SegmentInfo); ok { + r0 = rf(broker, collectionID, chosenPartitionIDs...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[int64]*datapb.SegmentInfo) + } + } + + if rf, ok := ret.Get(1).(func(Broker, int64, ...int64) map[string]*DmChannel); ok { + r1 = rf(broker, collectionID, chosenPartitionIDs...) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(map[string]*DmChannel) + } + } + + if rf, ok := ret.Get(2).(func(Broker, int64, ...int64) error); ok { + r2 = rf(broker, collectionID, chosenPartitionIDs...) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockTargetManager_PullNextTargetV1_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PullNextTargetV1' +type MockTargetManager_PullNextTargetV1_Call struct { + *mock.Call +} + +// PullNextTargetV1 is a helper method to define mock.On call +// - broker Broker +// - collectionID int64 +// - chosenPartitionIDs ...int64 +func (_e *MockTargetManager_Expecter) PullNextTargetV1(broker interface{}, collectionID interface{}, chosenPartitionIDs ...interface{}) *MockTargetManager_PullNextTargetV1_Call { + return &MockTargetManager_PullNextTargetV1_Call{Call: _e.mock.On("PullNextTargetV1", + append([]interface{}{broker, collectionID}, chosenPartitionIDs...)...)} +} + +func (_c *MockTargetManager_PullNextTargetV1_Call) Run(run func(broker Broker, collectionID int64, chosenPartitionIDs ...int64)) *MockTargetManager_PullNextTargetV1_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]int64, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(int64) + } + } + run(args[0].(Broker), args[1].(int64), variadicArgs...) + }) + return _c +} + +func (_c *MockTargetManager_PullNextTargetV1_Call) Return(_a0 map[int64]*datapb.SegmentInfo, _a1 map[string]*DmChannel, _a2 error) *MockTargetManager_PullNextTargetV1_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockTargetManager_PullNextTargetV1_Call) RunAndReturn(run func(Broker, int64, ...int64) (map[int64]*datapb.SegmentInfo, map[string]*DmChannel, error)) *MockTargetManager_PullNextTargetV1_Call { + _c.Call.Return(run) + return _c +} + +// PullNextTargetV2 provides a mock function with given fields: broker, collectionID, chosenPartitionIDs +func (_m *MockTargetManager) PullNextTargetV2(broker Broker, collectionID int64, chosenPartitionIDs ...int64) (map[int64]*datapb.SegmentInfo, map[string]*DmChannel, error) { + _va := make([]interface{}, len(chosenPartitionIDs)) + for _i := range chosenPartitionIDs { + _va[_i] = chosenPartitionIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, broker, collectionID) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 map[int64]*datapb.SegmentInfo + var r1 map[string]*DmChannel + var r2 error + if rf, ok := ret.Get(0).(func(Broker, int64, ...int64) (map[int64]*datapb.SegmentInfo, map[string]*DmChannel, error)); ok { + return rf(broker, collectionID, chosenPartitionIDs...) + } + if rf, ok := ret.Get(0).(func(Broker, int64, ...int64) map[int64]*datapb.SegmentInfo); ok { + r0 = rf(broker, collectionID, chosenPartitionIDs...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[int64]*datapb.SegmentInfo) + } + } + + if rf, ok := ret.Get(1).(func(Broker, int64, ...int64) map[string]*DmChannel); ok { + r1 = rf(broker, collectionID, chosenPartitionIDs...) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(map[string]*DmChannel) + } + } + + if rf, ok := ret.Get(2).(func(Broker, int64, ...int64) error); ok { + r2 = rf(broker, collectionID, chosenPartitionIDs...) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockTargetManager_PullNextTargetV2_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PullNextTargetV2' +type MockTargetManager_PullNextTargetV2_Call struct { + *mock.Call +} + +// PullNextTargetV2 is a helper method to define mock.On call +// - broker Broker +// - collectionID int64 +// - chosenPartitionIDs ...int64 +func (_e *MockTargetManager_Expecter) PullNextTargetV2(broker interface{}, collectionID interface{}, chosenPartitionIDs ...interface{}) *MockTargetManager_PullNextTargetV2_Call { + return &MockTargetManager_PullNextTargetV2_Call{Call: _e.mock.On("PullNextTargetV2", + append([]interface{}{broker, collectionID}, chosenPartitionIDs...)...)} +} + +func (_c *MockTargetManager_PullNextTargetV2_Call) Run(run func(broker Broker, collectionID int64, chosenPartitionIDs ...int64)) *MockTargetManager_PullNextTargetV2_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]int64, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(int64) + } + } + run(args[0].(Broker), args[1].(int64), variadicArgs...) + }) + return _c +} + +func (_c *MockTargetManager_PullNextTargetV2_Call) Return(_a0 map[int64]*datapb.SegmentInfo, _a1 map[string]*DmChannel, _a2 error) *MockTargetManager_PullNextTargetV2_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockTargetManager_PullNextTargetV2_Call) RunAndReturn(run func(Broker, int64, ...int64) (map[int64]*datapb.SegmentInfo, map[string]*DmChannel, error)) *MockTargetManager_PullNextTargetV2_Call { + _c.Call.Return(run) + return _c +} + +// Recover provides a mock function with given fields: catalog +func (_m *MockTargetManager) Recover(catalog metastore.QueryCoordCatalog) error { + ret := _m.Called(catalog) + + var r0 error + if rf, ok := ret.Get(0).(func(metastore.QueryCoordCatalog) error); ok { + r0 = rf(catalog) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockTargetManager_Recover_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Recover' +type MockTargetManager_Recover_Call struct { + *mock.Call +} + +// Recover is a helper method to define mock.On call +// - catalog metastore.QueryCoordCatalog +func (_e *MockTargetManager_Expecter) Recover(catalog interface{}) *MockTargetManager_Recover_Call { + return &MockTargetManager_Recover_Call{Call: _e.mock.On("Recover", catalog)} +} + +func (_c *MockTargetManager_Recover_Call) Run(run func(catalog metastore.QueryCoordCatalog)) *MockTargetManager_Recover_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metastore.QueryCoordCatalog)) + }) + return _c +} + +func (_c *MockTargetManager_Recover_Call) Return(_a0 error) *MockTargetManager_Recover_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockTargetManager_Recover_Call) RunAndReturn(run func(metastore.QueryCoordCatalog) error) *MockTargetManager_Recover_Call { + _c.Call.Return(run) + return _c +} + +// RemoveCollection provides a mock function with given fields: collectionID +func (_m *MockTargetManager) RemoveCollection(collectionID int64) { + _m.Called(collectionID) +} + +// MockTargetManager_RemoveCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveCollection' +type MockTargetManager_RemoveCollection_Call struct { + *mock.Call +} + +// RemoveCollection is a helper method to define mock.On call +// - collectionID int64 +func (_e *MockTargetManager_Expecter) RemoveCollection(collectionID interface{}) *MockTargetManager_RemoveCollection_Call { + return &MockTargetManager_RemoveCollection_Call{Call: _e.mock.On("RemoveCollection", collectionID)} +} + +func (_c *MockTargetManager_RemoveCollection_Call) Run(run func(collectionID int64)) *MockTargetManager_RemoveCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64)) + }) + return _c +} + +func (_c *MockTargetManager_RemoveCollection_Call) Return() *MockTargetManager_RemoveCollection_Call { + _c.Call.Return() + return _c +} + +func (_c *MockTargetManager_RemoveCollection_Call) RunAndReturn(run func(int64)) *MockTargetManager_RemoveCollection_Call { + _c.Call.Return(run) + return _c +} + +// RemovePartition provides a mock function with given fields: collectionID, partitionIDs +func (_m *MockTargetManager) RemovePartition(collectionID int64, partitionIDs ...int64) { + _va := make([]interface{}, len(partitionIDs)) + for _i := range partitionIDs { + _va[_i] = partitionIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, collectionID) + _ca = append(_ca, _va...) + _m.Called(_ca...) +} + +// MockTargetManager_RemovePartition_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemovePartition' +type MockTargetManager_RemovePartition_Call struct { + *mock.Call +} + +// RemovePartition is a helper method to define mock.On call +// - collectionID int64 +// - partitionIDs ...int64 +func (_e *MockTargetManager_Expecter) RemovePartition(collectionID interface{}, partitionIDs ...interface{}) *MockTargetManager_RemovePartition_Call { + return &MockTargetManager_RemovePartition_Call{Call: _e.mock.On("RemovePartition", + append([]interface{}{collectionID}, partitionIDs...)...)} +} + +func (_c *MockTargetManager_RemovePartition_Call) Run(run func(collectionID int64, partitionIDs ...int64)) *MockTargetManager_RemovePartition_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]int64, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(int64) + } + } + run(args[0].(int64), variadicArgs...) + }) + return _c +} + +func (_c *MockTargetManager_RemovePartition_Call) Return() *MockTargetManager_RemovePartition_Call { + _c.Call.Return() + return _c +} + +func (_c *MockTargetManager_RemovePartition_Call) RunAndReturn(run func(int64, ...int64)) *MockTargetManager_RemovePartition_Call { + _c.Call.Return(run) + return _c +} + +// SaveCurrentTarget provides a mock function with given fields: catalog +func (_m *MockTargetManager) SaveCurrentTarget(catalog metastore.QueryCoordCatalog) { + _m.Called(catalog) +} + +// MockTargetManager_SaveCurrentTarget_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveCurrentTarget' +type MockTargetManager_SaveCurrentTarget_Call struct { + *mock.Call +} + +// SaveCurrentTarget is a helper method to define mock.On call +// - catalog metastore.QueryCoordCatalog +func (_e *MockTargetManager_Expecter) SaveCurrentTarget(catalog interface{}) *MockTargetManager_SaveCurrentTarget_Call { + return &MockTargetManager_SaveCurrentTarget_Call{Call: _e.mock.On("SaveCurrentTarget", catalog)} +} + +func (_c *MockTargetManager_SaveCurrentTarget_Call) Run(run func(catalog metastore.QueryCoordCatalog)) *MockTargetManager_SaveCurrentTarget_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metastore.QueryCoordCatalog)) + }) + return _c +} + +func (_c *MockTargetManager_SaveCurrentTarget_Call) Return() *MockTargetManager_SaveCurrentTarget_Call { + _c.Call.Return() + return _c +} + +func (_c *MockTargetManager_SaveCurrentTarget_Call) RunAndReturn(run func(metastore.QueryCoordCatalog)) *MockTargetManager_SaveCurrentTarget_Call { + _c.Call.Return(run) + return _c +} + +// UpdateCollectionCurrentTarget provides a mock function with given fields: collectionID +func (_m *MockTargetManager) UpdateCollectionCurrentTarget(collectionID int64) bool { + ret := _m.Called(collectionID) + + var r0 bool + if rf, ok := ret.Get(0).(func(int64) bool); ok { + r0 = rf(collectionID) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// MockTargetManager_UpdateCollectionCurrentTarget_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCollectionCurrentTarget' +type MockTargetManager_UpdateCollectionCurrentTarget_Call struct { + *mock.Call +} + +// UpdateCollectionCurrentTarget is a helper method to define mock.On call +// - collectionID int64 +func (_e *MockTargetManager_Expecter) UpdateCollectionCurrentTarget(collectionID interface{}) *MockTargetManager_UpdateCollectionCurrentTarget_Call { + return &MockTargetManager_UpdateCollectionCurrentTarget_Call{Call: _e.mock.On("UpdateCollectionCurrentTarget", collectionID)} +} + +func (_c *MockTargetManager_UpdateCollectionCurrentTarget_Call) Run(run func(collectionID int64)) *MockTargetManager_UpdateCollectionCurrentTarget_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64)) + }) + return _c +} + +func (_c *MockTargetManager_UpdateCollectionCurrentTarget_Call) Return(_a0 bool) *MockTargetManager_UpdateCollectionCurrentTarget_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockTargetManager_UpdateCollectionCurrentTarget_Call) RunAndReturn(run func(int64) bool) *MockTargetManager_UpdateCollectionCurrentTarget_Call { + _c.Call.Return(run) + return _c +} + +// UpdateCollectionNextTarget provides a mock function with given fields: collectionID +func (_m *MockTargetManager) UpdateCollectionNextTarget(collectionID int64) error { + ret := _m.Called(collectionID) + + var r0 error + if rf, ok := ret.Get(0).(func(int64) error); ok { + r0 = rf(collectionID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockTargetManager_UpdateCollectionNextTarget_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCollectionNextTarget' +type MockTargetManager_UpdateCollectionNextTarget_Call struct { + *mock.Call +} + +// UpdateCollectionNextTarget is a helper method to define mock.On call +// - collectionID int64 +func (_e *MockTargetManager_Expecter) UpdateCollectionNextTarget(collectionID interface{}) *MockTargetManager_UpdateCollectionNextTarget_Call { + return &MockTargetManager_UpdateCollectionNextTarget_Call{Call: _e.mock.On("UpdateCollectionNextTarget", collectionID)} +} + +func (_c *MockTargetManager_UpdateCollectionNextTarget_Call) Run(run func(collectionID int64)) *MockTargetManager_UpdateCollectionNextTarget_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64)) + }) + return _c +} + +func (_c *MockTargetManager_UpdateCollectionNextTarget_Call) Return(_a0 error) *MockTargetManager_UpdateCollectionNextTarget_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockTargetManager_UpdateCollectionNextTarget_Call) RunAndReturn(run func(int64) error) *MockTargetManager_UpdateCollectionNextTarget_Call { + _c.Call.Return(run) + return _c +} + +// NewMockTargetManager creates a new instance of MockTargetManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockTargetManager(t interface { + mock.TestingT + Cleanup(func()) +}) *MockTargetManager { + mock := &MockTargetManager{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/querycoordv2/meta/target_manager.go b/internal/querycoordv2/meta/target_manager.go index 6808e6ca9a126..376fd37569ec5 100644 --- a/internal/querycoordv2/meta/target_manager.go +++ b/internal/querycoordv2/meta/target_manager.go @@ -50,6 +50,29 @@ const ( NextTargetFirst ) +type TargetManagerInterface interface { + UpdateCollectionCurrentTarget(collectionID int64) bool + UpdateCollectionNextTarget(collectionID int64) error + PullNextTargetV1(broker Broker, collectionID int64, chosenPartitionIDs ...int64) (map[int64]*datapb.SegmentInfo, map[string]*DmChannel, error) + PullNextTargetV2(broker Broker, collectionID int64, chosenPartitionIDs ...int64) (map[int64]*datapb.SegmentInfo, map[string]*DmChannel, error) + RemoveCollection(collectionID int64) + RemovePartition(collectionID int64, partitionIDs ...int64) + GetGrowingSegmentsByCollection(collectionID int64, scope TargetScope) typeutil.UniqueSet + GetGrowingSegmentsByChannel(collectionID int64, channelName string, scope TargetScope) typeutil.UniqueSet + GetSealedSegmentsByCollection(collectionID int64, scope TargetScope) map[int64]*datapb.SegmentInfo + GetSealedSegmentsByChannel(collectionID int64, channelName string, scope TargetScope) map[int64]*datapb.SegmentInfo + GetDroppedSegmentsByChannel(collectionID int64, channelName string, scope TargetScope) []int64 + GetSealedSegmentsByPartition(collectionID int64, partitionID int64, scope TargetScope) map[int64]*datapb.SegmentInfo + GetDmChannelsByCollection(collectionID int64, scope TargetScope) map[string]*DmChannel + GetDmChannel(collectionID int64, channel string, scope TargetScope) *DmChannel + GetSealedSegment(collectionID int64, id int64, scope TargetScope) *datapb.SegmentInfo + GetCollectionTargetVersion(collectionID int64, scope TargetScope) int64 + IsCurrentTargetExist(collectionID int64) bool + IsNextTargetExist(collectionID int64) bool + SaveCurrentTarget(catalog metastore.QueryCoordCatalog) + Recover(catalog metastore.QueryCoordCatalog) error +} + type TargetManager struct { rwMutex sync.RWMutex broker Broker diff --git a/internal/querynodev2/server.go b/internal/querynodev2/server.go index 57f059f53328c..a74c2c14628a5 100644 --- a/internal/querynodev2/server.go +++ b/internal/querynodev2/server.go @@ -130,6 +130,10 @@ type QueryNode struct { // parameter turning hook queryHook optimizers.QueryHook + + // record the last modify ts of segment/channel distribution + lastModifyLock sync.RWMutex + lastModifyTs int64 } // NewQueryNode will return a QueryNode with abnormal state. diff --git a/internal/querynodev2/services.go b/internal/querynodev2/services.go index 54bf390ba8b53..f7807886ce128 100644 --- a/internal/querynodev2/services.go +++ b/internal/querynodev2/services.go @@ -21,6 +21,7 @@ import ( "fmt" "strconv" "sync" + "time" "github.com/golang/protobuf/proto" "github.com/samber/lo" @@ -190,6 +191,8 @@ func (node *QueryNode) composeIndexMeta(indexInfos []*indexpb.IndexInfo, schema // WatchDmChannels create consumers on dmChannels to receive Incremental data,which is the important part of real-time query func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDmChannelsRequest) (status *commonpb.Status, e error) { + defer node.updateDistributionModifyTS() + channel := req.GetInfos()[0] log := log.Ctx(ctx).With( zap.Int64("collectionID", req.GetCollectionID()), @@ -339,6 +342,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm } func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) { + defer node.updateDistributionModifyTS() log := log.Ctx(ctx).With( zap.Int64("collectionID", req.GetCollectionID()), zap.String("channel", req.GetChannelName()), @@ -396,6 +400,7 @@ func (node *QueryNode) LoadPartitions(ctx context.Context, req *querypb.LoadPart // LoadSegments load historical data into query node, historical data can be vector data or index func (node *QueryNode) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) { + defer node.updateDistributionModifyTS() segment := req.GetInfos()[0] log := log.Ctx(ctx).With( @@ -528,6 +533,7 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, req *querypb.Relea // ReleaseSegments remove the specified segments from query node according segmentIDs, partitionIDs, and collectionID func (node *QueryNode) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) { + defer node.updateDistributionModifyTS() log := log.Ctx(ctx).With( zap.Int64("collectionID", req.GetCollectionID()), zap.String("shard", req.GetShard()), @@ -1175,6 +1181,23 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get } defer node.lifetime.Done() + lastModifyTs := node.getDistributionModifyTS() + distributionChange := func() bool { + if req.GetLastUpdateTs() == 0 { + return true + } + + return req.GetLastUpdateTs() < lastModifyTs + } + + if !distributionChange() { + return &querypb.GetDataDistributionResponse{ + Status: merr.Success(), + NodeID: node.GetNodeID(), + LastModifyTs: lastModifyTs, + }, nil + } + sealedSegments := node.manager.Segment.GetBy(segments.WithType(commonpb.SegmentState_Sealed)) segmentVersionInfos := make([]*querypb.SegmentVersionInfo, 0, len(sealedSegments)) for _, s := range sealedSegments { @@ -1240,15 +1263,18 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get }) return &querypb.GetDataDistributionResponse{ - Status: merr.Success(), - NodeID: node.GetNodeID(), - Segments: segmentVersionInfos, - Channels: channelVersionInfos, - LeaderViews: leaderViews, + Status: merr.Success(), + NodeID: node.GetNodeID(), + Segments: segmentVersionInfos, + Channels: channelVersionInfos, + LeaderViews: leaderViews, + LastModifyTs: lastModifyTs, }, nil } func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDistributionRequest) (*commonpb.Status, error) { + defer node.updateDistributionModifyTS() + log := log.Ctx(ctx).With(zap.Int64("collectionID", req.GetCollectionID()), zap.String("channel", req.GetChannel()), zap.Int64("currentNodeID", node.GetNodeID())) // check node healthy @@ -1403,3 +1429,16 @@ func (req *deleteRequestStringer) String() string { tss := req.GetTimestamps() return fmt.Sprintf("%s, timestamp range: [%d-%d]", pkInfo, tss[0], tss[len(tss)-1]) } + +func (node *QueryNode) updateDistributionModifyTS() { + node.lastModifyLock.Lock() + defer node.lastModifyLock.Unlock() + + node.lastModifyTs = time.Now().UnixNano() +} + +func (node *QueryNode) getDistributionModifyTS() int64 { + node.lastModifyLock.RLock() + defer node.lastModifyLock.RUnlock() + return node.lastModifyTs +} From 7d1d5a838a2f497981dd26dbf29a545a39276c5d Mon Sep 17 00:00:00 2001 From: wei liu Date: Fri, 21 Jun 2024 10:26:02 +0800 Subject: [PATCH 07/21] fix: Fix GetReplicas API return nil status (#33715) (#34019) issue: #33702 pr: #33715 Signed-off-by: Wei Liu --- internal/querycoordv2/services.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/internal/querycoordv2/services.go b/internal/querycoordv2/services.go index dea9817a2777f..b1fb72dd435e1 100644 --- a/internal/querycoordv2/services.go +++ b/internal/querycoordv2/services.go @@ -840,9 +840,7 @@ func (s *Server) GetReplicas(ctx context.Context, req *milvuspb.GetReplicasReque replicas := s.meta.ReplicaManager.GetByCollection(req.GetCollectionID()) if len(replicas) == 0 { - return &milvuspb.GetReplicasResponse{ - Replicas: make([]*milvuspb.ReplicaInfo, 0), - }, nil + return resp, nil } for _, replica := range replicas { From c219dca00109be86f603c51933dbffbc68715f2f Mon Sep 17 00:00:00 2001 From: jaime Date: Fri, 21 Jun 2024 11:46:02 +0800 Subject: [PATCH 08/21] fix: metrics database_num is 0 after restarting rootcoord (#34010) issue: https://github.com/milvus-io/milvus/issues/34041 Signed-off-by: jaime --- internal/rootcoord/meta_table.go | 13 ++++++++++--- internal/rootcoord/root_coord.go | 2 -- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/internal/rootcoord/meta_table.go b/internal/rootcoord/meta_table.go index 3054276271827..49bca49368780 100644 --- a/internal/rootcoord/meta_table.go +++ b/internal/rootcoord/meta_table.go @@ -139,6 +139,7 @@ func (mt *MetaTable) reload() error { metrics.RootCoordNumOfCollections.Reset() metrics.RootCoordNumOfPartitions.Reset() + metrics.RootCoordNumOfDatabases.Set(0) // recover databases. dbs, err := mt.catalog.ListDatabases(mt.ctx, typeutil.MaxTimestamp) @@ -184,6 +185,7 @@ func (mt *MetaTable) reload() error { } } + metrics.RootCoordNumOfDatabases.Inc() metrics.RootCoordNumOfCollections.WithLabelValues(dbName).Add(float64(collectionNum)) log.Info("collections recovered from db", zap.String("db_name", dbName), zap.Int64("collection_num", collectionNum), @@ -255,7 +257,11 @@ func (mt *MetaTable) CreateDatabase(ctx context.Context, db *model.Database, ts mt.ddLock.Lock() defer mt.ddLock.Unlock() - return mt.createDatabasePrivate(ctx, db, ts) + if err := mt.createDatabasePrivate(ctx, db, ts); err != nil { + return err + } + metrics.RootCoordNumOfDatabases.Inc() + return nil } func (mt *MetaTable) createDatabasePrivate(ctx context.Context, db *model.Database, ts typeutil.Timestamp) error { @@ -271,8 +277,8 @@ func (mt *MetaTable) createDatabasePrivate(ctx context.Context, db *model.Databa mt.names.createDbIfNotExist(dbName) mt.aliases.createDbIfNotExist(dbName) mt.dbName2Meta[dbName] = db - log.Ctx(ctx).Info("create database", zap.String("db", dbName), zap.Uint64("ts", ts)) + log.Ctx(ctx).Info("create database", zap.String("db", dbName), zap.Uint64("ts", ts)) return nil } @@ -322,8 +328,9 @@ func (mt *MetaTable) DropDatabase(ctx context.Context, dbName string, ts typeuti mt.names.dropDb(dbName) mt.aliases.dropDb(dbName) delete(mt.dbName2Meta, dbName) - log.Ctx(ctx).Info("drop database", zap.String("db", dbName), zap.Uint64("ts", ts)) + metrics.RootCoordNumOfDatabases.Dec() + log.Ctx(ctx).Info("drop database", zap.String("db", dbName), zap.Uint64("ts", ts)) return nil } diff --git a/internal/rootcoord/root_coord.go b/internal/rootcoord/root_coord.go index 2f40efa440ddf..62f6cef9fd17f 100644 --- a/internal/rootcoord/root_coord.go +++ b/internal/rootcoord/root_coord.go @@ -867,7 +867,6 @@ func (c *Core) CreateDatabase(ctx context.Context, in *milvuspb.CreateDatabaseRe metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.SuccessLabel).Inc() metrics.RootCoordDDLReqLatency.WithLabelValues(method).Observe(float64(tr.ElapseSpan().Milliseconds())) - metrics.RootCoordNumOfDatabases.Inc() log.Ctx(ctx).Info("done to create database", zap.String("role", typeutil.RootCoordRole), zap.String("dbName", in.GetDbName()), zap.Int64("msgID", in.GetBase().GetMsgID()), zap.Uint64("ts", t.GetTs())) @@ -912,7 +911,6 @@ func (c *Core) DropDatabase(ctx context.Context, in *milvuspb.DropDatabaseReques metrics.RootCoordDDLReqCounter.WithLabelValues(method, metrics.SuccessLabel).Inc() metrics.RootCoordDDLReqLatency.WithLabelValues(method).Observe(float64(tr.ElapseSpan().Milliseconds())) - metrics.RootCoordNumOfDatabases.Dec() metrics.CleanupRootCoordDBMetrics(in.GetDbName()) log.Ctx(ctx).Info("done to drop database", zap.String("role", typeutil.RootCoordRole), zap.String("dbName", in.GetDbName()), zap.Int64("msgID", in.GetBase().GetMsgID()), From 891a94ad9e35b69eccd254e4d8083ac08fef34c6 Mon Sep 17 00:00:00 2001 From: congqixia Date: Fri, 21 Jun 2024 12:04:00 +0800 Subject: [PATCH 09/21] fix: [2.4] Check nodeID wildcard when removing pkOracle (#33895) (#34020) Cherry-pick from master pr: #33895 See also #33894 Signed-off-by: Congqi Xia --- internal/querynodev2/pkoracle/candidate.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/querynodev2/pkoracle/candidate.go b/internal/querynodev2/pkoracle/candidate.go index bb2479702b7de..9f8a8b7daf601 100644 --- a/internal/querynodev2/pkoracle/candidate.go +++ b/internal/querynodev2/pkoracle/candidate.go @@ -52,7 +52,8 @@ func WithSegmentType(typ commonpb.SegmentState) CandidateFilter { // WithWorkerID returns CandidateFilter with provided worker id. func WithWorkerID(workerID int64) CandidateFilter { return func(candidate candidateWithWorker) bool { - return candidate.workerID == workerID + return candidate.workerID == workerID || + workerID == -1 // wildcard for offline node } } From e02a95e3c2b3ab8e31f700cde925520d7a8db7cb Mon Sep 17 00:00:00 2001 From: congqixia Date: Fri, 21 Jun 2024 14:14:01 +0800 Subject: [PATCH 10/21] fix: [2.4] Return record with largest timestamp for entires with same PK (#33936) (#34024) Cherry-pick from master pr: #33936 See also #33883 --------- Signed-off-by: Congqi Xia --- internal/core/src/segcore/InsertRecord.h | 7 +- internal/querynodev2/segments/reducer.go | 48 ++++++ internal/querynodev2/segments/result.go | 48 +++--- internal/querynodev2/segments/result_test.go | 150 ++++++++++++++----- pkg/util/typeutil/schema.go | 52 +++++++ 5 files changed, 245 insertions(+), 60 deletions(-) diff --git a/internal/core/src/segcore/InsertRecord.h b/internal/core/src/segcore/InsertRecord.h index 7d62e303eeda8..5aa6247d9c7c0 100644 --- a/internal/core/src/segcore/InsertRecord.h +++ b/internal/core/src/segcore/InsertRecord.h @@ -115,10 +115,6 @@ class OffsetOrderedMap : public OffsetMap { bool false_filtered_out) const override { std::shared_lock lck(mtx_); - if (limit == Unlimited || limit == NoLimit) { - limit = map_.size(); - } - // TODO: we can't retrieve pk by offset very conveniently. // Selectivity should be done outside. return find_first_by_index(limit, bitset, false_filtered_out); @@ -141,6 +137,9 @@ class OffsetOrderedMap : public OffsetMap { if (!false_filtered_out) { cnt = size - bitset.count(); } + if (limit == Unlimited || limit == NoLimit) { + limit = cnt; + } limit = std::min(limit, cnt); std::vector seg_offsets; seg_offsets.reserve(limit); diff --git a/internal/querynodev2/segments/reducer.go b/internal/querynodev2/segments/reducer.go index 67335ede1d8d0..d5ef51a7df7b3 100644 --- a/internal/querynodev2/segments/reducer.go +++ b/internal/querynodev2/segments/reducer.go @@ -3,10 +3,15 @@ package segments import ( "context" + "github.com/samber/lo" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/segcorepb" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type internalReducer interface { @@ -30,3 +35,46 @@ func CreateSegCoreReducer(req *querypb.QueryRequest, schema *schemapb.Collection } return newDefaultLimitReducerSegcore(req, schema, manager) } + +type TimestampedRetrieveResult[T interface { + typeutil.ResultWithID + GetFieldsData() []*schemapb.FieldData +}] struct { + Result T + Timestamps []int64 +} + +func (r *TimestampedRetrieveResult[T]) GetIds() *schemapb.IDs { + return r.Result.GetIds() +} + +func (r *TimestampedRetrieveResult[T]) GetHasMoreResult() bool { + return r.Result.GetHasMoreResult() +} + +func (r *TimestampedRetrieveResult[T]) GetTimestamps() []int64 { + return r.Timestamps +} + +func NewTimestampedRetrieveResult[T interface { + typeutil.ResultWithID + GetFieldsData() []*schemapb.FieldData +}](result T) (*TimestampedRetrieveResult[T], error) { + tsField, has := lo.Find(result.GetFieldsData(), func(fd *schemapb.FieldData) bool { + return fd.GetFieldId() == common.TimeStampField + }) + if !has { + return nil, merr.WrapErrServiceInternal("RetrieveResult does not have timestamp field") + } + timestamps := tsField.GetScalars().GetLongData().GetData() + idSize := typeutil.GetSizeOfIDs(result.GetIds()) + + if idSize != len(timestamps) { + return nil, merr.WrapErrServiceInternal("id length is not equal to timestamp length") + } + + return &TimestampedRetrieveResult[T]{ + Result: result, + Timestamps: timestamps, + }, nil +} diff --git a/internal/querynodev2/segments/result.go b/internal/querynodev2/segments/result.go index 0ac61d81c9bb7..46b23337f8ac0 100644 --- a/internal/querynodev2/segments/result.go +++ b/internal/querynodev2/segments/result.go @@ -399,7 +399,7 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna _, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "MergeInternalRetrieveResult") defer sp.End() - validRetrieveResults := []*internalpb.RetrieveResults{} + validRetrieveResults := []*TimestampedRetrieveResult[*internalpb.RetrieveResults]{} relatedDataSize := int64(0) hasMoreResult := false for _, r := range retrieveResults { @@ -409,7 +409,11 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna if r == nil || len(r.GetFieldsData()) == 0 || size == 0 { continue } - validRetrieveResults = append(validRetrieveResults, r) + tr, err := NewTimestampedRetrieveResult(r) + if err != nil { + return nil, err + } + validRetrieveResults = append(validRetrieveResults, tr) loopEnd += size hasMoreResult = hasMoreResult || r.GetHasMoreResult() } @@ -423,23 +427,23 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna loopEnd = int(param.limit) } - ret.FieldsData = make([]*schemapb.FieldData, len(validRetrieveResults[0].GetFieldsData())) - idTsMap := make(map[interface{}]uint64) + ret.FieldsData = make([]*schemapb.FieldData, len(validRetrieveResults[0].Result.GetFieldsData())) + idTsMap := make(map[interface{}]int64) cursors := make([]int64, len(validRetrieveResults)) var retSize int64 maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() for j := 0; j < loopEnd; { - sel, drainOneResult := typeutil.SelectMinPK(validRetrieveResults, cursors) + sel, drainOneResult := typeutil.SelectMinPKWithTimestamp(validRetrieveResults, cursors) if sel == -1 || (param.mergeStopForBest && drainOneResult) { break } pk := typeutil.GetPK(validRetrieveResults[sel].GetIds(), cursors[sel]) - ts := getTS(validRetrieveResults[sel], cursors[sel]) + ts := validRetrieveResults[sel].Timestamps[cursors[sel]] if _, ok := idTsMap[pk]; !ok { typeutil.AppendPKs(ret.Ids, pk) - retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel]) + retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].Result.GetFieldsData(), cursors[sel]) idTsMap[pk] = ts j++ } else { @@ -448,7 +452,7 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna if ts != 0 && ts > idTsMap[pk] { idTsMap[pk] = ts typeutil.DeleteFieldData(ret.FieldsData) - retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel]) + retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].Result.GetFieldsData(), cursors[sel]) } } @@ -514,7 +518,7 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore loopEnd int ) - validRetrieveResults := []*segcorepb.RetrieveResults{} + validRetrieveResults := []*TimestampedRetrieveResult[*segcorepb.RetrieveResults]{} validSegments := make([]Segment, 0, len(segments)) selectedOffsets := make([][]int64, 0, len(retrieveResults)) selectedIndexes := make([][]int64, 0, len(retrieveResults)) @@ -526,7 +530,11 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore log.Debug("filter out invalid retrieve result") continue } - validRetrieveResults = append(validRetrieveResults, r) + tr, err := NewTimestampedRetrieveResult(r) + if err != nil { + return nil, err + } + validRetrieveResults = append(validRetrieveResults, tr) if plan.ignoreNonPk { validSegments = append(validSegments, segments[i]) } @@ -548,29 +556,35 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore limit = int(param.limit) } - idSet := make(map[interface{}]struct{}) cursors := make([]int64, len(validRetrieveResults)) + idTsMap := make(map[any]int64) var availableCount int var retSize int64 maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() for j := 0; j < loopEnd && (limit == -1 || availableCount < limit); j++ { - sel, drainOneResult := typeutil.SelectMinPK(validRetrieveResults, cursors) + sel, drainOneResult := typeutil.SelectMinPKWithTimestamp(validRetrieveResults, cursors) if sel == -1 || (param.mergeStopForBest && drainOneResult) { break } pk := typeutil.GetPK(validRetrieveResults[sel].GetIds(), cursors[sel]) - if _, ok := idSet[pk]; !ok { + ts := validRetrieveResults[sel].Timestamps[cursors[sel]] + if _, ok := idTsMap[pk]; !ok { typeutil.AppendPKs(ret.Ids, pk) selected = append(selected, sel) - selectedOffsets[sel] = append(selectedOffsets[sel], validRetrieveResults[sel].GetOffset()[cursors[sel]]) + selectedOffsets[sel] = append(selectedOffsets[sel], validRetrieveResults[sel].Result.GetOffset()[cursors[sel]]) selectedIndexes[sel] = append(selectedIndexes[sel], cursors[sel]) - idSet[pk] = struct{}{} + idTsMap[pk] = ts availableCount++ } else { // primary keys duplicate skipDupCnt++ + if ts != 0 && ts > idTsMap[pk] { + idTsMap[pk] = ts + selectedOffsets[sel][len(selectedOffsets[sel])-1] = validRetrieveResults[sel].Result.GetOffset()[cursors[sel]] + selectedIndexes[sel][len(selectedIndexes[sel])-1] = cursors[sel] + } } cursors[sel]++ @@ -585,11 +599,11 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore // judge the `!plan.ignoreNonPk` condition. _, span2 := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "MergeSegcoreResults-AppendFieldData") defer span2.End() - ret.FieldsData = make([]*schemapb.FieldData, len(validRetrieveResults[0].GetFieldsData())) + ret.FieldsData = make([]*schemapb.FieldData, len(validRetrieveResults[0].Result.GetFieldsData())) cursors = make([]int64, len(validRetrieveResults)) for _, sel := range selected { // cannot use `cursors[sel]` directly, since some of them may be skipped. - retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), selectedIndexes[sel][cursors[sel]]) + retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].Result.GetFieldsData(), selectedIndexes[sel][cursors[sel]]) // limit retrieve result to avoid oom if retSize > maxOutputSize { diff --git a/internal/querynodev2/segments/result_test.go b/internal/querynodev2/segments/result_test.go index 6fcaf4196584a..794321ce126d6 100644 --- a/internal/querynodev2/segments/result_test.go +++ b/internal/querynodev2/segments/result_test.go @@ -22,6 +22,7 @@ import ( "sort" "testing" + "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" @@ -33,6 +34,15 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) +func getFieldData[T interface { + GetFieldsData() []*schemapb.FieldData +}](rs T, fieldID int64) (*schemapb.FieldData, bool) { + fd, has := lo.Find(rs.GetFieldsData(), func(fd *schemapb.FieldData) bool { + return fd.GetFieldId() == fieldID + }) + return fd, has +} + type ResultSuite struct { suite.Suite } @@ -54,10 +64,12 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0} var fieldDataArray1 []*schemapb.FieldData + fieldDataArray1 = append(fieldDataArray1, genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{1000, 2000}, 1)) fieldDataArray1 = append(fieldDataArray1, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1)) fieldDataArray1 = append(fieldDataArray1, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim)) var fieldDataArray2 []*schemapb.FieldData + fieldDataArray2 = append(fieldDataArray2, genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2000, 3000}, 1)) fieldDataArray2 = append(fieldDataArray2, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1)) fieldDataArray2 = append(fieldDataArray2, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim)) @@ -88,10 +100,14 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2}, NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) suite.NoError(err) - suite.Equal(2, len(result.GetFieldsData())) + suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]int64{0, 1}, result.GetIds().GetIntId().GetData()) - suite.Equal(Int64Array, result.GetFieldsData()[0].GetScalars().GetLongData().Data) - suite.InDeltaSlice(FloatVector, result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + intFieldData, has := getFieldData(result, Int64FieldID) + suite.Require().True(has) + suite.Equal(Int64Array, intFieldData.GetScalars().GetLongData().Data) + vectorFieldData, has := getFieldData(result, FloatVectorFieldID) + suite.Require().True(has) + suite.InDeltaSlice(FloatVector, vectorFieldData.GetVectors().GetFloatVector().Data, 10e-10) }) suite.Run("test nil results", func() { @@ -168,11 +184,15 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { suite.Run(test.description, func() { result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{r1, r2}, NewMergeParam(test.limit, make([]int64, 0), nil, false)) - suite.Equal(2, len(result.GetFieldsData())) + suite.Equal(3, len(result.GetFieldsData())) suite.Equal(int(test.limit), len(result.GetIds().GetIntId().GetData())) suite.Equal(resultIDs[0:test.limit], result.GetIds().GetIntId().GetData()) - suite.Equal(resultField0[0:test.limit], result.GetFieldsData()[0].GetScalars().GetLongData().Data) - suite.InDeltaSlice(resultFloat[0:test.limit*Dim], result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + intFieldData, has := getFieldData(result, Int64FieldID) + suite.Require().True(has) + suite.Equal(resultField0[0:test.limit], intFieldData.GetScalars().GetLongData().Data) + vectorFieldData, has := getFieldData(result, FloatVectorFieldID) + suite.Require().True(has) + suite.InDeltaSlice(resultFloat[0:test.limit*Dim], vectorFieldData.GetVectors().GetFloatVector().Data, 10e-10) suite.NoError(err) }) } @@ -211,10 +231,14 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { suite.Run("test int ID", func() { result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{r1, r2}, NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) - suite.Equal(2, len(result.GetFieldsData())) + suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]int64{1, 2, 3, 4}, result.GetIds().GetIntId().GetData()) - suite.Equal([]int64{11, 11, 22, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) - suite.InDeltaSlice(resultFloat, result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + intFieldData, has := getFieldData(result, Int64FieldID) + suite.Require().True(has) + suite.Equal([]int64{11, 11, 22, 22}, intFieldData.GetScalars().GetLongData().Data) + vectorFieldData, has := getFieldData(result, FloatVectorFieldID) + suite.Require().True(has) + suite.InDeltaSlice(resultFloat, vectorFieldData.GetVectors().GetFloatVector().Data, 10e-10) suite.NoError(err) }) @@ -238,10 +262,14 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{r1, r2}, NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) suite.NoError(err) - suite.Equal(2, len(result.GetFieldsData())) + suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]string{"a", "b", "c", "d"}, result.GetIds().GetStrId().GetData()) - suite.Equal([]int64{11, 11, 22, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) - suite.InDeltaSlice(resultFloat, result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + intFieldData, has := getFieldData(result, Int64FieldID) + suite.Require().True(has) + suite.Equal([]int64{11, 11, 22, 22}, intFieldData.GetScalars().GetLongData().Data) + vectorFieldData, has := getFieldData(result, FloatVectorFieldID) + suite.Require().True(has) + suite.InDeltaSlice(resultFloat, vectorFieldData.GetVectors().GetFloatVector().Data, 10e-10) suite.NoError(err) }) }) @@ -259,10 +287,12 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0} var fieldDataArray1 []*schemapb.FieldData + fieldDataArray1 = append(fieldDataArray1, genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{1000, 2000}, 1)) fieldDataArray1 = append(fieldDataArray1, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1)) fieldDataArray1 = append(fieldDataArray1, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim)) var fieldDataArray2 []*schemapb.FieldData + fieldDataArray2 = append(fieldDataArray2, genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2000, 3000}, 1)) fieldDataArray2 = append(fieldDataArray2, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1)) fieldDataArray2 = append(fieldDataArray2, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim)) @@ -291,10 +321,14 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2}, NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) suite.NoError(err) - suite.Equal(2, len(result.GetFieldsData())) + suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]int64{0, 1}, result.GetIds().GetIntId().GetData()) - suite.Equal(Int64Array, result.GetFieldsData()[0].GetScalars().GetLongData().Data) - suite.InDeltaSlice(FloatVector, result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + intFieldData, has := getFieldData(result, Int64FieldID) + suite.Require().True(has) + suite.Equal(Int64Array, intFieldData.GetScalars().GetLongData().GetData()) + vectorFieldData, has := getFieldData(result, FloatVectorFieldID) + suite.Require().True(has) + suite.InDeltaSlice(FloatVector, vectorFieldData.GetVectors().GetFloatVector().Data, 10e-10) }) suite.Run("test nil results", func() { @@ -389,11 +423,16 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { suite.Run(test.description, func() { result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{r1, r2}, NewMergeParam(test.limit, make([]int64, 0), nil, false)) - suite.Equal(2, len(result.GetFieldsData())) + suite.Equal(3, len(result.GetFieldsData())) suite.Equal(int(test.limit), len(result.GetIds().GetIntId().GetData())) suite.Equal(resultIDs[0:test.limit], result.GetIds().GetIntId().GetData()) - suite.Equal(resultField0[0:test.limit], result.GetFieldsData()[0].GetScalars().GetLongData().Data) - suite.InDeltaSlice(resultFloat[0:test.limit*Dim], result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + + intFieldData, has := getFieldData(result, Int64FieldID) + suite.Require().True(has) + suite.Equal(resultField0[0:test.limit], intFieldData.GetScalars().GetLongData().Data) + vectorFieldData, has := getFieldData(result, FloatVectorFieldID) + suite.Require().True(has) + suite.InDeltaSlice(resultFloat[0:test.limit*Dim], vectorFieldData.GetVectors().GetFloatVector().Data, 10e-10) suite.NoError(err) }) } @@ -430,10 +469,15 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { suite.Run("test int ID", func() { result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{r1, r2}, NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) - suite.Equal(2, len(result.GetFieldsData())) + suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]int64{1, 2, 3, 4}, result.GetIds().GetIntId().GetData()) - suite.Equal([]int64{11, 11, 22, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) - suite.InDeltaSlice(resultFloat, result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + + intFieldData, has := getFieldData(result, Int64FieldID) + suite.Require().True(has) + suite.Equal([]int64{11, 11, 22, 22}, intFieldData.GetScalars().GetLongData().Data) + vectorFieldData, has := getFieldData(result, FloatVectorFieldID) + suite.Require().True(has) + suite.InDeltaSlice(resultFloat, vectorFieldData.GetVectors().GetFloatVector().Data, 10e-10) suite.NoError(err) }) @@ -457,10 +501,14 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{r1, r2}, NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) suite.NoError(err) - suite.Equal(2, len(result.GetFieldsData())) + suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]string{"a", "b", "c", "d"}, result.GetIds().GetStrId().GetData()) - suite.Equal([]int64{11, 11, 22, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) - suite.InDeltaSlice(resultFloat, result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + intFieldData, has := getFieldData(result, Int64FieldID) + suite.Require().True(has) + suite.Equal([]int64{11, 11, 22, 22}, intFieldData.GetScalars().GetLongData().Data) + vectorFieldData, has := getFieldData(result, FloatVectorFieldID) + suite.Require().True(has) + suite.InDeltaSlice(resultFloat, vectorFieldData.GetVectors().GetFloatVector().Data, 10e-10) suite.NoError(err) }) }) @@ -478,12 +526,14 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0} var fieldDataArray1 []*schemapb.FieldData + fieldDataArray1 = append(fieldDataArray1, genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{1000, 2000, 3000}, 1)) fieldDataArray1 = append(fieldDataArray1, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:3], 1)) fieldDataArray1 = append(fieldDataArray1, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:12], Dim)) var fieldDataArray2 []*schemapb.FieldData + fieldDataArray2 = append(fieldDataArray2, genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2000, 3000, 4000}, 1)) fieldDataArray2 = append(fieldDataArray2, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:3], 1)) fieldDataArray2 = append(fieldDataArray2, genFieldData(FloatVectorFieldName, FloatVectorFieldID, @@ -518,13 +568,17 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2}, NewMergeParam(3, make([]int64, 0), nil, true)) suite.NoError(err) - suite.Equal(2, len(result.GetFieldsData())) + suite.Equal(3, len(result.GetFieldsData())) // has more result both, stop reduce when draining one result // here, we can only get best result from 0 to 4 without 6, because result1 has more results suite.Equal([]int64{0, 1, 2, 3, 4}, result.GetIds().GetIntId().GetData()) - suite.Equal([]int64{11, 22, 11, 22, 33}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) + intFieldData, has := getFieldData(result, Int64FieldID) + suite.Require().True(has) + suite.Equal([]int64{11, 22, 11, 22, 33}, intFieldData.GetScalars().GetLongData().Data) + vectorFieldData, has := getFieldData(result, FloatVectorFieldID) + suite.Require().True(has) suite.InDeltaSlice([]float32{1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 11, 22, 33, 44}, - result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + vectorFieldData.GetVectors().GetFloatVector().Data, 10e-10) }) suite.Run("merge stop unlimited", func() { result1.HasMoreResult = false @@ -532,13 +586,17 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2}, NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, true)) suite.NoError(err) - suite.Equal(2, len(result.GetFieldsData())) + suite.Equal(3, len(result.GetFieldsData())) // as result1 and result2 don't have better results neither // we can reduce all available result into the reduced result suite.Equal([]int64{0, 1, 2, 3, 4, 6}, result.GetIds().GetIntId().GetData()) - suite.Equal([]int64{11, 22, 11, 22, 33, 33}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) + intFieldData, has := getFieldData(result, Int64FieldID) + suite.Require().True(has) + suite.Equal([]int64{11, 22, 11, 22, 33, 33}, intFieldData.GetScalars().GetLongData().Data) + vectorFieldData, has := getFieldData(result, FloatVectorFieldID) + suite.Require().True(has) suite.InDeltaSlice([]float32{1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 11, 22, 33, 44, 11, 22, 33, 44}, - result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + vectorFieldData.GetVectors().GetFloatVector().Data, 10e-10) }) suite.Run("merge stop one limited", func() { result1.HasMoreResult = true @@ -546,12 +604,16 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2}, NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, true)) suite.NoError(err) - suite.Equal(2, len(result.GetFieldsData())) + suite.Equal(3, len(result.GetFieldsData())) // as result1 may have better results, stop reducing when draining it suite.Equal([]int64{0, 1, 2, 3, 4}, result.GetIds().GetIntId().GetData()) - suite.Equal([]int64{11, 22, 11, 22, 33}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) + intFieldData, has := getFieldData(result, Int64FieldID) + suite.Require().True(has) + suite.Equal([]int64{11, 22, 11, 22, 33}, intFieldData.GetScalars().GetLongData().Data) + vectorFieldData, has := getFieldData(result, FloatVectorFieldID) + suite.Require().True(has) suite.InDeltaSlice([]float32{1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 11, 22, 33, 44}, - result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + vectorFieldData.GetVectors().GetFloatVector().Data, 10e-10) }) }) @@ -581,11 +643,15 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2}, NewMergeParam(3, make([]int64, 0), nil, true)) suite.NoError(err) - suite.Equal(2, len(result.GetFieldsData())) + suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]int64{0, 2, 4, 6, 7}, result.GetIds().GetIntId().GetData()) - suite.Equal([]int64{11, 11, 22, 22, 33}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) + intFieldData, has := getFieldData(result, Int64FieldID) + suite.Require().True(has) + suite.Equal([]int64{11, 11, 22, 22, 33}, intFieldData.GetScalars().GetLongData().Data) + vectorFieldData, has := getFieldData(result, FloatVectorFieldID) + suite.Require().True(has) suite.InDeltaSlice([]float32{1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, 11, 22, 33, 44}, - result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10) + vectorFieldData.GetVectors().GetFloatVector().Data, 10e-10) }) suite.Run("test stop internal merge for best with early termination", func() { @@ -599,6 +665,12 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { }, FieldsData: fieldDataArray1, } + var drainDataArray2 []*schemapb.FieldData + drainDataArray2 = append(drainDataArray2, genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2000}, 1)) + drainDataArray2 = append(drainDataArray2, genFieldData(Int64FieldName, Int64FieldID, + schemapb.DataType_Int64, Int64Array[0:1], 1)) + drainDataArray2 = append(drainDataArray2, genFieldData(FloatVectorFieldName, FloatVectorFieldID, + schemapb.DataType_FloatVector, FloatVector[0:4], Dim)) result2 := &internalpb.RetrieveResults{ Ids: &schemapb.IDs{ IdField: &schemapb.IDs_IntId{ @@ -607,7 +679,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { }, }, }, - FieldsData: fieldDataArray2, + FieldsData: drainDataArray2, } suite.Run("test drain one result without more results", func() { result1.HasMoreResult = false @@ -615,7 +687,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2}, NewMergeParam(3, make([]int64, 0), nil, true)) suite.NoError(err) - suite.Equal(2, len(result.GetFieldsData())) + suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]int64{0, 2, 4, 7}, result.GetIds().GetIntId().GetData()) }) suite.Run("test drain one result with more results", func() { @@ -624,7 +696,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2}, NewMergeParam(3, make([]int64, 0), nil, true)) suite.NoError(err) - suite.Equal(2, len(result.GetFieldsData())) + suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]int64{0, 2}, result.GetIds().GetIntId().GetData()) }) }) diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go index dfa35f2109dec..3a9d2392f27c7 100644 --- a/pkg/util/typeutil/schema.go +++ b/pkg/util/typeutil/schema.go @@ -1326,6 +1326,10 @@ type ResultWithID interface { GetHasMoreResult() bool } +type ResultWithTimestamp interface { + GetTimestamps() []int64 +} + // SelectMinPK select the index of the minPK in results T of the cursors. func SelectMinPK[T ResultWithID](results []T, cursors []int64) (int, bool) { var ( @@ -1365,6 +1369,54 @@ func SelectMinPK[T ResultWithID](results []T, cursors []int64) (int, bool) { return sel, drainResult } +func SelectMinPKWithTimestamp[T interface { + ResultWithID + ResultWithTimestamp +}](results []T, cursors []int64) (int, bool) { + var ( + sel = -1 + drainResult = false + maxTimestamp int64 = 0 + minIntPK int64 = math.MaxInt64 + + firstStr = true + minStrPK string + ) + for i, cursor := range cursors { + timestamps := results[i].GetTimestamps() + // if cursor has run out of all results from one result and this result has more matched results + // in this case we have tell reduce to stop because better results may be retrieved in the following iteration + if int(cursor) >= GetSizeOfIDs(results[i].GetIds()) && (results[i].GetHasMoreResult()) { + drainResult = true + continue + } + + pkInterface := GetPK(results[i].GetIds(), cursor) + + switch pk := pkInterface.(type) { + case string: + ts := timestamps[cursor] + if firstStr || pk < minStrPK || (pk == minStrPK && ts > maxTimestamp) { + firstStr = false + minStrPK = pk + sel = i + maxTimestamp = ts + } + case int64: + ts := timestamps[cursor] + if pk < minIntPK || (pk == minIntPK && ts > maxTimestamp) { + minIntPK = pk + sel = i + maxTimestamp = ts + } + default: + continue + } + } + + return sel, drainResult +} + func AppendGroupByValue(dstResData *schemapb.SearchResultData, groupByVal interface{}, srcDataType schemapb.DataType, ) error { From 89461db5f32f1f0d936c6850223b2dc038a1ba2b Mon Sep 17 00:00:00 2001 From: zhuwenxing Date: Fri, 21 Jun 2024 14:20:08 +0800 Subject: [PATCH 11/21] test: update the lib of bf16 (#34044) pr: https://github.com/milvus-io/milvus/pull/34043 Signed-off-by: zhuwenxing --- tests/restful_client_v2/requirements.txt | 3 +-- tests/restful_client_v2/utils/utils.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/restful_client_v2/requirements.txt b/tests/restful_client_v2/requirements.txt index 5f7aa7242368b..624e0f269dbb8 100644 --- a/tests/restful_client_v2/requirements.txt +++ b/tests/restful_client_v2/requirements.txt @@ -12,5 +12,4 @@ pytest-xdist==2.5.0 minio==7.1.14 tenacity==8.1.0 # for bf16 datatype -jax==0.4.13 -jaxlib==0.4.13 +ml-dtypes==0.2.0 diff --git a/tests/restful_client_v2/utils/utils.py b/tests/restful_client_v2/utils/utils.py index 112e26e787e3e..cbd7640edf0eb 100644 --- a/tests/restful_client_v2/utils/utils.py +++ b/tests/restful_client_v2/utils/utils.py @@ -4,7 +4,7 @@ import string from faker import Faker import numpy as np -import jax.numpy as jnp +from ml_dtypes import bfloat16 from sklearn import preprocessing import base64 import requests @@ -191,7 +191,7 @@ def gen_bf16_vectors(num, dim): for _ in range(num): raw_vector = [random.random() for _ in range(dim)] raw_vectors.append(raw_vector) - bf16_vector = np.array(jnp.array(raw_vector, dtype=jnp.bfloat16)).view(np.uint8).tolist() + bf16_vector = np.array(raw_vector, dtype=bfloat16).view(np.uint8).tolist() bf16_vectors.append(bytes(bf16_vector)) return raw_vectors, bf16_vectors From 061a00c58f1c7a80117b8027075a266590946b6d Mon Sep 17 00:00:00 2001 From: wei liu Date: Fri, 21 Jun 2024 16:56:02 +0800 Subject: [PATCH 12/21] enhance: Enable database level replica num and resource groups for loading collection (#33052) (#33981) pr: #33052 issue: #30040 This PR introduce two database level props: 1. database.replica.number 2. database.resource_groups User can set those two database props by AlterDatabase API, then can load collection without specified replica_num and resource groups. then it will use database level load param when try to load collections. Signed-off-by: Wei Liu --- internal/proxy/task.go | 5 - .../querycoordv2/meta/coordinator_broker.go | 46 +++++ .../meta/coordinator_broker_test.go | 87 ++++++++ internal/querycoordv2/meta/mock_broker.go | 119 +++++++++++ internal/querycoordv2/server_test.go | 12 +- internal/querycoordv2/services.go | 36 ++++ internal/querycoordv2/services_test.go | 2 + pkg/common/common.go | 41 ++++ pkg/common/common_test.go | 50 +++++ tests/integration/replicas/load/load_test.go | 187 ++++++++++++++++++ 10 files changed, 575 insertions(+), 10 deletions(-) create mode 100644 tests/integration/replicas/load/load_test.go diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 469db6ad98ffc..19f3ea4c24804 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -1548,11 +1548,6 @@ func (t *loadCollectionTask) PreExecute(ctx context.Context) error { return err } - // To compat with LoadCollcetion before Milvus@2.1 - if t.ReplicaNumber == 0 { - t.ReplicaNumber = 1 - } - return nil } diff --git a/internal/querycoordv2/meta/coordinator_broker.go b/internal/querycoordv2/meta/coordinator_broker.go index cbcb9fced74e5..2df54688affb0 100644 --- a/internal/querycoordv2/meta/coordinator_broker.go +++ b/internal/querycoordv2/meta/coordinator_broker.go @@ -30,7 +30,9 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/merr" @@ -47,6 +49,8 @@ type Broker interface { GetSegmentInfo(ctx context.Context, segmentID ...UniqueID) (*datapb.GetSegmentInfoResponse, error) GetIndexInfo(ctx context.Context, collectionID UniqueID, segmentID UniqueID) ([]*querypb.FieldIndexInfo, error) GetRecoveryInfoV2(ctx context.Context, collectionID UniqueID, partitionIDs ...UniqueID) ([]*datapb.VchannelInfo, []*datapb.SegmentInfo, error) + DescribeDatabase(ctx context.Context, dbName string) (*rootcoordpb.DescribeDatabaseResponse, error) + GetCollectionLoadInfo(ctx context.Context, collectionID UniqueID) ([]string, int64, error) } type CoordinatorBroker struct { @@ -83,6 +87,48 @@ func (broker *CoordinatorBroker) DescribeCollection(ctx context.Context, collect return resp, nil } +func (broker *CoordinatorBroker) DescribeDatabase(ctx context.Context, dbName string) (*rootcoordpb.DescribeDatabaseResponse, error) { + ctx, cancel := context.WithTimeout(ctx, paramtable.Get().QueryCoordCfg.BrokerTimeout.GetAsDuration(time.Millisecond)) + defer cancel() + + req := &rootcoordpb.DescribeDatabaseRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection), + ), + DbName: dbName, + } + resp, err := broker.rootCoord.DescribeDatabase(ctx, req) + if err := merr.CheckRPCCall(resp, err); err != nil { + log.Ctx(ctx).Warn("failed to describe database", zap.Error(err)) + return nil, err + } + return resp, nil +} + +// try to get database level replica_num and resource groups, return (resource_groups, replica_num, error) +func (broker *CoordinatorBroker) GetCollectionLoadInfo(ctx context.Context, collectionID UniqueID) ([]string, int64, error) { + // to do by weiliu1031: querycoord should cache mappings: collectionID->dbName + collectionInfo, err := broker.DescribeCollection(ctx, collectionID) + if err != nil { + return nil, 0, err + } + + dbInfo, err := broker.DescribeDatabase(ctx, collectionInfo.GetDbName()) + if err != nil { + return nil, 0, err + } + replicaNum, err := common.DatabaseLevelReplicaNumber(dbInfo.GetProperties()) + if err != nil { + return nil, 0, err + } + rgs, err := common.DatabaseLevelResourceGroups(dbInfo.GetProperties()) + if err != nil { + return nil, 0, err + } + + return rgs, replicaNum, nil +} + func (broker *CoordinatorBroker) GetPartitions(ctx context.Context, collectionID UniqueID) ([]UniqueID, error) { ctx, cancel := context.WithTimeout(ctx, paramtable.Get().QueryCoordCfg.BrokerTimeout.GetAsDuration(time.Millisecond)) defer cancel() diff --git a/internal/querycoordv2/meta/coordinator_broker_test.go b/internal/querycoordv2/meta/coordinator_broker_test.go index 476a997dd2ae9..778268f7ce66b 100644 --- a/internal/querycoordv2/meta/coordinator_broker_test.go +++ b/internal/querycoordv2/meta/coordinator_broker_test.go @@ -18,6 +18,7 @@ package meta import ( "context" + "strings" "testing" "github.com/cockroachdb/errors" @@ -32,6 +33,8 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -490,6 +493,90 @@ func (s *CoordinatorBrokerDataCoordSuite) TestGetIndexInfo() { }) } +func (s *CoordinatorBrokerRootCoordSuite) TestDescribeDatabase() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("normal_case", func() { + s.rootcoord.EXPECT().DescribeDatabase(mock.Anything, mock.Anything). + Return(&rootcoordpb.DescribeDatabaseResponse{ + Status: merr.Success(), + }, nil) + _, err := s.broker.DescribeDatabase(ctx, "fake_db1") + s.NoError(err) + s.resetMock() + }) + + s.Run("rootcoord_return_error", func() { + s.rootcoord.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")) + _, err := s.broker.DescribeDatabase(ctx, "fake_db1") + s.Error(err) + s.resetMock() + }) + + s.Run("rootcoord_return_failure_status", func() { + s.rootcoord.EXPECT().DescribeDatabase(mock.Anything, mock.Anything). + Return(&rootcoordpb.DescribeDatabaseResponse{ + Status: merr.Status(errors.New("fake error")), + }, nil) + _, err := s.broker.DescribeDatabase(ctx, "fake_db1") + s.Error(err) + s.resetMock() + }) + + s.Run("rootcoord_return_unimplemented", func() { + s.rootcoord.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(nil, merr.ErrServiceUnimplemented) + _, err := s.broker.DescribeDatabase(ctx, "fake_db1") + s.Error(err) + s.resetMock() + }) +} + +func (s *CoordinatorBrokerRootCoordSuite) TestGetCollectionLoadInfo() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("normal_case", func() { + s.rootcoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + DbName: "fake_db1", + }, nil) + s.rootcoord.EXPECT().DescribeDatabase(mock.Anything, mock.Anything). + Return(&rootcoordpb.DescribeDatabaseResponse{ + Status: merr.Success(), + Properties: []*commonpb.KeyValuePair{ + { + Key: common.DatabaseReplicaNumber, + Value: "3", + }, + { + Key: common.DatabaseResourceGroups, + Value: strings.Join([]string{"rg1", "rg2"}, ","), + }, + }, + }, nil) + rgs, replicas, err := s.broker.GetCollectionLoadInfo(ctx, 1) + s.NoError(err) + s.Equal(int64(3), replicas) + s.Contains(rgs, "rg1") + s.Contains(rgs, "rg2") + s.resetMock() + }) + + s.Run("props not set", func() { + s.rootcoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + DbName: "fake_db1", + }, nil) + s.rootcoord.EXPECT().DescribeDatabase(mock.Anything, mock.Anything). + Return(&rootcoordpb.DescribeDatabaseResponse{ + Status: merr.Success(), + Properties: []*commonpb.KeyValuePair{}, + }, nil) + _, _, err := s.broker.GetCollectionLoadInfo(ctx, 1) + s.Error(err) + s.resetMock() + }) +} + func TestCoordinatorBroker(t *testing.T) { suite.Run(t, new(CoordinatorBrokerRootCoordSuite)) suite.Run(t, new(CoordinatorBrokerDataCoordSuite)) diff --git a/internal/querycoordv2/meta/mock_broker.go b/internal/querycoordv2/meta/mock_broker.go index ff3548985547f..a940aff58bc91 100644 --- a/internal/querycoordv2/meta/mock_broker.go +++ b/internal/querycoordv2/meta/mock_broker.go @@ -13,6 +13,8 @@ import ( mock "github.com/stretchr/testify/mock" querypb "github.com/milvus-io/milvus/internal/proto/querypb" + + rootcoordpb "github.com/milvus-io/milvus/internal/proto/rootcoordpb" ) // MockBroker is an autogenerated mock type for the Broker type @@ -83,6 +85,123 @@ func (_c *MockBroker_DescribeCollection_Call) RunAndReturn(run func(context.Cont return _c } +// DescribeDatabase provides a mock function with given fields: ctx, dbName +func (_m *MockBroker) DescribeDatabase(ctx context.Context, dbName string) (*rootcoordpb.DescribeDatabaseResponse, error) { + ret := _m.Called(ctx, dbName) + + var r0 *rootcoordpb.DescribeDatabaseResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*rootcoordpb.DescribeDatabaseResponse, error)); ok { + return rf(ctx, dbName) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *rootcoordpb.DescribeDatabaseResponse); ok { + r0 = rf(ctx, dbName) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*rootcoordpb.DescribeDatabaseResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, dbName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockBroker_DescribeDatabase_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeDatabase' +type MockBroker_DescribeDatabase_Call struct { + *mock.Call +} + +// DescribeDatabase is a helper method to define mock.On call +// - ctx context.Context +// - dbName string +func (_e *MockBroker_Expecter) DescribeDatabase(ctx interface{}, dbName interface{}) *MockBroker_DescribeDatabase_Call { + return &MockBroker_DescribeDatabase_Call{Call: _e.mock.On("DescribeDatabase", ctx, dbName)} +} + +func (_c *MockBroker_DescribeDatabase_Call) Run(run func(ctx context.Context, dbName string)) *MockBroker_DescribeDatabase_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockBroker_DescribeDatabase_Call) Return(_a0 *rootcoordpb.DescribeDatabaseResponse, _a1 error) *MockBroker_DescribeDatabase_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockBroker_DescribeDatabase_Call) RunAndReturn(run func(context.Context, string) (*rootcoordpb.DescribeDatabaseResponse, error)) *MockBroker_DescribeDatabase_Call { + _c.Call.Return(run) + return _c +} + +// GetCollectionLoadInfo provides a mock function with given fields: ctx, collectionID +func (_m *MockBroker) GetCollectionLoadInfo(ctx context.Context, collectionID int64) ([]string, int64, error) { + ret := _m.Called(ctx, collectionID) + + var r0 []string + var r1 int64 + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, int64) ([]string, int64, error)); ok { + return rf(ctx, collectionID) + } + if rf, ok := ret.Get(0).(func(context.Context, int64) []string); ok { + r0 = rf(ctx, collectionID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int64) int64); ok { + r1 = rf(ctx, collectionID) + } else { + r1 = ret.Get(1).(int64) + } + + if rf, ok := ret.Get(2).(func(context.Context, int64) error); ok { + r2 = rf(ctx, collectionID) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockBroker_GetCollectionLoadInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionLoadInfo' +type MockBroker_GetCollectionLoadInfo_Call struct { + *mock.Call +} + +// GetCollectionLoadInfo is a helper method to define mock.On call +// - ctx context.Context +// - collectionID int64 +func (_e *MockBroker_Expecter) GetCollectionLoadInfo(ctx interface{}, collectionID interface{}) *MockBroker_GetCollectionLoadInfo_Call { + return &MockBroker_GetCollectionLoadInfo_Call{Call: _e.mock.On("GetCollectionLoadInfo", ctx, collectionID)} +} + +func (_c *MockBroker_GetCollectionLoadInfo_Call) Run(run func(ctx context.Context, collectionID int64)) *MockBroker_GetCollectionLoadInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64)) + }) + return _c +} + +func (_c *MockBroker_GetCollectionLoadInfo_Call) Return(_a0 []string, _a1 int64, _a2 error) *MockBroker_GetCollectionLoadInfo_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockBroker_GetCollectionLoadInfo_Call) RunAndReturn(run func(context.Context, int64) ([]string, int64, error)) *MockBroker_GetCollectionLoadInfo_Call { + _c.Call.Return(run) + return _c +} + // GetIndexInfo provides a mock function with given fields: ctx, collectionID, segmentID func (_m *MockBroker) GetIndexInfo(ctx context.Context, collectionID int64, segmentID int64) ([]*querypb.FieldIndexInfo, error) { ret := _m.Called(ctx, collectionID, segmentID) diff --git a/internal/querycoordv2/server_test.go b/internal/querycoordv2/server_test.go index f71172fd89394..78c2fdb89b6f1 100644 --- a/internal/querycoordv2/server_test.go +++ b/internal/querycoordv2/server_test.go @@ -436,17 +436,19 @@ func (suite *ServerSuite) loadAll() { for _, collection := range suite.collections { if suite.loadTypes[collection] == querypb.LoadType_LoadCollection { req := &querypb.LoadCollectionRequest{ - CollectionID: collection, - ReplicaNumber: suite.replicaNumber[collection], + CollectionID: collection, + ReplicaNumber: suite.replicaNumber[collection], + ResourceGroups: []string{meta.DefaultResourceGroupName}, } resp, err := suite.server.LoadCollection(ctx, req) suite.NoError(err) suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode) } else { req := &querypb.LoadPartitionsRequest{ - CollectionID: collection, - PartitionIDs: suite.partitions[collection], - ReplicaNumber: suite.replicaNumber[collection], + CollectionID: collection, + PartitionIDs: suite.partitions[collection], + ReplicaNumber: suite.replicaNumber[collection], + ResourceGroups: []string{meta.DefaultResourceGroupName}, } resp, err := suite.server.LoadPartitions(ctx, req) suite.NoError(err) diff --git a/internal/querycoordv2/services.go b/internal/querycoordv2/services.go index b1fb72dd435e1..159b933e80385 100644 --- a/internal/querycoordv2/services.go +++ b/internal/querycoordv2/services.go @@ -215,6 +215,24 @@ func (s *Server) LoadCollection(ctx context.Context, req *querypb.LoadCollection return merr.Status(err), nil } + if req.GetReplicaNumber() <= 0 || len(req.GetResourceGroups()) == 0 { + // when replica number or resource groups is not set, use database level config + rgs, replicas, err := s.broker.GetCollectionLoadInfo(ctx, req.GetCollectionID()) + if err != nil { + log.Warn("failed to get data base level load info", zap.Error(err)) + } + + if req.GetReplicaNumber() <= 0 { + log.Info("load collection use database level replica number", zap.Int64("databaseLevelReplicaNum", replicas)) + req.ReplicaNumber = int32(replicas) + } + + if len(req.GetResourceGroups()) == 0 { + log.Info("load collection use database level resource groups", zap.Strings("databaseLevelResourceGroups", rgs)) + req.ResourceGroups = rgs + } + } + if err := s.checkResourceGroup(req.GetCollectionID(), req.GetResourceGroups()); err != nil { msg := "failed to load collection" log.Warn(msg, zap.Error(err)) @@ -316,6 +334,24 @@ func (s *Server) LoadPartitions(ctx context.Context, req *querypb.LoadPartitions return merr.Status(err), nil } + if req.GetReplicaNumber() <= 0 || len(req.GetResourceGroups()) == 0 { + // when replica number or resource groups is not set, use database level config + rgs, replicas, err := s.broker.GetCollectionLoadInfo(ctx, req.GetCollectionID()) + if err != nil { + log.Warn("failed to get data base level load info", zap.Error(err)) + } + + if req.GetReplicaNumber() <= 0 { + log.Info("load collection use database level replica number", zap.Int64("databaseLevelReplicaNum", replicas)) + req.ReplicaNumber = int32(replicas) + } + + if len(req.GetResourceGroups()) == 0 { + log.Info("load collection use database level resource groups", zap.Strings("databaseLevelResourceGroups", rgs)) + req.ResourceGroups = rgs + } + } + if err := s.checkResourceGroup(req.GetCollectionID(), req.GetResourceGroups()); err != nil { msg := "failed to load partitions" log.Warn(msg, zap.Error(err)) diff --git a/internal/querycoordv2/services_test.go b/internal/querycoordv2/services_test.go index 744004fd8f074..e4fb877d0101f 100644 --- a/internal/querycoordv2/services_test.go +++ b/internal/querycoordv2/services_test.go @@ -207,6 +207,8 @@ func (suite *ServiceSuite) SetupTest() { } suite.server.UpdateStateCode(commonpb.StateCode_Healthy) + + suite.broker.EXPECT().GetCollectionLoadInfo(mock.Anything, mock.Anything).Return([]string{meta.DefaultResourceGroupName}, 1, nil).Maybe() } func (suite *ServiceSuite) TestShowCollections() { diff --git a/pkg/common/common.go b/pkg/common/common.go index b847c627f4f42..4830b8ee92cb3 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -18,6 +18,8 @@ package common import ( "encoding/binary" + "fmt" + "strconv" "strings" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -137,6 +139,10 @@ const ( CollectionDiskQuotaKey = "collection.diskProtection.diskQuota.mb" PartitionDiskQuotaKey = "partition.diskProtection.diskQuota.mb" + + // database level properties + DatabaseReplicaNumber = "database.replica.number" + DatabaseResourceGroups = "database.resource_groups" ) // common properties @@ -208,3 +214,38 @@ const ( // LatestVerision is the magic number for watch latest revision LatestRevision = int64(-1) ) + +func DatabaseLevelReplicaNumber(kvs []*commonpb.KeyValuePair) (int64, error) { + for _, kv := range kvs { + if kv.Key == DatabaseReplicaNumber { + replicaNum, err := strconv.ParseInt(kv.Value, 10, 64) + if err != nil { + return 0, fmt.Errorf("invalid database property: [key=%s] [value=%s]", kv.Key, kv.Value) + } + + return replicaNum, nil + } + } + + return 0, fmt.Errorf("database property not found: %s", DatabaseReplicaNumber) +} + +func DatabaseLevelResourceGroups(kvs []*commonpb.KeyValuePair) ([]string, error) { + for _, kv := range kvs { + if kv.Key == DatabaseResourceGroups { + invalidPropValue := fmt.Errorf("invalid database property: [key=%s] [value=%s]", kv.Key, kv.Value) + if len(kv.Value) == 0 { + return nil, invalidPropValue + } + + rgs := strings.Split(kv.Value, ",") + if len(rgs) == 0 { + return nil, invalidPropValue + } + + return rgs, nil + } + } + + return nil, fmt.Errorf("database property not found: %s", DatabaseResourceGroups) +} diff --git a/pkg/common/common_test.go b/pkg/common/common_test.go index 7228b1b6ab8e8..2dc31e33fb16a 100644 --- a/pkg/common/common_test.go +++ b/pkg/common/common_test.go @@ -1,9 +1,12 @@ package common import ( + "strings" "testing" "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" ) func TestIsSystemField(t *testing.T) { @@ -38,3 +41,50 @@ func TestIsSystemField(t *testing.T) { }) } } + +func TestDatabaseProperties(t *testing.T) { + props := []*commonpb.KeyValuePair{ + { + Key: DatabaseReplicaNumber, + Value: "3", + }, + { + Key: DatabaseResourceGroups, + Value: strings.Join([]string{"rg1", "rg2"}, ","), + }, + } + + replicaNum, err := DatabaseLevelReplicaNumber(props) + assert.NoError(t, err) + assert.Equal(t, int64(3), replicaNum) + + rgs, err := DatabaseLevelResourceGroups(props) + assert.NoError(t, err) + assert.Contains(t, rgs, "rg1") + assert.Contains(t, rgs, "rg2") + + // test prop not found + _, err = DatabaseLevelReplicaNumber(nil) + assert.Error(t, err) + + _, err = DatabaseLevelResourceGroups(nil) + assert.Error(t, err) + + // test invalid prop value + + props = []*commonpb.KeyValuePair{ + { + Key: DatabaseReplicaNumber, + Value: "xxxx", + }, + { + Key: DatabaseResourceGroups, + Value: "", + }, + } + _, err = DatabaseLevelReplicaNumber(props) + assert.Error(t, err) + + _, err = DatabaseLevelResourceGroups(props) + assert.Error(t, err) +} diff --git a/tests/integration/replicas/load/load_test.go b/tests/integration/replicas/load/load_test.go new file mode 100644 index 0000000000000..837a634c53799 --- /dev/null +++ b/tests/integration/replicas/load/load_test.go @@ -0,0 +1,187 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package balance + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/rgpb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/querycoordv2/meta" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/tests/integration" +) + +const ( + dim = 128 + dbName = "" + collectionName = "test_load_collection" +) + +type LoadTestSuite struct { + integration.MiniClusterSuite +} + +func (s *LoadTestSuite) SetupSuite() { + paramtable.Init() + paramtable.Get().Save(paramtable.Get().QueryCoordCfg.BalanceCheckInterval.Key, "1000") + paramtable.Get().Save(paramtable.Get().QueryNodeCfg.GracefulStopTimeout.Key, "1") + + s.Require().NoError(s.SetupEmbedEtcd()) +} + +func (s *LoadTestSuite) loadCollection(collectionName string, replica int, rgs []string) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // load + loadStatus, err := s.Cluster.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + ReplicaNumber: int32(replica), + ResourceGroups: rgs, + }) + s.NoError(err) + s.True(merr.Ok(loadStatus)) + s.WaitForLoad(ctx, collectionName) +} + +func (s *LoadTestSuite) releaseCollection(collectionName string) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // load + status, err := s.Cluster.Proxy.ReleaseCollection(ctx, &milvuspb.ReleaseCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + }) + s.NoError(err) + s.True(merr.Ok(status)) +} + +func (s *LoadTestSuite) TestLoadWithDatabaseLevelConfig() { + ctx := context.Background() + s.CreateCollectionWithConfiguration(ctx, &integration.CreateCollectionConfig{ + DBName: dbName, + Dim: dim, + CollectionName: collectionName, + ChannelNum: 1, + SegmentNum: 3, + RowNumPerSegment: 2000, + }) + + // prepare resource groups + rgNum := 3 + rgs := make([]string, 0) + for i := 0; i < rgNum; i++ { + rgs = append(rgs, fmt.Sprintf("rg_%d", i)) + s.Cluster.QueryCoord.CreateResourceGroup(ctx, &milvuspb.CreateResourceGroupRequest{ + ResourceGroup: rgs[i], + Config: &rgpb.ResourceGroupConfig{ + Requests: &rgpb.ResourceGroupLimit{ + NodeNum: 1, + }, + Limits: &rgpb.ResourceGroupLimit{ + NodeNum: 1, + }, + + TransferFrom: []*rgpb.ResourceGroupTransfer{ + { + ResourceGroup: meta.DefaultResourceGroupName, + }, + }, + TransferTo: []*rgpb.ResourceGroupTransfer{ + { + ResourceGroup: meta.DefaultResourceGroupName, + }, + }, + }, + }) + } + + resp, err := s.Cluster.QueryCoord.ListResourceGroups(ctx, &milvuspb.ListResourceGroupsRequest{}) + s.NoError(err) + s.True(merr.Ok(resp.GetStatus())) + s.Len(resp.GetResourceGroups(), rgNum+1) + + for i := 1; i < rgNum; i++ { + s.Cluster.AddQueryNode() + } + + s.Eventually(func() bool { + matchCounter := 0 + for _, rg := range rgs { + resp1, err := s.Cluster.QueryCoord.DescribeResourceGroup(ctx, &querypb.DescribeResourceGroupRequest{ + ResourceGroup: rg, + }) + s.NoError(err) + s.True(merr.Ok(resp.GetStatus())) + if len(resp1.ResourceGroup.Nodes) == 1 { + matchCounter += 1 + } + } + return matchCounter == rgNum + }, 30*time.Second, time.Second) + + status, err := s.Cluster.Proxy.AlterDatabase(ctx, &milvuspb.AlterDatabaseRequest{ + DbName: "default", + Properties: []*commonpb.KeyValuePair{ + { + Key: common.DatabaseReplicaNumber, + Value: "3", + }, + { + Key: common.DatabaseResourceGroups, + Value: strings.Join(rgs, ","), + }, + }, + }) + s.NoError(err) + s.True(merr.Ok(status)) + + resp1, err := s.Cluster.Proxy.DescribeDatabase(ctx, &milvuspb.DescribeDatabaseRequest{ + DbName: "default", + }) + s.NoError(err) + s.True(merr.Ok(resp1.Status)) + s.Len(resp1.GetProperties(), 2) + + // load collection without specified replica and rgs + s.loadCollection(collectionName, 0, nil) + resp2, err := s.Cluster.Proxy.GetReplicas(ctx, &milvuspb.GetReplicasRequest{ + DbName: dbName, + CollectionName: collectionName, + }) + s.NoError(err) + s.True(merr.Ok(resp2.Status)) + s.Len(resp2.GetReplicas(), 3) + s.releaseCollection(collectionName) +} + +func TestReplicas(t *testing.T) { + suite.Run(t, new(LoadTestSuite)) +} From 83ff7591954b968ad1b92f50ff6da3dc324942a1 Mon Sep 17 00:00:00 2001 From: "sammy.huang" Date: Sat, 22 Jun 2024 19:40:05 +0800 Subject: [PATCH 13/21] enhance: upgrade build-env to ubuntu 22.04 and gcc12 (#33961) pr: #33959 Signed-off-by: Liang Huang --- .github/workflows/publish-builder.yaml | 4 +- .../docker/builder/cpu/ubuntu22.04/Dockerfile | 66 +++++++++++++++++++ build/docker/milvus/ubuntu22.04/Dockerfile | 37 +++++++++++ 3 files changed, 105 insertions(+), 2 deletions(-) create mode 100644 build/docker/builder/cpu/ubuntu22.04/Dockerfile create mode 100644 build/docker/milvus/ubuntu22.04/Dockerfile diff --git a/.github/workflows/publish-builder.yaml b/.github/workflows/publish-builder.yaml index d8d0029fd44d5..573b8bf352173 100644 --- a/.github/workflows/publish-builder.yaml +++ b/.github/workflows/publish-builder.yaml @@ -28,7 +28,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu20.04, amazonlinux2023, rockylinux8] + os: [ubuntu22.04, amazonlinux2023, rockylinux8] env: OS_NAME: ${{ matrix.os }} IMAGE_ARCH: ${{ matrix.arch }} @@ -83,7 +83,7 @@ jobs: file: build/docker/builder/cpu/${{ matrix.os }}/Dockerfile - name: Bump Builder Version uses: ./.github/actions/bump-builder-version - if: success() && github.event_name == 'push' && github.repository == 'milvus-io/milvus' && matrix.os == 'ubuntu20.04' + if: success() && github.event_name == 'push' && github.repository == 'milvus-io/milvus' && matrix.os == 'ubuntu22.04' with: tag: "${{ steps.extracter.outputs.version }}-${{ steps.extracter.outputs.sha_short }}" type: cpu diff --git a/build/docker/builder/cpu/ubuntu22.04/Dockerfile b/build/docker/builder/cpu/ubuntu22.04/Dockerfile new file mode 100644 index 0000000000000..bef41afff31e0 --- /dev/null +++ b/build/docker/builder/cpu/ubuntu22.04/Dockerfile @@ -0,0 +1,66 @@ +# Copyright (C) 2019-2022 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under the License. + +FROM ubuntu:jammy-20240530 + +ARG TARGETARCH + +RUN apt-get update && apt-get install -y --no-install-recommends wget curl ca-certificates gnupg2 \ + g++ gcc gdb gdbserver ninja-build git make ccache libssl-dev zlib1g-dev zip unzip \ + clang-format-12 clang-tidy-12 lcov libtool m4 autoconf automake python3 python3-pip \ + pkg-config uuid-dev libaio-dev libopenblas-dev && \ + apt-get remove --purge -y && \ + rm -rf /var/lib/apt/lists/* + +# upgrade gcc to 12 +RUN apt-get update && apt-get install -y gcc-12 g++-12 && cd /usr/bin \ + && unlink gcc && ln -s gcc-12 gcc \ + && unlink g++ && ln -s g++-12 g++ + +RUN pip3 install conan==1.61.0 + +RUN echo "target arch $TARGETARCH" +RUN wget -qO- "https://cmake.org/files/v3.27/cmake-3.27.5-linux-`uname -m`.tar.gz" | tar --strip-components=1 -xz -C /usr/local + +RUN mkdir /opt/vcpkg && \ + wget -qO- vcpkg.tar.gz https://github.com/microsoft/vcpkg/archive/master.tar.gz | tar --strip-components=1 -xz -C /opt/vcpkg && \ + rm -rf vcpkg.tar.gz + +ENV VCPKG_FORCE_SYSTEM_BINARIES 1 + +RUN /opt/vcpkg/bootstrap-vcpkg.sh -disableMetrics && ln -s /opt/vcpkg/vcpkg /usr/local/bin/vcpkg && vcpkg version + +RUN vcpkg install azure-identity-cpp azure-storage-blobs-cpp gtest + +# Install Go +ENV GOPATH /go +ENV GOROOT /usr/local/go +ENV GO111MODULE on +ENV PATH $GOPATH/bin:$GOROOT/bin:$PATH +RUN mkdir -p /usr/local/go && wget -qO- "https://go.dev/dl/go1.21.10.linux-$TARGETARCH.tar.gz" | tar --strip-components=1 -xz -C /usr/local/go && \ + mkdir -p "$GOPATH/src" "$GOPATH/bin" && \ + go clean --modcache && \ + chmod -R 777 "$GOPATH" && chmod -R a+w $(go env GOTOOLDIR) + +# refer: https://code.visualstudio.com/docs/remote/containers-advanced#_avoiding-extension-reinstalls-on-container-rebuild +RUN mkdir -p /home/milvus/.vscode-server/extensions \ + /home/milvus/.vscode-server-insiders/extensions \ + && chmod -R 777 /home/milvus + +COPY --chown=0:0 build/docker/builder/entrypoint.sh / + +RUN curl https://sh.rustup.rs -sSf | \ + sh -s -- --default-toolchain=1.73 -y + +ENV PATH=/root/.cargo/bin:$PATH + +ENTRYPOINT [ "/entrypoint.sh" ] +CMD ["tail", "-f", "/dev/null"] diff --git a/build/docker/milvus/ubuntu22.04/Dockerfile b/build/docker/milvus/ubuntu22.04/Dockerfile new file mode 100644 index 0000000000000..40a4e9e0fa79f --- /dev/null +++ b/build/docker/milvus/ubuntu22.04/Dockerfile @@ -0,0 +1,37 @@ +# Copyright (C) 2019-2022 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under the License. + +FROM ubuntu:jammy-20240530 + +ARG TARGETARCH + +RUN apt-get update && \ + apt-get install -y --no-install-recommends curl ca-certificates libaio-dev libgomp1 libopenblas-dev && \ + apt-get remove --purge -y && \ + rm -rf /var/lib/apt/lists/* + +COPY --chown=root:root --chmod=774 ./bin/ /milvus/bin/ + +COPY --chown=root:root --chmod=774 ./configs/ /milvus/configs/ + +COPY --chown=root:root --chmod=774 ./lib/ /milvus/lib/ + +ENV PATH=/milvus/bin:$PATH +ENV LD_LIBRARY_PATH=/milvus/lib:$LD_LIBRARY_PATH:/usr/lib +ENV LD_PRELOAD=/milvus/lib/libjemalloc.so +ENV MALLOC_CONF=background_thread:true + +# Add Tini +ADD https://github.com/krallin/tini/releases/download/v0.19.0/tini-$TARGETARCH /tini +RUN chmod +x /tini +ENTRYPOINT ["/tini", "--"] + +WORKDIR /milvus/ From 2fda43e49f861aa29b759c93f6af5b8f87e2a60d Mon Sep 17 00:00:00 2001 From: "yihao.dai" Date: Sat, 22 Jun 2024 19:42:02 +0800 Subject: [PATCH 14/21] fix: Do compressBinlog to fix logID 0 (#34060) (#34062) Do compressBinlog to ensure that reloadFromKV will fill binlogs' logID after datacoord restarts. issue: https://github.com/milvus-io/milvus/issues/34059 pr: https://github.com/milvus-io/milvus/pull/34060 --------- Signed-off-by: bigsheeper --- internal/metastore/kv/datacoord/kv_catalog.go | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/internal/metastore/kv/datacoord/kv_catalog.go b/internal/metastore/kv/datacoord/kv_catalog.go index a7cae134c9ac3..4272458cc3c48 100644 --- a/internal/metastore/kv/datacoord/kv_catalog.go +++ b/internal/metastore/kv/datacoord/kv_catalog.go @@ -219,29 +219,23 @@ func (kc *Catalog) applyBinlogInfo(segments []*datapb.SegmentInfo, insertLogs, d for _, segmentInfo := range segments { if len(segmentInfo.Binlogs) == 0 { segmentInfo.Binlogs = insertLogs[segmentInfo.ID] - } else { - err = binlog.CompressFieldBinlogs(segmentInfo.Binlogs) - if err != nil { - return err - } + } + if err = binlog.CompressFieldBinlogs(segmentInfo.Binlogs); err != nil { + return err } if len(segmentInfo.Deltalogs) == 0 { segmentInfo.Deltalogs = deltaLogs[segmentInfo.ID] - } else { - err = binlog.CompressFieldBinlogs(segmentInfo.Deltalogs) - if err != nil { - return err - } + } + if err = binlog.CompressFieldBinlogs(segmentInfo.Deltalogs); err != nil { + return err } if len(segmentInfo.Statslogs) == 0 { segmentInfo.Statslogs = statsLogs[segmentInfo.ID] - } else { - err = binlog.CompressFieldBinlogs(segmentInfo.Statslogs) - if err != nil { - return err - } + } + if err = binlog.CompressFieldBinlogs(segmentInfo.Statslogs); err != nil { + return err } } return nil From 9ea4aa9bc0fb257ed59f226f6d42f4a63d912201 Mon Sep 17 00:00:00 2001 From: "sammy.huang" Date: Sun, 23 Jun 2024 09:34:14 +0800 Subject: [PATCH 15/21] enhance: get environment variable from .env (#34080) Signed-off-by: Liang Huang --- build/build_image.sh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/build/build_image.sh b/build/build_image.sh index 68a2cc82f6720..9480dcb368bd2 100755 --- a/build/build_image.sh +++ b/build/build_image.sh @@ -23,6 +23,11 @@ set -x # Absolute path to the toplevel milvus directory. toplevel=$(dirname "$(cd "$(dirname "${0}")"; pwd)") +if [[ -f "$toplevel/.env" ]]; then + set -a # automatically export all variables from .env + source $toplevel/.env + set +a # stop automatically exporting +fi OS_NAME="${OS_NAME:-ubuntu20.04}" MILVUS_IMAGE_REPO="${MILVUS_IMAGE_REPO:-milvusdb/milvus}" From 63c9e6e023a2a70626b0de64b199c47c9e25279b Mon Sep 17 00:00:00 2001 From: sre-ci-robot <56469371+sre-ci-robot@users.noreply.github.com> Date: Mon, 24 Jun 2024 10:36:25 +0800 Subject: [PATCH 16/21] [automated] Update cpu Builder image changes (#34030) Update cpu Builder image changes See changes: https://github.com/milvus-io/milvus/commit/b3d425f50a4fcf119c676835a1a09e6c9949a1c5 Signed-off-by: sre-ci-robot sre-ci-robot@users.noreply.github.com Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .env | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.env b/.env index 6beb24525c5e1..44b54ff9b2f21 100644 --- a/.env +++ b/.env @@ -5,8 +5,8 @@ IMAGE_ARCH=amd64 OS_NAME=ubuntu20.04 # for services.builder.image in docker-compose.yml -DATE_VERSION=20240520-d27db99 -LATEST_DATE_VERSION=20240520-d27db99 +DATE_VERSION=20240620-b3d425f +LATEST_DATE_VERSION=20240620-b3d425f # for services.gpubuilder.image in docker-compose.yml GPU_DATE_VERSION=20240520-c35eaaa From f4debe5e5ede1fb9711de1a40e49ca67a5b03a5a Mon Sep 17 00:00:00 2001 From: sre-ci-robot <56469371+sre-ci-robot@users.noreply.github.com> Date: Mon, 24 Jun 2024 10:38:13 +0800 Subject: [PATCH 17/21] [automated] Update gpu Builder image changes (#34031) Update gpu Builder image changes See changes: https://github.com/milvus-io/milvus/commit/b3d425f50a4fcf119c676835a1a09e6c9949a1c5 Signed-off-by: sre-ci-robot sre-ci-robot@users.noreply.github.com Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .env | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.env b/.env index 44b54ff9b2f21..dffe0e27df665 100644 --- a/.env +++ b/.env @@ -9,8 +9,8 @@ DATE_VERSION=20240620-b3d425f LATEST_DATE_VERSION=20240620-b3d425f # for services.gpubuilder.image in docker-compose.yml -GPU_DATE_VERSION=20240520-c35eaaa -LATEST_GPU_DATE_VERSION=20240520-c35eaaa +GPU_DATE_VERSION=20240620-b3d425f +LATEST_GPU_DATE_VERSION=20240620-b3d425f # for other services in docker-compose.yml MINIO_ADDRESS=minio:9000 From 630a726f351d9931772533b1ccf58eda9029835a Mon Sep 17 00:00:00 2001 From: zhuwenxing Date: Mon, 24 Jun 2024 10:40:03 +0800 Subject: [PATCH 18/21] test: refine restful testcases trace (#34065) pr: https://github.com/milvus-io/milvus/pull/34066 --------- Signed-off-by: zhuwenxing --- tests/restful_client_v2/api/milvus.py | 152 +++++++--- tests/restful_client_v2/base/testbase.py | 20 +- .../testcases/test_vector_operations.py | 281 ++++++++++++++++++ 3 files changed, 412 insertions(+), 41 deletions(-) diff --git a/tests/restful_client_v2/api/milvus.py b/tests/restful_client_v2/api/milvus.py index 76807a4d36b37..9c1dabbdbb83f 100644 --- a/tests/restful_client_v2/api/milvus.py +++ b/tests/restful_client_v2/api/milvus.py @@ -8,17 +8,75 @@ from minio.commonconfig import CopySource from tenacity import retry, retry_if_exception_type, stop_after_attempt from requests.exceptions import ConnectionError - - -def logger_request_response(response, url, tt, headers, data, str_data, str_response, method): - if len(data) > 2000: - data = data[:1000] + "..." + data[-1000:] +import urllib.parse + +ENABLE_LOG_SAVE = False + + +def simplify_list(lst): + if len(lst) > 20: + return [lst[0], '...', lst[-1]] + return lst + + +def simplify_dict(d): + if d is None: + d = {} + if len(d) > 20: + keys = list(d.keys()) + d = {keys[0]: d[keys[0]], '...': '...', keys[-1]: d[keys[-1]]} + simplified = {} + for k, v in d.items(): + if isinstance(v, list): + simplified[k] = simplify_list([simplify_dict(item) if isinstance(item, dict) else simplify_list( + item) if isinstance(item, list) else item for item in v]) + elif isinstance(v, dict): + simplified[k] = simplify_dict(v) + else: + simplified[k] = v + return simplified + + +def build_curl_command(method, url, headers, data=None, params=None): + if isinstance(params, dict): + query_string = urllib.parse.urlencode(params) + url = f"{url}?{query_string}" + curl_cmd = [f"curl -X {method} '{url}'"] + + for key, value in headers.items(): + curl_cmd.append(f" -H '{key}: {value}'") + + if data: + # process_and_simplify(data) + data = json.dumps(data, indent=4) + curl_cmd.append(f" -d '{data}'") + + return " \\\n".join(curl_cmd) + + +def logger_request_response(response, url, tt, headers, data, str_data, str_response, method, params=None): + # save data to jsonl file + + data_dict = json.loads(data) if data else {} + data_dict_simple = simplify_dict(data_dict) + if ENABLE_LOG_SAVE: + with open('request_response.jsonl', 'a') as f: + f.write(json.dumps({ + "method": method, + "url": url, + "headers": headers, + "params": params, + "data": data_dict_simple, + "response": response.json() + }) + "\n") + data = json.dumps(data_dict_simple, indent=4) try: if response.status_code == 200: if ('code' in response.json() and response.json()["code"] == 0) or ( 'Code' in response.json() and response.json()["Code"] == 0): logger.debug( - f"\nmethod: {method}, \nurl: {url}, \ncost time: {tt}, \nheader: {headers}, \npayload: {str_data}, \nresponse: {str_response}") + f"\nmethod: {method}, \nurl: {url}, \ncost time: {tt}, \nheader: {headers}, \npayload: {data}, \nresponse: {str_response}") + else: logger.debug( f"\nmethod: {method}, \nurl: {url}, \ncost time: {tt}, \nheader: {headers}, \npayload: {data}, \nresponse: {response.text}") @@ -30,21 +88,31 @@ def logger_request_response(response, url, tt, headers, data, str_data, str_resp f"method: \nmethod: {method}, \nurl: {url}, \ncost time: {tt}, \nheader: {headers}, \npayload: {data}, \nresponse: {response.text}, \nerror: {e}") -class Requests: +class Requests(): + uuid = str(uuid.uuid1()) + api_key = None + def __init__(self, url=None, api_key=None): self.url = url self.api_key = api_key + if self.uuid is None: + self.uuid = str(uuid.uuid1()) self.headers = { 'Content-Type': 'application/json', 'Authorization': f'Bearer {self.api_key}', - 'RequestId': str(uuid.uuid1()) + 'RequestId': self.uuid } - def update_headers(self): + @classmethod + def update_uuid(cls, _uuid): + cls.uuid = _uuid + + @classmethod + def update_headers(cls): headers = { 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}', - 'RequestId': str(uuid.uuid1()) + 'Authorization': f'Bearer {cls.api_key}', + 'RequestId': cls.uuid } return headers @@ -59,7 +127,7 @@ def post(self, url, headers=None, data=None, params=None): response = requests.post(url, headers=headers, data=data, params=params) tt = time.time() - t0 str_response = response.text[:200] + '...' + response.text[-200:] if len(response.text) > 400 else response.text - logger_request_response(response, url, tt, headers, data, str_data, str_response, "post") + logger_request_response(response, url, tt, headers, data, str_data, str_response, "post", params=params) return response @retry(retry=retry_if_exception_type(ConnectionError), stop=stop_after_attempt(3)) @@ -74,7 +142,7 @@ def get(self, url, headers=None, params=None, data=None): response = requests.get(url, headers=headers, params=params, data=data) tt = time.time() - t0 str_response = response.text[:200] + '...' + response.text[-200:] if len(response.text) > 400 else response.text - logger_request_response(response, url, tt, headers, data, str_data, str_response, "get") + logger_request_response(response, url, tt, headers, data, str_data, str_response, "get", params=params) return response @retry(retry=retry_if_exception_type(ConnectionError), stop=stop_after_attempt(3)) @@ -111,12 +179,13 @@ def __init__(self, endpoint, token): self.db_name = None self.headers = self.update_headers() - def update_headers(self): + @classmethod + def update_headers(cls): headers = { 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}', + 'Authorization': f'Bearer {cls.api_key}', 'Accept-Type-Allow-Int64': "true", - 'RequestId': str(uuid.uuid1()) + 'RequestId': cls.uuid } return headers @@ -195,8 +264,6 @@ def vector_hybrid_search(self, payload, db_name="default", timeout=10): return response.json() - - def vector_query(self, payload, db_name="default", timeout=5): time.sleep(1) url = f'{self.endpoint}/v2/vectordb/entities/query' @@ -269,13 +336,14 @@ def __init__(self, endpoint, token): self.db_name = None self.headers = self.update_headers() - def update_headers(self, headers=None): + @classmethod + def update_headers(cls, headers=None): if headers is not None: return headers headers = { 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}', - 'RequestId': str(uuid.uuid1()) + 'Authorization': f'Bearer {cls.api_key}', + 'RequestId': cls.uuid } return headers @@ -415,11 +483,12 @@ def __init__(self, endpoint, token): self.db_name = None self.headers = self.update_headers() - def update_headers(self): + @classmethod + def update_headers(cls): headers = { 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}', - 'RequestId': str(uuid.uuid1()) + 'Authorization': f'Bearer {cls.api_key}', + 'RequestId': cls.uuid } return headers @@ -530,11 +599,12 @@ def __init__(self, endpoint, token): self.db_name = None self.headers = self.update_headers() - def update_headers(self): + @classmethod + def update_headers(cls): headers = { 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}', - 'RequestId': str(uuid.uuid1()) + 'Authorization': f'Bearer {cls.api_key}', + 'RequestId': cls.uuid } return headers @@ -594,11 +664,12 @@ def __init__(self, endpoint, token): self.headers = self.update_headers() self.role_names = [] - def update_headers(self): + @classmethod + def update_headers(cls): headers = { 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}', - 'RequestId': str(uuid.uuid1()) + 'Authorization': f'Bearer {cls.api_key}', + 'RequestId': cls.uuid } return headers @@ -653,11 +724,12 @@ def __init__(self, endpoint, token): self.db_name = None self.headers = self.update_headers() - def update_headers(self): + @classmethod + def update_headers(cls): headers = { 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}', - 'RequestId': str(uuid.uuid1()) + 'Authorization': f'Bearer {cls.api_key}', + 'RequestId': cls.uuid } return headers @@ -714,11 +786,12 @@ def __init__(self, endpoint, token): self.db_name = None self.headers = self.update_headers() - def update_headers(self): + @classmethod + def update_headers(cls): headers = { 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}', - 'RequestId': str(uuid.uuid1()) + 'Authorization': f'Bearer {cls.api_key}', + 'RequestId': cls.uuid } return headers @@ -765,11 +838,12 @@ def __init__(self, endpoint, token): self.db_name = None self.headers = self.update_headers() - def update_headers(self): + @classmethod + def update_headers(cls): headers = { 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}', - 'RequestId': str(uuid.uuid1()) + 'Authorization': f'Bearer {cls.api_key}', + 'RequestId': cls.uuid } return headers diff --git a/tests/restful_client_v2/base/testbase.py b/tests/restful_client_v2/base/testbase.py index a47239ae96102..c4d0d3f2bb07e 100644 --- a/tests/restful_client_v2/base/testbase.py +++ b/tests/restful_client_v2/base/testbase.py @@ -2,10 +2,11 @@ import sys import pytest import time +import uuid from pymilvus import connections, db from utils.util_log import test_log as logger from api.milvus import (VectorClient, CollectionClient, PartitionClient, IndexClient, AliasClient, - UserClient, RoleClient, ImportJobClient, StorageClient) + UserClient, RoleClient, ImportJobClient, StorageClient, Requests) from utils.utils import get_data_by_payload @@ -35,7 +36,7 @@ class Base: class TestBase(Base): - + req = None def teardown_method(self): self.collection_client.api_key = self.api_key all_collections = self.collection_client.collection_list()['data'] @@ -49,19 +50,34 @@ def teardown_method(self): except Exception as e: logger.error(e) + # def setup_method(self): + # self.req = Requests() + # self.req.uuid = str(uuid.uuid1()) + @pytest.fixture(scope="function", autouse=True) def init_client(self, endpoint, token, minio_host, bucket_name, root_path): + _uuid = str(uuid.uuid1()) + self.req = Requests() + self.req.update_uuid(_uuid) self.endpoint = f"{endpoint}" self.api_key = f"{token}" self.invalid_api_key = "invalid_token" self.vector_client = VectorClient(self.endpoint, self.api_key) + self.vector_client.update_uuid(_uuid) self.collection_client = CollectionClient(self.endpoint, self.api_key) + self.collection_client.update_uuid(_uuid) self.partition_client = PartitionClient(self.endpoint, self.api_key) + self.partition_client.update_uuid(_uuid) self.index_client = IndexClient(self.endpoint, self.api_key) + self.index_client.update_uuid(_uuid) self.alias_client = AliasClient(self.endpoint, self.api_key) + self.alias_client.update_uuid(_uuid) self.user_client = UserClient(self.endpoint, self.api_key) + self.user_client.update_uuid(_uuid) self.role_client = RoleClient(self.endpoint, self.api_key) + self.role_client.update_uuid(_uuid) self.import_job_client = ImportJobClient(self.endpoint, self.api_key) + self.import_job_client.update_uuid(_uuid) self.storage_client = StorageClient(f"{minio_host}:9000", "minioadmin", "minioadmin", bucket_name, root_path) if token is None: self.vector_client.api_key = None diff --git a/tests/restful_client_v2/testcases/test_vector_operations.py b/tests/restful_client_v2/testcases/test_vector_operations.py index bce1f9acca4a9..26ff4b6b51851 100644 --- a/tests/restful_client_v2/testcases/test_vector_operations.py +++ b/tests/restful_client_v2/testcases/test_vector_operations.py @@ -238,6 +238,287 @@ def test_insert_entities_with_all_vector_datatype(self, nb, dim, insert_round, a assert rsp['code'] == 0 assert len(rsp['data']) == 50 + @pytest.mark.parametrize("insert_round", [1]) + @pytest.mark.parametrize("auto_id", [True]) + @pytest.mark.parametrize("is_partition_key", [True]) + @pytest.mark.parametrize("enable_dynamic_schema", [True]) + @pytest.mark.parametrize("nb", [3000]) + @pytest.mark.parametrize("dim", [128]) + def test_insert_entities_with_all_vector_datatype_0(self, nb, dim, insert_round, auto_id, + is_partition_key, enable_dynamic_schema): + """ + Insert a vector with a simple payload + """ + # create a collection + name = gen_collection_name() + payload = { + "collectionName": name, + "schema": { + "autoId": auto_id, + "enableDynamicField": enable_dynamic_schema, + "fields": [ + {"fieldName": "book_id", "dataType": "Int64", "isPrimary": True, "elementTypeParams": {}}, + {"fieldName": "user_id", "dataType": "Int64", "isPartitionKey": is_partition_key, + "elementTypeParams": {}}, + {"fieldName": "word_count", "dataType": "Int64", "elementTypeParams": {}}, + {"fieldName": "book_describe", "dataType": "VarChar", "elementTypeParams": {"max_length": "256"}}, + {"fieldName": "book_vector", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{dim}"}}, + {"fieldName": "float_vector", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{dim}"}}, + {"fieldName": "float16_vector", "dataType": "Float16Vector", + "elementTypeParams": {"dim": f"{dim}"}}, + {"fieldName": "bfloat16_vector", "dataType": "BFloat16Vector", + "elementTypeParams": {"dim": f"{dim}"}}, + ] + }, + "indexParams": [ + {"fieldName": "book_vector", "indexName": "book_vector", "metricType": "L2", + "params": {"index_type": "FLAT"}}, + {"fieldName": "float_vector", "indexName": "float_vector", "metricType": "L2", + "params": {"index_type": "IVF_FLAT", "nlist": 128}}, + {"fieldName": "float16_vector", "indexName": "float16_vector", "metricType": "L2", + "params": {"index_type": "IVF_SQ8", "nlist": "128"}}, + {"fieldName": "bfloat16_vector", "indexName": "bfloat16_vector", "metricType": "L2", + "params": {"index_type": "IVF_PQ", "nlist": 128, "m": 16, "nbits": 8}}, + ] + } + + rsp = self.collection_client.collection_create(payload) + assert rsp['code'] == 0 + rsp = self.collection_client.collection_describe(name) + logger.info(f"rsp: {rsp}") + assert rsp['code'] == 0 + # insert data + for i in range(insert_round): + data = [] + for i in range(nb): + if auto_id: + tmp = { + "user_id": i, + "word_count": i, + "book_describe": f"book_{i}", + "book_vector": gen_vector(datatype="FloatVector", dim=dim), + "float_vector": gen_vector(datatype="FloatVector", dim=dim), + "float16_vector": gen_vector(datatype="Float16Vector", dim=dim), + "bfloat16_vector": gen_vector(datatype="BFloat16Vector", dim=dim), + } + else: + tmp = { + "book_id": i, + "user_id": i, + "word_count": i, + "book_describe": f"book_{i}", + "book_vector": gen_vector(datatype="FloatVector", dim=dim), + "float_vector": gen_vector(datatype="FloatVector", dim=dim), + "float16_vector": gen_vector(datatype="Float16Vector", dim=dim), + "bfloat16_vector": gen_vector(datatype="BFloat16Vector", dim=dim), + } + if enable_dynamic_schema: + tmp.update({f"dynamic_field_{i}": i}) + data.append(tmp) + payload = { + "collectionName": name, + "data": data, + } + rsp = self.vector_client.vector_insert(payload) + assert rsp['code'] == 0 + assert rsp['data']['insertCount'] == nb + c = Collection(name) + res = c.query( + expr="user_id > 0", + limit=1, + output_fields=["*"], + ) + logger.info(f"res: {res}") + # query data to make sure the data is inserted + rsp = self.vector_client.vector_query({"collectionName": name, "filter": "user_id > 0", "limit": 50}) + assert rsp['code'] == 0 + assert len(rsp['data']) == 50 + + @pytest.mark.parametrize("insert_round", [1]) + @pytest.mark.parametrize("auto_id", [True]) + @pytest.mark.parametrize("is_partition_key", [True]) + @pytest.mark.parametrize("enable_dynamic_schema", [True]) + @pytest.mark.parametrize("nb", [3000]) + @pytest.mark.parametrize("dim", [128]) + def test_insert_entities_with_all_vector_datatype_1(self, nb, dim, insert_round, auto_id, + is_partition_key, enable_dynamic_schema): + """ + Insert a vector with a simple payload + """ + # create a collection + name = gen_collection_name() + payload = { + "collectionName": name, + "schema": { + "autoId": auto_id, + "enableDynamicField": enable_dynamic_schema, + "fields": [ + {"fieldName": "book_id", "dataType": "Int64", "isPrimary": True, "elementTypeParams": {}}, + {"fieldName": "user_id", "dataType": "Int64", "isPartitionKey": is_partition_key, + "elementTypeParams": {}}, + {"fieldName": "word_count", "dataType": "Int64", "elementTypeParams": {}}, + {"fieldName": "book_describe", "dataType": "VarChar", "elementTypeParams": {"max_length": "256"}}, + {"fieldName": "float_vector", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{dim}"}}, + {"fieldName": "float16_vector", "dataType": "Float16Vector", + "elementTypeParams": {"dim": f"{dim}"}}, + {"fieldName": "bfloat16_vector", "dataType": "BFloat16Vector", + "elementTypeParams": {"dim": f"{dim}"}}, + ] + }, + "indexParams": [ + {"fieldName": "float_vector", "indexName": "float_vector", "metricType": "L2", + "params": {"index_type": "HNSW", "M": 32, "efConstruction": 360}}, + {"fieldName": "float16_vector", "indexName": "float16_vector", "metricType": "L2", + "params": {"index_type": "SCANN", "nlist": "128"}}, + {"fieldName": "bfloat16_vector", "indexName": "bfloat16_vector", "metricType": "L2", + "params": {"index_type": "DISKANN"}}, + ] + } + + rsp = self.collection_client.collection_create(payload) + assert rsp['code'] == 0 + rsp = self.collection_client.collection_describe(name) + logger.info(f"rsp: {rsp}") + assert rsp['code'] == 0 + # insert data + for i in range(insert_round): + data = [] + for i in range(nb): + if auto_id: + tmp = { + "user_id": i, + "word_count": i, + "book_describe": f"book_{i}", + "float_vector": gen_vector(datatype="FloatVector", dim=dim), + "float16_vector": gen_vector(datatype="Float16Vector", dim=dim), + "bfloat16_vector": gen_vector(datatype="BFloat16Vector", dim=dim), + } + else: + tmp = { + "book_id": i, + "user_id": i, + "word_count": i, + "book_describe": f"book_{i}", + "float_vector": gen_vector(datatype="FloatVector", dim=dim), + "float16_vector": gen_vector(datatype="Float16Vector", dim=dim), + "bfloat16_vector": gen_vector(datatype="BFloat16Vector", dim=dim), + } + if enable_dynamic_schema: + tmp.update({f"dynamic_field_{i}": i}) + data.append(tmp) + payload = { + "collectionName": name, + "data": data, + } + rsp = self.vector_client.vector_insert(payload) + assert rsp['code'] == 0 + assert rsp['data']['insertCount'] == nb + c = Collection(name) + res = c.query( + expr="user_id > 0", + limit=1, + output_fields=["*"], + ) + logger.info(f"res: {res}") + # query data to make sure the data is inserted + rsp = self.vector_client.vector_query({"collectionName": name, "filter": "user_id > 0", "limit": 50}) + assert rsp['code'] == 0 + assert len(rsp['data']) == 50 + + @pytest.mark.parametrize("insert_round", [1]) + @pytest.mark.parametrize("auto_id", [True]) + @pytest.mark.parametrize("is_partition_key", [True]) + @pytest.mark.parametrize("enable_dynamic_schema", [True]) + @pytest.mark.parametrize("nb", [3000]) + @pytest.mark.parametrize("dim", [128]) + def test_insert_entities_with_all_vector_datatype_2(self, nb, dim, insert_round, auto_id, + is_partition_key, enable_dynamic_schema): + """ + Insert a vector with a simple payload + """ + # create a collection + name = gen_collection_name() + payload = { + "collectionName": name, + "schema": { + "autoId": auto_id, + "enableDynamicField": enable_dynamic_schema, + "fields": [ + {"fieldName": "book_id", "dataType": "Int64", "isPrimary": True, "elementTypeParams": {}}, + {"fieldName": "user_id", "dataType": "Int64", "isPartitionKey": is_partition_key, + "elementTypeParams": {}}, + {"fieldName": "word_count", "dataType": "Int64", "elementTypeParams": {}}, + {"fieldName": "book_describe", "dataType": "VarChar", "elementTypeParams": {"max_length": "256"}}, + {"fieldName": "binary_vector_0", "dataType": "BinaryVector", "elementTypeParams": {"dim": f"{dim}"}}, + {"fieldName": "binary_vector_1", "dataType": "BinaryVector", "elementTypeParams": {"dim": f"{dim}"}}, + {"fieldName": "sparse_float_vector_0", "dataType": "SparseFloatVector"}, + {"fieldName": "sparse_float_vector_1", "dataType": "SparseFloatVector"}, + ] + }, + "indexParams": [ + {"fieldName": "binary_vector_0", "indexName": "binary_vector_0_index", "metricType": "HAMMING", + "params": {"index_type": "BIN_FLAT"}}, + {"fieldName": "binary_vector_1", "indexName": "binary_vector_1_index", "metricType": "HAMMING", + "params": {"index_type": "BIN_IVF_FLAT", "nlist": "512"}}, + {"fieldName": "sparse_float_vector_0", "indexName": "sparse_float_vector_0_index", "metricType": "IP", + "params": {"index_type": "SPARSE_INVERTED_INDEX", "drop_ratio_build": "0.2"}}, + {"fieldName": "sparse_float_vector_1", "indexName": "sparse_float_vector_1_index", "metricType": "IP", + "params": {"index_type": "SPARSE_WAND", "drop_ratio_build": "0.2"}} + ] + } + + rsp = self.collection_client.collection_create(payload) + assert rsp['code'] == 0 + rsp = self.collection_client.collection_describe(name) + logger.info(f"rsp: {rsp}") + assert rsp['code'] == 0 + # insert data + for i in range(insert_round): + data = [] + for i in range(nb): + if auto_id: + tmp = { + "user_id": i, + "word_count": i, + "book_describe": f"book_{i}", + "binary_vector_0": gen_vector(datatype="BinaryVector", dim=dim), + "binary_vector_1": gen_vector(datatype="BinaryVector", dim=dim), + "sparse_float_vector_0": gen_vector(datatype="SparseFloatVector", dim=dim, sparse_format="dok"), + "sparse_float_vector_1": gen_vector(datatype="SparseFloatVector", dim=dim, sparse_format="dok"), + } + else: + tmp = { + "book_id": i, + "user_id": i, + "word_count": i, + "book_describe": f"book_{i}", + "binary_vector_0": gen_vector(datatype="BinaryVector", dim=dim), + "binary_vector_1": gen_vector(datatype="BinaryVector", dim=dim), + "sparse_float_vector_0": gen_vector(datatype="SparseFloatVector", dim=dim, sparse_format="dok"), + "sparse_float_vector_1": gen_vector(datatype="SparseFloatVector", dim=dim, sparse_format="dok"), + } + if enable_dynamic_schema: + tmp.update({f"dynamic_field_{i}": i}) + data.append(tmp) + payload = { + "collectionName": name, + "data": data, + } + rsp = self.vector_client.vector_insert(payload) + assert rsp['code'] == 0 + assert rsp['data']['insertCount'] == nb + c = Collection(name) + res = c.query( + expr="user_id > 0", + limit=1, + output_fields=["*"], + ) + logger.info(f"res: {res}") + # query data to make sure the data is inserted + rsp = self.vector_client.vector_query({"collectionName": name, "filter": "user_id > 0", "limit": 50}) + assert rsp['code'] == 0 + assert len(rsp['data']) == 50 + @pytest.mark.parametrize("insert_round", [1]) @pytest.mark.parametrize("auto_id", [True, False]) @pytest.mark.parametrize("is_partition_key", [True, False]) From 22e6807e9ab04f356995ce51497ae5ad14490cfa Mon Sep 17 00:00:00 2001 From: Jiquan Long Date: Mon, 24 Jun 2024 10:50:03 +0800 Subject: [PATCH 19/21] feat: support inverted index for array (#33452) (#34053) pr: https://github.com/milvus-io/milvus/pull/33184 pr: https://github.com/milvus-io/milvus/pull/33452 pr: https://github.com/milvus-io/milvus/pull/33633 issue: https://github.com/milvus-io/milvus/issues/27704 Co-authored-by: xiaocai2333 --------- Signed-off-by: Cai Zhang Signed-off-by: longjiquan Co-authored-by: cai.zhang --- internal/core/src/common/Schema.h | 9 + internal/core/src/exec/expression/Expr.h | 16 + .../src/exec/expression/JsonContainsExpr.cpp | 98 ++++- .../src/exec/expression/JsonContainsExpr.h | 7 + .../core/src/exec/expression/UnaryExpr.cpp | 166 +++++++- internal/core/src/exec/expression/UnaryExpr.h | 8 + internal/core/src/expr/ITypeExpr.h | 17 +- internal/core/src/index/IndexFactory.cpp | 105 +++-- internal/core/src/index/IndexFactory.h | 16 +- .../core/src/index/InvertedIndexTantivy.cpp | 374 +++++++++--------- .../core/src/index/InvertedIndexTantivy.h | 30 +- internal/core/src/index/ScalarIndex.h | 14 + internal/core/src/index/TantivyConfig.h | 51 --- internal/core/src/indexbuilder/IndexFactory.h | 1 + internal/core/src/indexbuilder/index_c.cpp | 183 ++++++--- internal/core/src/indexbuilder/index_c.h | 8 +- internal/core/src/pb/CMakeLists.txt | 8 +- internal/core/src/segcore/Types.h | 1 + internal/core/src/segcore/load_index_c.cpp | 51 ++- internal/core/src/segcore/load_index_c.h | 5 + internal/core/src/storage/Types.h | 1 + .../core/thirdparty/tantivy/CMakeLists.txt | 6 + internal/core/thirdparty/tantivy/ffi_demo.cpp | 17 + .../tantivy-binding/include/tantivy-binding.h | 18 + .../tantivy/tantivy-binding/src/demo_c.rs | 14 + .../tantivy-binding/src/index_writer.rs | 76 +++- .../tantivy-binding/src/index_writer_c.rs | 74 ++++ .../tantivy/tantivy-binding/src/lib.rs | 1 + .../core/thirdparty/tantivy/tantivy-wrapper.h | 56 +++ internal/core/thirdparty/tantivy/test.cpp | 80 ++++ internal/core/unittest/CMakeLists.txt | 1 + .../unittest/test_array_inverted_index.cpp | 297 ++++++++++++++ internal/core/unittest/test_index_wrapper.cpp | 2 +- .../core/unittest/test_inverted_index.cpp | 27 +- internal/core/unittest/test_scalar_index.cpp | 28 +- internal/core/unittest/test_utils/DataGen.h | 26 +- .../core/unittest/test_utils/GenExprProto.h | 11 +- internal/datacoord/index_builder.go | 40 +- internal/datacoord/index_builder_test.go | 75 +++- internal/indexnode/indexnode_service.go | 2 + internal/indexnode/task.go | 256 ++++++------ internal/indexnode/task_test.go | 14 +- internal/indexnode/util.go | 12 + internal/indexnode/util_test.go | 41 ++ internal/proto/cgo_msg.proto | 23 ++ internal/proto/index_cgo_msg.proto | 50 +++ internal/proto/index_coord.proto | 4 + .../querynodev2/segments/load_index_info.go | 32 ++ internal/querynodev2/segments/segment.go | 50 ++- internal/util/indexcgowrapper/index.go | 25 +- pkg/util/indexparamcheck/inverted_checker.go | 3 +- .../indexparamcheck/inverted_checker_test.go | 2 +- scripts/generate_proto.sh | 3 + tests/python_client/testcases/test_index.py | 5 +- 54 files changed, 1980 insertions(+), 560 deletions(-) delete mode 100644 internal/core/src/index/TantivyConfig.h create mode 100644 internal/core/thirdparty/tantivy/ffi_demo.cpp create mode 100644 internal/core/thirdparty/tantivy/tantivy-binding/src/demo_c.rs create mode 100644 internal/core/unittest/test_array_inverted_index.cpp create mode 100644 internal/indexnode/util_test.go create mode 100644 internal/proto/cgo_msg.proto diff --git a/internal/core/src/common/Schema.h b/internal/core/src/common/Schema.h index b1068dd650392..754766f54388b 100644 --- a/internal/core/src/common/Schema.h +++ b/internal/core/src/common/Schema.h @@ -51,6 +51,15 @@ class Schema { return field_id; } + FieldId + AddDebugArrayField(const std::string& name, DataType element_type) { + auto field_id = FieldId(debug_id); + debug_id++; + this->AddField( + FieldName(name), field_id, DataType::ARRAY, element_type); + return field_id; + } + // auto gen field_id for convenience FieldId AddDebugField(const std::string& name, diff --git a/internal/core/src/exec/expression/Expr.h b/internal/core/src/exec/expression/Expr.h index ea9eeac92cef9..a300515560b2d 100644 --- a/internal/core/src/exec/expression/Expr.h +++ b/internal/core/src/exec/expression/Expr.h @@ -280,6 +280,22 @@ class SegmentExpr : public Expr { return result; } + template + void + ProcessIndexChunksV2(FUNC func, ValTypes... values) { + typedef std:: + conditional_t, std::string, T> + IndexInnerType; + using Index = index::ScalarIndex; + + for (size_t i = current_index_chunk_; i < num_index_chunk_; i++) { + const Index& index = + segment_->chunk_scalar_index(field_id_, i); + auto* index_ptr = const_cast(&index); + func(index_ptr, values...); + } + } + template bool CanUseIndex(OpType op) const { diff --git a/internal/core/src/exec/expression/JsonContainsExpr.cpp b/internal/core/src/exec/expression/JsonContainsExpr.cpp index 72251c301fb14..bbcc852c2a8e2 100644 --- a/internal/core/src/exec/expression/JsonContainsExpr.cpp +++ b/internal/core/src/exec/expression/JsonContainsExpr.cpp @@ -23,7 +23,14 @@ namespace exec { void PhyJsonContainsFilterExpr::Eval(EvalCtx& context, VectorPtr& result) { switch (expr_->column_.data_type_) { - case DataType::ARRAY: + case DataType::ARRAY: { + if (is_index_mode_) { + result = EvalArrayContainsForIndexSegment(); + } else { + result = EvalJsonContainsForDataSegment(); + } + break; + } case DataType::JSON: { if (is_index_mode_) { PanicInfo( @@ -94,7 +101,6 @@ PhyJsonContainsFilterExpr::EvalJsonContainsForDataSegment() { return ExecJsonContainsWithDiffType(); } } - break; } case proto::plan::JSONContainsExpr_JSONOp_ContainsAll: { if (IsArrayDataType(data_type)) { @@ -145,7 +151,6 @@ PhyJsonContainsFilterExpr::EvalJsonContainsForDataSegment() { return ExecJsonContainsAllWithDiffType(); } } - break; } default: PanicInfo(ExprInvalid, @@ -748,5 +753,92 @@ PhyJsonContainsFilterExpr::ExecJsonContainsWithDiffType() { return res_vec; } +VectorPtr +PhyJsonContainsFilterExpr::EvalArrayContainsForIndexSegment() { + switch (expr_->column_.element_type_) { + case DataType::BOOL: { + return ExecArrayContainsForIndexSegmentImpl(); + } + case DataType::INT8: { + return ExecArrayContainsForIndexSegmentImpl(); + } + case DataType::INT16: { + return ExecArrayContainsForIndexSegmentImpl(); + } + case DataType::INT32: { + return ExecArrayContainsForIndexSegmentImpl(); + } + case DataType::INT64: { + return ExecArrayContainsForIndexSegmentImpl(); + } + case DataType::FLOAT: { + return ExecArrayContainsForIndexSegmentImpl(); + } + case DataType::DOUBLE: { + return ExecArrayContainsForIndexSegmentImpl(); + } + case DataType::VARCHAR: + case DataType::STRING: { + return ExecArrayContainsForIndexSegmentImpl(); + } + default: + PanicInfo(DataTypeInvalid, + fmt::format("unsupported data type for " + "ExecArrayContainsForIndexSegmentImpl: {}", + expr_->column_.element_type_)); + } +} + +template +VectorPtr +PhyJsonContainsFilterExpr::ExecArrayContainsForIndexSegmentImpl() { + typedef std::conditional_t, + std::string, + ExprValueType> + GetType; + using Index = index::ScalarIndex; + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + + std::unordered_set elements; + for (auto const& element : expr_->vals_) { + elements.insert(GetValueFromProto(element)); + } + boost::container::vector elems(elements.begin(), elements.end()); + auto execute_sub_batch = + [this](Index* index_ptr, + const boost::container::vector& vals) { + switch (expr_->op_) { + case proto::plan::JSONContainsExpr_JSONOp_Contains: + case proto::plan::JSONContainsExpr_JSONOp_ContainsAny: { + return index_ptr->In(vals.size(), vals.data()); + } + case proto::plan::JSONContainsExpr_JSONOp_ContainsAll: { + TargetBitmap result(index_ptr->Count()); + result.set(); + for (size_t i = 0; i < vals.size(); i++) { + auto sub = index_ptr->In(1, &vals[i]); + result &= sub; + } + return result; + } + default: + PanicInfo( + ExprInvalid, + "unsupported array contains type {}", + proto::plan::JSONContainsExpr_JSONOp_Name(expr_->op_)); + } + }; + auto res = ProcessIndexChunks(execute_sub_batch, elems); + AssertInfo(res.size() == real_batch_size, + "internal error: expr processed rows {} not equal " + "expect batch size {}", + res.size(), + real_batch_size); + return std::make_shared(std::move(res)); +} + } //namespace exec } // namespace milvus diff --git a/internal/core/src/exec/expression/JsonContainsExpr.h b/internal/core/src/exec/expression/JsonContainsExpr.h index c757dc0d3fb92..a0cfdfdea0841 100644 --- a/internal/core/src/exec/expression/JsonContainsExpr.h +++ b/internal/core/src/exec/expression/JsonContainsExpr.h @@ -80,6 +80,13 @@ class PhyJsonContainsFilterExpr : public SegmentExpr { VectorPtr ExecJsonContainsWithDiffType(); + VectorPtr + EvalArrayContainsForIndexSegment(); + + template + VectorPtr + ExecArrayContainsForIndexSegmentImpl(); + private: std::shared_ptr expr_; }; diff --git a/internal/core/src/exec/expression/UnaryExpr.cpp b/internal/core/src/exec/expression/UnaryExpr.cpp index f780ec487ba47..b9567133de801 100644 --- a/internal/core/src/exec/expression/UnaryExpr.cpp +++ b/internal/core/src/exec/expression/UnaryExpr.cpp @@ -20,6 +20,66 @@ namespace milvus { namespace exec { +template +VectorPtr +PhyUnaryRangeFilterExpr::ExecRangeVisitorImplArrayForIndex() { + return ExecRangeVisitorImplArray(); +} + +template <> +VectorPtr +PhyUnaryRangeFilterExpr::ExecRangeVisitorImplArrayForIndex< + proto::plan::Array>() { + switch (expr_->op_type_) { + case proto::plan::Equal: + case proto::plan::NotEqual: { + switch (expr_->column_.element_type_) { + case DataType::BOOL: { + return ExecArrayEqualForIndex(expr_->op_type_ == + proto::plan::NotEqual); + } + case DataType::INT8: { + return ExecArrayEqualForIndex( + expr_->op_type_ == proto::plan::NotEqual); + } + case DataType::INT16: { + return ExecArrayEqualForIndex( + expr_->op_type_ == proto::plan::NotEqual); + } + case DataType::INT32: { + return ExecArrayEqualForIndex( + expr_->op_type_ == proto::plan::NotEqual); + } + case DataType::INT64: { + return ExecArrayEqualForIndex( + expr_->op_type_ == proto::plan::NotEqual); + } + case DataType::FLOAT: + case DataType::DOUBLE: { + // not accurate on floating point number, rollback to bruteforce. + return ExecRangeVisitorImplArray(); + } + case DataType::VARCHAR: { + if (segment_->type() == SegmentType::Growing) { + return ExecArrayEqualForIndex( + expr_->op_type_ == proto::plan::NotEqual); + } else { + return ExecArrayEqualForIndex( + expr_->op_type_ == proto::plan::NotEqual); + } + } + default: + PanicInfo(DataTypeInvalid, + "unsupported element type when execute array " + "equal for index: {}", + expr_->column_.element_type_); + } + } + default: + return ExecRangeVisitorImplArray(); + } +} + void PhyUnaryRangeFilterExpr::Eval(EvalCtx& context, VectorPtr& result) { switch (expr_->column_.data_type_) { @@ -99,7 +159,13 @@ PhyUnaryRangeFilterExpr::Eval(EvalCtx& context, VectorPtr& result) { result = ExecRangeVisitorImplArray(); break; case proto::plan::GenericValue::ValCase::kArrayVal: - result = ExecRangeVisitorImplArray(); + if (is_index_mode_) { + result = ExecRangeVisitorImplArrayForIndex< + proto::plan::Array>(); + } else { + result = + ExecRangeVisitorImplArray(); + } break; default: PanicInfo( @@ -196,6 +262,104 @@ PhyUnaryRangeFilterExpr::ExecRangeVisitorImplArray() { return res_vec; } +template +VectorPtr +PhyUnaryRangeFilterExpr::ExecArrayEqualForIndex(bool reverse) { + typedef std:: + conditional_t, std::string, T> + IndexInnerType; + using Index = index::ScalarIndex; + auto real_batch_size = GetNextBatchSize(); + if (real_batch_size == 0) { + return nullptr; + } + + // get all elements. + auto val = GetValueFromProto(expr_->val_); + if (val.array_size() == 0) { + // rollback to bruteforce. no candidates will be filtered out via index. + return ExecRangeVisitorImplArray(); + } + + // cache the result to suit the framework. + auto batch_res = + ProcessIndexChunks([this, &val, reverse](Index* _) { + boost::container::vector elems; + for (auto const& element : val.array()) { + auto e = GetValueFromProto(element); + if (std::find(elems.begin(), elems.end(), e) == elems.end()) { + elems.push_back(e); + } + } + + // filtering by index, get candidates. + auto size_per_chunk = segment_->size_per_chunk(); + auto retrieve = [ size_per_chunk, this ](int64_t offset) -> auto { + auto chunk_idx = offset / size_per_chunk; + auto chunk_offset = offset % size_per_chunk; + const auto& chunk = + segment_->template chunk_data(field_id_, + chunk_idx); + return chunk.data() + chunk_offset; + }; + + // compare the array via the raw data. + auto filter = [&retrieve, &val, reverse](size_t offset) -> bool { + auto data_ptr = retrieve(offset); + return data_ptr->is_same_array(val) ^ reverse; + }; + + // collect all candidates. + std::unordered_set candidates; + std::unordered_set tmp_candidates; + auto first_callback = [&candidates](size_t offset) -> void { + candidates.insert(offset); + }; + auto callback = [&candidates, + &tmp_candidates](size_t offset) -> void { + if (candidates.find(offset) != candidates.end()) { + tmp_candidates.insert(offset); + } + }; + auto execute_sub_batch = + [](Index* index_ptr, + const IndexInnerType& val, + const std::function& callback) { + index_ptr->InApplyCallback(1, &val, callback); + }; + + // run in-filter. + for (size_t idx = 0; idx < elems.size(); idx++) { + if (idx == 0) { + ProcessIndexChunksV2( + execute_sub_batch, elems[idx], first_callback); + } else { + ProcessIndexChunksV2( + execute_sub_batch, elems[idx], callback); + candidates = std::move(tmp_candidates); + } + // the size of candidates is small enough. + if (candidates.size() * 100 < active_count_) { + break; + } + } + TargetBitmap res(active_count_); + // run post-filter. The filter will only be executed once in the framework. + for (const auto& candidate : candidates) { + res[candidate] = filter(candidate); + } + return res; + }); + AssertInfo(batch_res.size() == real_batch_size, + "internal error: expr processed rows {} not equal " + "expect batch size {}", + batch_res.size(), + real_batch_size); + + // return the result. + return std::make_shared(std::move(batch_res)); +} + template VectorPtr PhyUnaryRangeFilterExpr::ExecRangeVisitorImplJson() { diff --git a/internal/core/src/exec/expression/UnaryExpr.h b/internal/core/src/exec/expression/UnaryExpr.h index e6342eda86434..40371e0e51f38 100644 --- a/internal/core/src/exec/expression/UnaryExpr.h +++ b/internal/core/src/exec/expression/UnaryExpr.h @@ -310,6 +310,14 @@ class PhyUnaryRangeFilterExpr : public SegmentExpr { VectorPtr ExecRangeVisitorImplArray(); + template + VectorPtr + ExecRangeVisitorImplArrayForIndex(); + + template + VectorPtr + ExecArrayEqualForIndex(bool reverse); + // Check overflow and cache result for performace template ColumnVectorPtr diff --git a/internal/core/src/expr/ITypeExpr.h b/internal/core/src/expr/ITypeExpr.h index 102709aa16b83..6716f8af2f66f 100644 --- a/internal/core/src/expr/ITypeExpr.h +++ b/internal/core/src/expr/ITypeExpr.h @@ -113,11 +113,13 @@ IsMaterializedViewSupported(const DataType& data_type) { struct ColumnInfo { FieldId field_id_; DataType data_type_; + DataType element_type_; std::vector nested_path_; ColumnInfo(const proto::plan::ColumnInfo& column_info) : field_id_(column_info.field_id()), data_type_(static_cast(column_info.data_type())), + element_type_(static_cast(column_info.element_type())), nested_path_(column_info.nested_path().begin(), column_info.nested_path().end()) { } @@ -127,6 +129,7 @@ struct ColumnInfo { std::vector nested_path = {}) : field_id_(field_id), data_type_(data_type), + element_type_(DataType::NONE), nested_path_(std::move(nested_path)) { } @@ -140,6 +143,10 @@ struct ColumnInfo { return false; } + if (element_type_ != other.element_type_) { + return false; + } + for (int i = 0; i < nested_path_.size(); ++i) { if (nested_path_[i] != other.nested_path_[i]) { return false; @@ -151,10 +158,12 @@ struct ColumnInfo { std::string ToString() const { - return fmt::format("[FieldId:{}, data_type:{}, nested_path:{}]", - std::to_string(field_id_.get()), - data_type_, - milvus::Join(nested_path_, ",")); + return fmt::format( + "[FieldId:{}, data_type:{}, element_type:{}, nested_path:{}]", + std::to_string(field_id_.get()), + data_type_, + element_type_, + milvus::Join(nested_path_, ",")); } }; diff --git a/internal/core/src/index/IndexFactory.cpp b/internal/core/src/index/IndexFactory.cpp index a593d087eb270..8c0ada968aab8 100644 --- a/internal/core/src/index/IndexFactory.cpp +++ b/internal/core/src/index/IndexFactory.cpp @@ -34,13 +34,9 @@ template ScalarIndexPtr IndexFactory::CreateScalarIndex( const IndexType& index_type, - const storage::FileManagerContext& file_manager_context, - DataType d_type) { + const storage::FileManagerContext& file_manager_context) { if (index_type == INVERTED_INDEX_TYPE) { - TantivyConfig cfg; - cfg.data_type_ = d_type; - return std::make_unique>(cfg, - file_manager_context); + return std::make_unique>(file_manager_context); } return CreateScalarIndexSort(file_manager_context); } @@ -56,14 +52,11 @@ template <> ScalarIndexPtr IndexFactory::CreateScalarIndex( const IndexType& index_type, - const storage::FileManagerContext& file_manager_context, - DataType d_type) { + const storage::FileManagerContext& file_manager_context) { #if defined(__linux__) || defined(__APPLE__) if (index_type == INVERTED_INDEX_TYPE) { - TantivyConfig cfg; - cfg.data_type_ = d_type; return std::make_unique>( - cfg, file_manager_context); + file_manager_context); } return CreateStringIndexMarisa(file_manager_context); #else @@ -76,13 +69,10 @@ ScalarIndexPtr IndexFactory::CreateScalarIndex( const IndexType& index_type, const storage::FileManagerContext& file_manager_context, - std::shared_ptr space, - DataType d_type) { + std::shared_ptr space) { if (index_type == INVERTED_INDEX_TYPE) { - TantivyConfig cfg; - cfg.data_type_ = d_type; - return std::make_unique>( - cfg, file_manager_context, space); + return std::make_unique>(file_manager_context, + space); } return CreateScalarIndexSort(file_manager_context, space); } @@ -92,14 +82,11 @@ ScalarIndexPtr IndexFactory::CreateScalarIndex( const IndexType& index_type, const storage::FileManagerContext& file_manager_context, - std::shared_ptr space, - DataType d_type) { + std::shared_ptr space) { #if defined(__linux__) || defined(__APPLE__) if (index_type == INVERTED_INDEX_TYPE) { - TantivyConfig cfg; - cfg.data_type_ = d_type; return std::make_unique>( - cfg, file_manager_context, space); + file_manager_context, space); } return CreateStringIndexMarisa(file_manager_context, space); #else @@ -132,41 +119,32 @@ IndexFactory::CreateIndex( } IndexBasePtr -IndexFactory::CreateScalarIndex( - const CreateIndexInfo& create_index_info, +IndexFactory::CreatePrimitiveScalarIndex( + DataType data_type, + IndexType index_type, const storage::FileManagerContext& file_manager_context) { - auto data_type = create_index_info.field_type; - auto index_type = create_index_info.index_type; - switch (data_type) { // create scalar index case DataType::BOOL: - return CreateScalarIndex( - index_type, file_manager_context, data_type); + return CreateScalarIndex(index_type, file_manager_context); case DataType::INT8: - return CreateScalarIndex( - index_type, file_manager_context, data_type); + return CreateScalarIndex(index_type, file_manager_context); case DataType::INT16: - return CreateScalarIndex( - index_type, file_manager_context, data_type); + return CreateScalarIndex(index_type, file_manager_context); case DataType::INT32: - return CreateScalarIndex( - index_type, file_manager_context, data_type); + return CreateScalarIndex(index_type, file_manager_context); case DataType::INT64: - return CreateScalarIndex( - index_type, file_manager_context, data_type); + return CreateScalarIndex(index_type, file_manager_context); case DataType::FLOAT: - return CreateScalarIndex( - index_type, file_manager_context, data_type); + return CreateScalarIndex(index_type, file_manager_context); case DataType::DOUBLE: - return CreateScalarIndex( - index_type, file_manager_context, data_type); + return CreateScalarIndex(index_type, file_manager_context); // create string index case DataType::STRING: case DataType::VARCHAR: - return CreateScalarIndex( - index_type, file_manager_context, data_type); + return CreateScalarIndex(index_type, + file_manager_context); default: throw SegcoreError( DataTypeInvalid, @@ -174,6 +152,24 @@ IndexFactory::CreateScalarIndex( } } +IndexBasePtr +IndexFactory::CreateScalarIndex( + const CreateIndexInfo& create_index_info, + const storage::FileManagerContext& file_manager_context) { + switch (create_index_info.field_type) { + case DataType::ARRAY: + return CreatePrimitiveScalarIndex( + static_cast( + file_manager_context.fieldDataMeta.schema.element_type()), + create_index_info.index_type, + file_manager_context); + default: + return CreatePrimitiveScalarIndex(create_index_info.field_type, + create_index_info.index_type, + file_manager_context); + } +} + IndexBasePtr IndexFactory::CreateVectorIndex( const CreateIndexInfo& create_index_info, @@ -249,32 +245,25 @@ IndexFactory::CreateScalarIndex(const CreateIndexInfo& create_index_info, switch (data_type) { // create scalar index case DataType::BOOL: - return CreateScalarIndex( - index_type, file_manager, space, data_type); + return CreateScalarIndex(index_type, file_manager, space); case DataType::INT8: - return CreateScalarIndex( - index_type, file_manager, space, data_type); + return CreateScalarIndex(index_type, file_manager, space); case DataType::INT16: - return CreateScalarIndex( - index_type, file_manager, space, data_type); + return CreateScalarIndex(index_type, file_manager, space); case DataType::INT32: - return CreateScalarIndex( - index_type, file_manager, space, data_type); + return CreateScalarIndex(index_type, file_manager, space); case DataType::INT64: - return CreateScalarIndex( - index_type, file_manager, space, data_type); + return CreateScalarIndex(index_type, file_manager, space); case DataType::FLOAT: - return CreateScalarIndex( - index_type, file_manager, space, data_type); + return CreateScalarIndex(index_type, file_manager, space); case DataType::DOUBLE: - return CreateScalarIndex( - index_type, file_manager, space, data_type); + return CreateScalarIndex(index_type, file_manager, space); // create string index case DataType::STRING: case DataType::VARCHAR: return CreateScalarIndex( - index_type, file_manager, space, data_type); + index_type, file_manager, space); default: throw SegcoreError( DataTypeInvalid, diff --git a/internal/core/src/index/IndexFactory.h b/internal/core/src/index/IndexFactory.h index 75bd090292907..47b255ab4e912 100644 --- a/internal/core/src/index/IndexFactory.h +++ b/internal/core/src/index/IndexFactory.h @@ -65,6 +65,13 @@ class IndexFactory { CreateVectorIndex(const CreateIndexInfo& create_index_info, const storage::FileManagerContext& file_manager_context); + IndexBasePtr + CreatePrimitiveScalarIndex( + DataType data_type, + IndexType index_type, + const storage::FileManagerContext& file_manager_context = + storage::FileManagerContext()); + IndexBasePtr CreateScalarIndex(const CreateIndexInfo& create_index_info, const storage::FileManagerContext& file_manager_context = @@ -89,15 +96,13 @@ class IndexFactory { ScalarIndexPtr CreateScalarIndex(const IndexType& index_type, const storage::FileManagerContext& file_manager = - storage::FileManagerContext(), - DataType d_type = DataType::NONE); + storage::FileManagerContext()); template ScalarIndexPtr CreateScalarIndex(const IndexType& index_type, const storage::FileManagerContext& file_manager, - std::shared_ptr space, - DataType d_type = DataType::NONE); + std::shared_ptr space); }; // template <> @@ -112,6 +117,5 @@ ScalarIndexPtr IndexFactory::CreateScalarIndex( const IndexType& index_type, const storage::FileManagerContext& file_manager_context, - std::shared_ptr space, - DataType d_type); + std::shared_ptr space); } // namespace milvus::index diff --git a/internal/core/src/index/InvertedIndexTantivy.cpp b/internal/core/src/index/InvertedIndexTantivy.cpp index 5bb8ba3b16103..3b9a54fae940b 100644 --- a/internal/core/src/index/InvertedIndexTantivy.cpp +++ b/internal/core/src/index/InvertedIndexTantivy.cpp @@ -23,12 +23,50 @@ #include "InvertedIndexTantivy.h" namespace milvus::index { +inline TantivyDataType +get_tantivy_data_type(proto::schema::DataType data_type) { + switch (data_type) { + case proto::schema::DataType::Bool: { + return TantivyDataType::Bool; + } + + case proto::schema::DataType::Int8: + case proto::schema::DataType::Int16: + case proto::schema::DataType::Int32: + case proto::schema::DataType::Int64: { + return TantivyDataType::I64; + } + + case proto::schema::DataType::Float: + case proto::schema::DataType::Double: { + return TantivyDataType::F64; + } + + case proto::schema::DataType::VarChar: { + return TantivyDataType::Keyword; + } + + default: + PanicInfo(ErrorCode::NotImplemented, + fmt::format("not implemented data type: {}", data_type)); + } +} + +inline TantivyDataType +get_tantivy_data_type(const proto::schema::FieldSchema& schema) { + switch (schema.data_type()) { + case proto::schema::Array: + return get_tantivy_data_type(schema.element_type()); + default: + return get_tantivy_data_type(schema.data_type()); + } +} + template InvertedIndexTantivy::InvertedIndexTantivy( - const TantivyConfig& cfg, const storage::FileManagerContext& ctx, std::shared_ptr space) - : cfg_(cfg), space_(space) { + : space_(space), schema_(ctx.fieldDataMeta.schema) { mem_file_manager_ = std::make_shared(ctx, ctx.space_); disk_file_manager_ = std::make_shared(ctx, ctx.space_); auto field = @@ -36,7 +74,7 @@ InvertedIndexTantivy::InvertedIndexTantivy( auto prefix = disk_file_manager_->GetLocalIndexObjectPrefix(); path_ = prefix; boost::filesystem::create_directories(path_); - d_type_ = cfg_.to_tantivy_data_type(); + d_type_ = get_tantivy_data_type(schema_); if (tantivy_index_exist(path_.c_str())) { LOG_INFO( "index {} already exists, which should happen in loading progress", @@ -114,83 +152,7 @@ InvertedIndexTantivy::Build(const Config& config) { AssertInfo(insert_files.has_value(), "insert_files were empty"); auto field_datas = mem_file_manager_->CacheRawDataToMemory(insert_files.value()); - switch (cfg_.data_type_) { - case DataType::BOOL: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data(static_cast(data->Data()), - n); - } - break; - } - - case DataType::INT8: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - case DataType::INT16: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - case DataType::INT32: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - case DataType::INT64: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - case DataType::FLOAT: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - case DataType::DOUBLE: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - case DataType::VARCHAR: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - default: - PanicInfo(ErrorCode::NotImplemented, - fmt::format("todo: not supported, {}", cfg_.data_type_)); - } + build_index(field_datas); } template @@ -211,84 +173,7 @@ InvertedIndexTantivy::BuildV2(const Config& config) { field_data->FillFieldData(col_data); field_datas.push_back(field_data); } - - switch (cfg_.data_type_) { - case DataType::BOOL: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data(static_cast(data->Data()), - n); - } - break; - } - - case DataType::INT8: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - case DataType::INT16: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - case DataType::INT32: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - case DataType::INT64: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - case DataType::FLOAT: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - case DataType::DOUBLE: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - case DataType::VARCHAR: { - for (const auto& data : field_datas) { - auto n = data->get_num_rows(); - wrapper_->add_data( - static_cast(data->Data()), n); - } - break; - } - - default: - PanicInfo(ErrorCode::NotImplemented, - fmt::format("todo: not supported, {}", cfg_.data_type_)); - } + build_index(field_datas); } template @@ -319,6 +204,25 @@ apply_hits(TargetBitmap& bitset, const RustArrayWrapper& w, bool v) { } } +inline void +apply_hits_with_filter(TargetBitmap& bitset, + const RustArrayWrapper& w, + const std::function& filter) { + for (size_t j = 0; j < w.array_.len; j++) { + auto the_offset = w.array_.array[j]; + bitset[the_offset] = filter(the_offset); + } +} + +inline void +apply_hits_with_callback( + const RustArrayWrapper& w, + const std::function& callback) { + for (size_t j = 0; j < w.array_.len; j++) { + callback(w.array_.array[j]); + } +} + template const TargetBitmap InvertedIndexTantivy::In(size_t n, const T* values) { @@ -330,10 +234,33 @@ InvertedIndexTantivy::In(size_t n, const T* values) { return bitset; } +template +const TargetBitmap +InvertedIndexTantivy::InApplyFilter( + size_t n, const T* values, const std::function& filter) { + TargetBitmap bitset(Count()); + for (size_t i = 0; i < n; ++i) { + auto array = wrapper_->term_query(values[i]); + apply_hits_with_filter(bitset, array, filter); + } + return bitset; +} + +template +void +InvertedIndexTantivy::InApplyCallback( + size_t n, const T* values, const std::function& callback) { + for (size_t i = 0; i < n; ++i) { + auto array = wrapper_->term_query(values[i]); + apply_hits_with_callback(array, callback); + } +} + template const TargetBitmap InvertedIndexTantivy::NotIn(size_t n, const T* values) { - TargetBitmap bitset(Count(), true); + TargetBitmap bitset(Count()); + bitset.set(); for (size_t i = 0; i < n; ++i) { auto array = wrapper_->term_query(values[i]); apply_hits(bitset, array, false); @@ -425,25 +352,118 @@ void InvertedIndexTantivy::BuildWithRawData(size_t n, const void* values, const Config& config) { - if constexpr (!std::is_same_v) { - PanicInfo(Unsupported, - "InvertedIndex.BuildWithRawData only support string"); + if constexpr (std::is_same_v) { + schema_.set_data_type(proto::schema::DataType::Bool); + } + if constexpr (std::is_same_v) { + schema_.set_data_type(proto::schema::DataType::Int8); + } + if constexpr (std::is_same_v) { + schema_.set_data_type(proto::schema::DataType::Int16); + } + if constexpr (std::is_same_v) { + schema_.set_data_type(proto::schema::DataType::Int32); + } + if constexpr (std::is_same_v) { + schema_.set_data_type(proto::schema::DataType::Int64); + } + if constexpr (std::is_same_v) { + schema_.set_data_type(proto::schema::DataType::Float); + } + if constexpr (std::is_same_v) { + schema_.set_data_type(proto::schema::DataType::Double); + } + if constexpr (std::is_same_v) { + schema_.set_data_type(proto::schema::DataType::VarChar); + } + boost::uuids::random_generator generator; + auto uuid = generator(); + auto prefix = boost::uuids::to_string(uuid); + path_ = fmt::format("/tmp/{}", prefix); + boost::filesystem::create_directories(path_); + d_type_ = get_tantivy_data_type(schema_); + std::string field = "test_inverted_index"; + wrapper_ = std::make_shared( + field.c_str(), d_type_, path_.c_str()); + if (config.find("is_array") != config.end()) { + // only used in ut. + auto arr = static_cast*>(values); + for (size_t i = 0; i < n; i++) { + wrapper_->template add_multi_data(arr[i].data(), arr[i].size()); + } } else { - boost::uuids::random_generator generator; - auto uuid = generator(); - auto prefix = boost::uuids::to_string(uuid); - path_ = fmt::format("/tmp/{}", prefix); - boost::filesystem::create_directories(path_); - cfg_ = TantivyConfig{ - .data_type_ = DataType::VARCHAR, - }; - d_type_ = cfg_.to_tantivy_data_type(); - std::string field = "test_inverted_index"; - wrapper_ = std::make_shared( - field.c_str(), d_type_, path_.c_str()); - wrapper_->add_data(static_cast(values), - n); - finish(); + wrapper_->add_data(static_cast(values), n); + } + finish(); +} + +template +void +InvertedIndexTantivy::build_index( + const std::vector>& field_datas) { + switch (schema_.data_type()) { + case proto::schema::DataType::Bool: + case proto::schema::DataType::Int8: + case proto::schema::DataType::Int16: + case proto::schema::DataType::Int32: + case proto::schema::DataType::Int64: + case proto::schema::DataType::Float: + case proto::schema::DataType::Double: + case proto::schema::DataType::String: + case proto::schema::DataType::VarChar: { + for (const auto& data : field_datas) { + auto n = data->get_num_rows(); + wrapper_->add_data(static_cast(data->Data()), n); + } + break; + } + + case proto::schema::DataType::Array: { + build_index_for_array(field_datas); + break; + } + + default: + PanicInfo(ErrorCode::NotImplemented, + fmt::format("Inverted index not supported on {}", + schema_.data_type())); + } +} + +template +void +InvertedIndexTantivy::build_index_for_array( + const std::vector>& field_datas) { + for (const auto& data : field_datas) { + auto n = data->get_num_rows(); + auto array_column = static_cast(data->Data()); + for (int64_t i = 0; i < n; i++) { + assert(array_column[i].get_element_type() == + static_cast(schema_.element_type())); + wrapper_->template add_multi_data( + reinterpret_cast(array_column[i].data()), + array_column[i].length()); + } + } +} + +template <> +void +InvertedIndexTantivy::build_index_for_array( + const std::vector>& field_datas) { + for (const auto& data : field_datas) { + auto n = data->get_num_rows(); + auto array_column = static_cast(data->Data()); + for (int64_t i = 0; i < n; i++) { + assert(array_column[i].get_element_type() == + static_cast(schema_.element_type())); + std::vector output; + for (int64_t j = 0; j < array_column[i].length(); j++) { + output.push_back( + array_column[i].template get_data(j)); + } + wrapper_->template add_multi_data(output.data(), output.size()); + } } } diff --git a/internal/core/src/index/InvertedIndexTantivy.h b/internal/core/src/index/InvertedIndexTantivy.h index 0ea2f64d869d3..53fb9c2d687ac 100644 --- a/internal/core/src/index/InvertedIndexTantivy.h +++ b/internal/core/src/index/InvertedIndexTantivy.h @@ -18,7 +18,6 @@ #include "tantivy-binding.h" #include "tantivy-wrapper.h" #include "index/StringIndex.h" -#include "index/TantivyConfig.h" #include "storage/space.h" namespace milvus::index { @@ -36,13 +35,11 @@ class InvertedIndexTantivy : public ScalarIndex { InvertedIndexTantivy() = default; - explicit InvertedIndexTantivy(const TantivyConfig& cfg, - const storage::FileManagerContext& ctx) - : InvertedIndexTantivy(cfg, ctx, nullptr) { + explicit InvertedIndexTantivy(const storage::FileManagerContext& ctx) + : InvertedIndexTantivy(ctx, nullptr) { } - explicit InvertedIndexTantivy(const TantivyConfig& cfg, - const storage::FileManagerContext& ctx, + explicit InvertedIndexTantivy(const storage::FileManagerContext& ctx, std::shared_ptr space); ~InvertedIndexTantivy(); @@ -114,6 +111,18 @@ class InvertedIndexTantivy : public ScalarIndex { const TargetBitmap In(size_t n, const T* values) override; + const TargetBitmap + InApplyFilter( + size_t n, + const T* values, + const std::function& filter) override; + + void + InApplyCallback( + size_t n, + const T* values, + const std::function& callback) override; + const TargetBitmap NotIn(size_t n, const T* values) override; @@ -160,11 +169,18 @@ class InvertedIndexTantivy : public ScalarIndex { void finish(); + void + build_index(const std::vector>& field_datas); + + void + build_index_for_array( + const std::vector>& field_datas); + private: std::shared_ptr wrapper_; - TantivyConfig cfg_; TantivyDataType d_type_; std::string path_; + proto::schema::FieldSchema schema_; /* * To avoid IO amplification, we use both mem file manager & disk file manager diff --git a/internal/core/src/index/ScalarIndex.h b/internal/core/src/index/ScalarIndex.h index aacef521f5db3..37d22a288d80b 100644 --- a/internal/core/src/index/ScalarIndex.h +++ b/internal/core/src/index/ScalarIndex.h @@ -50,6 +50,20 @@ class ScalarIndex : public IndexBase { virtual const TargetBitmap In(size_t n, const T* values) = 0; + virtual const TargetBitmap + InApplyFilter(size_t n, + const T* values, + const std::function& filter) { + PanicInfo(ErrorCode::Unsupported, "InApplyFilter is not implemented"); + } + + virtual void + InApplyCallback(size_t n, + const T* values, + const std::function& callback) { + PanicInfo(ErrorCode::Unsupported, "InApplyCallback is not implemented"); + } + virtual const TargetBitmap NotIn(size_t n, const T* values) = 0; diff --git a/internal/core/src/index/TantivyConfig.h b/internal/core/src/index/TantivyConfig.h deleted file mode 100644 index 355b4c76efc9d..0000000000000 --- a/internal/core/src/index/TantivyConfig.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License - -#pragma once - -#include "storage/Types.h" -#include "tantivy-binding.h" - -namespace milvus::index { -struct TantivyConfig { - DataType data_type_; - - TantivyDataType - to_tantivy_data_type() { - switch (data_type_) { - case DataType::BOOL: { - return TantivyDataType::Bool; - } - - case DataType::INT8: - case DataType::INT16: - case DataType::INT32: - case DataType::INT64: { - return TantivyDataType::I64; - } - - case DataType::FLOAT: - case DataType::DOUBLE: { - return TantivyDataType::F64; - } - - case DataType::VARCHAR: { - return TantivyDataType::Keyword; - } - - default: - PanicInfo( - ErrorCode::NotImplemented, - fmt::format("not implemented data type: {}", data_type_)); - } - } -}; -} // namespace milvus::index \ No newline at end of file diff --git a/internal/core/src/indexbuilder/IndexFactory.h b/internal/core/src/indexbuilder/IndexFactory.h index cd361499b4065..1380a6e9817d3 100644 --- a/internal/core/src/indexbuilder/IndexFactory.h +++ b/internal/core/src/indexbuilder/IndexFactory.h @@ -60,6 +60,7 @@ class IndexFactory { case DataType::DOUBLE: case DataType::VARCHAR: case DataType::STRING: + case DataType::ARRAY: return CreateScalarIndex(type, config, context); case DataType::VECTOR_FLOAT: diff --git a/internal/core/src/indexbuilder/index_c.cpp b/internal/core/src/indexbuilder/index_c.cpp index 28a629052cad7..7ccaf7c414a24 100644 --- a/internal/core/src/indexbuilder/index_c.cpp +++ b/internal/core/src/indexbuilder/index_c.cpp @@ -84,29 +84,95 @@ CreateIndexV0(enum CDataType dtype, return status; } +milvus::storage::StorageConfig +get_storage_config(const milvus::proto::indexcgo::StorageConfig& config) { + auto storage_config = milvus::storage::StorageConfig(); + storage_config.address = std::string(config.address()); + storage_config.bucket_name = std::string(config.bucket_name()); + storage_config.access_key_id = std::string(config.access_keyid()); + storage_config.access_key_value = std::string(config.secret_access_key()); + storage_config.root_path = std::string(config.root_path()); + storage_config.storage_type = std::string(config.storage_type()); + storage_config.cloud_provider = std::string(config.cloud_provider()); + storage_config.iam_endpoint = std::string(config.iamendpoint()); + storage_config.cloud_provider = std::string(config.cloud_provider()); + storage_config.useSSL = config.usessl(); + storage_config.sslCACert = config.sslcacert(); + storage_config.useIAM = config.useiam(); + storage_config.region = config.region(); + storage_config.useVirtualHost = config.use_virtual_host(); + storage_config.requestTimeoutMs = config.request_timeout_ms(); + return storage_config; +} + +milvus::OptFieldT +get_opt_field(const ::google::protobuf::RepeatedPtrField< + milvus::proto::indexcgo::OptionalFieldInfo>& field_infos) { + milvus::OptFieldT opt_fields_map; + for (const auto& field_info : field_infos) { + auto field_id = field_info.fieldid(); + if (opt_fields_map.find(field_id) == opt_fields_map.end()) { + opt_fields_map[field_id] = { + field_info.field_name(), + static_cast(field_info.field_type()), + {}}; + } + for (const auto& str : field_info.data_paths()) { + std::get<2>(opt_fields_map[field_id]).emplace_back(str); + } + } + + return opt_fields_map; +} + +milvus::Config +get_config(std::unique_ptr& info) { + milvus::Config config; + for (auto i = 0; i < info->index_params().size(); ++i) { + const auto& param = info->index_params(i); + config[param.key()] = param.value(); + } + + for (auto i = 0; i < info->type_params().size(); ++i) { + const auto& param = info->type_params(i); + config[param.key()] = param.value(); + } + + config["insert_files"] = info->insert_files(); + if (info->opt_fields().size()) { + config["opt_fields"] = get_opt_field(info->opt_fields()); + } + + return config; +} + CStatus -CreateIndex(CIndex* res_index, CBuildIndexInfo c_build_index_info) { +CreateIndex(CIndex* res_index, + const uint8_t* serialized_build_index_info, + const uint64_t len) { try { - auto build_index_info = (BuildIndexInfo*)c_build_index_info; - auto field_type = build_index_info->field_type; + auto build_index_info = + std::make_unique(); + auto res = + build_index_info->ParseFromArray(serialized_build_index_info, len); + AssertInfo(res, "Unmarshall build index info failed"); - milvus::index::CreateIndexInfo index_info; - index_info.field_type = build_index_info->field_type; + auto field_type = + static_cast(build_index_info->field_schema().data_type()); - auto& config = build_index_info->config; - config["insert_files"] = build_index_info->insert_files; - if (build_index_info->opt_fields.size()) { - config["opt_fields"] = build_index_info->opt_fields; - } + milvus::index::CreateIndexInfo index_info; + index_info.field_type = field_type; + auto storage_config = + get_storage_config(build_index_info->storage_config()); + auto config = get_config(build_index_info); // get index type auto index_type = milvus::index::GetValueFromConfig( config, "index_type"); AssertInfo(index_type.has_value(), "index type is empty"); index_info.index_type = index_type.value(); - auto engine_version = build_index_info->index_engine_version; - + auto engine_version = build_index_info->current_index_version(); index_info.index_engine_version = engine_version; config[milvus::index::INDEX_ENGINE_VERSION] = std::to_string(engine_version); @@ -121,24 +187,31 @@ CreateIndex(CIndex* res_index, CBuildIndexInfo c_build_index_info) { // init file manager milvus::storage::FieldDataMeta field_meta{ - build_index_info->collection_id, - build_index_info->partition_id, - build_index_info->segment_id, - build_index_info->field_id}; - - milvus::storage::IndexMeta index_meta{build_index_info->segment_id, - build_index_info->field_id, - build_index_info->index_build_id, - build_index_info->index_version}; - auto chunk_manager = milvus::storage::CreateChunkManager( - build_index_info->storage_config); + build_index_info->collectionid(), + build_index_info->partitionid(), + build_index_info->segmentid(), + build_index_info->field_schema().fieldid(), + build_index_info->field_schema()}; + + milvus::storage::IndexMeta index_meta{ + build_index_info->segmentid(), + build_index_info->field_schema().fieldid(), + build_index_info->buildid(), + build_index_info->index_version(), + "", + build_index_info->field_schema().name(), + field_type, + build_index_info->dim(), + }; + auto chunk_manager = + milvus::storage::CreateChunkManager(storage_config); milvus::storage::FileManagerContext fileManagerContext( field_meta, index_meta, chunk_manager); auto index = milvus::indexbuilder::IndexFactory::GetInstance().CreateIndex( - build_index_info->field_type, config, fileManagerContext); + field_type, config, fileManagerContext); index->Build(); *res_index = index.release(); auto status = CStatus(); @@ -159,22 +232,32 @@ CreateIndex(CIndex* res_index, CBuildIndexInfo c_build_index_info) { } CStatus -CreateIndexV2(CIndex* res_index, CBuildIndexInfo c_build_index_info) { +CreateIndexV2(CIndex* res_index, + const uint8_t* serialized_build_index_info, + const uint64_t len) { try { - auto build_index_info = (BuildIndexInfo*)c_build_index_info; - auto field_type = build_index_info->field_type; + auto build_index_info = + std::make_unique(); + auto res = + build_index_info->ParseFromArray(serialized_build_index_info, len); + AssertInfo(res, "Unmarshall build index info failed"); + auto field_type = + static_cast(build_index_info->field_schema().data_type()); + milvus::index::CreateIndexInfo index_info; - index_info.field_type = build_index_info->field_type; - index_info.dim = build_index_info->dim; + index_info.field_type = field_type; + index_info.dim = build_index_info->dim(); - auto& config = build_index_info->config; + auto storage_config = + get_storage_config(build_index_info->storage_config()); + auto config = get_config(build_index_info); // get index type auto index_type = milvus::index::GetValueFromConfig( config, "index_type"); AssertInfo(index_type.has_value(), "index type is empty"); index_info.index_type = index_type.value(); - auto engine_version = build_index_info->index_engine_version; + auto engine_version = build_index_info->current_index_version(); index_info.index_engine_version = engine_version; config[milvus::index::INDEX_ENGINE_VERSION] = std::to_string(engine_version); @@ -188,39 +271,39 @@ CreateIndexV2(CIndex* res_index, CBuildIndexInfo c_build_index_info) { } milvus::storage::FieldDataMeta field_meta{ - build_index_info->collection_id, - build_index_info->partition_id, - build_index_info->segment_id, - build_index_info->field_id}; + build_index_info->collectionid(), + build_index_info->partitionid(), + build_index_info->segmentid(), + build_index_info->field_schema().fieldid()}; milvus::storage::IndexMeta index_meta{ - build_index_info->segment_id, - build_index_info->field_id, - build_index_info->index_build_id, - build_index_info->index_version, - build_index_info->field_name, + build_index_info->segmentid(), + build_index_info->field_schema().fieldid(), + build_index_info->buildid(), + build_index_info->index_version(), "", - build_index_info->field_type, - build_index_info->dim, + build_index_info->field_schema().name(), + field_type, + build_index_info->dim(), }; auto store_space = milvus_storage::Space::Open( - build_index_info->data_store_path, + build_index_info->store_path(), milvus_storage::Options{nullptr, - build_index_info->data_store_version}); + build_index_info->store_version()}); AssertInfo(store_space.ok() && store_space.has_value(), "create space failed: {}", store_space.status().ToString()); auto index_space = milvus_storage::Space::Open( - build_index_info->index_store_path, + build_index_info->index_store_path(), milvus_storage::Options{.schema = store_space.value()->schema()}); AssertInfo(index_space.ok() && index_space.has_value(), "create space failed: {}", index_space.status().ToString()); LOG_INFO("init space success"); - auto chunk_manager = milvus::storage::CreateChunkManager( - build_index_info->storage_config); + auto chunk_manager = + milvus::storage::CreateChunkManager(storage_config); milvus::storage::FileManagerContext fileManagerContext( field_meta, index_meta, @@ -229,9 +312,9 @@ CreateIndexV2(CIndex* res_index, CBuildIndexInfo c_build_index_info) { auto index = milvus::indexbuilder::IndexFactory::GetInstance().CreateIndex( - build_index_info->field_type, - build_index_info->field_name, - build_index_info->dim, + field_type, + build_index_info->field_schema().name(), + build_index_info->dim(), config, fileManagerContext, std::move(store_space.value())); diff --git a/internal/core/src/indexbuilder/index_c.h b/internal/core/src/indexbuilder/index_c.h index 16cd76e4531ce..53ce5552fef0a 100644 --- a/internal/core/src/indexbuilder/index_c.h +++ b/internal/core/src/indexbuilder/index_c.h @@ -28,7 +28,9 @@ CreateIndexV0(enum CDataType dtype, CIndex* res_index); CStatus -CreateIndex(CIndex* res_index, CBuildIndexInfo c_build_index_info); +CreateIndex(CIndex* res_index, + const uint8_t* serialized_build_index_info, + const uint64_t len); CStatus DeleteIndex(CIndex index); @@ -130,7 +132,9 @@ CStatus SerializeIndexAndUpLoadV2(CIndex index, CBinarySet* c_binary_set); CStatus -CreateIndexV2(CIndex* res_index, CBuildIndexInfo c_build_index_info); +CreateIndexV2(CIndex* res_index, + const uint8_t* serialized_build_index_info, + const uint64_t len); CStatus AppendIndexStorageInfo(CBuildIndexInfo c_build_index_info, diff --git a/internal/core/src/pb/CMakeLists.txt b/internal/core/src/pb/CMakeLists.txt index 3c00203cf4c25..35726d9c24c65 100644 --- a/internal/core/src/pb/CMakeLists.txt +++ b/internal/core/src/pb/CMakeLists.txt @@ -11,12 +11,10 @@ find_package(Protobuf REQUIRED) +file(GLOB_RECURSE milvus_proto_srcs + "${CMAKE_CURRENT_SOURCE_DIR}/*.cc") add_library(milvus_proto STATIC - common.pb.cc - index_cgo_msg.pb.cc - plan.pb.cc - schema.pb.cc - segcore.pb.cc + ${milvus_proto_srcs} ) message(STATUS "milvus proto sources: " ${milvus_proto_srcs}) diff --git a/internal/core/src/segcore/Types.h b/internal/core/src/segcore/Types.h index 73ba7fcb188b6..106799ce2610f 100644 --- a/internal/core/src/segcore/Types.h +++ b/internal/core/src/segcore/Types.h @@ -46,6 +46,7 @@ struct LoadIndexInfo { std::string uri; int64_t index_store_version; IndexVersion index_engine_version; + proto::schema::FieldSchema schema; }; } // namespace milvus::segcore diff --git a/internal/core/src/segcore/load_index_c.cpp b/internal/core/src/segcore/load_index_c.cpp index 7f851948545d3..3df3a92879751 100644 --- a/internal/core/src/segcore/load_index_c.cpp +++ b/internal/core/src/segcore/load_index_c.cpp @@ -25,6 +25,7 @@ #include "storage/Util.h" #include "storage/RemoteChunkManagerSingleton.h" #include "storage/LocalChunkManagerSingleton.h" +#include "pb/cgo_msg.pb.h" bool IsLoadWithDisk(const char* index_type, int index_engine_version) { @@ -258,7 +259,8 @@ AppendIndexV2(CTraceContext c_trace, CLoadIndexInfo c_load_index_info) { load_index_info->collection_id, load_index_info->partition_id, load_index_info->segment_id, - load_index_info->field_id}; + load_index_info->field_id, + load_index_info->schema}; milvus::storage::IndexMeta index_meta{load_index_info->segment_id, load_index_info->field_id, load_index_info->index_build_id, @@ -484,3 +486,50 @@ AppendStorageInfo(CLoadIndexInfo c_load_index_info, load_index_info->uri = uri; load_index_info->index_store_version = version; } + +CStatus +FinishLoadIndexInfo(CLoadIndexInfo c_load_index_info, + const uint8_t* serialized_load_index_info, + const uint64_t len) { + try { + auto info_proto = std::make_unique(); + info_proto->ParseFromArray(serialized_load_index_info, len); + auto load_index_info = + static_cast(c_load_index_info); + // TODO: keep this since LoadIndexInfo is used by SegmentSealed. + { + load_index_info->collection_id = info_proto->collectionid(); + load_index_info->partition_id = info_proto->partitionid(); + load_index_info->segment_id = info_proto->segmentid(); + load_index_info->field_id = info_proto->field().fieldid(); + load_index_info->field_type = + static_cast(info_proto->field().data_type()); + load_index_info->enable_mmap = info_proto->enable_mmap(); + load_index_info->mmap_dir_path = info_proto->mmap_dir_path(); + load_index_info->index_id = info_proto->indexid(); + load_index_info->index_build_id = info_proto->index_buildid(); + load_index_info->index_version = info_proto->index_version(); + for (const auto& [k, v] : info_proto->index_params()) { + load_index_info->index_params[k] = v; + } + load_index_info->index_files.assign( + info_proto->index_files().begin(), + info_proto->index_files().end()); + load_index_info->uri = info_proto->uri(); + load_index_info->index_store_version = + info_proto->index_store_version(); + load_index_info->index_engine_version = + info_proto->index_engine_version(); + load_index_info->schema = info_proto->field(); + } + auto status = CStatus(); + status.error_code = milvus::Success; + status.error_msg = ""; + return status; + } catch (std::exception& e) { + auto status = CStatus(); + status.error_code = milvus::UnexpectedError; + status.error_msg = strdup(e.what()); + return status; + } +} diff --git a/internal/core/src/segcore/load_index_c.h b/internal/core/src/segcore/load_index_c.h index 7a3d89b797670..8755aa7396162 100644 --- a/internal/core/src/segcore/load_index_c.h +++ b/internal/core/src/segcore/load_index_c.h @@ -76,6 +76,11 @@ void AppendStorageInfo(CLoadIndexInfo c_load_index_info, const char* uri, int64_t version); + +CStatus +FinishLoadIndexInfo(CLoadIndexInfo c_load_index_info, + const uint8_t* serialized_load_index_info, + const uint64_t len); #ifdef __cplusplus } #endif diff --git a/internal/core/src/storage/Types.h b/internal/core/src/storage/Types.h index 924873dccda64..fbd72d0a59a78 100644 --- a/internal/core/src/storage/Types.h +++ b/internal/core/src/storage/Types.h @@ -64,6 +64,7 @@ struct FieldDataMeta { int64_t partition_id; int64_t segment_id; int64_t field_id; + proto::schema::FieldSchema schema; }; enum CodecType { diff --git a/internal/core/thirdparty/tantivy/CMakeLists.txt b/internal/core/thirdparty/tantivy/CMakeLists.txt index f4d928922874f..c1435a032a85e 100644 --- a/internal/core/thirdparty/tantivy/CMakeLists.txt +++ b/internal/core/thirdparty/tantivy/CMakeLists.txt @@ -71,3 +71,9 @@ target_link_libraries(bench_tantivy boost_filesystem dl ) + +add_executable(ffi_demo ffi_demo.cpp) +target_link_libraries(ffi_demo + tantivy_binding + dl + ) diff --git a/internal/core/thirdparty/tantivy/ffi_demo.cpp b/internal/core/thirdparty/tantivy/ffi_demo.cpp new file mode 100644 index 0000000000000..1626d655f175d --- /dev/null +++ b/internal/core/thirdparty/tantivy/ffi_demo.cpp @@ -0,0 +1,17 @@ +#include +#include + +#include "tantivy-binding.h" + +int +main(int argc, char* argv[]) { + std::vector data{"data1", "data2", "data3"}; + std::vector datas{}; + for (auto& s : data) { + datas.push_back(s.c_str()); + } + + print_vector_of_strings(datas.data(), datas.size()); + + return 0; +} diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/include/tantivy-binding.h b/internal/core/thirdparty/tantivy/tantivy-binding/include/tantivy-binding.h index 3b22018bf047e..045d4a50e6a2c 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/include/tantivy-binding.h +++ b/internal/core/thirdparty/tantivy/tantivy-binding/include/tantivy-binding.h @@ -97,6 +97,24 @@ void tantivy_index_add_bools(void *ptr, const bool *array, uintptr_t len); void tantivy_index_add_keyword(void *ptr, const char *s); +void tantivy_index_add_multi_int8s(void *ptr, const int8_t *array, uintptr_t len); + +void tantivy_index_add_multi_int16s(void *ptr, const int16_t *array, uintptr_t len); + +void tantivy_index_add_multi_int32s(void *ptr, const int32_t *array, uintptr_t len); + +void tantivy_index_add_multi_int64s(void *ptr, const int64_t *array, uintptr_t len); + +void tantivy_index_add_multi_f32s(void *ptr, const float *array, uintptr_t len); + +void tantivy_index_add_multi_f64s(void *ptr, const double *array, uintptr_t len); + +void tantivy_index_add_multi_bools(void *ptr, const bool *array, uintptr_t len); + +void tantivy_index_add_multi_keywords(void *ptr, const char *const *array, uintptr_t len); + bool tantivy_index_exist(const char *path); +void print_vector_of_strings(const char *const *ptr, uintptr_t len); + } // extern "C" diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/demo_c.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/demo_c.rs new file mode 100644 index 0000000000000..257a41f17a891 --- /dev/null +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/demo_c.rs @@ -0,0 +1,14 @@ +use std::{ffi::{c_char, CStr}, slice}; + +#[no_mangle] +pub extern "C" fn print_vector_of_strings(ptr: *const *const c_char, len: usize) { + let arr : &[*const c_char] = unsafe { + slice::from_raw_parts(ptr, len) + }; + for element in arr { + let c_str = unsafe { + CStr::from_ptr(*element) + }; + println!("{}", c_str.to_str().unwrap()); + } +} \ No newline at end of file diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/index_writer.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/index_writer.rs index ce96a5b4d5a30..2c8d56bf38694 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/src/index_writer.rs +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/index_writer.rs @@ -1,10 +1,11 @@ -use futures::executor::block_on; +use std::ffi::CStr; +use libc::c_char; use tantivy::schema::{Field, IndexRecordOption, Schema, TextFieldIndexing, TextOptions, INDEXED}; -use tantivy::{doc, tokenizer, Index, IndexWriter, SingleSegmentIndexWriter}; +use tantivy::{doc, tokenizer, Index, SingleSegmentIndexWriter, Document}; use crate::data_type::TantivyDataType; -use crate::index_writer; + use crate::log::init_log; pub struct IndexWriterWrapper { @@ -98,7 +99,74 @@ impl IndexWriterWrapper { .unwrap(); } - pub fn finish(mut self) { + pub fn add_multi_i8s(&mut self, datas: &[i8]) { + let mut document = Document::default(); + for data in datas { + document.add_field_value(self.field, *data as i64); + } + self.index_writer.add_document(document).unwrap(); + } + + pub fn add_multi_i16s(&mut self, datas: &[i16]) { + let mut document = Document::default(); + for data in datas { + document.add_field_value(self.field, *data as i64); + } + self.index_writer.add_document(document).unwrap(); + } + + pub fn add_multi_i32s(&mut self, datas: &[i32]) { + let mut document = Document::default(); + for data in datas { + document.add_field_value(self.field, *data as i64); + } + self.index_writer.add_document(document).unwrap(); + } + + pub fn add_multi_i64s(&mut self, datas: &[i64]) { + let mut document = Document::default(); + for data in datas { + document.add_field_value(self.field, *data); + } + self.index_writer.add_document(document).unwrap(); + } + + pub fn add_multi_f32s(&mut self, datas: &[f32]) { + let mut document = Document::default(); + for data in datas { + document.add_field_value(self.field, *data as f64); + } + self.index_writer.add_document(document).unwrap(); + } + + pub fn add_multi_f64s(&mut self, datas: &[f64]) { + let mut document = Document::default(); + for data in datas { + document.add_field_value(self.field, *data); + } + self.index_writer.add_document(document).unwrap(); + } + + pub fn add_multi_bools(&mut self, datas: &[bool]) { + let mut document = Document::default(); + for data in datas { + document.add_field_value(self.field, *data); + } + self.index_writer.add_document(document).unwrap(); + } + + pub fn add_multi_keywords(&mut self, datas: &[*const c_char]) { + let mut document = Document::default(); + for element in datas { + let data = unsafe { + CStr::from_ptr(*element) + }; + document.add_field_value(self.field, data.to_str().unwrap()); + } + self.index_writer.add_document(document).unwrap(); + } + + pub fn finish(self) { self.index_writer .finalize() .expect("failed to build inverted index"); diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/index_writer_c.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/index_writer_c.rs index c8822781158e8..b13f550d7cb00 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/src/index_writer_c.rs +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/index_writer_c.rs @@ -122,3 +122,77 @@ pub extern "C" fn tantivy_index_add_keyword(ptr: *mut c_void, s: *const c_char) let c_str = unsafe { CStr::from_ptr(s) }; unsafe { (*real).add_keyword(c_str.to_str().unwrap()) } } + +// --------------------------------------------- array ------------------------------------------ + +#[no_mangle] +pub extern "C" fn tantivy_index_add_multi_int8s(ptr: *mut c_void, array: *const i8, len: usize) { + let real = ptr as *mut IndexWriterWrapper; + unsafe { + let arr = slice::from_raw_parts(array, len); + (*real).add_multi_i8s(arr) + } +} + +#[no_mangle] +pub extern "C" fn tantivy_index_add_multi_int16s(ptr: *mut c_void, array: *const i16, len: usize) { + let real = ptr as *mut IndexWriterWrapper; + unsafe { + let arr = slice::from_raw_parts(array, len) ; + (*real).add_multi_i16s(arr); + } +} + +#[no_mangle] +pub extern "C" fn tantivy_index_add_multi_int32s(ptr: *mut c_void, array: *const i32, len: usize) { + let real = ptr as *mut IndexWriterWrapper; + unsafe { + let arr = slice::from_raw_parts(array, len) ; + (*real).add_multi_i32s(arr); + } +} + +#[no_mangle] +pub extern "C" fn tantivy_index_add_multi_int64s(ptr: *mut c_void, array: *const i64, len: usize) { + let real = ptr as *mut IndexWriterWrapper; + unsafe { + let arr = slice::from_raw_parts(array, len) ; + (*real).add_multi_i64s(arr); + } +} + +#[no_mangle] +pub extern "C" fn tantivy_index_add_multi_f32s(ptr: *mut c_void, array: *const f32, len: usize) { + let real = ptr as *mut IndexWriterWrapper; + unsafe { + let arr = slice::from_raw_parts(array, len) ; + (*real).add_multi_f32s(arr); + } +} + +#[no_mangle] +pub extern "C" fn tantivy_index_add_multi_f64s(ptr: *mut c_void, array: *const f64, len: usize) { + let real = ptr as *mut IndexWriterWrapper; + unsafe { + let arr = slice::from_raw_parts(array, len) ; + (*real).add_multi_f64s(arr); + } +} + +#[no_mangle] +pub extern "C" fn tantivy_index_add_multi_bools(ptr: *mut c_void, array: *const bool, len: usize) { + let real = ptr as *mut IndexWriterWrapper; + unsafe { + let arr = slice::from_raw_parts(array, len) ; + (*real).add_multi_bools(arr); + } +} + +#[no_mangle] +pub extern "C" fn tantivy_index_add_multi_keywords(ptr: *mut c_void, array: *const *const c_char, len: usize) { + let real = ptr as *mut IndexWriterWrapper; + unsafe { + let arr = slice::from_raw_parts(array, len); + (*real).add_multi_keywords(arr) + } +} diff --git a/internal/core/thirdparty/tantivy/tantivy-binding/src/lib.rs b/internal/core/thirdparty/tantivy/tantivy-binding/src/lib.rs index aa069cb3b32b6..c6193de3f6908 100644 --- a/internal/core/thirdparty/tantivy/tantivy-binding/src/lib.rs +++ b/internal/core/thirdparty/tantivy/tantivy-binding/src/lib.rs @@ -10,6 +10,7 @@ mod log; mod util; mod util_c; mod vec_collector; +mod demo_c; pub fn add(left: usize, right: usize) -> usize { left + right diff --git a/internal/core/thirdparty/tantivy/tantivy-wrapper.h b/internal/core/thirdparty/tantivy/tantivy-wrapper.h index 358f14ea49ed0..7574d3875ca24 100644 --- a/internal/core/thirdparty/tantivy/tantivy-wrapper.h +++ b/internal/core/thirdparty/tantivy/tantivy-wrapper.h @@ -1,5 +1,7 @@ #include #include +#include +#include #include "tantivy-binding.h" namespace milvus::tantivy { @@ -186,6 +188,60 @@ struct TantivyIndexWrapper { typeid(T).name()); } + template + void + add_multi_data(const T* array, uintptr_t len) { + assert(!finished_); + + if constexpr (std::is_same_v) { + tantivy_index_add_multi_bools(writer_, array, len); + return; + } + + if constexpr (std::is_same_v) { + tantivy_index_add_multi_int8s(writer_, array, len); + return; + } + + if constexpr (std::is_same_v) { + tantivy_index_add_multi_int16s(writer_, array, len); + return; + } + + if constexpr (std::is_same_v) { + tantivy_index_add_multi_int32s(writer_, array, len); + return; + } + + if constexpr (std::is_same_v) { + tantivy_index_add_multi_int64s(writer_, array, len); + return; + } + + if constexpr (std::is_same_v) { + tantivy_index_add_multi_f32s(writer_, array, len); + return; + } + + if constexpr (std::is_same_v) { + tantivy_index_add_multi_f64s(writer_, array, len); + return; + } + + if constexpr (std::is_same_v) { + std::vector views; + for (uintptr_t i = 0; i < len; i++) { + views.push_back(array[i].c_str()); + } + tantivy_index_add_multi_keywords(writer_, views.data(), len); + return; + } + + throw fmt::format( + "InvertedIndex.add_multi_data: unsupported data type: {}", + typeid(T).name()); + } + inline void finish() { if (!finished_) { diff --git a/internal/core/thirdparty/tantivy/test.cpp b/internal/core/thirdparty/tantivy/test.cpp index 1c67a69673a5c..a380481042487 100644 --- a/internal/core/thirdparty/tantivy/test.cpp +++ b/internal/core/thirdparty/tantivy/test.cpp @@ -200,6 +200,83 @@ test_32717() { } } +std::set +to_set(const RustArrayWrapper& w) { + std::set s(w.array_.array, w.array_.array + w.array_.len); + return s; +} + +template +std::map> +build_inverted_index(const std::vector>& vec_of_array) { + std::map> inverted_index; + for (uint32_t i = 0; i < vec_of_array.size(); i++) { + for (const auto& term : vec_of_array[i]) { + inverted_index[term].insert(i); + } + } + return inverted_index; +} + +void +test_array_int() { + using T = int64_t; + + auto path = "/tmp/inverted-index/test-binding/"; + boost::filesystem::remove_all(path); + boost::filesystem::create_directories(path); + auto w = TantivyIndexWrapper("test_field_name", guess_data_type(), path); + + std::vector> vec_of_array{ + {10, 40, 50}, + {20, 50}, + {10, 50, 60}, + }; + + for (const auto& arr : vec_of_array) { + w.add_multi_data(arr.data(), arr.size()); + } + w.finish(); + + assert(w.count() == vec_of_array.size()); + + auto inverted_index = build_inverted_index(vec_of_array); + for (const auto& [term, posting_list] : inverted_index) { + auto hits = to_set(w.term_query(term)); + assert(posting_list == hits); + } +} + +void +test_array_string() { + using T = std::string; + + auto path = "/tmp/inverted-index/test-binding/"; + boost::filesystem::remove_all(path); + boost::filesystem::create_directories(path); + auto w = + TantivyIndexWrapper("test_field_name", TantivyDataType::Keyword, path); + + std::vector> vec_of_array{ + {"10", "40", "50"}, + {"20", "50"}, + {"10", "50", "60"}, + }; + + for (const auto& arr : vec_of_array) { + w.add_multi_data(arr.data(), arr.size()); + } + w.finish(); + + assert(w.count() == vec_of_array.size()); + + auto inverted_index = build_inverted_index(vec_of_array); + for (const auto& [term, posting_list] : inverted_index) { + auto hits = to_set(w.term_query(term)); + assert(posting_list == hits); + } +} + int main(int argc, char* argv[]) { test_32717(); @@ -216,5 +293,8 @@ main(int argc, char* argv[]) { run(); + test_array_int(); + test_array_string(); + return 0; } diff --git a/internal/core/unittest/CMakeLists.txt b/internal/core/unittest/CMakeLists.txt index 657198c9b88c2..e742e25a5a2bb 100644 --- a/internal/core/unittest/CMakeLists.txt +++ b/internal/core/unittest/CMakeLists.txt @@ -66,6 +66,7 @@ set(MILVUS_TEST_FILES test_group_by.cpp test_regex_query_util.cpp test_regex_query.cpp + test_array_inverted_index.cpp ) if ( BUILD_DISK_ANN STREQUAL "ON" ) diff --git a/internal/core/unittest/test_array_inverted_index.cpp b/internal/core/unittest/test_array_inverted_index.cpp new file mode 100644 index 0000000000000..cd4833b52bf38 --- /dev/null +++ b/internal/core/unittest/test_array_inverted_index.cpp @@ -0,0 +1,297 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICEN_SE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRAN_TIES OR CON_DITION_S OF AN_Y KIN_D, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include +#include + +#include "pb/plan.pb.h" +#include "index/InvertedIndexTantivy.h" +#include "common/Schema.h" +#include "segcore/SegmentSealedImpl.h" +#include "test_utils/DataGen.h" +#include "test_utils/GenExprProto.h" +#include "query/PlanProto.h" +#include "query/generated/ExecPlanNodeVisitor.h" + +using namespace milvus; +using namespace milvus::query; +using namespace milvus::segcore; + +template +SchemaPtr +GenTestSchema() { + auto schema_ = std::make_shared(); + schema_->AddDebugField( + "fvec", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2); + auto pk = schema_->AddDebugField("pk", DataType::INT64); + schema_->set_primary_field_id(pk); + + if constexpr (std::is_same_v) { + schema_->AddDebugArrayField("array", DataType::BOOL); + } else if constexpr (std::is_same_v) { + schema_->AddDebugArrayField("array", DataType::INT8); + } else if constexpr (std::is_same_v) { + schema_->AddDebugArrayField("array", DataType::INT16); + } else if constexpr (std::is_same_v) { + schema_->AddDebugArrayField("array", DataType::INT32); + } else if constexpr (std::is_same_v) { + schema_->AddDebugArrayField("array", DataType::INT64); + } else if constexpr (std::is_same_v) { + schema_->AddDebugArrayField("array", DataType::FLOAT); + } else if constexpr (std::is_same_v) { + schema_->AddDebugArrayField("array", DataType::DOUBLE); + } else if constexpr (std::is_same_v) { + schema_->AddDebugArrayField("array", DataType::VARCHAR); + } + + return schema_; +} + +template +class ArrayInvertedIndexTest : public ::testing::Test { + public: + void + SetUp() override { + schema_ = GenTestSchema(); + seg_ = CreateSealedSegment(schema_); + N_ = 3000; + uint64_t seed = 19190504; + auto raw_data = DataGen(schema_, N_, seed); + auto array_col = + raw_data.get_col(schema_->get_field_id(FieldName("array"))) + ->scalars() + .array_data() + .data(); + for (size_t i = 0; i < N_; i++) { + boost::container::vector array; + if constexpr (std::is_same_v) { + for (size_t j = 0; j < array_col[i].bool_data().data_size(); + j++) { + array.push_back(array_col[i].bool_data().data(j)); + } + } else if constexpr (std::is_same_v) { + for (size_t j = 0; j < array_col[i].long_data().data_size(); + j++) { + array.push_back(array_col[i].long_data().data(j)); + } + } else if constexpr (std::is_integral_v) { + for (size_t j = 0; j < array_col[i].int_data().data_size(); + j++) { + array.push_back(array_col[i].int_data().data(j)); + } + } else if constexpr (std::is_floating_point_v) { + for (size_t j = 0; j < array_col[i].float_data().data_size(); + j++) { + array.push_back(array_col[i].float_data().data(j)); + } + } else if constexpr (std::is_same_v) { + for (size_t j = 0; j < array_col[i].string_data().data_size(); + j++) { + array.push_back(array_col[i].string_data().data(j)); + } + } + vec_of_array_.push_back(array); + } + SealedLoadFieldData(raw_data, *seg_); + LoadInvertedIndex(); + } + + void + TearDown() override { + } + + void + LoadInvertedIndex() { + auto index = std::make_unique>(); + Config cfg; + cfg["is_array"] = true; + index->BuildWithRawData(N_, vec_of_array_.data(), cfg); + LoadIndexInfo info{ + .field_id = schema_->get_field_id(FieldName("array")).get(), + .index = std::move(index), + }; + seg_->LoadIndex(info); + } + + public: + SchemaPtr schema_; + SegmentSealedUPtr seg_; + int64_t N_; + std::vector> vec_of_array_; +}; + +TYPED_TEST_SUITE_P(ArrayInvertedIndexTest); + +TYPED_TEST_P(ArrayInvertedIndexTest, ArrayContainsAny) { + const auto& meta = this->schema_->operator[](FieldName("array")); + auto column_info = test::GenColumnInfo( + meta.get_id().get(), + static_cast(meta.get_data_type()), + false, + false, + static_cast(meta.get_element_type())); + auto contains_expr = std::make_unique(); + contains_expr->set_allocated_column_info(column_info); + contains_expr->set_op(proto::plan::JSONContainsExpr_JSONOp:: + JSONContainsExpr_JSONOp_ContainsAny); + contains_expr->set_elements_same_type(true); + for (const auto& elem : this->vec_of_array_[0]) { + auto t = test::GenGenericValue(elem); + contains_expr->mutable_elements()->AddAllocated(t); + } + auto expr = test::GenExpr(); + expr->set_allocated_json_contains_expr(contains_expr.release()); + + auto parser = ProtoParser(*this->schema_); + auto typed_expr = parser.ParseExprs(*expr); + auto parsed = + std::make_shared(DEFAULT_PLANNODE_ID, typed_expr); + + auto segpromote = dynamic_cast(this->seg_.get()); + query::ExecPlanNodeVisitor visitor(*segpromote, MAX_TIMESTAMP); + BitsetType final; + visitor.ExecuteExprNode(parsed, segpromote, this->N_, final); + + std::unordered_set elems(this->vec_of_array_[0].begin(), + this->vec_of_array_[0].end()); + auto ref = [this, &elems](size_t offset) -> bool { + std::unordered_set row(this->vec_of_array_[offset].begin(), + this->vec_of_array_[offset].end()); + for (const auto& elem : elems) { + if (row.find(elem) != row.end()) { + return true; + } + } + return false; + }; + ASSERT_EQ(final.size(), this->N_); + for (size_t i = 0; i < this->N_; i++) { + ASSERT_EQ(final[i], ref(i)) << "i: " << i << ", final[i]: " << final[i] + << ", ref(i): " << ref(i); + } +} + +TYPED_TEST_P(ArrayInvertedIndexTest, ArrayContainsAll) { + const auto& meta = this->schema_->operator[](FieldName("array")); + auto column_info = test::GenColumnInfo( + meta.get_id().get(), + static_cast(meta.get_data_type()), + false, + false, + static_cast(meta.get_element_type())); + auto contains_expr = std::make_unique(); + contains_expr->set_allocated_column_info(column_info); + contains_expr->set_op(proto::plan::JSONContainsExpr_JSONOp:: + JSONContainsExpr_JSONOp_ContainsAll); + contains_expr->set_elements_same_type(true); + for (const auto& elem : this->vec_of_array_[0]) { + auto t = test::GenGenericValue(elem); + contains_expr->mutable_elements()->AddAllocated(t); + } + auto expr = test::GenExpr(); + expr->set_allocated_json_contains_expr(contains_expr.release()); + + auto parser = ProtoParser(*this->schema_); + auto typed_expr = parser.ParseExprs(*expr); + auto parsed = + std::make_shared(DEFAULT_PLANNODE_ID, typed_expr); + + auto segpromote = dynamic_cast(this->seg_.get()); + query::ExecPlanNodeVisitor visitor(*segpromote, MAX_TIMESTAMP); + BitsetType final; + visitor.ExecuteExprNode(parsed, segpromote, this->N_, final); + + std::unordered_set elems(this->vec_of_array_[0].begin(), + this->vec_of_array_[0].end()); + auto ref = [this, &elems](size_t offset) -> bool { + std::unordered_set row(this->vec_of_array_[offset].begin(), + this->vec_of_array_[offset].end()); + for (const auto& elem : elems) { + if (row.find(elem) == row.end()) { + return false; + } + } + return true; + }; + ASSERT_EQ(final.size(), this->N_); + for (size_t i = 0; i < this->N_; i++) { + ASSERT_EQ(final[i], ref(i)) << "i: " << i << ", final[i]: " << final[i] + << ", ref(i): " << ref(i); + } +} + +TYPED_TEST_P(ArrayInvertedIndexTest, ArrayEqual) { + if (std::is_floating_point_v) { + GTEST_SKIP() << "not accurate to perform equal comparison on floating " + "point number"; + } + + const auto& meta = this->schema_->operator[](FieldName("array")); + auto column_info = test::GenColumnInfo( + meta.get_id().get(), + static_cast(meta.get_data_type()), + false, + false, + static_cast(meta.get_element_type())); + auto unary_range_expr = std::make_unique(); + unary_range_expr->set_allocated_column_info(column_info); + unary_range_expr->set_op(proto::plan::OpType::Equal); + auto arr = new proto::plan::GenericValue; + arr->mutable_array_val()->set_element_type( + static_cast(meta.get_element_type())); + arr->mutable_array_val()->set_same_type(true); + for (const auto& elem : this->vec_of_array_[0]) { + auto e = test::GenGenericValue(elem); + arr->mutable_array_val()->mutable_array()->AddAllocated(e); + } + unary_range_expr->set_allocated_value(arr); + auto expr = test::GenExpr(); + expr->set_allocated_unary_range_expr(unary_range_expr.release()); + + auto parser = ProtoParser(*this->schema_); + auto typed_expr = parser.ParseExprs(*expr); + auto parsed = + std::make_shared(DEFAULT_PLANNODE_ID, typed_expr); + + auto segpromote = dynamic_cast(this->seg_.get()); + query::ExecPlanNodeVisitor visitor(*segpromote, MAX_TIMESTAMP); + BitsetType final; + visitor.ExecuteExprNode(parsed, segpromote, this->N_, final); + + auto ref = [this](size_t offset) -> bool { + if (this->vec_of_array_[0].size() != + this->vec_of_array_[offset].size()) { + return false; + } + auto size = this->vec_of_array_[0].size(); + for (size_t i = 0; i < size; i++) { + if (this->vec_of_array_[0][i] != this->vec_of_array_[offset][i]) { + return false; + } + } + return true; + }; + ASSERT_EQ(final.size(), this->N_); + for (size_t i = 0; i < this->N_; i++) { + ASSERT_EQ(final[i], ref(i)) << "i: " << i << ", final[i]: " << final[i] + << ", ref(i): " << ref(i); + } +} + +using ElementType = testing:: + Types; + +REGISTER_TYPED_TEST_CASE_P(ArrayInvertedIndexTest, + ArrayContainsAny, + ArrayContainsAll, + ArrayEqual); + +INSTANTIATE_TYPED_TEST_SUITE_P(Naive, ArrayInvertedIndexTest, ElementType); diff --git a/internal/core/unittest/test_index_wrapper.cpp b/internal/core/unittest/test_index_wrapper.cpp index 39f6841957dc4..79581bc96947b 100644 --- a/internal/core/unittest/test_index_wrapper.cpp +++ b/internal/core/unittest/test_index_wrapper.cpp @@ -23,7 +23,7 @@ using namespace milvus; using namespace milvus::segcore; -using namespace milvus::proto::indexcgo; +using namespace milvus::proto; using Param = std::pair; diff --git a/internal/core/unittest/test_inverted_index.cpp b/internal/core/unittest/test_inverted_index.cpp index eeddfe6e9d81a..c8b9bf3663235 100644 --- a/internal/core/unittest/test_inverted_index.cpp +++ b/internal/core/unittest/test_inverted_index.cpp @@ -25,20 +25,25 @@ using namespace milvus; -// TODO: I would suggest that our all indexes use this test to simulate the real production environment. - namespace milvus::test { auto gen_field_meta(int64_t collection_id = 1, int64_t partition_id = 2, int64_t segment_id = 3, - int64_t field_id = 101) -> storage::FieldDataMeta { - return storage::FieldDataMeta{ + int64_t field_id = 101, + DataType data_type = DataType::NONE, + DataType element_type = DataType::NONE) + -> storage::FieldDataMeta { + auto meta = storage::FieldDataMeta{ .collection_id = collection_id, .partition_id = partition_id, .segment_id = segment_id, .field_id = field_id, }; + meta.schema.set_data_type(static_cast(data_type)); + meta.schema.set_element_type( + static_cast(element_type)); + return meta; } auto @@ -86,7 +91,7 @@ struct ChunkManagerWrapper { }; } // namespace milvus::test -template +template void test_run() { int64_t collection_id = 1; @@ -96,8 +101,8 @@ test_run() { int64_t index_build_id = 1000; int64_t index_version = 10000; - auto field_meta = - test::gen_field_meta(collection_id, partition_id, segment_id, field_id); + auto field_meta = test::gen_field_meta( + collection_id, partition_id, segment_id, field_id, dtype, element_type); auto index_meta = test::gen_index_meta( segment_id, field_id, index_build_id, index_version); @@ -305,8 +310,12 @@ test_string() { int64_t index_build_id = 1000; int64_t index_version = 10000; - auto field_meta = - test::gen_field_meta(collection_id, partition_id, segment_id, field_id); + auto field_meta = test::gen_field_meta(collection_id, + partition_id, + segment_id, + field_id, + dtype, + DataType::NONE); auto index_meta = test::gen_index_meta( segment_id, field_id, index_build_id, index_version); diff --git a/internal/core/unittest/test_scalar_index.cpp b/internal/core/unittest/test_scalar_index.cpp index 8b11c89530e9b..f7becf13b492f 100644 --- a/internal/core/unittest/test_scalar_index.cpp +++ b/internal/core/unittest/test_scalar_index.cpp @@ -49,6 +49,14 @@ TYPED_TEST_P(TypedScalarIndexTest, Dummy) { std::cout << milvus::GetDType() << std::endl; } +auto +GetTempFileManagerCtx(CDataType data_type) { + auto ctx = milvus::storage::FileManagerContext(); + ctx.fieldDataMeta.schema.set_data_type( + static_cast(data_type)); + return ctx; +} + TYPED_TEST_P(TypedScalarIndexTest, Constructor) { using T = TypeParam; auto dtype = milvus::GetDType(); @@ -59,7 +67,7 @@ TYPED_TEST_P(TypedScalarIndexTest, Constructor) { create_index_info.index_type = index_type; auto index = milvus::index::IndexFactory::GetInstance().CreateScalarIndex( - create_index_info); + create_index_info, GetTempFileManagerCtx(dtype)); } } @@ -73,7 +81,7 @@ TYPED_TEST_P(TypedScalarIndexTest, Count) { create_index_info.index_type = index_type; auto index = milvus::index::IndexFactory::GetInstance().CreateScalarIndex( - create_index_info); + create_index_info, GetTempFileManagerCtx(dtype)); auto scalar_index = dynamic_cast*>(index.get()); auto arr = GenSortedArr(nb); @@ -92,7 +100,7 @@ TYPED_TEST_P(TypedScalarIndexTest, HasRawData) { create_index_info.index_type = index_type; auto index = milvus::index::IndexFactory::GetInstance().CreateScalarIndex( - create_index_info); + create_index_info, GetTempFileManagerCtx(dtype)); auto scalar_index = dynamic_cast*>(index.get()); auto arr = GenSortedArr(nb); @@ -112,7 +120,7 @@ TYPED_TEST_P(TypedScalarIndexTest, In) { create_index_info.index_type = index_type; auto index = milvus::index::IndexFactory::GetInstance().CreateScalarIndex( - create_index_info); + create_index_info, GetTempFileManagerCtx(dtype)); auto scalar_index = dynamic_cast*>(index.get()); auto arr = GenSortedArr(nb); @@ -131,7 +139,7 @@ TYPED_TEST_P(TypedScalarIndexTest, NotIn) { create_index_info.index_type = index_type; auto index = milvus::index::IndexFactory::GetInstance().CreateScalarIndex( - create_index_info); + create_index_info, GetTempFileManagerCtx(dtype)); auto scalar_index = dynamic_cast*>(index.get()); auto arr = GenSortedArr(nb); @@ -150,7 +158,7 @@ TYPED_TEST_P(TypedScalarIndexTest, Reverse) { create_index_info.index_type = index_type; auto index = milvus::index::IndexFactory::GetInstance().CreateScalarIndex( - create_index_info); + create_index_info, GetTempFileManagerCtx(dtype)); auto scalar_index = dynamic_cast*>(index.get()); auto arr = GenSortedArr(nb); @@ -169,7 +177,7 @@ TYPED_TEST_P(TypedScalarIndexTest, Range) { create_index_info.index_type = index_type; auto index = milvus::index::IndexFactory::GetInstance().CreateScalarIndex( - create_index_info); + create_index_info, GetTempFileManagerCtx(dtype)); auto scalar_index = dynamic_cast*>(index.get()); auto arr = GenSortedArr(nb); @@ -188,7 +196,7 @@ TYPED_TEST_P(TypedScalarIndexTest, Codec) { create_index_info.index_type = index_type; auto index = milvus::index::IndexFactory::GetInstance().CreateScalarIndex( - create_index_info); + create_index_info, GetTempFileManagerCtx(dtype)); auto scalar_index = dynamic_cast*>(index.get()); auto arr = GenSortedArr(nb); @@ -197,7 +205,7 @@ TYPED_TEST_P(TypedScalarIndexTest, Codec) { auto binary_set = index->Serialize(nullptr); auto copy_index = milvus::index::IndexFactory::GetInstance().CreateScalarIndex( - create_index_info); + create_index_info, GetTempFileManagerCtx(dtype)); copy_index->Load(binary_set); auto copy_scalar_index = @@ -368,6 +376,8 @@ TYPED_TEST_P(TypedScalarIndexTestV2, Base) { auto space = TestSpace(temp_path, vec_size, dataset, scalars); milvus::storage::FileManagerContext file_manager_context( {}, {.field_name = "scalar"}, chunk_manager, space); + file_manager_context.fieldDataMeta.schema.set_data_type( + static_cast(dtype)); auto index = milvus::index::IndexFactory::GetInstance().CreateScalarIndex( create_index_info, file_manager_context, space); diff --git a/internal/core/unittest/test_utils/DataGen.h b/internal/core/unittest/test_utils/DataGen.h index 37c3d6f27676d..3b69ed98e8ec0 100644 --- a/internal/core/unittest/test_utils/DataGen.h +++ b/internal/core/unittest/test_utils/DataGen.h @@ -480,8 +480,30 @@ inline GeneratedData DataGen(SchemaPtr schema, } break; } - case DataType::INT8: - case DataType::INT16: + case DataType::INT8: { + for (int i = 0; i < N / repeat_count; i++) { + milvus::proto::schema::ScalarField field_data; + + for (int j = 0; j < array_len; j++) { + field_data.mutable_int_data()->add_data( + static_cast(random())); + } + data[i] = field_data; + } + break; + } + case DataType::INT16: { + for (int i = 0; i < N / repeat_count; i++) { + milvus::proto::schema::ScalarField field_data; + + for (int j = 0; j < array_len; j++) { + field_data.mutable_int_data()->add_data( + static_cast(random())); + } + data[i] = field_data; + } + break; + } case DataType::INT32: { for (int i = 0; i < N / repeat_count; i++) { milvus::proto::schema::ScalarField field_data; diff --git a/internal/core/unittest/test_utils/GenExprProto.h b/internal/core/unittest/test_utils/GenExprProto.h index 171273b1fc7fd..77f0a4964e4bb 100644 --- a/internal/core/unittest/test_utils/GenExprProto.h +++ b/internal/core/unittest/test_utils/GenExprProto.h @@ -15,15 +15,18 @@ namespace milvus::test { inline auto -GenColumnInfo(int64_t field_id, - proto::schema::DataType field_type, - bool auto_id, - bool is_pk) { +GenColumnInfo( + int64_t field_id, + proto::schema::DataType field_type, + bool auto_id, + bool is_pk, + proto::schema::DataType element_type = proto::schema::DataType::None) { auto column_info = new proto::plan::ColumnInfo(); column_info->set_field_id(field_id); column_info->set_data_type(field_type); column_info->set_is_autoid(auto_id); column_info->set_is_primary_key(is_pk); + column_info->set_element_type(element_type); return column_info; } diff --git a/internal/datacoord/index_builder.go b/internal/datacoord/index_builder.go index be03a613ef634..3c87b94d23f60 100644 --- a/internal/datacoord/index_builder.go +++ b/internal/datacoord/index_builder.go @@ -347,28 +347,29 @@ func (ib *indexBuilder) process(buildID UniqueID) bool { } } var req *indexpb.CreateJobRequest - if Params.CommonCfg.EnableStorageV2.GetAsBool() { - collectionInfo, err := ib.handler.GetCollection(ib.ctx, segment.GetCollectionID()) - if err != nil { - log.Info("index builder get collection info failed", zap.Int64("collectionID", segment.GetCollectionID()), zap.Error(err)) - return false - } + collectionInfo, err := ib.handler.GetCollection(ib.ctx, segment.GetCollectionID()) + if err != nil { + log.Ctx(ib.ctx).Info("index builder get collection info failed", zap.Int64("collectionID", segment.GetCollectionID()), zap.Error(err)) + return false + } - schema := collectionInfo.Schema - var field *schemapb.FieldSchema + schema := collectionInfo.Schema + var field *schemapb.FieldSchema - for _, f := range schema.Fields { - if f.FieldID == fieldID { - field = f - break - } - } - - dim, err := storage.GetDimFromParams(field.TypeParams) - if err != nil { - return false + for _, f := range schema.Fields { + if f.FieldID == fieldID { + field = f + break } + } + dim, err := storage.GetDimFromParams(field.TypeParams) + if err != nil { + log.Ctx(ib.ctx).Warn("failed to get dim from field type params", + zap.String("field type", field.GetDataType().String()), zap.Error(err)) + // don't return, maybe field is scalar field or sparseFloatVector + } + if Params.CommonCfg.EnableStorageV2.GetAsBool() { storePath, err := itypeutil.GetStorageURI(params.Params.CommonCfg.StorageScheme.GetValue(), params.Params.CommonCfg.StoragePathPrefix.GetValue(), segment.GetID()) if err != nil { log.Ctx(ib.ctx).Warn("failed to get storage uri", zap.Error(err)) @@ -402,6 +403,7 @@ func (ib *indexBuilder) process(buildID UniqueID) bool { CurrentIndexVersion: ib.indexEngineVersionManager.GetCurrentIndexEngineVersion(), DataIds: binlogIDs, OptionalScalarFields: optionalFields, + Field: field, } } else { req = &indexpb.CreateJobRequest{ @@ -420,6 +422,8 @@ func (ib *indexBuilder) process(buildID UniqueID) bool { SegmentID: segment.GetID(), FieldID: fieldID, OptionalScalarFields: optionalFields, + Dim: int64(dim), + Field: field, } } diff --git a/internal/datacoord/index_builder_test.go b/internal/datacoord/index_builder_test.go index 46d8c7fe3f43e..9488c70f5e818 100644 --- a/internal/datacoord/index_builder_test.go +++ b/internal/datacoord/index_builder_test.go @@ -675,7 +675,30 @@ func TestIndexBuilder(t *testing.T) { chunkManager := &mocks.ChunkManager{} chunkManager.EXPECT().RootPath().Return("root") - ib := newIndexBuilder(ctx, mt, nodeManager, chunkManager, newIndexEngineVersionManager(), nil) + handler := NewNMockHandler(t) + handler.EXPECT().GetCollection(mock.Anything, mock.Anything).Return(&collectionInfo{ + ID: collID, + Schema: &schemapb.CollectionSchema{ + Name: "coll", + Fields: []*schemapb.FieldSchema{ + { + FieldID: fieldID, + Name: "vec", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "128", + }, + }, + }, + }, + EnableDynamicField: false, + Properties: nil, + }, + }, nil) + + ib := newIndexBuilder(ctx, mt, nodeManager, chunkManager, newIndexEngineVersionManager(), handler) assert.Equal(t, 6, len(ib.tasks)) assert.Equal(t, indexTaskInit, ib.tasks[buildID]) @@ -741,6 +764,30 @@ func TestIndexBuilder_Error(t *testing.T) { chunkManager := &mocks.ChunkManager{} chunkManager.EXPECT().RootPath().Return("root") + + handler := NewNMockHandler(t) + handler.EXPECT().GetCollection(mock.Anything, mock.Anything).Return(&collectionInfo{ + ID: collID, + Schema: &schemapb.CollectionSchema{ + Name: "coll", + Fields: []*schemapb.FieldSchema{ + { + FieldID: fieldID, + Name: "vec", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "128", + }, + }, + }, + }, + EnableDynamicField: false, + Properties: nil, + }, + }, nil) + ib := &indexBuilder{ ctx: context.Background(), tasks: map[int64]indexTaskState{ @@ -749,6 +796,7 @@ func TestIndexBuilder_Error(t *testing.T) { meta: createMetaTable(ec), chunkManager: chunkManager, indexEngineVersionManager: newIndexEngineVersionManager(), + handler: handler, } t.Run("meta not exist", func(t *testing.T) { @@ -1414,9 +1462,32 @@ func TestVecIndexWithOptionalScalarField(t *testing.T) { mt.collections[collID].Schema.Fields[1].DataType = schemapb.DataType_VarChar } + handler := NewNMockHandler(t) + handler.EXPECT().GetCollection(mock.Anything, mock.Anything).Return(&collectionInfo{ + ID: collID, + Schema: &schemapb.CollectionSchema{ + Name: "coll", + Fields: []*schemapb.FieldSchema{ + { + FieldID: fieldID, + Name: "vec", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "128", + }, + }, + }, + }, + EnableDynamicField: false, + Properties: nil, + }, + }, nil) + paramtable.Get().CommonCfg.EnableMaterializedView.SwapTempValue("true") defer paramtable.Get().CommonCfg.EnableMaterializedView.SwapTempValue("false") - ib := newIndexBuilder(ctx, &mt, nodeManager, cm, newIndexEngineVersionManager(), nil) + ib := newIndexBuilder(ctx, &mt, nodeManager, cm, newIndexEngineVersionManager(), handler) t.Run("success to get opt field on startup", func(t *testing.T) { ic.EXPECT().CreateJob(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn( diff --git a/internal/indexnode/indexnode_service.go b/internal/indexnode/indexnode_service.go index a690e35e4a10a..fb9d5a0cc19a1 100644 --- a/internal/indexnode/indexnode_service.go +++ b/internal/indexnode/indexnode_service.go @@ -55,6 +55,8 @@ func (i *IndexNode) CreateJob(ctx context.Context, req *indexpb.CreateJobRequest defer i.lifetime.Done() log.Info("IndexNode building index ...", zap.Int64("collectionID", req.GetCollectionID()), + zap.Int64("partitionID", req.GetPartitionID()), + zap.Int64("segmentID", req.GetSegmentID()), zap.Int64("indexID", req.GetIndexID()), zap.String("indexName", req.GetIndexName()), zap.String("indexFilePrefix", req.GetIndexFilePrefix()), diff --git a/internal/indexnode/task.go b/internal/indexnode/task.go index b14343900d99c..54c8b3fe45a66 100644 --- a/internal/indexnode/task.go +++ b/internal/indexnode/task.go @@ -18,7 +18,6 @@ package indexnode import ( "context" - "encoding/json" "fmt" "runtime/debug" "strconv" @@ -30,6 +29,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/indexcgopb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/indexcgowrapper" @@ -84,12 +84,21 @@ type indexBuildTaskV2 struct { } func (it *indexBuildTaskV2) parseParams(ctx context.Context) error { - it.collectionID = it.req.CollectionID - it.partitionID = it.req.PartitionID - it.segmentID = it.req.SegmentID - it.fieldType = it.req.FieldType - it.fieldID = it.req.FieldID - it.fieldName = it.req.FieldName + it.collectionID = it.req.GetCollectionID() + it.partitionID = it.req.GetPartitionID() + it.segmentID = it.req.GetSegmentID() + it.fieldType = it.req.GetFieldType() + if it.fieldType == schemapb.DataType_None { + it.fieldType = it.req.GetField().GetDataType() + } + it.fieldID = it.req.GetFieldID() + if it.fieldID == 0 { + it.fieldID = it.req.GetField().GetFieldID() + } + it.fieldName = it.req.GetFieldName() + if it.fieldName == "" { + it.fieldName = it.req.GetField().GetName() + } return nil } @@ -138,61 +147,66 @@ func (it *indexBuildTaskV2) BuildIndex(ctx context.Context) error { } } - var buildIndexInfo *indexcgowrapper.BuildIndexInfo - buildIndexInfo, err = indexcgowrapper.NewBuildIndexInfo(it.req.GetStorageConfig()) - defer indexcgowrapper.DeleteBuildIndexInfo(buildIndexInfo) - if err != nil { - log.Ctx(ctx).Warn("create build index info failed", zap.Error(err)) - return err - } - err = buildIndexInfo.AppendFieldMetaInfoV2(it.collectionID, it.partitionID, it.segmentID, it.fieldID, it.fieldType, it.fieldName, it.req.Dim) - if err != nil { - log.Ctx(ctx).Warn("append field meta failed", zap.Error(err)) - return err - } - - err = buildIndexInfo.AppendIndexMetaInfo(it.req.IndexID, it.req.BuildID, it.req.IndexVersion) - if err != nil { - log.Ctx(ctx).Warn("append index meta failed", zap.Error(err)) - return err - } - - err = buildIndexInfo.AppendBuildIndexParam(it.newIndexParams) - if err != nil { - log.Ctx(ctx).Warn("append index params failed", zap.Error(err)) - return err - } - - err = buildIndexInfo.AppendIndexStorageInfo(it.req.StorePath, it.req.IndexStorePath, it.req.StoreVersion) - if err != nil { - log.Ctx(ctx).Warn("append storage info failed", zap.Error(err)) - return err - } - - jsonIndexParams, err := json.Marshal(it.newIndexParams) - if err != nil { - log.Ctx(ctx).Error("failed to json marshal index params", zap.Error(err)) - return err - } - - log.Ctx(ctx).Info("index params are ready", - zap.Int64("buildID", it.BuildID), - zap.String("index params", string(jsonIndexParams))) - - err = buildIndexInfo.AppendBuildTypeParam(it.newTypeParams) - if err != nil { - log.Ctx(ctx).Warn("append type params failed", zap.Error(err)) - return err + storageConfig := &indexcgopb.StorageConfig{ + Address: it.req.GetStorageConfig().GetAddress(), + AccessKeyID: it.req.GetStorageConfig().GetAccessKeyID(), + SecretAccessKey: it.req.GetStorageConfig().GetSecretAccessKey(), + UseSSL: it.req.GetStorageConfig().GetUseSSL(), + BucketName: it.req.GetStorageConfig().GetBucketName(), + RootPath: it.req.GetStorageConfig().GetRootPath(), + UseIAM: it.req.GetStorageConfig().GetUseIAM(), + IAMEndpoint: it.req.GetStorageConfig().GetIAMEndpoint(), + StorageType: it.req.GetStorageConfig().GetStorageType(), + UseVirtualHost: it.req.GetStorageConfig().GetUseVirtualHost(), + Region: it.req.GetStorageConfig().GetRegion(), + CloudProvider: it.req.GetStorageConfig().GetCloudProvider(), + RequestTimeoutMs: it.req.GetStorageConfig().GetRequestTimeoutMs(), + SslCACert: it.req.GetStorageConfig().GetSslCACert(), + } + + optFields := make([]*indexcgopb.OptionalFieldInfo, 0, len(it.req.GetOptionalScalarFields())) + for _, optField := range it.req.GetOptionalScalarFields() { + optFields = append(optFields, &indexcgopb.OptionalFieldInfo{ + FieldID: optField.GetFieldID(), + FieldName: optField.GetFieldName(), + FieldType: optField.GetFieldType(), + DataPaths: optField.GetDataPaths(), + }) } - for _, optField := range it.req.GetOptionalScalarFields() { - if err := buildIndexInfo.AppendOptionalField(optField); err != nil { - log.Ctx(ctx).Warn("append optional field failed", zap.Error(err)) - return err + it.currentIndexVersion = getCurrentIndexVersion(it.req.GetCurrentIndexVersion()) + field := it.req.GetField() + if field == nil || field.GetDataType() == schemapb.DataType_None { + field = &schemapb.FieldSchema{ + FieldID: it.fieldID, + Name: it.fieldName, + DataType: it.fieldType, } } - it.index, err = indexcgowrapper.CreateIndexV2(ctx, buildIndexInfo) + buildIndexParams := &indexcgopb.BuildIndexInfo{ + ClusterID: it.ClusterID, + BuildID: it.BuildID, + CollectionID: it.collectionID, + PartitionID: it.partitionID, + SegmentID: it.segmentID, + IndexVersion: it.req.GetIndexVersion(), + CurrentIndexVersion: it.currentIndexVersion, + NumRows: it.req.GetNumRows(), + Dim: it.req.GetDim(), + IndexFilePrefix: it.req.GetIndexFilePrefix(), + InsertFiles: it.req.GetDataPaths(), + FieldSchema: field, + StorageConfig: storageConfig, + IndexParams: mapToKVPairs(it.newIndexParams), + TypeParams: mapToKVPairs(it.newTypeParams), + StorePath: it.req.GetStorePath(), + StoreVersion: it.req.GetStoreVersion(), + IndexStorePath: it.req.GetIndexStorePath(), + OptFields: optFields, + } + + it.index, err = indexcgowrapper.CreateIndexV2(ctx, buildIndexParams) if err != nil { if it.index != nil && it.index.CleanLocalData() != nil { log.Ctx(ctx).Error("failed to clean cached data on disk after build index failed", @@ -328,7 +342,7 @@ func (it *indexBuildTask) Prepare(ctx context.Context) error { if len(it.req.DataPaths) == 0 { for _, id := range it.req.GetDataIds() { - path := metautil.BuildInsertLogPath(it.req.GetStorageConfig().RootPath, it.req.GetCollectionID(), it.req.GetPartitionID(), it.req.GetSegmentID(), it.req.GetFieldID(), id) + path := metautil.BuildInsertLogPath(it.req.GetStorageConfig().RootPath, it.req.GetCollectionID(), it.req.GetPartitionID(), it.req.GetSegmentID(), it.req.GetField().GetFieldID(), id) it.req.DataPaths = append(it.req.DataPaths, path) } } @@ -362,16 +376,10 @@ func (it *indexBuildTask) Prepare(ctx context.Context) error { } it.newTypeParams = typeParams it.newIndexParams = indexParams + it.statistic.IndexParams = it.req.GetIndexParams() - // ugly codes to get dimension - if dimStr, ok := typeParams[common.DimKey]; ok { - var err error - it.statistic.Dim, err = strconv.ParseInt(dimStr, 10, 64) - if err != nil { - log.Ctx(ctx).Error("parse dimesion failed", zap.Error(err)) - // ignore error - } - } + it.statistic.Dim = it.req.GetDim() + log.Ctx(ctx).Info("Successfully prepare indexBuildTask", zap.Int64("buildID", it.BuildID), zap.Int64("Collection", it.collectionID), zap.Int64("SegmentID", it.segmentID)) return nil @@ -482,69 +490,65 @@ func (it *indexBuildTask) BuildIndex(ctx context.Context) error { } } - var buildIndexInfo *indexcgowrapper.BuildIndexInfo - buildIndexInfo, err = indexcgowrapper.NewBuildIndexInfo(it.req.GetStorageConfig()) - defer indexcgowrapper.DeleteBuildIndexInfo(buildIndexInfo) - if err != nil { - log.Ctx(ctx).Warn("create build index info failed", zap.Error(err)) - return err - } - err = buildIndexInfo.AppendFieldMetaInfo(it.collectionID, it.partitionID, it.segmentID, it.fieldID, it.fieldType) - if err != nil { - log.Ctx(ctx).Warn("append field meta failed", zap.Error(err)) - return err - } - - err = buildIndexInfo.AppendIndexMetaInfo(it.req.IndexID, it.req.BuildID, it.req.IndexVersion) - if err != nil { - log.Ctx(ctx).Warn("append index meta failed", zap.Error(err)) - return err - } - - err = buildIndexInfo.AppendBuildIndexParam(it.newIndexParams) - if err != nil { - log.Ctx(ctx).Warn("append index params failed", zap.Error(err)) - return err - } - - jsonIndexParams, err := json.Marshal(it.newIndexParams) - if err != nil { - log.Ctx(ctx).Error("failed to json marshal index params", zap.Error(err)) - return err - } - - log.Ctx(ctx).Info("index params are ready", - zap.Int64("buildID", it.BuildID), - zap.String("index params", string(jsonIndexParams))) - - err = buildIndexInfo.AppendBuildTypeParam(it.newTypeParams) - if err != nil { - log.Ctx(ctx).Warn("append type params failed", zap.Error(err)) - return err - } - - for _, path := range it.req.GetDataPaths() { - err = buildIndexInfo.AppendInsertFile(path) - if err != nil { - log.Ctx(ctx).Warn("append insert binlog path failed", zap.Error(err)) - return err - } + storageConfig := &indexcgopb.StorageConfig{ + Address: it.req.GetStorageConfig().GetAddress(), + AccessKeyID: it.req.GetStorageConfig().GetAccessKeyID(), + SecretAccessKey: it.req.GetStorageConfig().GetSecretAccessKey(), + UseSSL: it.req.GetStorageConfig().GetUseSSL(), + BucketName: it.req.GetStorageConfig().GetBucketName(), + RootPath: it.req.GetStorageConfig().GetRootPath(), + UseIAM: it.req.GetStorageConfig().GetUseIAM(), + IAMEndpoint: it.req.GetStorageConfig().GetIAMEndpoint(), + StorageType: it.req.GetStorageConfig().GetStorageType(), + UseVirtualHost: it.req.GetStorageConfig().GetUseVirtualHost(), + Region: it.req.GetStorageConfig().GetRegion(), + CloudProvider: it.req.GetStorageConfig().GetCloudProvider(), + RequestTimeoutMs: it.req.GetStorageConfig().GetRequestTimeoutMs(), + SslCACert: it.req.GetStorageConfig().GetSslCACert(), + } + + optFields := make([]*indexcgopb.OptionalFieldInfo, 0, len(it.req.GetOptionalScalarFields())) + for _, optField := range it.req.GetOptionalScalarFields() { + optFields = append(optFields, &indexcgopb.OptionalFieldInfo{ + FieldID: optField.GetFieldID(), + FieldName: optField.GetFieldName(), + FieldType: optField.GetFieldType(), + DataPaths: optField.GetDataPaths(), + }) } it.currentIndexVersion = getCurrentIndexVersion(it.req.GetCurrentIndexVersion()) - if err := buildIndexInfo.AppendIndexEngineVersion(it.currentIndexVersion); err != nil { - log.Ctx(ctx).Warn("append index engine version failed", zap.Error(err)) - return err - } - - for _, optField := range it.req.GetOptionalScalarFields() { - if err := buildIndexInfo.AppendOptionalField(optField); err != nil { - log.Ctx(ctx).Warn("append optional field failed", zap.Error(err)) - return err + field := it.req.GetField() + if field == nil || field.GetDataType() == schemapb.DataType_None { + field = &schemapb.FieldSchema{ + FieldID: it.fieldID, + Name: it.fieldName, + DataType: it.fieldType, } } - - it.index, err = indexcgowrapper.CreateIndex(ctx, buildIndexInfo) + buildIndexParams := &indexcgopb.BuildIndexInfo{ + ClusterID: it.ClusterID, + BuildID: it.BuildID, + CollectionID: it.collectionID, + PartitionID: it.partitionID, + SegmentID: it.segmentID, + IndexVersion: it.req.GetIndexVersion(), + CurrentIndexVersion: it.currentIndexVersion, + NumRows: it.req.GetNumRows(), + Dim: it.req.GetDim(), + IndexFilePrefix: it.req.GetIndexFilePrefix(), + InsertFiles: it.req.GetDataPaths(), + FieldSchema: field, + StorageConfig: storageConfig, + IndexParams: mapToKVPairs(it.newIndexParams), + TypeParams: mapToKVPairs(it.newTypeParams), + StorePath: it.req.GetStorePath(), + StoreVersion: it.req.GetStoreVersion(), + IndexStorePath: it.req.GetIndexStorePath(), + OptFields: optFields, + } + + it.index, err = indexcgowrapper.CreateIndex(ctx, buildIndexParams) if err != nil { if it.index != nil && it.index.CleanLocalData() != nil { log.Ctx(ctx).Error("failed to clean cached data on disk after build index failed", @@ -653,8 +657,6 @@ func (it *indexBuildTask) decodeBlobs(ctx context.Context, blobs []*storage.Blob deserializeDur := it.tr.RecordSpan() log.Ctx(ctx).Info("IndexNode deserialize data success", - zap.Int64("index id", it.req.IndexID), - zap.String("index name", it.req.IndexName), zap.Int64("collectionID", it.collectionID), zap.Int64("partitionID", it.partitionID), zap.Int64("segmentID", it.segmentID), diff --git a/internal/indexnode/task_test.go b/internal/indexnode/task_test.go index dc30abd800eec..6450c3e504a71 100644 --- a/internal/indexnode/task_test.go +++ b/internal/indexnode/task_test.go @@ -283,12 +283,14 @@ func (suite *IndexBuildTaskV2Suite) TestBuildIndex() { RootPath: "/tmp/milvus/data", StorageType: "local", }, - CollectionID: 1, - PartitionID: 1, - SegmentID: 1, - FieldID: 3, - FieldName: "vec", - FieldType: schemapb.DataType_FloatVector, + CollectionID: 1, + PartitionID: 1, + SegmentID: 1, + Field: &schemapb.FieldSchema{ + FieldID: 3, + Name: "vec", + DataType: schemapb.DataType_FloatVector, + }, StorePath: "file://" + suite.space.Path(), StoreVersion: suite.space.GetCurrentVersion(), IndexStorePath: "file://" + suite.space.Path(), diff --git a/internal/indexnode/util.go b/internal/indexnode/util.go index 9186f9855a81b..8aaa92910503f 100644 --- a/internal/indexnode/util.go +++ b/internal/indexnode/util.go @@ -19,6 +19,7 @@ package indexnode import ( "github.com/cockroachdb/errors" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) @@ -36,3 +37,14 @@ func estimateFieldDataSize(dim int64, numRows int64, dataType schemapb.DataType) return 0, nil } } + +func mapToKVPairs(m map[string]string) []*commonpb.KeyValuePair { + kvs := make([]*commonpb.KeyValuePair, 0, len(m)) + for k, v := range m { + kvs = append(kvs, &commonpb.KeyValuePair{ + Key: k, + Value: v, + }) + } + return kvs +} diff --git a/internal/indexnode/util_test.go b/internal/indexnode/util_test.go new file mode 100644 index 0000000000000..6d7d98e823240 --- /dev/null +++ b/internal/indexnode/util_test.go @@ -0,0 +1,41 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package indexnode + +import ( + "testing" + + "github.com/stretchr/testify/suite" +) + +type utilSuite struct { + suite.Suite +} + +func (s *utilSuite) Test_mapToKVPairs() { + indexParams := map[string]string{ + "index_type": "IVF_FLAT", + "dim": "128", + "nlist": "1024", + } + + s.Equal(3, len(mapToKVPairs(indexParams))) +} + +func Test_utilSuite(t *testing.T) { + suite.Run(t, new(utilSuite)) +} diff --git a/internal/proto/cgo_msg.proto b/internal/proto/cgo_msg.proto new file mode 100644 index 0000000000000..6d851e95e0550 --- /dev/null +++ b/internal/proto/cgo_msg.proto @@ -0,0 +1,23 @@ +syntax = "proto3"; + +package milvus.proto.cgo; +option go_package="github.com/milvus-io/milvus/internal/proto/cgopb"; + +import "schema.proto"; + +message LoadIndexInfo { + int64 collectionID = 1; + int64 partitionID = 2; + int64 segmentID = 3; + schema.FieldSchema field = 5; + bool enable_mmap = 6; + string mmap_dir_path = 7; + int64 indexID = 8; + int64 index_buildID = 9; + int64 index_version = 10; + map index_params = 11; + repeated string index_files = 12; + string uri = 13; + int64 index_store_version = 14; + int32 index_engine_version = 15; +} diff --git a/internal/proto/index_cgo_msg.proto b/internal/proto/index_cgo_msg.proto index 50b1ea5dde5a5..688f871f55aed 100644 --- a/internal/proto/index_cgo_msg.proto +++ b/internal/proto/index_cgo_msg.proto @@ -4,6 +4,7 @@ package milvus.proto.indexcgo; option go_package="github.com/milvus-io/milvus/internal/proto/indexcgopb"; import "common.proto"; +import "schema.proto"; message TypeParams { repeated common.KeyValuePair params = 1; @@ -30,3 +31,52 @@ message Binary { message BinarySet { repeated Binary datas = 1; } + +// Synchronously modify StorageConfig in index_coord.proto file +message StorageConfig { + string address = 1; + string access_keyID = 2; + string secret_access_key = 3; + bool useSSL = 4; + string bucket_name = 5; + string root_path = 6; + bool useIAM = 7; + string IAMEndpoint = 8; + string storage_type = 9; + bool use_virtual_host = 10; + string region = 11; + string cloud_provider = 12; + int64 request_timeout_ms = 13; + string sslCACert = 14; +} + +// Synchronously modify OptionalFieldInfo in index_coord.proto file +message OptionalFieldInfo { + int64 fieldID = 1; + string field_name = 2; + int32 field_type = 3; + repeated string data_paths = 4; +} + +message BuildIndexInfo { + string clusterID = 1; + int64 buildID = 2; + int64 collectionID = 3; + int64 partitionID = 4; + int64 segmentID = 5; + int64 index_version = 6; + int32 current_index_version = 7; + int64 num_rows = 8; + int64 dim = 9; + string index_file_prefix = 10; + repeated string insert_files = 11; +// repeated int64 data_ids = 12; + schema.FieldSchema field_schema = 12; + StorageConfig storage_config = 13; + repeated common.KeyValuePair index_params = 14; + repeated common.KeyValuePair type_params = 15; + string store_path = 16; + int64 store_version = 17; + string index_store_path = 18; + repeated OptionalFieldInfo opt_fields = 19; +} diff --git a/internal/proto/index_coord.proto b/internal/proto/index_coord.proto index d59452b17d2de..9204d7da2a9c7 100644 --- a/internal/proto/index_coord.proto +++ b/internal/proto/index_coord.proto @@ -8,6 +8,7 @@ import "common.proto"; import "internal.proto"; import "milvus.proto"; import "schema.proto"; +import "index_cgo_msg.proto"; service IndexCoord { rpc GetComponentStates(milvus.GetComponentStatesRequest) returns (milvus.ComponentStates) {} @@ -226,6 +227,7 @@ message GetIndexBuildProgressResponse { int64 pending_index_rows = 4; } +// Synchronously modify StorageConfig in index_cgo_msg.proto file message StorageConfig { string address = 1; string access_keyID = 2; @@ -243,6 +245,7 @@ message StorageConfig { string sslCACert = 14; } +// Synchronously modify OptionalFieldInfo in index_cgo_msg.proto file message OptionalFieldInfo { int64 fieldID = 1; string field_name = 2; @@ -276,6 +279,7 @@ message CreateJobRequest { int64 dim = 22; repeated int64 data_ids = 23; repeated OptionalFieldInfo optional_scalar_fields = 24; + schema.FieldSchema field = 25; } message QueryJobsRequest { diff --git a/internal/querynodev2/segments/load_index_info.go b/internal/querynodev2/segments/load_index_info.go index c5c1572475c40..04632bed95f2d 100644 --- a/internal/querynodev2/segments/load_index_info.go +++ b/internal/querynodev2/segments/load_index_info.go @@ -29,11 +29,13 @@ import ( "runtime" "unsafe" + "github.com/golang/protobuf/proto" "github.com/pingcap/log" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/datacoord" + "github.com/milvus-io/milvus/internal/proto/cgopb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/pkg/common" @@ -245,3 +247,33 @@ func (li *LoadIndexInfo) appendIndexEngineVersion(ctx context.Context, indexEngi return HandleCStatus(ctx, &status, "AppendIndexEngineVersion failed") } + +func (li *LoadIndexInfo) finish(ctx context.Context, info *cgopb.LoadIndexInfo) error { + marshaled, err := proto.Marshal(info) + if err != nil { + return err + } + + var status C.CStatus + _, _ = GetDynamicPool().Submit(func() (any, error) { + status = C.FinishLoadIndexInfo(li.cLoadIndexInfo, (*C.uint8_t)(unsafe.Pointer(&marshaled[0])), (C.uint64_t)(len(marshaled))) + return nil, nil + }).Await() + + if err := HandleCStatus(ctx, &status, "FinishLoadIndexInfo failed"); err != nil { + return err + } + + _, _ = GetLoadPool().Submit(func() (any, error) { + if paramtable.Get().CommonCfg.EnableStorageV2.GetAsBool() { + status = C.AppendIndexV3(li.cLoadIndexInfo) + } else { + traceCtx := ParseCTraceContext(ctx) + status = C.AppendIndexV2(traceCtx.ctx, li.cLoadIndexInfo) + runtime.KeepAlive(traceCtx) + } + return nil, nil + }).Await() + + return HandleCStatus(ctx, &status, "AppendIndex failed") +} diff --git a/internal/querynodev2/segments/segment.go b/internal/querynodev2/segments/segment.go index ea864607e0090..075111e7b2b04 100644 --- a/internal/querynodev2/segments/segment.go +++ b/internal/querynodev2/segments/segment.go @@ -45,6 +45,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" milvus_storage "github.com/milvus-io/milvus-storage/go/storage" "github.com/milvus-io/milvus-storage/go/storage/options" + "github.com/milvus-io/milvus/internal/proto/cgopb" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/segcorepb" @@ -56,6 +57,9 @@ import ( "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/indexparamcheck" + "github.com/milvus-io/milvus/pkg/util/indexparams" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metautil" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -1262,18 +1266,58 @@ func (s *LocalSegment) LoadIndex(ctx context.Context, indexInfo *querypb.FieldIn return err } defer deleteLoadIndexInfo(loadIndexInfo) + + schema, err := typeutil.CreateSchemaHelper(s.GetCollection().Schema()) + if err != nil { + return err + } + fieldSchema, err := schema.GetFieldFromID(indexInfo.GetFieldID()) + if err != nil { + return err + } + + indexParams := funcutil.KeyValuePair2Map(indexInfo.IndexParams) + // as Knowhere reports error if encounter an unknown param, we need to delete it + delete(indexParams, common.MmapEnabledKey) + + // some build params also exist in indexParams, which are useless during loading process + if indexParams["index_type"] == indexparamcheck.IndexDISKANN { + if err := indexparams.SetDiskIndexLoadParams(paramtable.Get(), indexParams, indexInfo.GetNumRows()); err != nil { + return err + } + } + + if err := indexparams.AppendPrepareLoadParams(paramtable.Get(), indexParams); err != nil { + return err + } + + indexInfoProto := &cgopb.LoadIndexInfo{ + CollectionID: s.Collection(), + PartitionID: s.Partition(), + SegmentID: s.ID(), + Field: fieldSchema, + EnableMmap: isIndexMmapEnable(indexInfo), + MmapDirPath: paramtable.Get().QueryNodeCfg.MmapDirPath.GetValue(), + IndexID: indexInfo.GetIndexID(), + IndexBuildID: indexInfo.GetBuildID(), + IndexVersion: indexInfo.GetIndexVersion(), + IndexParams: indexParams, + IndexFiles: indexInfo.GetIndexFilePaths(), + IndexEngineVersion: indexInfo.GetCurrentIndexVersion(), + IndexStoreVersion: indexInfo.GetIndexStoreVersion(), + } + if paramtable.Get().CommonCfg.EnableStorageV2.GetAsBool() { uri, err := typeutil_internal.GetStorageURI(paramtable.Get().CommonCfg.StorageScheme.GetValue(), paramtable.Get().CommonCfg.StoragePathPrefix.GetValue(), s.ID()) if err != nil { return err } - loadIndexInfo.appendStorageInfo(uri, indexInfo.IndexStoreVersion) + indexInfoProto.Uri = uri } newLoadIndexInfoSpan := tr.RecordSpan() // 2. - err = loadIndexInfo.appendLoadIndexInfo(ctx, indexInfo, s.Collection(), s.Partition(), s.ID(), fieldType) - if err != nil { + if err := loadIndexInfo.finish(ctx, indexInfoProto); err != nil { if loadIndexInfo.cleanLocalData(ctx) != nil { log.Warn("failed to clean cached data on disk after append index failed", zap.Int64("buildID", indexInfo.BuildID), diff --git a/internal/util/indexcgowrapper/index.go b/internal/util/indexcgowrapper/index.go index f0850b3b916de..a7cc7d0e9b21c 100644 --- a/internal/util/indexcgowrapper/index.go +++ b/internal/util/indexcgowrapper/index.go @@ -16,6 +16,7 @@ import ( "unsafe" "github.com/golang/protobuf/proto" + "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -94,9 +95,17 @@ func NewCgoIndex(dtype schemapb.DataType, typeParams, indexParams map[string]str return index, nil } -func CreateIndex(ctx context.Context, buildIndexInfo *BuildIndexInfo) (CodecIndex, error) { +func CreateIndex(ctx context.Context, buildIndexInfo *indexcgopb.BuildIndexInfo) (CodecIndex, error) { + buildIndexInfoBlob, err := proto.Marshal(buildIndexInfo) + if err != nil { + log.Ctx(ctx).Warn("marshal buildIndexInfo failed", + zap.String("clusterID", buildIndexInfo.GetClusterID()), + zap.Int64("buildID", buildIndexInfo.GetBuildID()), + zap.Error(err)) + return nil, err + } var indexPtr C.CIndex - status := C.CreateIndex(&indexPtr, buildIndexInfo.cBuildIndexInfo) + status := C.CreateIndex(&indexPtr, (*C.uint8_t)(unsafe.Pointer(&buildIndexInfoBlob[0])), (C.uint64_t)(len(buildIndexInfoBlob))) if err := HandleCStatus(&status, "failed to create index"); err != nil { return nil, err } @@ -109,9 +118,17 @@ func CreateIndex(ctx context.Context, buildIndexInfo *BuildIndexInfo) (CodecInde return index, nil } -func CreateIndexV2(ctx context.Context, buildIndexInfo *BuildIndexInfo) (CodecIndex, error) { +func CreateIndexV2(ctx context.Context, buildIndexInfo *indexcgopb.BuildIndexInfo) (CodecIndex, error) { + buildIndexInfoBlob, err := proto.Marshal(buildIndexInfo) + if err != nil { + log.Ctx(ctx).Warn("marshal buildIndexInfo failed", + zap.String("clusterID", buildIndexInfo.GetClusterID()), + zap.Int64("buildID", buildIndexInfo.GetBuildID()), + zap.Error(err)) + return nil, err + } var indexPtr C.CIndex - status := C.CreateIndexV2(&indexPtr, buildIndexInfo.cBuildIndexInfo) + status := C.CreateIndexV2(&indexPtr, (*C.uint8_t)(unsafe.Pointer(&buildIndexInfoBlob[0])), (C.uint64_t)(len(buildIndexInfoBlob))) if err := HandleCStatus(&status, "failed to create index"); err != nil { return nil, err } diff --git a/pkg/util/indexparamcheck/inverted_checker.go b/pkg/util/indexparamcheck/inverted_checker.go index b15549cd4b7a6..dfc24127d3569 100644 --- a/pkg/util/indexparamcheck/inverted_checker.go +++ b/pkg/util/indexparamcheck/inverted_checker.go @@ -17,7 +17,8 @@ func (c *INVERTEDChecker) CheckTrain(params map[string]string) error { } func (c *INVERTEDChecker) CheckValidDataType(dType schemapb.DataType) error { - if !typeutil.IsBoolType(dType) && !typeutil.IsArithmetic(dType) && !typeutil.IsStringType(dType) { + if !typeutil.IsBoolType(dType) && !typeutil.IsArithmetic(dType) && !typeutil.IsStringType(dType) && + !typeutil.IsArrayType(dType) { return fmt.Errorf("INVERTED are not supported on %s field", dType.String()) } return nil diff --git a/pkg/util/indexparamcheck/inverted_checker_test.go b/pkg/util/indexparamcheck/inverted_checker_test.go index afe41f89f1193..7a31290061490 100644 --- a/pkg/util/indexparamcheck/inverted_checker_test.go +++ b/pkg/util/indexparamcheck/inverted_checker_test.go @@ -18,8 +18,8 @@ func Test_INVERTEDIndexChecker(t *testing.T) { assert.NoError(t, c.CheckValidDataType(schemapb.DataType_Bool)) assert.NoError(t, c.CheckValidDataType(schemapb.DataType_Int64)) assert.NoError(t, c.CheckValidDataType(schemapb.DataType_Float)) + assert.NoError(t, c.CheckValidDataType(schemapb.DataType_Array)) assert.Error(t, c.CheckValidDataType(schemapb.DataType_JSON)) - assert.Error(t, c.CheckValidDataType(schemapb.DataType_Array)) assert.Error(t, c.CheckValidDataType(schemapb.DataType_FloatVector)) } diff --git a/scripts/generate_proto.sh b/scripts/generate_proto.sh index 2551f586c9f9c..286570b842aa8 100755 --- a/scripts/generate_proto.sh +++ b/scripts/generate_proto.sh @@ -44,6 +44,7 @@ pushd ${PROTO_DIR} mkdir -p etcdpb mkdir -p indexcgopb +mkdir -p cgopb mkdir -p internalpb mkdir -p rootcoordpb @@ -62,6 +63,7 @@ protoc_opt="${PROTOC_BIN} --proto_path=${API_PROTO_DIR} --proto_path=." ${protoc_opt} --go_out=plugins=grpc,paths=source_relative:./etcdpb etcd_meta.proto || { echo 'generate etcd_meta.proto failed'; exit 1; } ${protoc_opt} --go_out=plugins=grpc,paths=source_relative:./indexcgopb index_cgo_msg.proto || { echo 'generate index_cgo_msg failed '; exit 1; } +${protoc_opt} --go_out=plugins=grpc,paths=source_relative:./cgopb cgo_msg.proto || { echo 'generate cgo_msg failed '; exit 1; } ${protoc_opt} --go_out=plugins=grpc,paths=source_relative:./rootcoordpb root_coord.proto || { echo 'generate root_coord.proto failed'; exit 1; } ${protoc_opt} --go_out=plugins=grpc,paths=source_relative:./internalpb internal.proto || { echo 'generate internal.proto failed'; exit 1; } ${protoc_opt} --go_out=plugins=grpc,paths=source_relative:./proxypb proxy.proto|| { echo 'generate proxy.proto failed'; exit 1; } @@ -78,6 +80,7 @@ ${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb schema.proto|| { echo 'generate sche ${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb common.proto|| { echo 'generate common.proto failed'; exit 1; } ${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb segcore.proto|| { echo 'generate segcore.proto failed'; exit 1; } ${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb index_cgo_msg.proto|| { echo 'generate index_cgo_msg.proto failed'; exit 1; } +${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb cgo_msg.proto|| { echo 'generate cgo_msg.proto failed'; exit 1; } ${protoc_opt} --cpp_out=$CPP_SRC_DIR/src/pb plan.proto|| { echo 'generate plan.proto failed'; exit 1; } popd diff --git a/tests/python_client/testcases/test_index.py b/tests/python_client/testcases/test_index.py index 21962385028d1..6e9d914625e67 100644 --- a/tests/python_client/testcases/test_index.py +++ b/tests/python_client/testcases/test_index.py @@ -1313,10 +1313,7 @@ def test_create_inverted_index_on_array_field(self): collection_w = self.init_collection_wrap(schema=schema) # 2. create index scalar_index_params = {"index_type": "INVERTED"} - collection_w.create_index(ct.default_int32_array_field_name, index_params=scalar_index_params, - check_task=CheckTasks.err_res, - check_items={ct.err_code: 1100, - ct.err_msg: "create index on Array field is not supported"}) + collection_w.create_index(ct.default_int32_array_field_name, index_params=scalar_index_params) @pytest.mark.tags(CaseLabel.L1) def test_create_inverted_index_no_vector_index(self): From 59d910320d3d5a056e1c610a93a95e44c4acb517 Mon Sep 17 00:00:00 2001 From: yanliang567 <82361606+yanliang567@users.noreply.github.com> Date: Mon, 24 Jun 2024 11:34:03 +0800 Subject: [PATCH 20/21] test:[cherry-pick]Update tests for range search and add test for query with dup ids (#34069) related issue: https://github.com/milvus-io/milvus/issues/33883 pr: #34057 Signed-off-by: yanliang567 --- tests/python_client/testcases/test_query.py | 35 +++++++ tests/python_client/testcases/test_search.py | 98 ++++++++++---------- 2 files changed, 84 insertions(+), 49 deletions(-) diff --git a/tests/python_client/testcases/test_query.py b/tests/python_client/testcases/test_query.py index baeff46d5817f..b8f55a3a7efba 100644 --- a/tests/python_client/testcases/test_query.py +++ b/tests/python_client/testcases/test_query.py @@ -2275,6 +2275,41 @@ def test_query_dup_ids_dup_term_array(self): collection_w.query(term_expr, output_fields=["*"], check_items=CheckTasks.check_query_results, check_task={exp_res: res}) + @pytest.mark.tags(CaseLabel.L1) + @pytest.mark.parametrize("with_growing", [True]) + def test_query_to_get_latest_entity_with_dup_ids(self, with_growing): + """ + target: test query to get latest entity with duplicate primary keys + method: 1.create collection and insert dup primary key = 0 + 2.query with expr=dup_id + expected: return the latest entity; verify the result is same as dedup entities + """ + collection_w = self.init_collection_general(prefix, dim=16, is_flush=False, insert_data=False, is_index=False, + vector_data_type=ct.float_type, with_json=False)[0] + nb = 50 + rounds = 10 + for i in range(rounds): + df = cf.gen_default_dataframe_data(dim=16, nb=nb, start=i * nb, with_json=False) + df[ct.default_int64_field_name] = i + collection_w.insert(df) + # re-insert the last piece of data in df to refresh the timestamp + last_piece = df.iloc[-1:] + collection_w.insert(last_piece) + + if not with_growing: + collection_w.flush() + collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_index) + collection_w.load() + # verify the result returns the latest entity if there are duplicate primary keys + expr = f'{ct.default_int64_field_name} == 0' + res = collection_w.query(expr=expr, output_fields=[ct.default_int64_field_name, ct.default_float_field_name])[0] + assert len(res) == 1 and res[0][ct.default_float_field_name] == (nb - 1) * 1.0 + + # verify the result is same as dedup entities + expr = f'{ct.default_int64_field_name} >= 0' + res = collection_w.query(expr=expr, output_fields=[ct.default_int64_field_name, ct.default_float_field_name])[0] + assert len(res) == rounds + @pytest.mark.tags(CaseLabel.L0) def test_query_after_index(self): """ diff --git a/tests/python_client/testcases/test_search.py b/tests/python_client/testcases/test_search.py index e18915a57504b..e980f220ef435 100644 --- a/tests/python_client/testcases/test_search.py +++ b/tests/python_client/testcases/test_search.py @@ -6926,20 +6926,22 @@ def enable_dynamic_field(self, request): """ @pytest.mark.tags(CaseLabel.L0) @pytest.mark.parametrize("vector_data_type", ct.all_dense_vector_types) - def test_range_search_default(self, index_type, metric, vector_data_type): + @pytest.mark.parametrize("with_growing", [False, True]) + def test_range_search_default(self, index_type, metric, vector_data_type, with_growing): """ target: verify the range search returns correct results - method: 1. create collection, insert 8000 vectors, + method: 1. create collection, insert 10k vectors, 2. search with topk=1000 3. range search from the 30th-330th distance as filter 4. verified the range search results is same as the search results in the range """ collection_w = self.init_collection_general(prefix, auto_id=True, insert_data=False, is_index=False, vector_data_type=vector_data_type, with_json=False)[0] - nb = 2000 - for i in range(3): - data = cf.gen_general_default_list_data(nb=nb, auto_id=True, - vector_data_type=vector_data_type, with_json=False) + nb = 1000 + rounds = 10 + for i in range(rounds): + data = cf.gen_general_default_list_data(nb=nb, auto_id=True, vector_data_type=vector_data_type, + with_json=False, start=i*nb) collection_w.insert(data) collection_w.flush() @@ -6947,51 +6949,49 @@ def test_range_search_default(self, index_type, metric, vector_data_type): collection_w.create_index(ct.default_float_vec_field_name, index_params=_index_params) collection_w.load() - for i in range(2): - with_growing = bool(i % 2) - if with_growing is True: - # add some growing segments - for _ in range(2): - data = cf.gen_general_default_list_data(nb=nb, auto_id=True, - vector_data_type=vector_data_type, with_json=False) - collection_w.insert(data) + if with_growing is True: + # add some growing segments + for j in range(rounds//2): + data = cf.gen_general_default_list_data(nb=nb, auto_id=True, vector_data_type=vector_data_type, + with_json=False, start=(rounds+j)*nb) + collection_w.insert(data) - search_params = {"params": {}} - nq = 1 - search_vectors = cf.gen_vectors(nq, ct.default_dim, vector_data_type=vector_data_type) - search_res = collection_w.search(search_vectors, default_search_field, - search_params, limit=1000)[0] - assert len(search_res[0].ids) == 1000 - log.debug(f"search topk=1000 returns {len(search_res[0].ids)}") - check_topk = 300 - check_from = 30 - ids = search_res[0].ids[check_from:check_from + check_topk] - radius = search_res[0].distances[check_from + check_topk] - range_filter = search_res[0].distances[check_from] - - # rebuild the collection with test target index - collection_w.release() - collection_w.indexes[0].drop() - _index_params = {"index_type": index_type, "metric_type": metric, - "params": cf.get_index_params_params(index_type)} - collection_w.create_index(ct.default_float_vec_field_name, index_params=_index_params) - collection_w.load() + search_params = {"params": {}} + nq = 1 + search_vectors = cf.gen_vectors(nq, ct.default_dim, vector_data_type=vector_data_type) + search_res = collection_w.search(search_vectors, default_search_field, + search_params, limit=1000)[0] + assert len(search_res[0].ids) == 1000 + log.debug(f"search topk=1000 returns {len(search_res[0].ids)}") + check_topk = 300 + check_from = 30 + ids = search_res[0].ids[check_from:check_from + check_topk] + radius = search_res[0].distances[check_from + check_topk] + range_filter = search_res[0].distances[check_from] + + # rebuild the collection with test target index + collection_w.release() + collection_w.indexes[0].drop() + _index_params = {"index_type": index_type, "metric_type": metric, + "params": cf.get_index_params_params(index_type)} + collection_w.create_index(ct.default_float_vec_field_name, index_params=_index_params) + collection_w.load() - params = cf.get_search_params_params(index_type) - params.update({"radius": radius, "range_filter": range_filter}) - if index_type == "HNSW": - params.update({"ef": check_topk+100}) - if index_type == "IVF_PQ": - params.update({"max_empty_result_buckets": 100}) - range_search_params = {"params": params} - range_res = collection_w.search(search_vectors, default_search_field, - range_search_params, limit=check_topk)[0] - range_ids = range_res[0].ids - # assert len(range_ids) == check_topk - log.debug(f"range search radius={radius}, range_filter={range_filter}, range results num: {len(range_ids)}") - hit_rate = round(len(set(ids).intersection(set(range_ids))) / len(set(ids)), 2) - log.debug(f"range search results with growing {bool(i % 2)} hit rate: {hit_rate}") - assert hit_rate >= 0.2 # issue #32630 to improve the accuracy + params = cf.get_search_params_params(index_type) + params.update({"radius": radius, "range_filter": range_filter}) + if index_type == "HNSW": + params.update({"ef": check_topk+100}) + if index_type == "IVF_PQ": + params.update({"max_empty_result_buckets": 100}) + range_search_params = {"params": params} + range_res = collection_w.search(search_vectors, default_search_field, + range_search_params, limit=check_topk)[0] + range_ids = range_res[0].ids + # assert len(range_ids) == check_topk + log.debug(f"range search radius={radius}, range_filter={range_filter}, range results num: {len(range_ids)}") + hit_rate = round(len(set(ids).intersection(set(range_ids))) / len(set(ids)), 2) + log.debug(f"{vector_data_type} range search results {index_type} {metric} with_growing {with_growing} hit_rate: {hit_rate}") + assert hit_rate >= 0.2 # issue #32630 to improve the accuracy @pytest.mark.tags(CaseLabel.L2) @pytest.mark.parametrize("range_filter", [1000, 1000.0]) From 059aaad3bb78c8c416daec99e35d2887926b929e Mon Sep 17 00:00:00 2001 From: GenkenWei <247910571@qq.com> Date: Mon, 24 Jun 2024 14:16:02 +0800 Subject: [PATCH 21/21] fix: update ubuntu base image version (#33927) related to #33945 FIX CVEs of milvus base image: LOW: 32, MEDIUM: 50, **Total FIX: 82** ![image](https://github.com/milvus-io/milvus/assets/27683687/020fbb73-9a1f-42cb-8224-a595179b3533) ![image](https://github.com/milvus-io/milvus/assets/27683687/fef403a5-c658-4ee6-baf4-ace3a0387799) Signed-off-by: weizhenkun Co-authored-by: weizhenkun --- build/docker/milvus/ubuntu20.04/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/docker/milvus/ubuntu20.04/Dockerfile b/build/docker/milvus/ubuntu20.04/Dockerfile index 842a948466698..670f89b3d2042 100644 --- a/build/docker/milvus/ubuntu20.04/Dockerfile +++ b/build/docker/milvus/ubuntu20.04/Dockerfile @@ -9,7 +9,7 @@ # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under the License. -FROM ubuntu:focal-20220426 +FROM ubuntu:focal-20240530 ARG TARGETARCH