Skip to content

Commit

Permalink
Add missing gRPC methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Jotschi committed May 24, 2023
1 parent 8396b74 commit 6bd91de
Show file tree
Hide file tree
Showing 10 changed files with 258 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.metaloom.qdrant.client.grpc.proto.Points.CountPoints;
import io.metaloom.qdrant.client.grpc.proto.Points.CountResponse;
import io.metaloom.qdrant.client.grpc.proto.Points.DeletePayloadPoints;
import io.metaloom.qdrant.client.grpc.proto.Points.DeletePointVectors;
import io.metaloom.qdrant.client.grpc.proto.Points.DeletePoints;
import io.metaloom.qdrant.client.grpc.proto.Points.Filter;
import io.metaloom.qdrant.client.grpc.proto.Points.GetPoints;
Expand All @@ -25,9 +26,14 @@
import io.metaloom.qdrant.client.grpc.proto.Points.PointsIdsList;
import io.metaloom.qdrant.client.grpc.proto.Points.PointsOperationResponse;
import io.metaloom.qdrant.client.grpc.proto.Points.PointsSelector;
import io.metaloom.qdrant.client.grpc.proto.Points.RecommendGroupsResponse;
import io.metaloom.qdrant.client.grpc.proto.Points.RecommendPointGroups;
import io.metaloom.qdrant.client.grpc.proto.Points.ScrollPoints;
import io.metaloom.qdrant.client.grpc.proto.Points.ScrollResponse;
import io.metaloom.qdrant.client.grpc.proto.Points.SearchGroupsResponse;
import io.metaloom.qdrant.client.grpc.proto.Points.SearchPointGroups;
import io.metaloom.qdrant.client.grpc.proto.Points.SetPayloadPoints;
import io.metaloom.qdrant.client.grpc.proto.Points.UpdatePointVectors;
import io.metaloom.qdrant.client.grpc.proto.Points.UpsertPoints;
import io.metaloom.qdrant.client.grpc.proto.Points.WithPayloadSelector;
import io.metaloom.qdrant.client.grpc.proto.Points.WithVectorsSelector;
Expand Down Expand Up @@ -459,4 +465,28 @@ default GrpcClientRequest<PointsOperationResponse> upsertPoints(String collectio
() -> pointsAsyncStub(this).upsert(request.build()));
}

default GrpcClientRequest<PointsOperationResponse> updateVectors(UpdatePointVectors request) {
return request(
() -> pointsStub(this).updateVectors(request),
() -> pointsAsyncStub(this).updateVectors(request));
}

default GrpcClientRequest<PointsOperationResponse> deleteVectors(DeletePointVectors request) {
return request(
() -> pointsStub(this).deleteVectors(request),
() -> pointsAsyncStub(this).deleteVectors(request));
}

default GrpcClientRequest<SearchGroupsResponse> searchGroupPoints(SearchPointGroups request) {
return request(
() -> pointsStub(this).searchGroups(request),
() -> pointsAsyncStub(this).searchGroups(request));
}

default GrpcClientRequest<RecommendGroupsResponse> recommendGroupPoints(RecommendPointGroups request) {
return request(
() -> pointsStub(this).recommendGroups(request),
() -> pointsAsyncStub(this).recommendGroups(request));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,23 @@ public interface SearchMethods extends ClientSettings {
* Retrieve closest points based on vector similarity.
*
* @param collectionName
* @param vectorName
* @param vector
* @param limit
* @param scoreThreshold
* @return
*/
default GrpcClientRequest<SearchResponse> searchPoints(String collectionName, float[] vector, long limit, Float scoreThreshold) {
return searchPoints(collectionName, vector, null, null, limit, null, null, null, scoreThreshold);
default GrpcClientRequest<SearchResponse> searchPoints(String collectionName, String vectorName, float[] vector, long limit, Float scoreThreshold) {
return searchPoints(collectionName, vectorName, vector, null, null, limit, null, null, null, scoreThreshold);
}

/**
* Retrieve closest points based on vector similarity and given filtering conditions.
*
* @param collectionName
* Name of the collection to search in
* @param vectorName
* Optional name of the vector to be searched
* @param vector
* Vector data
* @param filter
Expand All @@ -62,7 +65,8 @@ default GrpcClientRequest<SearchResponse> searchPoints(String collectionName, fl
* higher or smaller than the threshold depending on the Distance function used. E.g. for cosine similarity only higher scores will be returned.
* @return
*/
default GrpcClientRequest<SearchResponse> searchPoints(String collectionName, float[] vector, Filter filter, SearchParams params, long limit,
default GrpcClientRequest<SearchResponse> searchPoints(String collectionName, String vectorName, float[] vector, Filter filter,
SearchParams params, long limit,
Long offset,
WithPayloadSelector withPayloadSelector, WithVectorsSelector withVectorsSelector, Float scoreThreshold) {
Objects.requireNonNull(collectionName, "A collection name must be specified");
Expand All @@ -78,6 +82,10 @@ default GrpcClientRequest<SearchResponse> searchPoints(String collectionName, fl
.addAllVector(vectorList)
.setCollectionName(collectionName);

if (vectorName != null) {
request.setVectorName(vectorName);
}

if (filter != null) {
request.setFilter(filter);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@

import io.metaloom.qdrant.client.grpc.proto.JsonWithInt;
import io.metaloom.qdrant.client.grpc.proto.JsonWithInt.Value;
import io.metaloom.qdrant.client.grpc.proto.Points.NamedVectors;
import io.metaloom.qdrant.client.grpc.proto.Points.PointId;
import io.metaloom.qdrant.client.grpc.proto.Points.PointStruct;
import io.metaloom.qdrant.client.grpc.proto.Points.PointStruct.Builder;
import io.metaloom.qdrant.client.grpc.proto.Points.PointVectors;
import io.metaloom.qdrant.client.grpc.proto.Points.PointsIdsList;
import io.metaloom.qdrant.client.grpc.proto.Points.PointsSelector;
import io.metaloom.qdrant.client.grpc.proto.Points.Vector;
import io.metaloom.qdrant.client.grpc.proto.Points.Vectors;
import io.metaloom.qdrant.client.grpc.proto.Points.WithPayloadSelector;
Expand All @@ -35,6 +39,14 @@ public static Vector vector(float[] vector) {
return builder.build();
}

public static List<Float> vectorList(Float... vectors) {
List<Float> vectorList = new ArrayList<>(vectors.length);
for (float f : vectors) {
vectorList.add(Float.valueOf(f));
}
return vectorList;
}

/**
* Convert the string into a value model.
*
Expand Down Expand Up @@ -113,7 +125,7 @@ public static PointStruct point(long id, float[] vectorData, Map<String, Value>
*/
public static PointStruct point(PointId id, float[] vectorData, Map<String, Value> payload) {
Objects.requireNonNull(id, "A pointId must be provided.");
Vector vector = ModelHelper.vector(vectorData);
Vector vector = vector(vectorData);
Builder builder = PointStruct.newBuilder()
.setId(id)
.setVectors(Vectors.newBuilder().setVector(vector));
Expand All @@ -123,11 +135,52 @@ public static PointStruct point(PointId id, float[] vectorData, Map<String, Valu
return builder.build();
}

public static PointStruct namedPoint(Long id, String vectorName, float[] vectorData, Map<String, Value> payload) {
return namedPoint(pointId(id), vectorName, vectorData, payload);
}

public static PointStruct namedPoint(PointId id, String vectorName, float[] vectorData, Map<String, Value> payload) {
Objects.requireNonNull(id, "A pointId must be provided.");
NamedVectors vectors = namedVector(vectorName, vectorData);
Builder builder = PointStruct.newBuilder()
.setId(id)
.setVectors(Vectors.newBuilder().setVectors(vectors));
if (payload != null) {
builder.putAllPayload(payload);
}
return builder.build();
}

public static PointsSelector selectByIds(Long... ids) {
PointsIdsList.Builder pointList = PointsIdsList.newBuilder();
for (Long id : ids) {
pointList.addIds(pointId(id));
}

return PointsSelector.newBuilder().setPoints(pointList.build()).build();
}

public static WithPayloadSelector withPayload() {
return WithPayloadSelector.newBuilder().setEnable(true).build();
}

public static WithVectorsSelector withVector() {
return WithVectorsSelector.newBuilder().setEnable(true).build();
}

public static PointVectors pointVector(long id, String name, float[] vector) {
Vectors vectors = Vectors.newBuilder()
.setVectors(namedVector(name, vector))
.build();
return PointVectors.newBuilder()
.setId(pointId(id))
.setVectors(vectors)
.build();
}

public static NamedVectors namedVector(String name, float[] vector) {
return NamedVectors.newBuilder()
.putVectors(name, vector(vector))
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

import org.junit.After;
import org.junit.Before;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;

import com.google.protobuf.GeneratedMessageV3;

Expand All @@ -15,24 +15,29 @@
import io.metaloom.qdrant.client.grpc.proto.Collections.CreateAlias;
import io.metaloom.qdrant.client.grpc.proto.Collections.Distance;
import io.metaloom.qdrant.client.grpc.proto.Collections.VectorParams;
import io.metaloom.qdrant.client.grpc.proto.Collections.VectorParamsMap;

public abstract class AbstractGRPCClientTest extends AbstractContainerTest {

public static final String TEST_COLLECTION_NAME = "the-test-collection";

public static final String TEST_VECTOR_NAME = "color";

public static final String TEST_VECTOR_NAME_2 = "color-2";

public static final String TEST_ALIAS_NAME = "new-alias-name";

protected QDrantGRPCClient client;

@Before
@BeforeEach
public void setupClient() {
client = QDrantGRPCClient.builder()
.setHostname(qdrant.getHost())
.setPort(qdrant.grpcPort())
.build();
}

@After
@AfterEach
public void closeClient() {
if (client != null) {
client.close();
Expand All @@ -45,8 +50,14 @@ protected void createCollection(String collectionName) {
.setDistance(Distance.Euclid)
.build();

// Add the params to a map
VectorParamsMap paramsMap = VectorParamsMap.newBuilder()
.putMap(TEST_VECTOR_NAME, params)
.putMap(TEST_VECTOR_NAME_2, params)
.build();

// Create new collections
assertTrue(client.createCollection(collectionName, params).sync().getResult());
assertTrue(client.createCollection(collectionName, paramsMap).sync().getResult());
}

protected void createAlias(String collectionName, String aliasName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import java.util.List;
import java.util.Map;

import org.junit.Test;
import org.junit.jupiter.api.Test;

import io.metaloom.qdrant.client.AbstractContainerTest;
import io.metaloom.qdrant.client.grpc.proto.Collections.Distance;
import io.metaloom.qdrant.client.grpc.proto.Collections.VectorParams;
import io.metaloom.qdrant.client.grpc.proto.Collections.VectorParamsMap;
import io.metaloom.qdrant.client.grpc.proto.JsonWithInt.Value;
import io.metaloom.qdrant.client.grpc.proto.Points.PointStruct;
import io.metaloom.qdrant.client.grpc.proto.Points.ScoredPoint;
Expand All @@ -33,8 +34,14 @@ public void testExample() throws Exception {
.setDistance(Distance.Euclid)
.build();

// Add the params to a map
VectorParamsMap paramsMap = VectorParamsMap.newBuilder()
.putMap("firstVector", params)
.putMap("secondVector", params)
.build();

// Create new collections - blocking
client.createCollection("test1", params).sync();
client.createCollection("test1", paramsMap).sync();
// .. or via Future API
client.createCollection("test2", params).async().get();
// .. or via RxJava API
Expand All @@ -51,7 +58,7 @@ public void testExample() throws Exception {
payload.put("color", ModelHelper.value("blue"));

// Now construct the point
PointStruct point = ModelHelper.point(42L + i, vector, payload);
PointStruct point = ModelHelper.namedPoint(42L + i, "firstVector", vector, payload);
// .. and insert it
client.upsertPoint("test1", point, true).sync();
}
Expand All @@ -61,7 +68,7 @@ public void testExample() throws Exception {

// Now run KNN search
float[] searchVector = new float[] { 0.43f, 0.09f, 0.41f, 1.35f };
List<ScoredPoint> searchResults = client.searchPoints("test1", searchVector, 2, null).sync().getResultList();
List<ScoredPoint> searchResults = client.searchPoints("test1", "firstVector", searchVector, 2, null).sync().getResultList();
for (ScoredPoint result : searchResults) {
System.out.println("Found: [" + result.getId().getNum() + "] " + result.getScore());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,36 +1,36 @@
package io.metaloom.qdrant.client.grpc;

import org.junit.Ignore;
import org.junit.Test;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import io.metaloom.qdrant.client.testcases.ClusterClientTestcases;

public class ClusterGRPCClientTest extends AbstractGRPCClientTest implements ClusterClientTestcases {

@Test
@Override
@Ignore("Not supported for gRPC")
@Disabled("Not supported for gRPC")
public void testGetClusterStatusInfo() throws Exception {

}

@Test
@Override
@Ignore("Not supported for gRPC")
@Disabled("Not supported for gRPC")
public void testRemovePeerFromCluster() throws Exception {

}

@Test
@Override
@Ignore("Not supported for gRPC")
@Disabled("Not supported for gRPC")
public void testCollectionClusterInfo() throws Exception {

}

@Test
@Override
@Ignore("Not supported for gRPC")
@Disabled("Not supported for gRPC")
public void testUpdateCollectionClusterSetup() throws Exception {

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import java.util.List;

import org.junit.Test;
import org.junit.jupiter.api.Test;

import io.metaloom.qdrant.client.grpc.AbstractGRPCClientTest;
import io.metaloom.qdrant.client.grpc.proto.Collections.AliasDescription;
Expand All @@ -29,7 +29,10 @@ public void testCreateCollectionWithNamedVectorParams() throws Exception {
.build();

// Add the params to a map
VectorParamsMap paramsMap = VectorParamsMap.newBuilder().putMap("colors", params).build();
VectorParamsMap paramsMap = VectorParamsMap.newBuilder()
.putMap(TEST_VECTOR_NAME, params)
.putMap(TEST_VECTOR_NAME_2, params)
.build();

// Create new collections
assertTrue(client.createCollection(TEST_COLLECTION_NAME, paramsMap).sync().getResult());
Expand All @@ -39,7 +42,8 @@ public void testCreateCollectionWithNamedVectorParams() throws Exception {
System.out.println(collection.getName());
GetCollectionInfoResponse info = client.loadCollections(collection.getName()).sync();
VectorsConfig config = info.getResult().getConfig().getParams().getVectorsConfig();
assertTrue("The config did not contain the colors vector parameters.", config.getParamsMap().containsMap("colors"));
assertTrue("The config did not contain the colors vector parameters.", config.getParamsMap().containsMap(TEST_VECTOR_NAME));
assertTrue("The config did not contain the colors-2 vector parameters.", config.getParamsMap().containsMap(TEST_VECTOR_NAME_2));
}
}

Expand Down Expand Up @@ -100,9 +104,9 @@ public void testGetCollectionInfo() throws Exception {

GetCollectionInfoResponse info = client.loadCollections(TEST_COLLECTION_NAME).sync();
assertEquals("Could not load correct distance from collection", Distance.Euclid,
info.getResult().getConfig().getParams().getVectorsConfig().getParams().getDistance());
info.getResult().getConfig().getParams().getVectorsConfig().getParamsMap().getMapMap().get(TEST_VECTOR_NAME).getDistance());
assertEquals("Could not load correct dimension from collection", 4,
info.getResult().getConfig().getParams().getVectorsConfig().getParams().getSize());
info.getResult().getConfig().getParams().getVectorsConfig().getParamsMap().getMapMap().get(TEST_VECTOR_NAME).getSize());
}

@Test
Expand Down
Loading

0 comments on commit 6bd91de

Please sign in to comment.