diff --git a/google/auth/crypt/_cryptography_rsa.py b/google/auth/crypt/_cryptography_rsa.py index 4f2d61166..1a3e9ff52 100644 --- a/google/auth/crypt/_cryptography_rsa.py +++ b/google/auth/crypt/_cryptography_rsa.py @@ -134,3 +134,18 @@ def from_string(cls, key, key_id=None): key, password=None, backend=_BACKEND ) return cls(private_key, key_id=key_id) + + def __getstate__(self): + """Pickle helper that serializes the _key attribute.""" + state = self.__dict__.copy() + state["_key"] = self._key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + return state + + def __setstate__(self, state): + """Pickle helper that deserializes the _key attribute.""" + state["_key"] = serialization.load_pem_private_key(state["_key"], None) + self.__dict__.update(state) diff --git a/google/auth/crypt/es256.py b/google/auth/crypt/es256.py index 7920cc7ff..820e4becc 100644 --- a/google/auth/crypt/es256.py +++ b/google/auth/crypt/es256.py @@ -158,3 +158,18 @@ def from_string(cls, key, key_id=None): key, password=None, backend=_BACKEND ) return cls(private_key, key_id=key_id) + + def __getstate__(self): + """Pickle helper that serializes the _key attribute.""" + state = self.__dict__.copy() + state["_key"] = self._key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + return state + + def __setstate__(self, state): + """Pickle helper that deserializes the _key attribute.""" + state["_key"] = serialization.load_pem_private_key(state["_key"], None) + self.__dict__.update(state) diff --git a/system_tests/secrets.tar.enc b/system_tests/secrets.tar.enc index cae671c34..1e35ceb60 100644 Binary files a/system_tests/secrets.tar.enc and b/system_tests/secrets.tar.enc differ diff --git a/tests/crypt/test__cryptography_rsa.py b/tests/crypt/test__cryptography_rsa.py index 99d8fc37c..1199f8d1b 100644 --- a/tests/crypt/test__cryptography_rsa.py +++ b/tests/crypt/test__cryptography_rsa.py @@ -14,6 +14,7 @@ import json import os +import pickle from cryptography.hazmat.primitives.asymmetric import rsa import pytest # type: ignore @@ -159,3 +160,17 @@ def test_from_service_account_file(self): assert signer.key_id == SERVICE_ACCOUNT_INFO[base._JSON_FILE_PRIVATE_KEY_ID] assert isinstance(signer._key, rsa.RSAPrivateKey) + + def test_pickle(self): + signer = _cryptography_rsa.RSASigner.from_service_account_file( + SERVICE_ACCOUNT_JSON_FILE + ) + + assert signer.key_id == SERVICE_ACCOUNT_INFO[base._JSON_FILE_PRIVATE_KEY_ID] + assert isinstance(signer._key, rsa.RSAPrivateKey) + + pickled_signer = pickle.dumps(signer) + signer = pickle.loads(pickled_signer) + + assert signer.key_id == SERVICE_ACCOUNT_INFO[base._JSON_FILE_PRIVATE_KEY_ID] + assert isinstance(signer._key, rsa.RSAPrivateKey) diff --git a/tests/crypt/test_es256.py b/tests/crypt/test_es256.py index 33465ce6d..f87648db4 100644 --- a/tests/crypt/test_es256.py +++ b/tests/crypt/test_es256.py @@ -15,6 +15,7 @@ import base64 import json import os +import pickle from cryptography.hazmat.primitives.asymmetric import ec import pytest # type: ignore @@ -141,3 +142,15 @@ def test_from_service_account_file(self): assert signer.key_id == SERVICE_ACCOUNT_INFO[base._JSON_FILE_PRIVATE_KEY_ID] assert isinstance(signer._key, ec.EllipticCurvePrivateKey) + + def test_pickle(self): + signer = es256.ES256Signer.from_service_account_file(SERVICE_ACCOUNT_JSON_FILE) + + assert signer.key_id == SERVICE_ACCOUNT_INFO[base._JSON_FILE_PRIVATE_KEY_ID] + assert isinstance(signer._key, ec.EllipticCurvePrivateKey) + + pickled_signer = pickle.dumps(signer) + signer = pickle.loads(pickled_signer) + + assert signer.key_id == SERVICE_ACCOUNT_INFO[base._JSON_FILE_PRIVATE_KEY_ID] + assert isinstance(signer._key, ec.EllipticCurvePrivateKey)