diff --git a/google/cloud/firestore_v1/async_query.py b/google/cloud/firestore_v1/async_query.py index e697594ffc..ca83c26306 100644 --- a/google/cloud/firestore_v1/async_query.py +++ b/google/cloud/firestore_v1/async_query.py @@ -20,7 +20,7 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING, AsyncGenerator, List, Optional, Type, Union +from typing import TYPE_CHECKING, AsyncGenerator, List, Optional, Type from google.api_core import gapic_v1 from google.api_core import retry_async as retries @@ -230,8 +230,9 @@ def find_nearest( query_vector: Vector, limit: int, distance_measure: DistanceMeasure, + *, distance_result_field: Optional[str] = None, - distance_threshold: Optional[Union[int, float]] = None, + distance_threshold: Optional[float] = None, ) -> AsyncVectorQuery: """ Finds the closest vector embeddings to the given query vector. @@ -244,8 +245,9 @@ def find_nearest( limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000. distance_measure (:class:`DistanceMeasure`): The Distance Measure to use. distance_result_field (Optional[str]): - Name of the field to output the result of the vector distance calculation - distance_threshold (Optional[Union[int, float]]): + Name of the field to output the result of the vector distance + calculation. If unset then the distance will not be returned. + distance_threshold (Optional[float]): A threshold for which no less similar documents will be returned. Returns: diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index f999f485fa..18c62aa33b 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -550,8 +550,9 @@ def find_nearest( query_vector: Vector, limit: int, distance_measure: DistanceMeasure, + *, distance_result_field: Optional[str] = None, - distance_threshold: Optional[Union[int, float]] = None, + distance_threshold: Optional[float] = None, ) -> VectorQuery: """ Finds the closest vector embeddings to the given query vector. @@ -565,7 +566,7 @@ def find_nearest( distance_measure (:class:`DistanceMeasure`): The Distance Measure to use. distance_result_field (Optional[str]): Name of the field to output the result of the vector distance calculation - distance_threshold (Optional[Union[int, float]]): + distance_threshold (Optional[float]): A threshold for which no less similar documents will be returned. Returns: diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 86284640f5..cfed454b93 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -982,8 +982,9 @@ def find_nearest( query_vector: Vector, limit: int, distance_measure: DistanceMeasure, + *, distance_result_field: Optional[str] = None, - distance_threshold: Optional[Union[int, float]] = None, + distance_threshold: Optional[float] = None, ) -> BaseVectorQuery: raise NotImplementedError diff --git a/google/cloud/firestore_v1/base_vector_query.py b/google/cloud/firestore_v1/base_vector_query.py index fd3bcdfb83..26cd5b1997 100644 --- a/google/cloud/firestore_v1/base_vector_query.py +++ b/google/cloud/firestore_v1/base_vector_query.py @@ -46,7 +46,7 @@ def __init__(self, nested_query) -> None: self._limit: Optional[int] = None self._distance_measure: Optional[DistanceMeasure] = None self._distance_result_field: Optional[str] = None - self._distance_threshold: Optional[Union[int, float]] = None + self._distance_threshold: Optional[float] = None @property def _client(self): @@ -73,7 +73,7 @@ def _to_protobuf(self) -> query.StructuredQuery: # Coerce ints to floats as required by the protobuf. distance_threshold_proto = None - if self._distance_threshold: + if self._distance_threshold is not None: distance_threshold_proto = float(self._distance_threshold) pb = self._nested_query._to_protobuf() @@ -120,8 +120,9 @@ def find_nearest( query_vector: Vector, limit: int, distance_measure: DistanceMeasure, + *, distance_result_field: Optional[str] = None, - distance_threshold: Optional[Union[int, float]] = None, + distance_threshold: Optional[float] = None, ): """Finds the closest vector embeddings to the given query vector.""" self._vector_field = vector_field diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index 4e7437d8ab..eb8f51dc8d 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -20,7 +20,7 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Generator, List, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Callable, Generator, List, Optional, Type from google.api_core import exceptions, gapic_v1 from google.api_core import retry as retries @@ -251,8 +251,9 @@ def find_nearest( query_vector: Vector, limit: int, distance_measure: DistanceMeasure, + *, distance_result_field: Optional[str] = None, - distance_threshold: Optional[Union[int, float]] = None, + distance_threshold: Optional[float] = None, ) -> Type["firestore_v1.vector_query.VectorQuery"]: """ Finds the closest vector embeddings to the given query vector. @@ -265,8 +266,9 @@ def find_nearest( limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000. distance_measure (:class:`DistanceMeasure`): The Distance Measure to use. distance_result_field (Optional[str]): - Name of the field to output the result of the vector distance calculation - distance_threshold (Optional[Union[int, float]]): + Name of the field to output the result of the vector distance + calculation. If unset then the distance will not be returned. + distance_threshold (Optional[float]): A threshold for which no less similar documents will be returned. diff --git a/tests/system/test_system.py b/tests/system/test_system.py index d8ae90f4d8..b67b8aecca 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -176,7 +176,14 @@ def on_snapshot(docs, changes, read_time): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_vector_search_collection(client, database): +@pytest.mark.parametrize( + "distance_measure", + [ + DistanceMeasure.EUCLIDEAN, + DistanceMeasure.COSINE, + ], +) +def test_vector_search_collection(client, database, distance_measure): # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection = client.collection(collection_id) @@ -184,7 +191,7 @@ def test_vector_search_collection(client, database): vector_query = collection.find_nearest( vector_field="embedding", query_vector=Vector([1.0, 2.0, 3.0]), - distance_measure=DistanceMeasure.EUCLIDEAN, + distance_measure=distance_measure, limit=1, ) returned = vector_query.get() @@ -198,7 +205,14 @@ def test_vector_search_collection(client, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_vector_search_collection_with_filter(client, database): +@pytest.mark.parametrize( + "distance_measure", + [ + DistanceMeasure.EUCLIDEAN, + DistanceMeasure.COSINE, + ], +) +def test_vector_search_collection_with_filter(client, database, distance_measure): # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection = client.collection(collection_id) @@ -206,7 +220,7 @@ def test_vector_search_collection_with_filter(client, database): vector_query = collection.where("color", "==", "red").find_nearest( vector_field="embedding", query_vector=Vector([1.0, 2.0, 3.0]), - distance_measure=DistanceMeasure.EUCLIDEAN, + distance_measure=distance_measure, limit=1, ) returned = vector_query.get() @@ -220,7 +234,7 @@ def test_vector_search_collection_with_filter(client, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_vector_search_collection_with_distance_parameters(client, database): +def test_vector_search_collection_with_distance_parameters_euclid(client, database): # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection = client.collection(collection_id) @@ -250,7 +264,44 @@ def test_vector_search_collection_with_distance_parameters(client, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_vector_search_collection_group(client, database): +def test_vector_search_collection_with_distance_parameters_cosine(client, database): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py + collection_id = "vector_search" + collection = client.collection(collection_id) + + vector_query = collection.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.COSINE, + limit=3, + distance_result_field="vector_distance", + distance_threshold=0.02, + ) + returned = vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 0.0, + } + assert returned[1].to_dict() == { + "embedding": Vector([3.0, 4.0, 5.0]), + "color": "yellow", + "vector_distance": 0.017292370176009153, + } + + +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize( + "distance_measure", + [ + DistanceMeasure.EUCLIDEAN, + DistanceMeasure.COSINE, + ], +) +def test_vector_search_collection_group(client, database, distance_measure): # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection_group = client.collection_group(collection_id) @@ -258,7 +309,7 @@ def test_vector_search_collection_group(client, database): vector_query = collection_group.find_nearest( vector_field="embedding", query_vector=Vector([1.0, 2.0, 3.0]), - distance_measure=DistanceMeasure.EUCLIDEAN, + distance_measure=distance_measure, limit=1, ) returned = vector_query.get() @@ -271,8 +322,15 @@ def test_vector_search_collection_group(client, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize( + "distance_measure", + [ + DistanceMeasure.EUCLIDEAN, + DistanceMeasure.COSINE, + ], +) @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_vector_search_collection_group_with_filter(client, database): +def test_vector_search_collection_group_with_filter(client, database, distance_measure): # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection_group = client.collection_group(collection_id) @@ -280,7 +338,7 @@ def test_vector_search_collection_group_with_filter(client, database): vector_query = collection_group.where("color", "==", "red").find_nearest( vector_field="embedding", query_vector=Vector([1.0, 2.0, 3.0]), - distance_measure=DistanceMeasure.EUCLIDEAN, + distance_measure=distance_measure, limit=1, ) returned = vector_query.get() @@ -294,7 +352,9 @@ def test_vector_search_collection_group_with_filter(client, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -def test_vector_search_collection_group_with_distance_parameters(client, database): +def test_vector_search_collection_group_with_distance_parameters_euclid( + client, database +): # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection_group = client.collection_group(collection_id) @@ -322,6 +382,38 @@ def test_vector_search_collection_group_with_distance_parameters(client, databas } +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_vector_search_collection_group_with_distance_parameters_cosine( + client, database +): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py + collection_id = "vector_search" + collection_group = client.collection_group(collection_id) + + vector_query = collection_group.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.COSINE, + limit=3, + distance_result_field="vector_distance", + distance_threshold=0.02, + ) + returned = vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 0.0, + } + assert returned[1].to_dict() == { + "embedding": Vector([3.0, 4.0, 5.0]), + "color": "yellow", + "vector_distance": 0.017292370176009153, + } + + @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) def test_create_document_w_subcollection(client, cleanup, database): collection_id = "doc-create-sub" + UNIQUE_RESOURCE_ID diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index c71a132ff3..78bd64c5c5 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -341,7 +341,14 @@ async def test_document_update_w_int_field(client, cleanup, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -async def test_vector_search_collection(client, database): +@pytest.mark.parametrize( + "distance_measure", + [ + DistanceMeasure.EUCLIDEAN, + DistanceMeasure.COSINE, + ], +) +async def test_vector_search_collection(client, database, distance_measure): # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection = client.collection(collection_id) @@ -349,7 +356,7 @@ async def test_vector_search_collection(client, database): vector_field="embedding", query_vector=Vector([1.0, 2.0, 3.0]), limit=1, - distance_measure=DistanceMeasure.EUCLIDEAN, + distance_measure=distance_measure, ) returned = await vector_query.get() assert isinstance(returned, list) @@ -362,7 +369,14 @@ async def test_vector_search_collection(client, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -async def test_vector_search_collection_with_filter(client, database): +@pytest.mark.parametrize( + "distance_measure", + [ + DistanceMeasure.EUCLIDEAN, + DistanceMeasure.COSINE, + ], +) +async def test_vector_search_collection_with_filter(client, database, distance_measure): # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection = client.collection(collection_id) @@ -370,7 +384,7 @@ async def test_vector_search_collection_with_filter(client, database): vector_field="embedding", query_vector=Vector([1.0, 2.0, 3.0]), limit=1, - distance_measure=DistanceMeasure.EUCLIDEAN, + distance_measure=distance_measure, ) returned = await vector_query.get() assert isinstance(returned, list) @@ -383,7 +397,9 @@ async def test_vector_search_collection_with_filter(client, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -async def test_vector_search_collection_with_distance_parameters(client, database): +async def test_vector_search_collection_with_distance_parameters_euclid( + client, database +): # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection = client.collection(collection_id) @@ -413,7 +429,46 @@ async def test_vector_search_collection_with_distance_parameters(client, databas @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -async def test_vector_search_collection_group(client, database): +async def test_vector_search_collection_with_distance_parameters_cosine( + client, database +): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py + collection_id = "vector_search" + collection = client.collection(collection_id) + + vector_query = collection.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.COSINE, + limit=3, + distance_result_field="vector_distance", + distance_threshold=0.02, + ) + returned = await vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 0.0, + } + assert returned[1].to_dict() == { + "embedding": Vector([3.0, 4.0, 5.0]), + "color": "yellow", + "vector_distance": 0.017292370176009153, + } + + +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize( + "distance_measure", + [ + DistanceMeasure.EUCLIDEAN, + DistanceMeasure.COSINE, + ], +) +async def test_vector_search_collection_group(client, database, distance_measure): # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection_group = client.collection_group(collection_id) @@ -421,7 +476,7 @@ async def test_vector_search_collection_group(client, database): vector_query = collection_group.find_nearest( vector_field="embedding", query_vector=Vector([1.0, 2.0, 3.0]), - distance_measure=DistanceMeasure.EUCLIDEAN, + distance_measure=distance_measure, limit=1, ) returned = await vector_query.get() @@ -435,7 +490,16 @@ async def test_vector_search_collection_group(client, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -async def test_vector_search_collection_group_with_filter(client, database): +@pytest.mark.parametrize( + "distance_measure", + [ + DistanceMeasure.EUCLIDEAN, + DistanceMeasure.COSINE, + ], +) +async def test_vector_search_collection_group_with_filter( + client, database, distance_measure +): # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" collection_group = client.collection_group(collection_id) @@ -443,7 +507,7 @@ async def test_vector_search_collection_group_with_filter(client, database): vector_query = collection_group.where("color", "==", "red").find_nearest( vector_field="embedding", query_vector=Vector([1.0, 2.0, 3.0]), - distance_measure=DistanceMeasure.EUCLIDEAN, + distance_measure=distance_measure, limit=1, ) returned = await vector_query.get() @@ -457,7 +521,7 @@ async def test_vector_search_collection_group_with_filter(client, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) -async def test_vector_search_collection_group_with_distance_parameters( +async def test_vector_search_collection_group_with_distance_parameters_euclid( client, database ): # Documents and Indexes are a manual step from util/bootstrap_vector_index.py @@ -487,6 +551,38 @@ async def test_vector_search_collection_group_with_distance_parameters( } +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_vector_search_collection_group_with_distance_parameters_cosine( + client, database +): + # Documents and Indexes are a manual step from util/bootstrap_vector_index.py + collection_id = "vector_search" + collection_group = client.collection_group(collection_id) + + vector_query = collection_group.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.COSINE, + limit=3, + distance_result_field="vector_distance", + distance_threshold=0.02, + ) + returned = await vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 2 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + "vector_distance": 0.0, + } + assert returned[1].to_dict() == { + "embedding": Vector([3.0, 4.0, 5.0]), + "color": "yellow", + "vector_distance": 0.017292370176009153, + } + + @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Internal Issue b/137867104") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) async def test_update_document(client, cleanup, database):