Skip to content

Commit

Permalink
feat: Add custom tls signer for ECP Provider. (#1402)
Browse files Browse the repository at this point in the history
feat: Add custom tls signer for ECP Provider.
  • Loading branch information
clundin25 authored Nov 30, 2023
1 parent 9b46ee3 commit 39eb287
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 68 deletions.
79 changes: 58 additions & 21 deletions google/auth/transport/_custom_tls_signer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,22 @@ def load_signer_lib(signer_lib_path):
return lib


def load_provider_lib(provider_lib_path):
_LOGGER.debug("loading provider library from %s", provider_lib_path)

# winmode parameter is only available for python 3.8+.
lib = (
ctypes.CDLL(provider_lib_path, winmode=0)
if sys.version_info >= (3, 8) and os.name == "nt"
else ctypes.CDLL(provider_lib_path)
)

lib.ECP_attach_to_ctx.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
lib.ECP_attach_to_ctx.restype = ctypes.c_int

return lib


# Computes SHA256 hash.
def _compute_sha256_digest(to_be_signed, to_be_signed_len):
from cryptography.hazmat.primitives import hashes
Expand Down Expand Up @@ -199,21 +215,31 @@ def __init__(self, enterprise_cert_file_path):
self._enterprise_cert_file_path = enterprise_cert_file_path
self._cert = None
self._sign_callback = None
self._provider_lib = None

def load_libraries(self):
try:
with open(self._enterprise_cert_file_path, "r") as f:
enterprise_cert_json = json.load(f)
libs = enterprise_cert_json["libs"]
signer_library = libs["ecp_client"]
offload_library = libs["tls_offload"]
except (KeyError, ValueError) as caught_exc:
new_exc = exceptions.MutualTLSChannelError(
"enterprise cert file is invalid", caught_exc
)
raise new_exc from caught_exc
self._offload_lib = load_offload_lib(offload_library)
self._signer_lib = load_signer_lib(signer_library)
with open(self._enterprise_cert_file_path, "r") as f:
enterprise_cert_json = json.load(f)
libs = enterprise_cert_json.get("libs", {})

signer_library = libs.get("ecp_client", None)
offload_library = libs.get("tls_offload", None)
provider_library = libs.get("ecp_provider", None)

# Using newer provider implementation. This is mutually exclusive to the
# offload implementation.
if provider_library:
self._provider_lib = load_provider_lib(provider_library)
return

# Using old offload implementation
if offload_library and signer_library:
self._offload_lib = load_offload_lib(offload_library)
self._signer_lib = load_signer_lib(signer_library)
self.set_up_custom_key()
return

raise exceptions.MutualTLSChannelError("enterprise cert file is invalid")

def set_up_custom_key(self):
# We need to keep a reference of the cert and sign callback so it won't
Expand All @@ -224,11 +250,22 @@ def set_up_custom_key(self):
)

def attach_to_ssl_context(self, ctx):
# In the TLS handshake, the signing operation will be done by the
# sign_callback.
if not self._offload_lib.ConfigureSslContext(
self._sign_callback,
ctypes.c_char_p(self._cert),
_cast_ssl_ctx_to_void_p(ctx._ctx._context),
):
raise exceptions.MutualTLSChannelError("failed to configure SSL context")
if self._provider_lib:
if not self._provider_lib.ECP_attach_to_ctx(
_cast_ssl_ctx_to_void_p(ctx._ctx._context),
self._enterprise_cert_file_path.encode("ascii"),
):
raise exceptions.MutualTLSChannelError(
"failed to configure ECP Provider SSL context"
)
elif self._offload_lib and self._signer_lib:
if not self._offload_lib.ConfigureSslContext(
self._sign_callback,
ctypes.c_char_p(self._cert),
_cast_ssl_ctx_to_void_p(ctx._ctx._context),
):
raise exceptions.MutualTLSChannelError(
"failed to configure ECP Offload SSL context"
)
else:
raise exceptions.MutualTLSChannelError("Invalid ECP configuration.")
1 change: 0 additions & 1 deletion google/auth/transport/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,6 @@ def __init__(self, enterprise_cert_file_path):

self.signer = _custom_tls_signer.CustomTlsSigner(enterprise_cert_file_path)
self.signer.load_libraries()
self.signer.set_up_custom_key()

poolmanager = create_urllib3_context()
poolmanager.load_verify_locations(cafile=certifi.where())
Expand Down
6 changes: 6 additions & 0 deletions tests/data/enterprise_cert_valid_provider.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"libs": {
"ecp_client": "/path/to/signer/lib",
"ecp_provider": "/path/to/provider/lib"
}
}
108 changes: 67 additions & 41 deletions tests/transport/test__custom_tls_signer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import ctypes
import os
Expand All @@ -30,11 +29,19 @@
ENTERPRISE_CERT_FILE = os.path.join(
os.path.dirname(__file__), "../data/enterprise_cert_valid.json"
)
ENTERPRISE_CERT_FILE_PROVIDER = os.path.join(
os.path.dirname(__file__), "../data/enterprise_cert_valid_provider.json"
)
INVALID_ENTERPRISE_CERT_FILE = os.path.join(
os.path.dirname(__file__), "../data/enterprise_cert_invalid.json"
)


def test_load_provider_lib():
with mock.patch("ctypes.CDLL", return_value=mock.MagicMock()):
_custom_tls_signer.load_provider_lib("/path/to/provider/lib")


def test_load_offload_lib():
with mock.patch("ctypes.CDLL", return_value=mock.MagicMock()):
lib = _custom_tls_signer.load_offload_lib("/path/to/offload/lib")
Expand Down Expand Up @@ -173,62 +180,81 @@ def test_custom_tls_signer():
) as load_offload_lib:
load_offload_lib.return_value = offload_lib
load_signer_lib.return_value = signer_lib
signer_object = _custom_tls_signer.CustomTlsSigner(ENTERPRISE_CERT_FILE)
signer_object.load_libraries()
assert signer_object._cert is None
with mock.patch(
"google.auth.transport._custom_tls_signer.get_cert"
) as get_cert:
with mock.patch(
"google.auth.transport._custom_tls_signer.get_sign_callback"
) as get_sign_callback:
get_cert.return_value = b"mock_cert"
signer_object = _custom_tls_signer.CustomTlsSigner(
ENTERPRISE_CERT_FILE
)
signer_object.load_libraries()
signer_object.attach_to_ssl_context(create_urllib3_context())
get_cert.assert_called_once()
get_sign_callback.assert_called_once()
offload_lib.ConfigureSslContext.assert_called_once()
assert signer_object._enterprise_cert_file_path == ENTERPRISE_CERT_FILE
assert signer_object._offload_lib == offload_lib
assert signer_object._signer_lib == signer_lib
load_signer_lib.assert_called_with("/path/to/signer/lib")
load_offload_lib.assert_called_with("/path/to/offload/lib")

# Test set_up_custom_key and set_up_ssl_context methods
with mock.patch("google.auth.transport._custom_tls_signer.get_cert") as get_cert:
with mock.patch(
"google.auth.transport._custom_tls_signer.get_sign_callback"
) as get_sign_callback:
get_cert.return_value = b"mock_cert"
signer_object.set_up_custom_key()
signer_object.attach_to_ssl_context(create_urllib3_context())
get_cert.assert_called_once()
get_sign_callback.assert_called_once()
offload_lib.ConfigureSslContext.assert_called_once()

def test_custom_tls_signer_provider():
provider_lib = mock.MagicMock()

def test_custom_tls_signer_failed_to_load_libraries():
# Test load_libraries method
with mock.patch(
"google.auth.transport._custom_tls_signer.load_provider_lib"
) as load_provider_lib:
load_provider_lib.return_value = provider_lib
signer_object = _custom_tls_signer.CustomTlsSigner(
ENTERPRISE_CERT_FILE_PROVIDER
)
signer_object.load_libraries()
signer_object.attach_to_ssl_context(mock.MagicMock())

assert signer_object._enterprise_cert_file_path == ENTERPRISE_CERT_FILE_PROVIDER
assert signer_object._provider_lib == provider_lib
load_provider_lib.assert_called_with("/path/to/provider/lib")


def test_custom_tls_signer_failed_to_load_libraries():
with pytest.raises(exceptions.MutualTLSChannelError) as excinfo:
signer_object = _custom_tls_signer.CustomTlsSigner(INVALID_ENTERPRISE_CERT_FILE)
signer_object.load_libraries()
assert excinfo.match("enterprise cert file is invalid")


def test_custom_tls_signer_fail_to_offload():
offload_lib = mock.MagicMock()
signer_lib = mock.MagicMock()
def test_custom_tls_signer_failed_to_attach():
with pytest.raises(exceptions.MutualTLSChannelError) as excinfo:
signer_object = _custom_tls_signer.CustomTlsSigner(ENTERPRISE_CERT_FILE)
signer_object._offload_lib = mock.MagicMock()
signer_object._signer_lib = mock.MagicMock()
signer_object._sign_callback = mock.MagicMock()
signer_object._cert = b"mock cert"
signer_object._offload_lib.ConfigureSslContext.return_value = False
signer_object.attach_to_ssl_context(mock.MagicMock())
assert excinfo.match("failed to configure ECP Offload SSL context")

with mock.patch(
"google.auth.transport._custom_tls_signer.load_signer_lib"
) as load_signer_lib:
with mock.patch(
"google.auth.transport._custom_tls_signer.load_offload_lib"
) as load_offload_lib:
load_offload_lib.return_value = offload_lib
load_signer_lib.return_value = signer_lib
signer_object = _custom_tls_signer.CustomTlsSigner(ENTERPRISE_CERT_FILE)
signer_object.load_libraries()

# set the return value to be 0 which indicts offload fails
offload_lib.ConfigureSslContext.return_value = 0
def test_custom_tls_signer_failed_to_attach_provider():
with pytest.raises(exceptions.MutualTLSChannelError) as excinfo:
signer_object = _custom_tls_signer.CustomTlsSigner(
ENTERPRISE_CERT_FILE_PROVIDER
)
signer_object._provider_lib = mock.MagicMock()
signer_object._provider_lib.ECP_attach_to_ctx.return_value = False
signer_object.attach_to_ssl_context(mock.MagicMock())
assert excinfo.match("failed to configure ECP Provider SSL context")


def test_custom_tls_signer_failed_to_attach_no_libs():
with pytest.raises(exceptions.MutualTLSChannelError) as excinfo:
with mock.patch(
"google.auth.transport._custom_tls_signer.get_cert"
) as get_cert:
with mock.patch(
"google.auth.transport._custom_tls_signer.get_sign_callback"
):
get_cert.return_value = b"mock_cert"
signer_object.set_up_custom_key()
signer_object.attach_to_ssl_context(create_urllib3_context())
assert excinfo.match("failed to configure SSL context")
signer_object = _custom_tls_signer.CustomTlsSigner(ENTERPRISE_CERT_FILE)
signer_object._offload_lib = None
signer_object._signer_lib = None
signer_object.attach_to_ssl_context(mock.MagicMock())
assert excinfo.match("Invalid ECP configuration.")
5 changes: 0 additions & 5 deletions tests/transport/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,17 +544,13 @@ class TestMutualTlsOffloadAdapter(object):
@mock.patch.object(
google.auth.transport._custom_tls_signer.CustomTlsSigner, "load_libraries"
)
@mock.patch.object(
google.auth.transport._custom_tls_signer.CustomTlsSigner, "set_up_custom_key"
)
@mock.patch.object(
google.auth.transport._custom_tls_signer.CustomTlsSigner,
"attach_to_ssl_context",
)
def test_success(
self,
mock_attach_to_ssl_context,
mock_set_up_custom_key,
mock_load_libraries,
mock_proxy_manager_for,
mock_init_poolmanager,
Expand All @@ -565,7 +561,6 @@ def test_success(
)

mock_load_libraries.assert_called_once()
mock_set_up_custom_key.assert_called_once()
assert mock_attach_to_ssl_context.call_count == 2

adapter.init_poolmanager()
Expand Down

0 comments on commit 39eb287

Please sign in to comment.