diff --git a/google/auth/transport/_custom_tls_signer.py b/google/auth/transport/_custom_tls_signer.py index 07f14df02..57a563d03 100644 --- a/google/auth/transport/_custom_tls_signer.py +++ b/google/auth/transport/_custom_tls_signer.py @@ -107,6 +107,22 @@ def load_signer_lib(signer_lib_path): return lib +def load_provider_lib(provider_lib_path): + _LOGGER.debug("loading provider library from %s", provider_lib_path) + + # winmode parameter is only available for python 3.8+. + lib = ( + ctypes.CDLL(provider_lib_path, winmode=0) + if sys.version_info >= (3, 8) and os.name == "nt" + else ctypes.CDLL(provider_lib_path) + ) + + lib.ECP_attach_to_ctx.argtypes = [ctypes.c_void_p, ctypes.c_char_p] + lib.ECP_attach_to_ctx.restype = ctypes.c_int + + return lib + + # Computes SHA256 hash. def _compute_sha256_digest(to_be_signed, to_be_signed_len): from cryptography.hazmat.primitives import hashes @@ -199,21 +215,31 @@ def __init__(self, enterprise_cert_file_path): self._enterprise_cert_file_path = enterprise_cert_file_path self._cert = None self._sign_callback = None + self._provider_lib = None def load_libraries(self): - try: - with open(self._enterprise_cert_file_path, "r") as f: - enterprise_cert_json = json.load(f) - libs = enterprise_cert_json["libs"] - signer_library = libs["ecp_client"] - offload_library = libs["tls_offload"] - except (KeyError, ValueError) as caught_exc: - new_exc = exceptions.MutualTLSChannelError( - "enterprise cert file is invalid", caught_exc - ) - raise new_exc from caught_exc - self._offload_lib = load_offload_lib(offload_library) - self._signer_lib = load_signer_lib(signer_library) + with open(self._enterprise_cert_file_path, "r") as f: + enterprise_cert_json = json.load(f) + libs = enterprise_cert_json.get("libs", {}) + + signer_library = libs.get("ecp_client", None) + offload_library = libs.get("tls_offload", None) + provider_library = libs.get("ecp_provider", None) + + # Using newer provider implementation. This is mutually exclusive to the + # offload implementation. + if provider_library: + self._provider_lib = load_provider_lib(provider_library) + return + + # Using old offload implementation + if offload_library and signer_library: + self._offload_lib = load_offload_lib(offload_library) + self._signer_lib = load_signer_lib(signer_library) + self.set_up_custom_key() + return + + raise exceptions.MutualTLSChannelError("enterprise cert file is invalid") def set_up_custom_key(self): # We need to keep a reference of the cert and sign callback so it won't @@ -224,11 +250,22 @@ def set_up_custom_key(self): ) def attach_to_ssl_context(self, ctx): - # In the TLS handshake, the signing operation will be done by the - # sign_callback. - if not self._offload_lib.ConfigureSslContext( - self._sign_callback, - ctypes.c_char_p(self._cert), - _cast_ssl_ctx_to_void_p(ctx._ctx._context), - ): - raise exceptions.MutualTLSChannelError("failed to configure SSL context") + if self._provider_lib: + if not self._provider_lib.ECP_attach_to_ctx( + _cast_ssl_ctx_to_void_p(ctx._ctx._context), + self._enterprise_cert_file_path.encode("ascii"), + ): + raise exceptions.MutualTLSChannelError( + "failed to configure ECP Provider SSL context" + ) + elif self._offload_lib and self._signer_lib: + if not self._offload_lib.ConfigureSslContext( + self._sign_callback, + ctypes.c_char_p(self._cert), + _cast_ssl_ctx_to_void_p(ctx._ctx._context), + ): + raise exceptions.MutualTLSChannelError( + "failed to configure ECP Offload SSL context" + ) + else: + raise exceptions.MutualTLSChannelError("Invalid ECP configuration.") diff --git a/google/auth/transport/requests.py b/google/auth/transport/requests.py index b9bcad359..aa1611322 100644 --- a/google/auth/transport/requests.py +++ b/google/auth/transport/requests.py @@ -274,7 +274,6 @@ def __init__(self, enterprise_cert_file_path): self.signer = _custom_tls_signer.CustomTlsSigner(enterprise_cert_file_path) self.signer.load_libraries() - self.signer.set_up_custom_key() poolmanager = create_urllib3_context() poolmanager.load_verify_locations(cafile=certifi.where()) diff --git a/tests/data/enterprise_cert_valid_provider.json b/tests/data/enterprise_cert_valid_provider.json new file mode 100644 index 000000000..9b7adf8bc --- /dev/null +++ b/tests/data/enterprise_cert_valid_provider.json @@ -0,0 +1,6 @@ +{ + "libs": { + "ecp_client": "/path/to/signer/lib", + "ecp_provider": "/path/to/provider/lib" + } +} diff --git a/tests/transport/test__custom_tls_signer.py b/tests/transport/test__custom_tls_signer.py index 5836b325a..d2907bad2 100644 --- a/tests/transport/test__custom_tls_signer.py +++ b/tests/transport/test__custom_tls_signer.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import base64 import ctypes import os @@ -30,11 +29,19 @@ ENTERPRISE_CERT_FILE = os.path.join( os.path.dirname(__file__), "../data/enterprise_cert_valid.json" ) +ENTERPRISE_CERT_FILE_PROVIDER = os.path.join( + os.path.dirname(__file__), "../data/enterprise_cert_valid_provider.json" +) INVALID_ENTERPRISE_CERT_FILE = os.path.join( os.path.dirname(__file__), "../data/enterprise_cert_invalid.json" ) +def test_load_provider_lib(): + with mock.patch("ctypes.CDLL", return_value=mock.MagicMock()): + _custom_tls_signer.load_provider_lib("/path/to/provider/lib") + + def test_load_offload_lib(): with mock.patch("ctypes.CDLL", return_value=mock.MagicMock()): lib = _custom_tls_signer.load_offload_lib("/path/to/offload/lib") @@ -173,62 +180,81 @@ def test_custom_tls_signer(): ) as load_offload_lib: load_offload_lib.return_value = offload_lib load_signer_lib.return_value = signer_lib - signer_object = _custom_tls_signer.CustomTlsSigner(ENTERPRISE_CERT_FILE) - signer_object.load_libraries() - assert signer_object._cert is None + with mock.patch( + "google.auth.transport._custom_tls_signer.get_cert" + ) as get_cert: + with mock.patch( + "google.auth.transport._custom_tls_signer.get_sign_callback" + ) as get_sign_callback: + get_cert.return_value = b"mock_cert" + signer_object = _custom_tls_signer.CustomTlsSigner( + ENTERPRISE_CERT_FILE + ) + signer_object.load_libraries() + signer_object.attach_to_ssl_context(create_urllib3_context()) + get_cert.assert_called_once() + get_sign_callback.assert_called_once() + offload_lib.ConfigureSslContext.assert_called_once() assert signer_object._enterprise_cert_file_path == ENTERPRISE_CERT_FILE assert signer_object._offload_lib == offload_lib assert signer_object._signer_lib == signer_lib load_signer_lib.assert_called_with("/path/to/signer/lib") load_offload_lib.assert_called_with("/path/to/offload/lib") - # Test set_up_custom_key and set_up_ssl_context methods - with mock.patch("google.auth.transport._custom_tls_signer.get_cert") as get_cert: - with mock.patch( - "google.auth.transport._custom_tls_signer.get_sign_callback" - ) as get_sign_callback: - get_cert.return_value = b"mock_cert" - signer_object.set_up_custom_key() - signer_object.attach_to_ssl_context(create_urllib3_context()) - get_cert.assert_called_once() - get_sign_callback.assert_called_once() - offload_lib.ConfigureSslContext.assert_called_once() +def test_custom_tls_signer_provider(): + provider_lib = mock.MagicMock() -def test_custom_tls_signer_failed_to_load_libraries(): # Test load_libraries method + with mock.patch( + "google.auth.transport._custom_tls_signer.load_provider_lib" + ) as load_provider_lib: + load_provider_lib.return_value = provider_lib + signer_object = _custom_tls_signer.CustomTlsSigner( + ENTERPRISE_CERT_FILE_PROVIDER + ) + signer_object.load_libraries() + signer_object.attach_to_ssl_context(mock.MagicMock()) + + assert signer_object._enterprise_cert_file_path == ENTERPRISE_CERT_FILE_PROVIDER + assert signer_object._provider_lib == provider_lib + load_provider_lib.assert_called_with("/path/to/provider/lib") + + +def test_custom_tls_signer_failed_to_load_libraries(): with pytest.raises(exceptions.MutualTLSChannelError) as excinfo: signer_object = _custom_tls_signer.CustomTlsSigner(INVALID_ENTERPRISE_CERT_FILE) signer_object.load_libraries() assert excinfo.match("enterprise cert file is invalid") -def test_custom_tls_signer_fail_to_offload(): - offload_lib = mock.MagicMock() - signer_lib = mock.MagicMock() +def test_custom_tls_signer_failed_to_attach(): + with pytest.raises(exceptions.MutualTLSChannelError) as excinfo: + signer_object = _custom_tls_signer.CustomTlsSigner(ENTERPRISE_CERT_FILE) + signer_object._offload_lib = mock.MagicMock() + signer_object._signer_lib = mock.MagicMock() + signer_object._sign_callback = mock.MagicMock() + signer_object._cert = b"mock cert" + signer_object._offload_lib.ConfigureSslContext.return_value = False + signer_object.attach_to_ssl_context(mock.MagicMock()) + assert excinfo.match("failed to configure ECP Offload SSL context") - with mock.patch( - "google.auth.transport._custom_tls_signer.load_signer_lib" - ) as load_signer_lib: - with mock.patch( - "google.auth.transport._custom_tls_signer.load_offload_lib" - ) as load_offload_lib: - load_offload_lib.return_value = offload_lib - load_signer_lib.return_value = signer_lib - signer_object = _custom_tls_signer.CustomTlsSigner(ENTERPRISE_CERT_FILE) - signer_object.load_libraries() - # set the return value to be 0 which indicts offload fails - offload_lib.ConfigureSslContext.return_value = 0 +def test_custom_tls_signer_failed_to_attach_provider(): + with pytest.raises(exceptions.MutualTLSChannelError) as excinfo: + signer_object = _custom_tls_signer.CustomTlsSigner( + ENTERPRISE_CERT_FILE_PROVIDER + ) + signer_object._provider_lib = mock.MagicMock() + signer_object._provider_lib.ECP_attach_to_ctx.return_value = False + signer_object.attach_to_ssl_context(mock.MagicMock()) + assert excinfo.match("failed to configure ECP Provider SSL context") + +def test_custom_tls_signer_failed_to_attach_no_libs(): with pytest.raises(exceptions.MutualTLSChannelError) as excinfo: - with mock.patch( - "google.auth.transport._custom_tls_signer.get_cert" - ) as get_cert: - with mock.patch( - "google.auth.transport._custom_tls_signer.get_sign_callback" - ): - get_cert.return_value = b"mock_cert" - signer_object.set_up_custom_key() - signer_object.attach_to_ssl_context(create_urllib3_context()) - assert excinfo.match("failed to configure SSL context") + signer_object = _custom_tls_signer.CustomTlsSigner(ENTERPRISE_CERT_FILE) + signer_object._offload_lib = None + signer_object._signer_lib = None + signer_object.attach_to_ssl_context(mock.MagicMock()) + assert excinfo.match("Invalid ECP configuration.") diff --git a/tests/transport/test_requests.py b/tests/transport/test_requests.py index d96281434..aadc1ddbf 100644 --- a/tests/transport/test_requests.py +++ b/tests/transport/test_requests.py @@ -544,9 +544,6 @@ class TestMutualTlsOffloadAdapter(object): @mock.patch.object( google.auth.transport._custom_tls_signer.CustomTlsSigner, "load_libraries" ) - @mock.patch.object( - google.auth.transport._custom_tls_signer.CustomTlsSigner, "set_up_custom_key" - ) @mock.patch.object( google.auth.transport._custom_tls_signer.CustomTlsSigner, "attach_to_ssl_context", @@ -554,7 +551,6 @@ class TestMutualTlsOffloadAdapter(object): def test_success( self, mock_attach_to_ssl_context, - mock_set_up_custom_key, mock_load_libraries, mock_proxy_manager_for, mock_init_poolmanager, @@ -565,7 +561,6 @@ def test_success( ) mock_load_libraries.assert_called_once() - mock_set_up_custom_key.assert_called_once() assert mock_attach_to_ssl_context.call_count == 2 adapter.init_poolmanager()