Skip to content

Commit

Permalink
feat: making iam endpoint universe-aware (#1604)
Browse files Browse the repository at this point in the history
* feat: making iam endpoint universe-aware

* feat: make sign and idtoken endpooints universe aware

* add universe_domain parameter for the iam request

* fix: test updates

---------

Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
  • Loading branch information
TimurSadykov and gcf-owl-bot[bot] authored Oct 19, 2024
1 parent f070de0 commit 16c728d
Show file tree
Hide file tree
Showing 8 changed files with 178 additions and 21 deletions.
13 changes: 6 additions & 7 deletions google/auth/iam.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,19 @@
http_client.GATEWAY_TIMEOUT,
}


_IAM_SCOPE = ["https://www.googleapis.com/auth/iam"]

_IAM_ENDPOINT = (
"https://iamcredentials.googleapis.com/v1/projects/-"
"https://iamcredentials.{}/v1/projects/-"
+ "/serviceAccounts/{}:generateAccessToken"
)

_IAM_SIGN_ENDPOINT = (
"https://iamcredentials.googleapis.com/v1/projects/-"
+ "/serviceAccounts/{}:signBlob"
"https://iamcredentials.{}/v1/projects/-" + "/serviceAccounts/{}:signBlob"
)

_IAM_IDTOKEN_ENDPOINT = (
"https://iamcredentials.googleapis.com/v1/"
+ "projects/-/serviceAccounts/{}:generateIdToken"
"https://iamcredentials.{}/v1/" + "projects/-/serviceAccounts/{}:generateIdToken"
)


Expand Down Expand Up @@ -90,7 +87,9 @@ def _make_signing_request(self, message):
message = _helpers.to_bytes(message)

method = "POST"
url = _IAM_SIGN_ENDPOINT.format(self._service_account_email)
url = _IAM_SIGN_ENDPOINT.format(
self._credentials.universe_domain, self._service_account_email
)
headers = {"Content-Type": "application/json"}
body = json.dumps(
{"payload": base64.b64encode(message).decode("utf-8")}
Expand Down
16 changes: 12 additions & 4 deletions google/auth/impersonated_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@


def _make_iam_token_request(
request, principal, headers, body, iam_endpoint_override=None
request, principal, headers, body, universe_domain, iam_endpoint_override=None
):
"""Makes a request to the Google Cloud IAM service for an access token.
Args:
Expand All @@ -67,7 +67,9 @@ def _make_iam_token_request(
`iamcredentials.googleapis.com` is not enabled or the
`Service Account Token Creator` is not assigned
"""
iam_endpoint = iam_endpoint_override or iam._IAM_ENDPOINT.format(principal)
iam_endpoint = iam_endpoint_override or iam._IAM_ENDPOINT.format(
universe_domain, principal
)

body = json.dumps(body).encode("utf-8")

Expand Down Expand Up @@ -219,6 +221,8 @@ def __init__(
and self._source_credentials._always_use_jwt_access
):
self._source_credentials._create_self_signed_jwt(None)

self._universe_domain = source_credentials.universe_domain
self._target_principal = target_principal
self._target_scopes = target_scopes
self._delegates = delegates
Expand Down Expand Up @@ -271,13 +275,16 @@ def _update_token(self, request):
principal=self._target_principal,
headers=headers,
body=body,
universe_domain=self.universe_domain,
iam_endpoint_override=self._iam_endpoint_override,
)

def sign_bytes(self, message):
from google.auth.transport.requests import AuthorizedSession

iam_sign_endpoint = iam._IAM_SIGN_ENDPOINT.format(self._target_principal)
iam_sign_endpoint = iam._IAM_SIGN_ENDPOINT.format(
self.universe_domain, self._target_principal
)

body = {
"payload": base64.b64encode(message).decode("utf-8"),
Expand Down Expand Up @@ -428,7 +435,8 @@ def refresh(self, request):
from google.auth.transport.requests import AuthorizedSession

iam_sign_endpoint = iam._IAM_IDTOKEN_ENDPOINT.format(
self._target_credentials.signer_email
self._target_credentials.universe_domain,
self._target_credentials.signer_email,
)

body = {
Expand Down
9 changes: 7 additions & 2 deletions google/oauth2/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,12 @@ def jwt_grant(request, token_uri, assertion, can_retry=True):


def call_iam_generate_id_token_endpoint(
request, iam_id_token_endpoint, signer_email, audience, access_token
request,
iam_id_token_endpoint,
signer_email,
audience,
access_token,
universe_domain,
):
"""Call iam.generateIdToken endpoint to get ID token.
Expand All @@ -339,7 +344,7 @@ def call_iam_generate_id_token_endpoint(

response_data = _token_endpoint_request(
request,
iam_id_token_endpoint.format(signer_email),
iam_id_token_endpoint.format(universe_domain, signer_email),
body,
access_token=access_token,
use_json=True,
Expand Down
1 change: 1 addition & 0 deletions google/oauth2/service_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,7 @@ def _refresh_with_iam_endpoint(self, request):
self.signer_email,
self._target_audience,
jwt_credentials.token.decode(),
self._universe_domain,
)

@_helpers.copy_docstring(credentials.Credentials)
Expand Down
20 changes: 20 additions & 0 deletions tests/compute_engine/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,16 @@ def test_with_target_audience_integration(self):
},
)

# mock information about universe_domain
responses.add(
responses.GET,
"http://metadata.google.internal/computeMetadata/v1/universe/"
"universe_domain",
status=200,
content_type="application/json",
json={},
)

# mock token for credentials
responses.add(
responses.GET,
Expand Down Expand Up @@ -659,6 +669,16 @@ def test_with_quota_project_integration(self):
},
)

# stubby response about universe_domain
responses.add(
responses.GET,
"http://metadata.google.internal/computeMetadata/v1/universe/"
"universe_domain",
status=200,
content_type="application/json",
json={},
)

# mock sign blob endpoint
signature = base64.b64encode(b"some-signature").decode("utf-8")
responses.add(
Expand Down
2 changes: 2 additions & 0 deletions tests/oauth2/test__client.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def test_call_iam_generate_id_token_endpoint():
"fake_email",
"fake_audience",
"fake_access_token",
"googleapis.com",
)

assert (
Expand Down Expand Up @@ -361,6 +362,7 @@ def test_call_iam_generate_id_token_endpoint_no_id_token():
"fake_email",
"fake_audience",
"fake_access_token",
"googleapis.com",
)
assert excinfo.match("No ID token in response")

Expand Down
8 changes: 5 additions & 3 deletions tests/oauth2/test_service_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@ def test_refresh_iam_flow(self, call_iam_generate_id_token_endpoint):
)
request = mock.Mock()
credentials.refresh(request)
req, iam_endpoint, signer_email, target_audience, access_token = call_iam_generate_id_token_endpoint.call_args[
req, iam_endpoint, signer_email, target_audience, access_token, universe_domain = call_iam_generate_id_token_endpoint.call_args[
0
]
assert req == request
Expand All @@ -798,6 +798,7 @@ def test_refresh_iam_flow(self, call_iam_generate_id_token_endpoint):
assert target_audience == "https://example.com"
decoded_access_token = jwt.decode(access_token, verify=False)
assert decoded_access_token["scope"] == "https://www.googleapis.com/auth/iam"
assert universe_domain == "googleapis.com"

@mock.patch(
"google.oauth2._client.call_iam_generate_id_token_endpoint", autospec=True
Expand All @@ -811,18 +812,19 @@ def test_refresh_iam_flow_non_gdu(self, call_iam_generate_id_token_endpoint):
)
request = mock.Mock()
credentials.refresh(request)
req, iam_endpoint, signer_email, target_audience, access_token = call_iam_generate_id_token_endpoint.call_args[
req, iam_endpoint, signer_email, target_audience, access_token, universe_domain = call_iam_generate_id_token_endpoint.call_args[
0
]
assert req == request
assert (
iam_endpoint
== "https://iamcredentials.fake-universe/v1/projects/-/serviceAccounts/{}:generateIdToken"
== "https://iamcredentials.{}/v1/projects/-/serviceAccounts/{}:generateIdToken"
)
assert signer_email == "[email protected]"
assert target_audience == "https://example.com"
decoded_access_token = jwt.decode(access_token, verify=False)
assert decoded_access_token["scope"] == "https://www.googleapis.com/auth/iam"
assert universe_domain == "fake-universe"

@mock.patch("google.oauth2._client.id_token_jwt_grant", autospec=True)
def test_before_request_refreshes(self, id_token_jwt_grant):
Expand Down
130 changes: 125 additions & 5 deletions tests/test_impersonated_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,13 @@ def test_get_cred_info(self):
"principal": "[email protected]",
}

def test_universe_domain_matching_source(self):
source_credentials = service_account.Credentials(
SIGNER, "[email protected]", TOKEN_URI, universe_domain="foo.bar"
)
credentials = self.make_credentials(source_credentials=source_credentials)
assert credentials.universe_domain == "foo.bar"

def test__make_copy_get_cred_info(self):
credentials = self.make_credentials()
credentials._cred_file_path = "/path/to/file"
Expand Down Expand Up @@ -231,6 +238,38 @@ def test_refresh_success(self, use_data_bytes, mock_donor_credentials):
== ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE
)

@pytest.mark.parametrize("use_data_bytes", [True, False])
def test_refresh_success_nonGdu(self, use_data_bytes, mock_donor_credentials):
source_credentials = service_account.Credentials(
SIGNER, "[email protected]", TOKEN_URI, universe_domain="foo.bar"
)
credentials = self.make_credentials(
lifetime=None, source_credentials=source_credentials
)
token = "token"

expire_time = (
_helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
).isoformat("T") + "Z"
response_body = {"accessToken": token, "expireTime": expire_time}

request = self.make_request(
data=json.dumps(response_body),
status=http_client.OK,
use_data_bytes=use_data_bytes,
)

credentials.refresh(request)

assert credentials.valid
assert not credentials.expired
# Confirm override endpoint used.
request_kwargs = request.call_args[1]
assert (
request_kwargs["url"]
== "https://iamcredentials.foo.bar/v1/projects/-/serviceAccounts/[email protected]:generateAccessToken"
)

@pytest.mark.parametrize("use_data_bytes", [True, False])
def test_refresh_success_iam_endpoint_override(
self, use_data_bytes, mock_donor_credentials
Expand Down Expand Up @@ -397,6 +436,38 @@ def test_service_account_email(self):

def test_sign_bytes(self, mock_donor_credentials, mock_authorizedsession_sign):
credentials = self.make_credentials(lifetime=None)
expected_url = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/[email protected]:signBlob"
self._sign_bytes_helper(
credentials,
mock_donor_credentials,
mock_authorizedsession_sign,
expected_url,
)

def test_sign_bytes_nonGdu(
self, mock_donor_credentials, mock_authorizedsession_sign
):
source_credentials = service_account.Credentials(
SIGNER, "[email protected]", TOKEN_URI, universe_domain="foo.bar"
)
credentials = self.make_credentials(
lifetime=None, source_credentials=source_credentials
)
expected_url = "https://iamcredentials.foo.bar/v1/projects/-/serviceAccounts/[email protected]:signBlob"
self._sign_bytes_helper(
credentials,
mock_donor_credentials,
mock_authorizedsession_sign,
expected_url,
)

def _sign_bytes_helper(
self,
credentials,
mock_donor_credentials,
mock_authorizedsession_sign,
expected_url,
):
token = "token"

expire_time = (
Expand All @@ -412,11 +483,19 @@ def test_sign_bytes(self, mock_donor_credentials, mock_authorizedsession_sign):
request.return_value = response

credentials.refresh(request)

assert credentials.valid
assert not credentials.expired

signature = credentials.sign_bytes(b"signed bytes")
mock_authorizedsession_sign.assert_called_with(
mock.ANY,
"POST",
expected_url,
None,
json={"payload": "c2lnbmVkIGJ5dGVz", "delegates": []},
headers={"Content-Type": "application/json"},
)

assert signature == b"signature"

def test_sign_bytes_failure(self):
Expand Down Expand Up @@ -563,6 +642,45 @@ def test_id_token_from_credential(
self, mock_donor_credentials, mock_authorizedsession_idtoken
):
credentials = self.make_credentials(lifetime=None)
target_credentials = self.make_credentials(lifetime=None)
expected_url = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/[email protected]:generateIdToken"
self._test_id_token_helper(
credentials,
target_credentials,
mock_donor_credentials,
mock_authorizedsession_idtoken,
expected_url,
)

def test_id_token_from_credential_nonGdu(
self, mock_donor_credentials, mock_authorizedsession_idtoken
):
source_credentials = service_account.Credentials(
SIGNER, "[email protected]", TOKEN_URI, universe_domain="foo.bar"
)
credentials = self.make_credentials(
lifetime=None, source_credentials=source_credentials
)
target_credentials = self.make_credentials(
lifetime=None, source_credentials=source_credentials
)
expected_url = "https://iamcredentials.foo.bar/v1/projects/-/serviceAccounts/[email protected]:generateIdToken"
self._test_id_token_helper(
credentials,
target_credentials,
mock_donor_credentials,
mock_authorizedsession_idtoken,
expected_url,
)

def _test_id_token_helper(
self,
credentials,
target_credentials,
mock_donor_credentials,
mock_authorizedsession_idtoken,
expected_url,
):
token = "token"
target_audience = "https://foo.bar"

Expand All @@ -580,17 +698,19 @@ def test_id_token_from_credential(
assert credentials.valid
assert not credentials.expired

new_credentials = self.make_credentials(lifetime=None)

id_creds = impersonated_credentials.IDTokenCredentials(
credentials, target_audience=target_audience, include_email=True
)
id_creds = id_creds.from_credentials(target_credentials=new_credentials)
id_creds = id_creds.from_credentials(target_credentials=target_credentials)
id_creds.refresh(request)

args = mock_authorizedsession_idtoken.call_args.args

assert args[2] == expected_url

assert id_creds.token == ID_TOKEN_DATA
assert id_creds._include_email is True
assert id_creds._target_credentials is new_credentials
assert id_creds._target_credentials is target_credentials

def test_id_token_with_target_audience(
self, mock_donor_credentials, mock_authorizedsession_idtoken
Expand Down

0 comments on commit 16c728d

Please sign in to comment.