diff --git a/libs/community/langchain_community/vectorstores/azuresearch.py b/libs/community/langchain_community/vectorstores/azuresearch.py index 6d19574e8ce30..6748f81dc2529 100644 --- a/libs/community/langchain_community/vectorstores/azuresearch.py +++ b/libs/community/langchain_community/vectorstores/azuresearch.py @@ -119,6 +119,21 @@ def _get_search_client( VectorSearchProfile, ) + class AzureBearerTokenCredential(TokenCredential): + def __init__(self, token: str): + # set the expiry to an hour from now. + self._token = AccessToken(token, int(time.time()) + 3600) + + def get_token( + self, + *scopes: str, + claims: Optional[str] = None, + tenant_id: Optional[str] = None, + enable_cae: bool = False, + **kwargs: Any, + ) -> AccessToken: + return self._token + additional_search_client_options = additional_search_client_options or {} default_fields = default_fields or [] credential: Union[AzureKeyCredential, TokenCredential, InteractiveBrowserCredential] @@ -131,11 +146,7 @@ def _get_search_client( else: credential = AzureKeyCredential(key) elif azure_ad_access_token is not None: - credential = TokenCredential( - lambda *scopes, **kwargs: AccessToken( - azure_ad_access_token, int(time.time()) + 3600 - ) - ) + credential = AzureBearerTokenCredential(azure_ad_access_token) else: credential = DefaultAzureCredential() index_client: SearchIndexClient = SearchIndexClient(