diff --git a/libs/community/langchain_community/vectorstores/databricks_vector_search.py b/libs/community/langchain_community/vectorstores/databricks_vector_search.py index cd52db0180dd6..2c0727647a8a1 100644 --- a/libs/community/langchain_community/vectorstores/databricks_vector_search.py +++ b/libs/community/langchain_community/vectorstores/databricks_vector_search.py @@ -341,7 +341,11 @@ def similarity_search_with_score( query_vector = None else: assert self.embeddings is not None, "embedding model is required." - query_text = None + # The value for `query_text` needs to be specified only for hybrid search. + if query_type is not None and query_type.upper() == "HYBRID": + query_text = query + else: + query_text = None query_vector = self.embeddings.embed_query(query) search_resp = self.index.similarity_search( columns=self.columns, @@ -487,6 +491,7 @@ def similarity_search_by_vector( filter: Optional[Any] = None, *, query_type: Optional[str] = None, + query: Optional[str] = None, **kwargs: Any, ) -> List[Document]: """Return docs most similar to embedding vector. @@ -505,6 +510,7 @@ def similarity_search_by_vector( k=k, filter=filter, query_type=query_type, + query=query, **kwargs, ) return [doc for doc, _ in docs_with_score] @@ -516,6 +522,7 @@ def similarity_search_by_vector_with_score( filter: Optional[Any] = None, *, query_type: Optional[str] = None, + query: Optional[str] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs most similar to embedding vector, along with scores. @@ -534,9 +541,25 @@ def similarity_search_by_vector_with_score( "`similarity_search_by_vector` is not supported for index with " "Databricks-managed embeddings." ) + if query_type is not None and query_type.upper() == "HYBRID": + if query is None: + raise ValueError( + "A value for `query` must be specified for hybrid search." + ) + query_text = query + else: + if query is not None: + raise ValueError( + ( + "Cannot specify both `embedding` and " + '`query` unless `query_type="HYBRID"' + ) + ) + query_text = None search_resp = self.index.similarity_search( columns=self.columns, query_vector=embedding, + query_text=query_text, filters=filter or _alias_filters(kwargs), num_results=k, query_type=query_type, diff --git a/libs/community/tests/unit_tests/vectorstores/test_databricks_vector_search.py b/libs/community/tests/unit_tests/vectorstores/test_databricks_vector_search.py index 508bf0ac1ba41..fcf0bfc9bc9c7 100644 --- a/libs/community/tests/unit_tests/vectorstores/test_databricks_vector_search.py +++ b/libs/community/tests/unit_tests/vectorstores/test_databricks_vector_search.py @@ -482,7 +482,7 @@ def test_delete_fail_no_ids() -> None: @pytest.mark.requires("databricks", "databricks.vector_search") @pytest.mark.parametrize( - "index_details, query_type", itertools.product(ALL_INDEXES, ALL_QUERY_TYPES) + "index_details, query_type", itertools.product(ALL_INDEXES, [None, "ANN"]) ) def test_similarity_search(index_details: dict, query_type: Optional[str]) -> None: index = mock_index(index_details) @@ -518,6 +518,42 @@ def test_similarity_search(index_details: dict, query_type: Optional[str]) -> No assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result]) +@pytest.mark.requires("databricks", "databricks.vector_search") +@pytest.mark.parametrize("index_details", ALL_INDEXES) +def test_similarity_search_hybrid(index_details: dict) -> None: + index = mock_index(index_details) + index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE + vectorsearch = default_databricks_vector_search(index) + query = "foo" + filters = {"some filter": True} + limit = 7 + + search_result = vectorsearch.similarity_search( + query, k=limit, filter=filters, query_type="HYBRID" + ) + if index_details == DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS: + index.similarity_search.assert_called_once_with( + columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN], + query_text=query, + query_vector=None, + filters=filters, + num_results=limit, + query_type="HYBRID", + ) + else: + index.similarity_search.assert_called_once_with( + columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN], + query_text=query, + query_vector=DEFAULT_EMBEDDING_MODEL.embed_query(query), + filters=filters, + num_results=limit, + query_type="HYBRID", + ) + assert len(search_result) == len(fake_texts) + assert sorted([d.page_content for d in search_result]) == sorted(fake_texts) + assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result]) + + @pytest.mark.requires("databricks", "databricks.vector_search") def test_similarity_search_both_filter_and_filters_passed() -> None: index = mock_index(DIRECT_ACCESS_INDEX) @@ -655,11 +691,44 @@ def test_standard_params() -> None: } +@pytest.mark.requires("databricks", "databricks.vector_search") +@pytest.mark.parametrize( + "index_details, query_type", + itertools.product( + [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX], [None, "ANN"] + ), +) +def test_similarity_search_by_vector( + index_details: dict, query_type: Optional[str] +) -> None: + index = mock_index(index_details) + index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE + vectorsearch = default_databricks_vector_search(index) + query_embedding = DEFAULT_EMBEDDING_MODEL.embed_query("foo") + filters = {"some filter": True} + limit = 7 + + search_result = vectorsearch.similarity_search_by_vector( + query_embedding, k=limit, filter=filters, query_type=query_type + ) + index.similarity_search.assert_called_once_with( + columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN], + query_vector=query_embedding, + filters=filters, + num_results=limit, + query_type=query_type, + query_text=None, + ) + assert len(search_result) == len(fake_texts) + assert sorted([d.page_content for d in search_result]) == sorted(fake_texts) + assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result]) + + @pytest.mark.requires("databricks", "databricks.vector_search") @pytest.mark.parametrize( "index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX] ) -def test_similarity_search_by_vector(index_details: dict) -> None: +def test_similarity_search_by_vector_hybrid(index_details: dict) -> None: index = mock_index(index_details) index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE vectorsearch = default_databricks_vector_search(index) @@ -668,14 +737,15 @@ def test_similarity_search_by_vector(index_details: dict) -> None: limit = 7 search_result = vectorsearch.similarity_search_by_vector( - query_embedding, k=limit, filter=filters + query_embedding, k=limit, filter=filters, query_type="HYBRID", query="foo" ) index.similarity_search.assert_called_once_with( columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN], query_vector=query_embedding, filters=filters, num_results=limit, - query_type=None, + query_type="HYBRID", + query_text="foo", ) assert len(search_result) == len(fake_texts) assert sorted([d.page_content for d in search_result]) == sorted(fake_texts)