diff --git a/libs/community/langchain_community/vectorstores/azuresearch.py b/libs/community/langchain_community/vectorstores/azuresearch.py index 6930c8319e4d9..d0aa15e2acbd1 100644 --- a/libs/community/langchain_community/vectorstores/azuresearch.py +++ b/libs/community/langchain_community/vectorstores/azuresearch.py @@ -42,6 +42,8 @@ logger = logging.getLogger() if TYPE_CHECKING: + from azure.core.credentials import TokenCredential + from azure.core.credentials_async import AsyncTokenCredential from azure.search.documents import SearchClient, SearchItemPaged from azure.search.documents.aio import ( AsyncSearchItemPaged, @@ -96,10 +98,13 @@ def _get_search_client( cors_options: Optional[CorsOptions] = None, async_: bool = False, additional_search_client_options: Optional[Dict[str, Any]] = None, + azure_credential: Optional[TokenCredential] = None, + azure_async_credential: Optional[AsyncTokenCredential] = None, ) -> Union[SearchClient, AsyncSearchClient]: from azure.core.credentials import AccessToken, AzureKeyCredential, TokenCredential from azure.core.exceptions import ResourceNotFoundError from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential + from azure.identity.aio import DefaultAzureCredential as AsyncDefaultAzureCredential from azure.search.documents import SearchClient from azure.search.documents.aio import SearchClient as AsyncSearchClient from azure.search.documents.indexes import SearchIndexClient @@ -143,12 +148,17 @@ def get_token( if key.upper() == "INTERACTIVE": credential = InteractiveBrowserCredential() credential.get_token("https://search.azure.com/.default") + async_credential = credential else: credential = AzureKeyCredential(key) + async_credential = credential elif azure_ad_access_token is not None: credential = AzureBearerTokenCredential(azure_ad_access_token) + async_credential = credential else: - credential = DefaultAzureCredential() + credential = azure_credential or DefaultAzureCredential() + async_credential = azure_async_credential or AsyncDefaultAzureCredential() + index_client: SearchIndexClient = SearchIndexClient( endpoint=endpoint, credential=credential, @@ -266,7 +276,7 @@ def fmt_err(x: str) -> str: return AsyncSearchClient( endpoint=endpoint, index_name=index_name, - credential=credential, + credential=async_credential, user_agent=user_agent, **additional_search_client_options, ) @@ -278,7 +288,7 @@ class AzureSearch(VectorStore): def __init__( self, azure_search_endpoint: str, - azure_search_key: str, + azure_search_key: Optional[str], index_name: str, embedding_function: Union[Callable, Embeddings], search_type: str = "hybrid", @@ -295,6 +305,8 @@ def __init__( vector_search_dimensions: Optional[int] = None, additional_search_client_options: Optional[Dict[str, Any]] = None, azure_ad_access_token: Optional[str] = None, + azure_credential: Optional[TokenCredential] = None, + azure_async_credential: Optional[AsyncTokenCredential] = None, **kwargs: Any, ): try: @@ -361,6 +373,7 @@ def __init__( user_agent=user_agent, cors_options=cors_options, additional_search_client_options=additional_search_client_options, + azure_credential=azure_credential, ) self.async_client = _get_search_client( azure_search_endpoint, @@ -377,6 +390,8 @@ def __init__( user_agent=user_agent, cors_options=cors_options, async_=True, + azure_credential=azure_credential, + azure_async_credential=azure_async_credential, ) self.search_type = search_type self.semantic_configuration_name = semantic_configuration_name