diff --git a/grpc/src/main/java/io/metaloom/qdrant/client/grpc/method/PointMethods.java b/grpc/src/main/java/io/metaloom/qdrant/client/grpc/method/PointMethods.java index f15d063..5f6074a 100644 --- a/grpc/src/main/java/io/metaloom/qdrant/client/grpc/method/PointMethods.java +++ b/grpc/src/main/java/io/metaloom/qdrant/client/grpc/method/PointMethods.java @@ -16,6 +16,7 @@ import io.metaloom.qdrant.client.grpc.proto.Points.CountPoints; import io.metaloom.qdrant.client.grpc.proto.Points.CountResponse; import io.metaloom.qdrant.client.grpc.proto.Points.DeletePayloadPoints; +import io.metaloom.qdrant.client.grpc.proto.Points.DeletePointVectors; import io.metaloom.qdrant.client.grpc.proto.Points.DeletePoints; import io.metaloom.qdrant.client.grpc.proto.Points.Filter; import io.metaloom.qdrant.client.grpc.proto.Points.GetPoints; @@ -25,9 +26,14 @@ import io.metaloom.qdrant.client.grpc.proto.Points.PointsIdsList; import io.metaloom.qdrant.client.grpc.proto.Points.PointsOperationResponse; import io.metaloom.qdrant.client.grpc.proto.Points.PointsSelector; +import io.metaloom.qdrant.client.grpc.proto.Points.RecommendGroupsResponse; +import io.metaloom.qdrant.client.grpc.proto.Points.RecommendPointGroups; import io.metaloom.qdrant.client.grpc.proto.Points.ScrollPoints; import io.metaloom.qdrant.client.grpc.proto.Points.ScrollResponse; +import io.metaloom.qdrant.client.grpc.proto.Points.SearchGroupsResponse; +import io.metaloom.qdrant.client.grpc.proto.Points.SearchPointGroups; import io.metaloom.qdrant.client.grpc.proto.Points.SetPayloadPoints; +import io.metaloom.qdrant.client.grpc.proto.Points.UpdatePointVectors; import io.metaloom.qdrant.client.grpc.proto.Points.UpsertPoints; import io.metaloom.qdrant.client.grpc.proto.Points.WithPayloadSelector; import io.metaloom.qdrant.client.grpc.proto.Points.WithVectorsSelector; @@ -459,4 +465,28 @@ default GrpcClientRequest upsertPoints(String collectio () -> pointsAsyncStub(this).upsert(request.build())); } + default GrpcClientRequest updateVectors(UpdatePointVectors request) { + return request( + () -> pointsStub(this).updateVectors(request), + () -> pointsAsyncStub(this).updateVectors(request)); + } + + default GrpcClientRequest deleteVectors(DeletePointVectors request) { + return request( + () -> pointsStub(this).deleteVectors(request), + () -> pointsAsyncStub(this).deleteVectors(request)); + } + + default GrpcClientRequest searchGroupPoints(SearchPointGroups request) { + return request( + () -> pointsStub(this).searchGroups(request), + () -> pointsAsyncStub(this).searchGroups(request)); + } + + default GrpcClientRequest recommendGroupPoints(RecommendPointGroups request) { + return request( + () -> pointsStub(this).recommendGroups(request), + () -> pointsAsyncStub(this).recommendGroups(request)); + } + } diff --git a/grpc/src/main/java/io/metaloom/qdrant/client/grpc/method/SearchMethods.java b/grpc/src/main/java/io/metaloom/qdrant/client/grpc/method/SearchMethods.java index 19d4573..dced963 100644 --- a/grpc/src/main/java/io/metaloom/qdrant/client/grpc/method/SearchMethods.java +++ b/grpc/src/main/java/io/metaloom/qdrant/client/grpc/method/SearchMethods.java @@ -29,13 +29,14 @@ public interface SearchMethods extends ClientSettings { * Retrieve closest points based on vector similarity. * * @param collectionName + * @param vectorName * @param vector * @param limit * @param scoreThreshold * @return */ - default GrpcClientRequest searchPoints(String collectionName, float[] vector, long limit, Float scoreThreshold) { - return searchPoints(collectionName, vector, null, null, limit, null, null, null, scoreThreshold); + default GrpcClientRequest searchPoints(String collectionName, String vectorName, float[] vector, long limit, Float scoreThreshold) { + return searchPoints(collectionName, vectorName, vector, null, null, limit, null, null, null, scoreThreshold); } /** @@ -43,6 +44,8 @@ default GrpcClientRequest searchPoints(String collectionName, fl * * @param collectionName * Name of the collection to search in + * @param vectorName + * Optional name of the vector to be searched * @param vector * Vector data * @param filter @@ -62,7 +65,8 @@ default GrpcClientRequest searchPoints(String collectionName, fl * higher or smaller than the threshold depending on the Distance function used. E.g. for cosine similarity only higher scores will be returned. * @return */ - default GrpcClientRequest searchPoints(String collectionName, float[] vector, Filter filter, SearchParams params, long limit, + default GrpcClientRequest searchPoints(String collectionName, String vectorName, float[] vector, Filter filter, + SearchParams params, long limit, Long offset, WithPayloadSelector withPayloadSelector, WithVectorsSelector withVectorsSelector, Float scoreThreshold) { Objects.requireNonNull(collectionName, "A collection name must be specified"); @@ -78,6 +82,10 @@ default GrpcClientRequest searchPoints(String collectionName, fl .addAllVector(vectorList) .setCollectionName(collectionName); + if (vectorName != null) { + request.setVectorName(vectorName); + } + if (filter != null) { request.setFilter(filter); } diff --git a/grpc/src/main/java/io/metaloom/qdrant/client/util/ModelHelper.java b/grpc/src/main/java/io/metaloom/qdrant/client/util/ModelHelper.java index 3fe96ac..da53d2e 100644 --- a/grpc/src/main/java/io/metaloom/qdrant/client/util/ModelHelper.java +++ b/grpc/src/main/java/io/metaloom/qdrant/client/util/ModelHelper.java @@ -8,9 +8,13 @@ import io.metaloom.qdrant.client.grpc.proto.JsonWithInt; import io.metaloom.qdrant.client.grpc.proto.JsonWithInt.Value; +import io.metaloom.qdrant.client.grpc.proto.Points.NamedVectors; import io.metaloom.qdrant.client.grpc.proto.Points.PointId; import io.metaloom.qdrant.client.grpc.proto.Points.PointStruct; import io.metaloom.qdrant.client.grpc.proto.Points.PointStruct.Builder; +import io.metaloom.qdrant.client.grpc.proto.Points.PointVectors; +import io.metaloom.qdrant.client.grpc.proto.Points.PointsIdsList; +import io.metaloom.qdrant.client.grpc.proto.Points.PointsSelector; import io.metaloom.qdrant.client.grpc.proto.Points.Vector; import io.metaloom.qdrant.client.grpc.proto.Points.Vectors; import io.metaloom.qdrant.client.grpc.proto.Points.WithPayloadSelector; @@ -35,6 +39,14 @@ public static Vector vector(float[] vector) { return builder.build(); } + public static List vectorList(Float... vectors) { + List vectorList = new ArrayList<>(vectors.length); + for (float f : vectors) { + vectorList.add(Float.valueOf(f)); + } + return vectorList; + } + /** * Convert the string into a value model. * @@ -113,7 +125,7 @@ public static PointStruct point(long id, float[] vectorData, Map */ public static PointStruct point(PointId id, float[] vectorData, Map payload) { Objects.requireNonNull(id, "A pointId must be provided."); - Vector vector = ModelHelper.vector(vectorData); + Vector vector = vector(vectorData); Builder builder = PointStruct.newBuilder() .setId(id) .setVectors(Vectors.newBuilder().setVector(vector)); @@ -123,6 +135,31 @@ public static PointStruct point(PointId id, float[] vectorData, Map payload) { + return namedPoint(pointId(id), vectorName, vectorData, payload); + } + + public static PointStruct namedPoint(PointId id, String vectorName, float[] vectorData, Map payload) { + Objects.requireNonNull(id, "A pointId must be provided."); + NamedVectors vectors = namedVector(vectorName, vectorData); + Builder builder = PointStruct.newBuilder() + .setId(id) + .setVectors(Vectors.newBuilder().setVectors(vectors)); + if (payload != null) { + builder.putAllPayload(payload); + } + return builder.build(); + } + + public static PointsSelector selectByIds(Long... ids) { + PointsIdsList.Builder pointList = PointsIdsList.newBuilder(); + for (Long id : ids) { + pointList.addIds(pointId(id)); + } + + return PointsSelector.newBuilder().setPoints(pointList.build()).build(); + } + public static WithPayloadSelector withPayload() { return WithPayloadSelector.newBuilder().setEnable(true).build(); } @@ -130,4 +167,20 @@ public static WithPayloadSelector withPayload() { public static WithVectorsSelector withVector() { return WithVectorsSelector.newBuilder().setEnable(true).build(); } + + public static PointVectors pointVector(long id, String name, float[] vector) { + Vectors vectors = Vectors.newBuilder() + .setVectors(namedVector(name, vector)) + .build(); + return PointVectors.newBuilder() + .setId(pointId(id)) + .setVectors(vectors) + .build(); + } + + public static NamedVectors namedVector(String name, float[] vector) { + return NamedVectors.newBuilder() + .putVectors(name, vector(vector)) + .build(); + } } diff --git a/grpc/src/test/java/io/metaloom/qdrant/client/grpc/AbstractGRPCClientTest.java b/grpc/src/test/java/io/metaloom/qdrant/client/grpc/AbstractGRPCClientTest.java index 3e25dcf..54db138 100644 --- a/grpc/src/test/java/io/metaloom/qdrant/client/grpc/AbstractGRPCClientTest.java +++ b/grpc/src/test/java/io/metaloom/qdrant/client/grpc/AbstractGRPCClientTest.java @@ -4,8 +4,8 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import org.junit.After; -import org.junit.Before; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; import com.google.protobuf.GeneratedMessageV3; @@ -15,16 +15,21 @@ import io.metaloom.qdrant.client.grpc.proto.Collections.CreateAlias; import io.metaloom.qdrant.client.grpc.proto.Collections.Distance; import io.metaloom.qdrant.client.grpc.proto.Collections.VectorParams; +import io.metaloom.qdrant.client.grpc.proto.Collections.VectorParamsMap; public abstract class AbstractGRPCClientTest extends AbstractContainerTest { public static final String TEST_COLLECTION_NAME = "the-test-collection"; + public static final String TEST_VECTOR_NAME = "color"; + + public static final String TEST_VECTOR_NAME_2 = "color-2"; + public static final String TEST_ALIAS_NAME = "new-alias-name"; protected QDrantGRPCClient client; - @Before + @BeforeEach public void setupClient() { client = QDrantGRPCClient.builder() .setHostname(qdrant.getHost()) @@ -32,7 +37,7 @@ public void setupClient() { .build(); } - @After + @AfterEach public void closeClient() { if (client != null) { client.close(); @@ -45,8 +50,14 @@ protected void createCollection(String collectionName) { .setDistance(Distance.Euclid) .build(); + // Add the params to a map + VectorParamsMap paramsMap = VectorParamsMap.newBuilder() + .putMap(TEST_VECTOR_NAME, params) + .putMap(TEST_VECTOR_NAME_2, params) + .build(); + // Create new collections - assertTrue(client.createCollection(collectionName, params).sync().getResult()); + assertTrue(client.createCollection(collectionName, paramsMap).sync().getResult()); } protected void createAlias(String collectionName, String aliasName) { diff --git a/grpc/src/test/java/io/metaloom/qdrant/client/grpc/BasicUsageExampleTest.java b/grpc/src/test/java/io/metaloom/qdrant/client/grpc/BasicUsageExampleTest.java index 373f6f0..dbe63e3 100644 --- a/grpc/src/test/java/io/metaloom/qdrant/client/grpc/BasicUsageExampleTest.java +++ b/grpc/src/test/java/io/metaloom/qdrant/client/grpc/BasicUsageExampleTest.java @@ -4,11 +4,12 @@ import java.util.List; import java.util.Map; -import org.junit.Test; +import org.junit.jupiter.api.Test; import io.metaloom.qdrant.client.AbstractContainerTest; import io.metaloom.qdrant.client.grpc.proto.Collections.Distance; import io.metaloom.qdrant.client.grpc.proto.Collections.VectorParams; +import io.metaloom.qdrant.client.grpc.proto.Collections.VectorParamsMap; import io.metaloom.qdrant.client.grpc.proto.JsonWithInt.Value; import io.metaloom.qdrant.client.grpc.proto.Points.PointStruct; import io.metaloom.qdrant.client.grpc.proto.Points.ScoredPoint; @@ -33,8 +34,14 @@ public void testExample() throws Exception { .setDistance(Distance.Euclid) .build(); + // Add the params to a map + VectorParamsMap paramsMap = VectorParamsMap.newBuilder() + .putMap("firstVector", params) + .putMap("secondVector", params) + .build(); + // Create new collections - blocking - client.createCollection("test1", params).sync(); + client.createCollection("test1", paramsMap).sync(); // .. or via Future API client.createCollection("test2", params).async().get(); // .. or via RxJava API @@ -51,7 +58,7 @@ public void testExample() throws Exception { payload.put("color", ModelHelper.value("blue")); // Now construct the point - PointStruct point = ModelHelper.point(42L + i, vector, payload); + PointStruct point = ModelHelper.namedPoint(42L + i, "firstVector", vector, payload); // .. and insert it client.upsertPoint("test1", point, true).sync(); } @@ -61,7 +68,7 @@ public void testExample() throws Exception { // Now run KNN search float[] searchVector = new float[] { 0.43f, 0.09f, 0.41f, 1.35f }; - List searchResults = client.searchPoints("test1", searchVector, 2, null).sync().getResultList(); + List searchResults = client.searchPoints("test1", "firstVector", searchVector, 2, null).sync().getResultList(); for (ScoredPoint result : searchResults) { System.out.println("Found: [" + result.getId().getNum() + "] " + result.getScore()); } diff --git a/grpc/src/test/java/io/metaloom/qdrant/client/grpc/ClusterGRPCClientTest.java b/grpc/src/test/java/io/metaloom/qdrant/client/grpc/ClusterGRPCClientTest.java index 644d913..9e9a709 100644 --- a/grpc/src/test/java/io/metaloom/qdrant/client/grpc/ClusterGRPCClientTest.java +++ b/grpc/src/test/java/io/metaloom/qdrant/client/grpc/ClusterGRPCClientTest.java @@ -1,7 +1,7 @@ package io.metaloom.qdrant.client.grpc; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import io.metaloom.qdrant.client.testcases.ClusterClientTestcases; @@ -9,28 +9,28 @@ public class ClusterGRPCClientTest extends AbstractGRPCClientTest implements Clu @Test @Override - @Ignore("Not supported for gRPC") + @Disabled("Not supported for gRPC") public void testGetClusterStatusInfo() throws Exception { } @Test @Override - @Ignore("Not supported for gRPC") + @Disabled("Not supported for gRPC") public void testRemovePeerFromCluster() throws Exception { } @Test @Override - @Ignore("Not supported for gRPC") + @Disabled("Not supported for gRPC") public void testCollectionClusterInfo() throws Exception { } @Test @Override - @Ignore("Not supported for gRPC") + @Disabled("Not supported for gRPC") public void testUpdateCollectionClusterSetup() throws Exception { } diff --git a/grpc/src/test/java/io/metaloom/qdrant/client/grpc/method/CollectionGRPCClientTest.java b/grpc/src/test/java/io/metaloom/qdrant/client/grpc/method/CollectionGRPCClientTest.java index a27e6d4..ae692e6 100644 --- a/grpc/src/test/java/io/metaloom/qdrant/client/grpc/method/CollectionGRPCClientTest.java +++ b/grpc/src/test/java/io/metaloom/qdrant/client/grpc/method/CollectionGRPCClientTest.java @@ -6,7 +6,7 @@ import java.util.List; -import org.junit.Test; +import org.junit.jupiter.api.Test; import io.metaloom.qdrant.client.grpc.AbstractGRPCClientTest; import io.metaloom.qdrant.client.grpc.proto.Collections.AliasDescription; @@ -29,7 +29,10 @@ public void testCreateCollectionWithNamedVectorParams() throws Exception { .build(); // Add the params to a map - VectorParamsMap paramsMap = VectorParamsMap.newBuilder().putMap("colors", params).build(); + VectorParamsMap paramsMap = VectorParamsMap.newBuilder() + .putMap(TEST_VECTOR_NAME, params) + .putMap(TEST_VECTOR_NAME_2, params) + .build(); // Create new collections assertTrue(client.createCollection(TEST_COLLECTION_NAME, paramsMap).sync().getResult()); @@ -39,7 +42,8 @@ public void testCreateCollectionWithNamedVectorParams() throws Exception { System.out.println(collection.getName()); GetCollectionInfoResponse info = client.loadCollections(collection.getName()).sync(); VectorsConfig config = info.getResult().getConfig().getParams().getVectorsConfig(); - assertTrue("The config did not contain the colors vector parameters.", config.getParamsMap().containsMap("colors")); + assertTrue("The config did not contain the colors vector parameters.", config.getParamsMap().containsMap(TEST_VECTOR_NAME)); + assertTrue("The config did not contain the colors-2 vector parameters.", config.getParamsMap().containsMap(TEST_VECTOR_NAME_2)); } } @@ -100,9 +104,9 @@ public void testGetCollectionInfo() throws Exception { GetCollectionInfoResponse info = client.loadCollections(TEST_COLLECTION_NAME).sync(); assertEquals("Could not load correct distance from collection", Distance.Euclid, - info.getResult().getConfig().getParams().getVectorsConfig().getParams().getDistance()); + info.getResult().getConfig().getParams().getVectorsConfig().getParamsMap().getMapMap().get(TEST_VECTOR_NAME).getDistance()); assertEquals("Could not load correct dimension from collection", 4, - info.getResult().getConfig().getParams().getVectorsConfig().getParams().getSize()); + info.getResult().getConfig().getParams().getVectorsConfig().getParamsMap().getMapMap().get(TEST_VECTOR_NAME).getSize()); } @Test diff --git a/grpc/src/test/java/io/metaloom/qdrant/client/grpc/method/PointGRPCClientTest.java b/grpc/src/test/java/io/metaloom/qdrant/client/grpc/method/PointGRPCClientTest.java index afaa6c5..958b5c3 100644 --- a/grpc/src/test/java/io/metaloom/qdrant/client/grpc/method/PointGRPCClientTest.java +++ b/grpc/src/test/java/io/metaloom/qdrant/client/grpc/method/PointGRPCClientTest.java @@ -1,9 +1,11 @@ package io.metaloom.qdrant.client.grpc.method; -import static io.metaloom.qdrant.client.util.ModelHelper.point; +import static io.metaloom.qdrant.client.util.ModelHelper.namedPoint; import static io.metaloom.qdrant.client.util.ModelHelper.pointId; import static io.metaloom.qdrant.client.util.ModelHelper.pointIds; +import static io.metaloom.qdrant.client.util.ModelHelper.selectByIds; import static io.metaloom.qdrant.client.util.ModelHelper.value; +import static io.metaloom.qdrant.client.util.ModelHelper.vectorList; import static io.metaloom.qdrant.client.util.ModelHelper.withPayload; import static io.metaloom.qdrant.client.util.ModelHelper.withVector; import static io.metaloom.qdrant.client.util.VectorUtil.toList; @@ -19,43 +21,49 @@ import java.util.Map; import java.util.UUID; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.testcontainers.shaded.com.google.common.collect.Sets; import io.metaloom.qdrant.client.grpc.AbstractGRPCClientTest; import io.metaloom.qdrant.client.grpc.proto.JsonWithInt.Value; +import io.metaloom.qdrant.client.grpc.proto.Points.DeletePointVectors; import io.metaloom.qdrant.client.grpc.proto.Points.GetResponse; +import io.metaloom.qdrant.client.grpc.proto.Points.NamedVectors; import io.metaloom.qdrant.client.grpc.proto.Points.PointStruct; +import io.metaloom.qdrant.client.grpc.proto.Points.PointVectors; import io.metaloom.qdrant.client.grpc.proto.Points.PointsOperationResponse; import io.metaloom.qdrant.client.grpc.proto.Points.RecommendBatchResponse; +import io.metaloom.qdrant.client.grpc.proto.Points.RecommendPointGroups; import io.metaloom.qdrant.client.grpc.proto.Points.RecommendPoints; import io.metaloom.qdrant.client.grpc.proto.Points.RecommendResponse; import io.metaloom.qdrant.client.grpc.proto.Points.RetrievedPoint; import io.metaloom.qdrant.client.grpc.proto.Points.ScoredPoint; import io.metaloom.qdrant.client.grpc.proto.Points.ScrollResponse; import io.metaloom.qdrant.client.grpc.proto.Points.SearchBatchResponse; +import io.metaloom.qdrant.client.grpc.proto.Points.SearchPointGroups; import io.metaloom.qdrant.client.grpc.proto.Points.SearchPoints; import io.metaloom.qdrant.client.grpc.proto.Points.SearchResponse; +import io.metaloom.qdrant.client.grpc.proto.Points.UpdatePointVectors; import io.metaloom.qdrant.client.grpc.proto.Points.UpdateStatus; -import io.metaloom.qdrant.client.grpc.proto.Points.Vector; import io.metaloom.qdrant.client.grpc.proto.Points.Vectors; +import io.metaloom.qdrant.client.grpc.proto.Points.VectorsSelector; import io.metaloom.qdrant.client.testcases.PointClientTestcases; import io.metaloom.qdrant.client.util.ModelHelper; public class PointGRPCClientTest extends AbstractGRPCClientTest implements PointClientTestcases { - @Before + @BeforeEach public void setupTestData() { createCollection(TEST_COLLECTION_NAME); Map values = new HashMap<>(); - values.put("color", value("blue")); - PointStruct p1 = point(42L, new float[] { 7.43f, 0.1f, 0.25f, 1.5f }, values); - PointStruct p2 = point(43L, new float[] { 0.45f, 2.61f, 0.88f, 6.25f }, values); - PointStruct p3 = point(44L, new float[] { 2.41f, 0.9f, 0.81f, 2.45f }, values); - PointStruct p4 = point(45L, new float[] { 0.42f, 1.0f, 0.51f, 5.85f }, values); + values.put(TEST_VECTOR_NAME, value("blue")); + PointStruct p1 = namedPoint(pointId(42L), TEST_VECTOR_NAME, new float[] { 7.43f, 0.1f, 0.25f, 1.5f }, values); + PointStruct p2 = namedPoint(pointId(43L), TEST_VECTOR_NAME, new float[] { 0.45f, 2.61f, 0.88f, 6.25f }, values); + PointStruct p3 = namedPoint(pointId(44L), TEST_VECTOR_NAME, new float[] { 2.41f, 0.9f, 0.81f, 2.45f }, values); + PointStruct p4 = namedPoint(pointId(45L), TEST_VECTOR_NAME, new float[] { 0.42f, 1.0f, 0.51f, 5.85f }, values); client.upsertPoint(TEST_COLLECTION_NAME, p1, true).sync(); client.upsertPoint(TEST_COLLECTION_NAME, p2, true).sync(); @@ -86,7 +94,7 @@ public void testUpsertPointViaList() throws Exception { float[] vec = new float[] { 0, 0, 0, 0 }; List list = new ArrayList<>(); for (int i = 0; i < 10; i++) { - list.add(point(80L + i, vec, null)); + list.add(namedPoint(80L + i, TEST_VECTOR_NAME, vec, null)); } client.upsertPoints(TEST_COLLECTION_NAME, list, true).sync(); @@ -95,20 +103,20 @@ public void testUpsertPointViaList() throws Exception { @Test public void testUpsertPointWithUuid() { - PointStruct point = point(UUID.randomUUID(), new float[] { 0.2f, 0.1f, 0.3f, 0.4f }, null); + PointStruct point = namedPoint(pointId(UUID.randomUUID()), TEST_VECTOR_NAME, new float[] { 0.2f, 0.1f, 0.3f, 0.4f }, null); client.upsertPoint(TEST_COLLECTION_NAME, point, true).sync(); assertPointCount(4 + 1, TEST_COLLECTION_NAME); } @Test @Override - @Ignore("Not supported via gRPC") + @Disabled("Not supported via gRPC") public void testUpsertPointsViaListBatch() throws Exception { } @Test @Override - @Ignore("Not supported via gRPC") + @Disabled("Not supported via gRPC") public void testUpsertPointsViaNamedBatch() throws Exception { } @@ -163,10 +171,10 @@ public void testDeletePointPayload() throws Exception { HashSet keys = Sets.newHashSet(extraKey); PointsOperationResponse response2 = client.deletePayload(TEST_COLLECTION_NAME, true, keys, null, pointId(42L)).sync(); assertEquals(UpdateStatus.Completed, response2.getResult().getStatus()); - + // And assert the operation Map after = client.getPoint(TEST_COLLECTION_NAME, true, false, pointId(42)).sync().getResult(0).getPayloadMap(); - assertEquals("There should only be one remaining prop",1, after.size()); + assertEquals("There should only be one remaining prop", 1, after.size()); assertFalse("The prop should have been deleted", after.containsKey(extraKey)); assertTrue("The prop should be still there", after.containsKey("color")); } @@ -197,7 +205,7 @@ public void testScrollPoints() throws Exception { @Override public void testSearchPoints() throws Exception { float[] vector = new float[] { 2.41f, 0.9f, 0.81f, 2.45f }; - SearchResponse response = client.searchPoints(TEST_COLLECTION_NAME, vector, 2, 100f).sync(); + SearchResponse response = client.searchPoints(TEST_COLLECTION_NAME, TEST_VECTOR_NAME, vector, 2, 100f).sync(); List list = response.getResultList(); assertFalse(list.isEmpty()); assertEquals("The first result should be exactly 0 scrore since it used the stored vector of a point for the search", 0f, @@ -215,6 +223,7 @@ public void testSearchBatchPoints() throws Exception { SearchPoints.Builder request = SearchPoints.newBuilder() .setLimit(10) .addAllVector(vectorList) + .setVectorName(TEST_VECTOR_NAME) .setCollectionName(TEST_COLLECTION_NAME); searches.add(request.build()); @@ -226,7 +235,7 @@ public void testSearchBatchPoints() throws Exception { @Test @Override public void testRecommendPoints() throws Exception { - RecommendResponse result = client.recommendPoints(TEST_COLLECTION_NAME, pointIds(42L), 2, null).sync(); + RecommendResponse result = client.recommendPoints(TEST_COLLECTION_NAME, pointIds(42L), 2, TEST_VECTOR_NAME).sync(); List list = result.getResultList(); assertFalse(list.isEmpty()); assertEquals("The result was not limited to two results.", 2, list.size()); @@ -240,6 +249,7 @@ public void testRecommendBatchPoints() throws Exception { RecommendPoints.Builder request = RecommendPoints.newBuilder() .setCollectionName(TEST_COLLECTION_NAME) .addAllPositive(pointIds(42L)) + .setUsing(TEST_VECTOR_NAME) .setLimit(10); searches.add(request.build()); @@ -253,15 +263,94 @@ public void testRecommendBatchPoints() throws Exception { public void testCountPoints() throws Exception { // Insert a new vector for (int i = 0; i < 10; i++) { - Vector vector = ModelHelper.vector(new float[] { 0.43f + i, 0.1f, 0.61f, 1.45f }); + NamedVectors vectors = ModelHelper.namedVector(TEST_VECTOR_NAME, new float[] { 0.43f + i, 0.1f, 0.61f, 1.45f }); PointStruct point = PointStruct.newBuilder() - .putPayload("color", ModelHelper.value("blue")) + .putPayload(TEST_VECTOR_NAME, ModelHelper.value("blue")) .setId(ModelHelper.pointId(82L + i)) - .setVectors(Vectors.newBuilder().setVector(vector)) + .setVectors(Vectors.newBuilder().setVectors(vectors).build()) .build(); assertEquals(UpdateStatus.Completed, client.upsertPoint(TEST_COLLECTION_NAME, point, true).sync().getResult().getStatus()); } assertPointCount(14, TEST_COLLECTION_NAME); } + @Test + @Override + public void testUpdateVectors() throws Exception { + GetResponse response = client.getPoint(TEST_COLLECTION_NAME, true, true, pointId(42L)).sync(); + assertFalse(response.getResult(0).getVectors().getVectors().containsVectors(TEST_VECTOR_NAME_2)); + + UpdatePointVectors request = UpdatePointVectors + .newBuilder() + .setCollectionName(TEST_COLLECTION_NAME) + .addPoints(PointVectors.newBuilder() + .setId(ModelHelper.pointId(42L)) + .setVectors(Vectors.newBuilder().setVectors(ModelHelper.namedVector(TEST_VECTOR_NAME_2, new float[] { 0.4f, 0.1f, 0.2f, 0.3f }))) + .build()) + .build(); + client.updateVectors(request).sync(); + + GetResponse response2 = client.getPoint(TEST_COLLECTION_NAME, true, true, pointId(42L)).sync(); + assertTrue(response2.getResult(0).getVectors().getVectors().containsVectors(TEST_VECTOR_NAME_2)); + } + + @Test + @Override + public void testDeleteVectors() throws Exception { + UpdatePointVectors request = UpdatePointVectors + .newBuilder() + .setCollectionName(TEST_COLLECTION_NAME) + .addPoints(PointVectors.newBuilder() + .setId(ModelHelper.pointId(42L)) + .setVectors(Vectors.newBuilder().setVectors(ModelHelper.namedVector(TEST_VECTOR_NAME_2, new float[] { 0.4f, 0.1f, 0.2f, 0.3f }))) + .build()) + .build(); + client.updateVectors(request).sync(); + + DeletePointVectors request2 = DeletePointVectors.newBuilder() + .setCollectionName(TEST_COLLECTION_NAME) + .setWait(true) + .setPointsSelector(selectByIds(42L)) + .setVectors(VectorsSelector.newBuilder().addNames(TEST_VECTOR_NAME_2)) + .build(); + client.deleteVectors(request2).sync(); + + GetResponse response = client.getPoint(TEST_COLLECTION_NAME, true, true, pointId(42L)).sync(); + assertFalse(response.getResult(0).getVectors().getVectors().containsVectors(TEST_VECTOR_NAME_2)); + + } + + @Test + @Override + public void testSearchGroupPoints() throws Exception { + SearchPointGroups request = SearchPointGroups.newBuilder() + .setCollectionName(TEST_COLLECTION_NAME) + .setGroupBy(TEST_VECTOR_NAME) + .setGroupSize(10) + .setWithVectors(withVector()) + .setWithPayload(withPayload()) + .setLimit(10) + .setVectorName(TEST_VECTOR_NAME) + .addAllVector(vectorList(0.1f, 0.2f, 0.3f, 0.4f)) + .build(); + client.searchGroupPoints(request).sync(); + } + + @Test + @Override + public void testRecommendGroupPoints() throws Exception { + RecommendPointGroups request = RecommendPointGroups.newBuilder() + .setCollectionName(TEST_COLLECTION_NAME) + .setGroupBy("color") + .setGroupSize(10) + .setWithVectors(withVector()) + .setWithPayload(withPayload()) + .addAllPositive(pointIds(42L)) + .setUsing(TEST_VECTOR_NAME) + .setLimit(10) + .build(); + client.recommendGroupPoints(request).sync(); + + } + } diff --git a/grpc/src/test/java/io/metaloom/qdrant/client/grpc/method/ServiceGRPCClientTest.java b/grpc/src/test/java/io/metaloom/qdrant/client/grpc/method/ServiceGRPCClientTest.java index 32d8a8e..d5d446a 100644 --- a/grpc/src/test/java/io/metaloom/qdrant/client/grpc/method/ServiceGRPCClientTest.java +++ b/grpc/src/test/java/io/metaloom/qdrant/client/grpc/method/ServiceGRPCClientTest.java @@ -1,7 +1,7 @@ package io.metaloom.qdrant.client.grpc.method; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import io.metaloom.qdrant.client.grpc.AbstractGRPCClientTest; import io.metaloom.qdrant.client.testcases.ServiceClientTestcases; @@ -10,14 +10,14 @@ public class ServiceGRPCClientTest extends AbstractGRPCClientTest implements Ser @Test @Override - @Ignore("Not supported for gRPC") + @Disabled("Not supported for gRPC") public void testCollectTelemetryData() throws Exception { } @Test @Override - @Ignore("Not supported for gRPC") + @Disabled("Not supported for gRPC") public void testLockOptions() throws Exception { } diff --git a/grpc/src/test/java/io/metaloom/qdrant/client/grpc/method/SnapshotGRPCClientTest.java b/grpc/src/test/java/io/metaloom/qdrant/client/grpc/method/SnapshotGRPCClientTest.java index ab12b98..0386c5c 100644 --- a/grpc/src/test/java/io/metaloom/qdrant/client/grpc/method/SnapshotGRPCClientTest.java +++ b/grpc/src/test/java/io/metaloom/qdrant/client/grpc/method/SnapshotGRPCClientTest.java @@ -6,8 +6,8 @@ import java.util.List; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import io.metaloom.qdrant.client.grpc.AbstractGRPCClientTest; import io.metaloom.qdrant.client.grpc.proto.SnapshotsService.CreateSnapshotResponse; @@ -53,14 +53,14 @@ public void testDeleteCollectionSnapshot() throws Exception { @Test @Override - @Ignore("Not supported for gRPC") + @Disabled("Not supported for gRPC") public void testDownloadCollectionSnapshot() throws Exception { } @Test @Override - @Ignore("Not supported for gRPC") + @Disabled("Not supported for gRPC") public void testRecoverCollectionSnapshot() throws Exception { } @@ -88,7 +88,7 @@ public void testListStorageSnapshot() throws Exception { @Test @Override - @Ignore("Not supported for gRPC") + @Disabled("Not supported for gRPC") public void testDownloadStorageSnapshot() throws Exception { // TODO Auto-generated method stub @@ -96,7 +96,7 @@ public void testDownloadStorageSnapshot() throws Exception { @Test @Override - @Ignore("Not supported for gRPC") + @Disabled("Not supported for gRPC") public void testRecoverStorageSnapshot() throws Exception { // TODO Auto-generated method stub