Skip to content

Commit

Permalink
Fix some comments
Browse files Browse the repository at this point in the history
  • Loading branch information
NickChittle committed Aug 24, 2024
1 parent dcfea26 commit 76cf90e
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 34 deletions.
10 changes: 6 additions & 4 deletions google/cloud/firestore_v1/async_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"""
from __future__ import annotations

from typing import TYPE_CHECKING, AsyncGenerator, List, Optional, Type, Union
from typing import TYPE_CHECKING, AsyncGenerator, List, Optional, Type

from google.api_core import gapic_v1
from google.api_core import retry_async as retries
Expand Down Expand Up @@ -230,8 +230,9 @@ def find_nearest(
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
*,
distance_result_field: Optional[str] = None,
distance_threshold: Optional[Union[int, float]] = None,
distance_threshold: Optional[float] = None,
) -> AsyncVectorQuery:
"""
Finds the closest vector embeddings to the given query vector.
Expand All @@ -244,8 +245,9 @@ def find_nearest(
limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000.
distance_measure (:class:`DistanceMeasure`): The Distance Measure to use.
distance_result_field (Optional[str]):
Name of the field to output the result of the vector distance calculation
distance_threshold (Optional[Union[int, float]]):
Name of the field to output the result of the vector distance
calculation. If unset then the distance will not be returned.
distance_threshold (Optional[float]):
A threshold for which no less similar documents will be returned.
Returns:
Expand Down
5 changes: 3 additions & 2 deletions google/cloud/firestore_v1/base_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,8 +550,9 @@ def find_nearest(
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
*,
distance_result_field: Optional[str] = None,
distance_threshold: Optional[Union[int, float]] = None,
distance_threshold: Optional[float] = None,
) -> VectorQuery:
"""
Finds the closest vector embeddings to the given query vector.
Expand All @@ -565,7 +566,7 @@ def find_nearest(
distance_measure (:class:`DistanceMeasure`): The Distance Measure to use.
distance_result_field (Optional[str]):
Name of the field to output the result of the vector distance calculation
distance_threshold (Optional[Union[int, float]]):
distance_threshold (Optional[float]):
A threshold for which no less similar documents will be returned.
Returns:
Expand Down
3 changes: 2 additions & 1 deletion google/cloud/firestore_v1/base_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,8 +982,9 @@ def find_nearest(
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
*,
distance_result_field: Optional[str] = None,
distance_threshold: Optional[Union[int, float]] = None,
distance_threshold: Optional[float] = None,
) -> BaseVectorQuery:
raise NotImplementedError

Expand Down
7 changes: 4 additions & 3 deletions google/cloud/firestore_v1/base_vector_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, nested_query) -> None:
self._limit: Optional[int] = None
self._distance_measure: Optional[DistanceMeasure] = None
self._distance_result_field: Optional[str] = None
self._distance_threshold: Optional[Union[int, float]] = None
self._distance_threshold: Optional[float] = None

@property
def _client(self):
Expand All @@ -73,7 +73,7 @@ def _to_protobuf(self) -> query.StructuredQuery:

# Coerce ints to floats as required by the protobuf.
distance_threshold_proto = None
if self._distance_threshold:
if self._distance_threshold is not None:
distance_threshold_proto = float(self._distance_threshold)

pb = self._nested_query._to_protobuf()
Expand Down Expand Up @@ -120,8 +120,9 @@ def find_nearest(
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
*,
distance_result_field: Optional[str] = None,
distance_threshold: Optional[Union[int, float]] = None,
distance_threshold: Optional[float] = None,
):
"""Finds the closest vector embeddings to the given query vector."""
self._vector_field = vector_field
Expand Down
10 changes: 6 additions & 4 deletions google/cloud/firestore_v1/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"""
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Generator, List, Optional, Type, Union
from typing import TYPE_CHECKING, Any, Callable, Generator, List, Optional, Type

from google.api_core import exceptions, gapic_v1
from google.api_core import retry as retries
Expand Down Expand Up @@ -251,8 +251,9 @@ def find_nearest(
query_vector: Vector,
limit: int,
distance_measure: DistanceMeasure,
*,
distance_result_field: Optional[str] = None,
distance_threshold: Optional[Union[int, float]] = None,
distance_threshold: Optional[float] = None,
) -> Type["firestore_v1.vector_query.VectorQuery"]:
"""
Finds the closest vector embeddings to the given query vector.
Expand All @@ -265,8 +266,9 @@ def find_nearest(
limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000.
distance_measure (:class:`DistanceMeasure`): The Distance Measure to use.
distance_result_field (Optional[str]):
Name of the field to output the result of the vector distance calculation
distance_threshold (Optional[Union[int, float]]):
Name of the field to output the result of the vector distance
calculation. If unset then the distance will not be returned.
distance_threshold (Optional[float]):
A threshold for which no less similar documents will be returned.
Expand Down
112 changes: 102 additions & 10 deletions tests/system/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,22 @@ def on_snapshot(docs, changes, read_time):

@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
def test_vector_search_collection(client, database):
@pytest.mark.parametrize(
"distance_measure",
[
DistanceMeasure.EUCLIDEAN,
DistanceMeasure.COSINE,
],
)
def test_vector_search_collection(client, database, distance_measure):
# Documents and Indexes are a manual step from util/bootstrap_vector_index.py
collection_id = "vector_search"
collection = client.collection(collection_id)

vector_query = collection.find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
distance_measure=DistanceMeasure.EUCLIDEAN,
distance_measure=distance_measure,
limit=1,
)
returned = vector_query.get()
Expand All @@ -198,15 +205,22 @@ def test_vector_search_collection(client, database):

@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
def test_vector_search_collection_with_filter(client, database):
@pytest.mark.parametrize(
"distance_measure",
[
DistanceMeasure.EUCLIDEAN,
DistanceMeasure.COSINE,
],
)
def test_vector_search_collection_with_filter(client, database, distance_measure):
# Documents and Indexes are a manual step from util/bootstrap_vector_index.py
collection_id = "vector_search"
collection = client.collection(collection_id)

vector_query = collection.where("color", "==", "red").find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
distance_measure=DistanceMeasure.EUCLIDEAN,
distance_measure=distance_measure,
limit=1,
)
returned = vector_query.get()
Expand All @@ -220,7 +234,7 @@ def test_vector_search_collection_with_filter(client, database):

@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
def test_vector_search_collection_with_distance_parameters(client, database):
def test_vector_search_collection_with_distance_parameters_euclid(client, database):
# Documents and Indexes are a manual step from util/bootstrap_vector_index.py
collection_id = "vector_search"
collection = client.collection(collection_id)
Expand Down Expand Up @@ -250,15 +264,52 @@ def test_vector_search_collection_with_distance_parameters(client, database):

@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
def test_vector_search_collection_group(client, database):
def test_vector_search_collection_with_distance_parameters_cosine(client, database):
# Documents and Indexes are a manual step from util/bootstrap_vector_index.py
collection_id = "vector_search"
collection = client.collection(collection_id)

vector_query = collection.find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
distance_measure=DistanceMeasure.COSINE,
limit=3,
distance_result_field="vector_distance",
distance_threshold=0.02,
)
returned = vector_query.get()
assert isinstance(returned, list)
assert len(returned) == 2
assert returned[0].to_dict() == {
"embedding": Vector([1.0, 2.0, 3.0]),
"color": "red",
"vector_distance": 0.0,
}
assert returned[1].to_dict() == {
"embedding": Vector([3.0, 4.0, 5.0]),
"color": "yellow",
"vector_distance": 0.017292370176009153,
}


@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
@pytest.mark.parametrize(
"distance_measure",
[
DistanceMeasure.EUCLIDEAN,
DistanceMeasure.COSINE,
],
)
def test_vector_search_collection_group(client, database, distance_measure):
# Documents and Indexes are a manual step from util/bootstrap_vector_index.py
collection_id = "vector_search"
collection_group = client.collection_group(collection_id)

vector_query = collection_group.find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
distance_measure=DistanceMeasure.EUCLIDEAN,
distance_measure=distance_measure,
limit=1,
)
returned = vector_query.get()
Expand All @@ -271,16 +322,23 @@ def test_vector_search_collection_group(client, database):


@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize(
"distance_measure",
[
DistanceMeasure.EUCLIDEAN,
DistanceMeasure.COSINE,
],
)
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
def test_vector_search_collection_group_with_filter(client, database):
def test_vector_search_collection_group_with_filter(client, database, distance_measure):
# Documents and Indexes are a manual step from util/bootstrap_vector_index.py
collection_id = "vector_search"
collection_group = client.collection_group(collection_id)

vector_query = collection_group.where("color", "==", "red").find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
distance_measure=DistanceMeasure.EUCLIDEAN,
distance_measure=distance_measure,
limit=1,
)
returned = vector_query.get()
Expand All @@ -294,7 +352,9 @@ def test_vector_search_collection_group_with_filter(client, database):

@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
def test_vector_search_collection_group_with_distance_parameters(client, database):
def test_vector_search_collection_group_with_distance_parameters_euclid(
client, database
):
# Documents and Indexes are a manual step from util/bootstrap_vector_index.py
collection_id = "vector_search"
collection_group = client.collection_group(collection_id)
Expand Down Expand Up @@ -322,6 +382,38 @@ def test_vector_search_collection_group_with_distance_parameters(client, databas
}


@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
def test_vector_search_collection_group_with_distance_parameters_cosine(
client, database
):
# Documents and Indexes are a manual step from util/bootstrap_vector_index.py
collection_id = "vector_search"
collection_group = client.collection_group(collection_id)

vector_query = collection_group.find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
distance_measure=DistanceMeasure.COSINE,
limit=3,
distance_result_field="vector_distance",
distance_threshold=0.02,
)
returned = vector_query.get()
assert isinstance(returned, list)
assert len(returned) == 2
assert returned[0].to_dict() == {
"embedding": Vector([1.0, 2.0, 3.0]),
"color": "red",
"vector_distance": 0.0,
}
assert returned[1].to_dict() == {
"embedding": Vector([3.0, 4.0, 5.0]),
"color": "yellow",
"vector_distance": 0.017292370176009153,
}


@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
def test_create_document_w_subcollection(client, cleanup, database):
collection_id = "doc-create-sub" + UNIQUE_RESOURCE_ID
Expand Down
Loading

0 comments on commit 76cf90e

Please sign in to comment.