Skip to content

Commit

Permalink
Merge branch 'main' into add-python-3.12
Browse files Browse the repository at this point in the history
  • Loading branch information
parthea authored Nov 28, 2023
2 parents 61890c3 + 7ab0fce commit 391fa52
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 9 deletions.
53 changes: 44 additions & 9 deletions google/auth/compute_engine/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def get(
recursive=False,
retry_count=5,
headers=None,
return_none_for_not_found_error=False,
):
"""Fetch a resource from the metadata server.
Expand All @@ -173,6 +174,8 @@ def get(
retry_count (int): How many times to attempt connecting to metadata
server using above timeout.
headers (Optional[Mapping[str, str]]): Headers for the request.
return_none_for_not_found_error (Optional[bool]): If True, returns None
for 404 error instead of throwing an exception.
Returns:
Union[Mapping, str]: If the metadata server returns JSON, a mapping of
Expand Down Expand Up @@ -216,8 +219,17 @@ def get(
"metadata service. Compute Engine Metadata server unavailable".format(url)
)

content = _helpers.from_bytes(response.data)

if response.status == http_client.NOT_FOUND and return_none_for_not_found_error:
_LOGGER.info(
"Compute Engine Metadata server call to %s returned 404, reason: %s",
path,
content,
)
return None

if response.status == http_client.OK:
content = _helpers.from_bytes(response.data)
if (
_helpers.parse_content_type(response.headers["content-type"])
== "application/json"
Expand All @@ -232,14 +244,14 @@ def get(
raise new_exc from caught_exc
else:
return content
else:
raise exceptions.TransportError(
"Failed to retrieve {} from the Google Compute Engine "
"metadata service. Status: {} Response:\n{}".format(
url, response.status, response.data
),
response,
)

raise exceptions.TransportError(
"Failed to retrieve {} from the Google Compute Engine "
"metadata service. Status: {} Response:\n{}".format(
url, response.status, response.data
),
response,
)


def get_project_id(request):
Expand All @@ -259,6 +271,29 @@ def get_project_id(request):
return get(request, "project/project-id")


def get_universe_domain(request):
"""Get the universe domain value from the metadata server.
Args:
request (google.auth.transport.Request): A callable used to make
HTTP requests.
Returns:
str: The universe domain value. If the universe domain endpoint is not
not found, return the default value, which is googleapis.com
Raises:
google.auth.exceptions.TransportError: if an error other than
404 occurs while retrieving metadata.
"""
universe_domain = get(
request, "universe/universe_domain", return_none_for_not_found_error=True
)
if not universe_domain:
return "googleapis.com"
return universe_domain


def get_service_account_info(request, service_account="default"):
"""Get information about a service account from the metadata server.
Expand Down
9 changes: 9 additions & 0 deletions google/auth/compute_engine/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
self._quota_project_id = quota_project_id
self._scopes = scopes
self._default_scopes = default_scopes
self._universe_domain_cached = False

def _retrieve_info(self, request):
"""Retrieve information about the service account.
Expand Down Expand Up @@ -131,6 +132,14 @@ def service_account_email(self):
def requires_scopes(self):
return not self._scopes

@property
def universe_domain(self):
if self._universe_domain_cached:
return self._universe_domain
self._universe_domain = _metadata.get_universe_domain()
self._universe_domain_cached = True
return self._universe_domain

@_helpers.copy_docstring(credentials.CredentialsWithQuotaProject)
def with_quota_project(self, quota_project_id):
return self.__class__(
Expand Down
Binary file modified system_tests/secrets.tar.enc
Binary file not shown.
59 changes: 59 additions & 0 deletions tests/compute_engine/test__metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,18 @@ def test_get_failure():
)


def test_get_return_none_for_not_found_error():
request = make_request("Metadata error", status=http_client.NOT_FOUND)

assert _metadata.get(request, PATH, return_none_for_not_found_error=True) is None

request.assert_called_once_with(
method="GET",
url=_metadata._METADATA_ROOT + PATH,
headers=_metadata._METADATA_HEADERS,
)


def test_get_failure_connection_failed():
request = make_request("")
request.side_effect = exceptions.TransportError()
Expand Down Expand Up @@ -371,6 +383,53 @@ def test_get_project_id():
assert project_id == project


def test_get_universe_domain_success():
request = make_request(
"fake_universe_domain", headers={"content-type": "text/plain"}
)

universe_domain = _metadata.get_universe_domain(request)

request.assert_called_once_with(
method="GET",
url=_metadata._METADATA_ROOT + "universe/universe_domain",
headers=_metadata._METADATA_HEADERS,
)
assert universe_domain == "fake_universe_domain"


def test_get_universe_domain_not_found():
# Test that if the universe domain endpoint returns 404 error, we should
# use googleapis.com as the universe domain
request = make_request("not found", status=http_client.NOT_FOUND)

universe_domain = _metadata.get_universe_domain(request)

request.assert_called_once_with(
method="GET",
url=_metadata._METADATA_ROOT + "universe/universe_domain",
headers=_metadata._METADATA_HEADERS,
)
assert universe_domain == "googleapis.com"


def test_get_universe_domain_other_error():
# Test that if the universe domain endpoint returns an error other than 404
# we should throw the error
request = make_request("unauthorized", status=http_client.UNAUTHORIZED)

with pytest.raises(exceptions.TransportError) as excinfo:
_metadata.get_universe_domain(request)

assert excinfo.match(r"unauthorized")

request.assert_called_once_with(
method="GET",
url=_metadata._METADATA_ROOT + "universe/universe_domain",
headers=_metadata._METADATA_HEADERS,
)


@mock.patch(
"google.auth.metrics.token_request_access_token_mds",
return_value=ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE,
Expand Down
20 changes: 20 additions & 0 deletions tests/compute_engine/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,26 @@ def test_token_usage_metrics(self):
assert headers["authorization"] == "Bearer token"
assert headers["x-goog-api-client"] == "cred-type/mds"

@mock.patch(
"google.auth.compute_engine._metadata.get_universe_domain",
return_value="fake_universe_domain",
)
def test_universe_domain(self, get_universe_domain):
self.credentials._universe_domain_cached = False
self.credentials._universe_domain = "googleapis.com"

# calling the universe_domain property should trigger a call to
# get_universe_domain to fetch the value. The value should be cached.
assert self.credentials.universe_domain == "fake_universe_domain"
assert self.credentials._universe_domain == "fake_universe_domain"
assert self.credentials._universe_domain_cached
get_universe_domain.assert_called_once()

# calling the universe_domain property the second time should use the
# cached value instead of calling get_universe_domain
assert self.credentials.universe_domain == "fake_universe_domain"
get_universe_domain.assert_called_once()


class TestIDTokenCredentials(object):
credentials = None
Expand Down

0 comments on commit 391fa52

Please sign in to comment.