Skip to content

Commit

Permalink
community[patch]: Fix Hybrid Search for non-Databricks managed embedd…
Browse files Browse the repository at this point in the history
…ings (#25590)

Description: Send both the query and query_embedding to the Databricks
index for hybrid search.

Issue: When using hybrid search with non-Databricks managed embedding we
currently don't pass both the embedding and query_text to the index.
Hybrid search requires both of these. This change fixes this issue for
both `similarity_search` and `similarity_search_by_vector`.

---------

Co-authored-by: Erick Friis <[email protected]>
  • Loading branch information
erikml-db and efriis authored Aug 23, 2024
1 parent bcd5842 commit 583b044
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,11 @@ def similarity_search_with_score(
query_vector = None
else:
assert self.embeddings is not None, "embedding model is required."
query_text = None
# The value for `query_text` needs to be specified only for hybrid search.
if query_type is not None and query_type.upper() == "HYBRID":
query_text = query
else:
query_text = None
query_vector = self.embeddings.embed_query(query)
search_resp = self.index.similarity_search(
columns=self.columns,
Expand Down Expand Up @@ -487,6 +491,7 @@ def similarity_search_by_vector(
filter: Optional[Any] = None,
*,
query_type: Optional[str] = None,
query: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs most similar to embedding vector.
Expand All @@ -505,6 +510,7 @@ def similarity_search_by_vector(
k=k,
filter=filter,
query_type=query_type,
query=query,
**kwargs,
)
return [doc for doc, _ in docs_with_score]
Expand All @@ -516,6 +522,7 @@ def similarity_search_by_vector_with_score(
filter: Optional[Any] = None,
*,
query_type: Optional[str] = None,
query: Optional[str] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs most similar to embedding vector, along with scores.
Expand All @@ -534,9 +541,25 @@ def similarity_search_by_vector_with_score(
"`similarity_search_by_vector` is not supported for index with "
"Databricks-managed embeddings."
)
if query_type is not None and query_type.upper() == "HYBRID":
if query is None:
raise ValueError(
"A value for `query` must be specified for hybrid search."
)
query_text = query
else:
if query is not None:
raise ValueError(
(
"Cannot specify both `embedding` and "
'`query` unless `query_type="HYBRID"'
)
)
query_text = None
search_resp = self.index.similarity_search(
columns=self.columns,
query_vector=embedding,
query_text=query_text,
filters=filter or _alias_filters(kwargs),
num_results=k,
query_type=query_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def test_delete_fail_no_ids() -> None:

@pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize(
"index_details, query_type", itertools.product(ALL_INDEXES, ALL_QUERY_TYPES)
"index_details, query_type", itertools.product(ALL_INDEXES, [None, "ANN"])
)
def test_similarity_search(index_details: dict, query_type: Optional[str]) -> None:
index = mock_index(index_details)
Expand Down Expand Up @@ -518,6 +518,42 @@ def test_similarity_search(index_details: dict, query_type: Optional[str]) -> No
assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result])


@pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize("index_details", ALL_INDEXES)
def test_similarity_search_hybrid(index_details: dict) -> None:
index = mock_index(index_details)
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE
vectorsearch = default_databricks_vector_search(index)
query = "foo"
filters = {"some filter": True}
limit = 7

search_result = vectorsearch.similarity_search(
query, k=limit, filter=filters, query_type="HYBRID"
)
if index_details == DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS:
index.similarity_search.assert_called_once_with(
columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN],
query_text=query,
query_vector=None,
filters=filters,
num_results=limit,
query_type="HYBRID",
)
else:
index.similarity_search.assert_called_once_with(
columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN],
query_text=query,
query_vector=DEFAULT_EMBEDDING_MODEL.embed_query(query),
filters=filters,
num_results=limit,
query_type="HYBRID",
)
assert len(search_result) == len(fake_texts)
assert sorted([d.page_content for d in search_result]) == sorted(fake_texts)
assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result])


@pytest.mark.requires("databricks", "databricks.vector_search")
def test_similarity_search_both_filter_and_filters_passed() -> None:
index = mock_index(DIRECT_ACCESS_INDEX)
Expand Down Expand Up @@ -655,11 +691,44 @@ def test_standard_params() -> None:
}


@pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize(
"index_details, query_type",
itertools.product(
[DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX], [None, "ANN"]
),
)
def test_similarity_search_by_vector(
index_details: dict, query_type: Optional[str]
) -> None:
index = mock_index(index_details)
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE
vectorsearch = default_databricks_vector_search(index)
query_embedding = DEFAULT_EMBEDDING_MODEL.embed_query("foo")
filters = {"some filter": True}
limit = 7

search_result = vectorsearch.similarity_search_by_vector(
query_embedding, k=limit, filter=filters, query_type=query_type
)
index.similarity_search.assert_called_once_with(
columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN],
query_vector=query_embedding,
filters=filters,
num_results=limit,
query_type=query_type,
query_text=None,
)
assert len(search_result) == len(fake_texts)
assert sorted([d.page_content for d in search_result]) == sorted(fake_texts)
assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result])


@pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize(
"index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX]
)
def test_similarity_search_by_vector(index_details: dict) -> None:
def test_similarity_search_by_vector_hybrid(index_details: dict) -> None:
index = mock_index(index_details)
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE
vectorsearch = default_databricks_vector_search(index)
Expand All @@ -668,14 +737,15 @@ def test_similarity_search_by_vector(index_details: dict) -> None:
limit = 7

search_result = vectorsearch.similarity_search_by_vector(
query_embedding, k=limit, filter=filters
query_embedding, k=limit, filter=filters, query_type="HYBRID", query="foo"
)
index.similarity_search.assert_called_once_with(
columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN],
query_vector=query_embedding,
filters=filters,
num_results=limit,
query_type=None,
query_type="HYBRID",
query_text="foo",
)
assert len(search_result) == len(fake_texts)
assert sorted([d.page_content for d in search_result]) == sorted(fake_texts)
Expand Down

0 comments on commit 583b044

Please sign in to comment.