From 06c267c90563fb34940995d49f4a6e21df50a34f Mon Sep 17 00:00:00 2001 From: ohmayr Date: Wed, 2 Oct 2024 19:16:58 +0000 Subject: [PATCH] fix incorrect tests --- .../abstract_operations_async_client.py | 2 +- .../abstract_operations_base_client.py | 55 ++++-- .../abstract_operations_client.py | 38 ++++ .../test_operations_rest_client.py | 177 ++++++++++-------- 4 files changed, 178 insertions(+), 94 deletions(-) diff --git a/google/api_core/operations_v1/abstract_operations_async_client.py b/google/api_core/operations_v1/abstract_operations_async_client.py index 5e6b1521..80b4621a 100644 --- a/google/api_core/operations_v1/abstract_operations_async_client.py +++ b/google/api_core/operations_v1/abstract_operations_async_client.py @@ -98,7 +98,7 @@ def __init__( """ super().__init__( credentials=credentials, # type: ignore - transport=transport, + transport=transport or "rest_asyncio", client_options=client_options, client_info=client_info, ) diff --git a/google/api_core/operations_v1/abstract_operations_base_client.py b/google/api_core/operations_v1/abstract_operations_base_client.py index e66fa15e..7358ac78 100644 --- a/google/api_core/operations_v1/abstract_operations_base_client.py +++ b/google/api_core/operations_v1/abstract_operations_base_client.py @@ -19,6 +19,7 @@ from typing import Dict, Optional, Type, Union from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions as core_exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core.operations_v1.transports.base import ( DEFAULT_CLIENT_INFO, @@ -39,7 +40,6 @@ from google.auth import credentials as ga_credentials # type: ignore from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.auth.transport import mtls # type: ignore -from google.oauth2 import service_account # type: ignore class AbstractOperationsBaseClientMeta(type): @@ -143,9 +143,7 @@ def from_service_account_info(cls, info: dict, *args, **kwargs): Returns: AbstractOperationsClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_info(info) - kwargs["credentials"] = credentials - return cls(*args, **kwargs) + raise NotImplementedError("`from_service_account_info` is not implemented.") @classmethod def from_service_account_file(cls, filename: str, *args, **kwargs): @@ -161,9 +159,7 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: AbstractOperationsClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file(filename) - kwargs["credentials"] = credentials - return cls(*args, **kwargs) + raise NotImplementedError("`from_service_account_file` is not implemented.") from_service_account_json = from_service_account_file @@ -362,15 +358,38 @@ def __init__( ) self._transport = transport else: - # TODO (WIP): This code block will fail becuase async rest layer does not support all params. Transport = type(self).get_transport_class(transport) - self._transport = Transport( - credentials=credentials, - credentials_file=client_options.credentials_file, - host=api_endpoint, - scopes=client_options.scopes, - client_cert_source_for_mtls=client_cert_source_func, - quota_project_id=client_options.quota_project_id, - client_info=client_info, - always_use_jwt_access=True, - ) + if "async" in str(Transport).lower(): + # TODO(https://github.com/googleapis/gapic-generator-python/issues/2136): Support the following parameters in async rest. + unsupported_params = { + "google.api_core.client_options.ClientOptions.credentials_file": client_options.credentials_file, + "google.api_core.client_options.ClientOptions.scopes": client_options.scopes, + "google.api_core.client_options.ClientOptions.quota_project_id": client_options.quota_project_id, + "google.api_core.client_options.ClientOptions.client_cert_source": client_options.client_cert_source, + "google.api_core.client_options.ClientOptions.api_audience": client_options.api_audience, + } + provided_unsupported_params = [ + name + for name, value in unsupported_params.items() + if value is not None + ] + if provided_unsupported_params: + raise core_exceptions.AsyncRestUnsupportedParameterError( + f"The following provided parameters are not supported for `transport=rest_asyncio`: {', '.join(provided_unsupported_params)}" + ) + self._transport = Transport( + credentials=credentials, + host=api_endpoint, + client_info=client_info, + ) + else: + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + always_use_jwt_access=True, + ) diff --git a/google/api_core/operations_v1/abstract_operations_client.py b/google/api_core/operations_v1/abstract_operations_client.py index 64f2cb7d..fc445362 100644 --- a/google/api_core/operations_v1/abstract_operations_client.py +++ b/google/api_core/operations_v1/abstract_operations_client.py @@ -28,6 +28,7 @@ ) from google.auth import credentials as ga_credentials # type: ignore from google.longrunning import operations_pb2 +from google.oauth2 import service_account # type: ignore import grpc OptionalRetry = Union[retries.Retry, object] @@ -98,6 +99,43 @@ def __init__( client_info=client_info, ) + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + AbstractOperationsClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + AbstractOperationsClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + def list_operations( self, name: str, diff --git a/tests/unit/operations_v1/test_operations_rest_client.py b/tests/unit/operations_v1/test_operations_rest_client.py index f1adbf04..68bd9287 100644 --- a/tests/unit/operations_v1/test_operations_rest_client.py +++ b/tests/unit/operations_v1/test_operations_rest_client.py @@ -197,16 +197,22 @@ def test__get_default_mtls_endpoint(client_class): ) def test_operations_client_from_service_account_info(client_class): creds = ga_credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_info" - ) as factory: - factory.return_value = creds - info = {"valid": True} - client = client_class.from_service_account_info(info) - assert client.transport._credentials == creds - assert isinstance(client, client_class) + if "async" in str(client_class): + # TODO(): Add support for service account info to async REST transport. + with pytest.raises(NotImplementedError): + info = {"valid": True} + client_class.from_service_account_info(info) + else: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info) + assert client.transport._credentials == creds + assert isinstance(client, client_class) - assert client.transport._host == "https://longrunning.googleapis.com" + assert client.transport._host == "https://longrunning.googleapis.com" @pytest.mark.parametrize( @@ -238,20 +244,26 @@ def test_operations_client_service_account_always_use_jwt(transport_class): CLIENTS, ) def test_operations_client_from_service_account_file(client_class): - creds = ga_credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: - factory.return_value = creds - client = client_class.from_service_account_file("dummy/file/path.json") - assert client.transport._credentials == creds - assert isinstance(client, client_class) - client = client_class.from_service_account_json("dummy/file/path.json") - assert client.transport._credentials == creds - assert isinstance(client, client_class) + if "async" in str(client_class): + # TODO(): Add support for service account creds to async REST transport. + with pytest.raises(NotImplementedError): + client_class.from_service_account_file("dummy/file/path.json") + else: + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + assert isinstance(client, client_class) - assert client.transport._host == "https://longrunning.googleapis.com" + assert client.transport._host == "https://longrunning.googleapis.com" @pytest.mark.parametrize( @@ -273,9 +285,10 @@ def test_operations_client_get_transport_class( assert transport == transport_class +# TODO(): Update this test case to include async REST once we have support for MTLS. @pytest.mark.parametrize( "client_class,transport_class,transport_name", - CLIENTS_WITH_TRANSPORT, + [(AbstractOperationsClient, transports.OperationsRestTransport, "rest")], ) @mock.patch.object( AbstractOperationsClient, @@ -501,19 +514,24 @@ def test_operations_client_client_options_scopes( options = client_options.ClientOptions( scopes=["1", "2"], ) - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options, transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=["1", "2"], - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - ) + if "async" in str(client_class): + # TODO(): Add support for scopes to async REST transport. + with pytest.raises(core_exceptions.AsyncRestUnsupportedParameterError): + client_class(client_options=options, transport=transport_name) + else: + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) @pytest.mark.parametrize( @@ -525,19 +543,24 @@ def test_operations_client_client_options_credentials_file( ): # Check the case credentials file is provided. options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options, transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file="credentials.json", - host=client.DEFAULT_ENDPOINT, - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - ) + if "async" in str(client_class): + # TODO(): Add support for credentials file to async REST transport. + with pytest.raises(core_exceptions.AsyncRestUnsupportedParameterError): + client_class(client_options=options, transport=transport_name) + else: + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) def test_list_operations_rest(): @@ -1167,12 +1190,22 @@ def test_operations_auth_adc(client_class): # If no credentials are provided, we should use ADC credentials. with mock.patch.object(google.auth, "default", autospec=True) as adc: adc.return_value = (ga_credentials.AnonymousCredentials(), None) - client_class() - adc.assert_called_once_with( - scopes=None, - default_scopes=(), - quota_project_id=None, - ) + + if "async" in str(client_class).lower(): + # TODO(): Add support for adc to async REST transport. + # NOTE: Ideally, the logic for adc shouldn't be called if transport + # is set to async REST. If the user does not configure credentials + # of type `google.auth.aio.credentials.Credentials`, + # we should raise an exception to avoid the adc workflow. + with pytest.raises(google.auth.exceptions.InvalidType): + client_class() + else: + client_class() + adc.assert_called_once_with( + scopes=None, + default_scopes=(), + quota_project_id=None, + ) # TODO(https://github.com/googleapis/python-api-core/issues/705): Add @@ -1196,12 +1229,12 @@ def test_operations_http_transport_client_cert_source_for_mtls(transport_class): @pytest.mark.parametrize( - "client_class", - CLIENTS, + "client_class,transport_class,credentials", + CLIENTS_WITH_CREDENTIALS, ) -def test_operations_host_no_port(client_class): +def test_operations_host_no_port(client_class, transport_class, credentials): client = client_class( - credentials=ga_credentials.AnonymousCredentials(), + credentials=credentials, client_options=client_options.ClientOptions( api_endpoint="longrunning.googleapis.com" ), @@ -1210,12 +1243,12 @@ def test_operations_host_no_port(client_class): @pytest.mark.parametrize( - "client_class", - CLIENTS, + "client_class,transport_class,credentials", + CLIENTS_WITH_CREDENTIALS, ) -def test_operations_host_with_port(client_class): +def test_operations_host_with_port(client_class, transport_class, credentials): client = client_class( - credentials=ga_credentials.AnonymousCredentials(), + credentials=credentials, client_options=client_options.ClientOptions( api_endpoint="longrunning.googleapis.com:8000" ), @@ -1367,27 +1400,21 @@ def test_parse_common_location_path(client_class): @pytest.mark.parametrize( - "client_class", - CLIENTS, + "client_class,transport_class,credentials", + CLIENTS_WITH_CREDENTIALS, ) -def test_client_withDEFAULT_CLIENT_INFO(client_class): +def test_client_withDEFAULT_CLIENT_INFO(client_class, transport_class, credentials): client_info = gapic_v1.client_info.ClientInfo() - - with mock.patch.object( - transports.OperationsTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transport_class, "_prep_wrapped_messages") as prep: client_class( - credentials=ga_credentials.AnonymousCredentials(), + credentials=credentials, client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object( - transports.OperationsTransport, "_prep_wrapped_messages" - ) as prep: - transport_class = client_class.get_transport_class() + with mock.patch.object(transport_class, "_prep_wrapped_messages") as prep: transport_class( - credentials=ga_credentials.AnonymousCredentials(), + credentials=credentials, client_info=client_info, ) prep.assert_called_once_with(client_info)