diff --git a/google/auth/compute_engine/_metadata.py b/google/auth/compute_engine/_metadata.py index 1b2f5161a..1c884c3c4 100644 --- a/google/auth/compute_engine/_metadata.py +++ b/google/auth/compute_engine/_metadata.py @@ -156,6 +156,7 @@ def get( recursive=False, retry_count=5, headers=None, + return_none_for_not_found_error=False, ): """Fetch a resource from the metadata server. @@ -173,6 +174,8 @@ def get( retry_count (int): How many times to attempt connecting to metadata server using above timeout. headers (Optional[Mapping[str, str]]): Headers for the request. + return_none_for_not_found_error (Optional[bool]): If True, returns None + for 404 error instead of throwing an exception. Returns: Union[Mapping, str]: If the metadata server returns JSON, a mapping of @@ -216,8 +219,17 @@ def get( "metadata service. Compute Engine Metadata server unavailable".format(url) ) + content = _helpers.from_bytes(response.data) + + if response.status == http_client.NOT_FOUND and return_none_for_not_found_error: + _LOGGER.info( + "Compute Engine Metadata server call to %s returned 404, reason: %s", + path, + content, + ) + return None + if response.status == http_client.OK: - content = _helpers.from_bytes(response.data) if ( _helpers.parse_content_type(response.headers["content-type"]) == "application/json" @@ -232,14 +244,14 @@ def get( raise new_exc from caught_exc else: return content - else: - raise exceptions.TransportError( - "Failed to retrieve {} from the Google Compute Engine " - "metadata service. Status: {} Response:\n{}".format( - url, response.status, response.data - ), - response, - ) + + raise exceptions.TransportError( + "Failed to retrieve {} from the Google Compute Engine " + "metadata service. Status: {} Response:\n{}".format( + url, response.status, response.data + ), + response, + ) def get_project_id(request): @@ -259,6 +271,29 @@ def get_project_id(request): return get(request, "project/project-id") +def get_universe_domain(request): + """Get the universe domain value from the metadata server. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + + Returns: + str: The universe domain value. If the universe domain endpoint is not + not found, return the default value, which is googleapis.com + + Raises: + google.auth.exceptions.TransportError: if an error other than + 404 occurs while retrieving metadata. + """ + universe_domain = get( + request, "universe/universe_domain", return_none_for_not_found_error=True + ) + if not universe_domain: + return "googleapis.com" + return universe_domain + + def get_service_account_info(request, service_account="default"): """Get information about a service account from the metadata server. diff --git a/google/auth/compute_engine/credentials.py b/google/auth/compute_engine/credentials.py index 7ae673880..fa30aa44a 100644 --- a/google/auth/compute_engine/credentials.py +++ b/google/auth/compute_engine/credentials.py @@ -73,6 +73,7 @@ def __init__( self._quota_project_id = quota_project_id self._scopes = scopes self._default_scopes = default_scopes + self._universe_domain_cached = False def _retrieve_info(self, request): """Retrieve information about the service account. @@ -131,6 +132,14 @@ def service_account_email(self): def requires_scopes(self): return not self._scopes + @property + def universe_domain(self): + if self._universe_domain_cached: + return self._universe_domain + self._universe_domain = _metadata.get_universe_domain() + self._universe_domain_cached = True + return self._universe_domain + @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) def with_quota_project(self, quota_project_id): return self.__class__( diff --git a/system_tests/secrets.tar.enc b/system_tests/secrets.tar.enc index a8ffdb89f..06696ab2a 100644 Binary files a/system_tests/secrets.tar.enc and b/system_tests/secrets.tar.enc differ diff --git a/tests/compute_engine/test__metadata.py b/tests/compute_engine/test__metadata.py index f0e432979..3b70c6d0a 100644 --- a/tests/compute_engine/test__metadata.py +++ b/tests/compute_engine/test__metadata.py @@ -325,6 +325,18 @@ def test_get_failure(): ) +def test_get_return_none_for_not_found_error(): + request = make_request("Metadata error", status=http_client.NOT_FOUND) + + assert _metadata.get(request, PATH, return_none_for_not_found_error=True) is None + + request.assert_called_once_with( + method="GET", + url=_metadata._METADATA_ROOT + PATH, + headers=_metadata._METADATA_HEADERS, + ) + + def test_get_failure_connection_failed(): request = make_request("") request.side_effect = exceptions.TransportError() @@ -371,6 +383,53 @@ def test_get_project_id(): assert project_id == project +def test_get_universe_domain_success(): + request = make_request( + "fake_universe_domain", headers={"content-type": "text/plain"} + ) + + universe_domain = _metadata.get_universe_domain(request) + + request.assert_called_once_with( + method="GET", + url=_metadata._METADATA_ROOT + "universe/universe_domain", + headers=_metadata._METADATA_HEADERS, + ) + assert universe_domain == "fake_universe_domain" + + +def test_get_universe_domain_not_found(): + # Test that if the universe domain endpoint returns 404 error, we should + # use googleapis.com as the universe domain + request = make_request("not found", status=http_client.NOT_FOUND) + + universe_domain = _metadata.get_universe_domain(request) + + request.assert_called_once_with( + method="GET", + url=_metadata._METADATA_ROOT + "universe/universe_domain", + headers=_metadata._METADATA_HEADERS, + ) + assert universe_domain == "googleapis.com" + + +def test_get_universe_domain_other_error(): + # Test that if the universe domain endpoint returns an error other than 404 + # we should throw the error + request = make_request("unauthorized", status=http_client.UNAUTHORIZED) + + with pytest.raises(exceptions.TransportError) as excinfo: + _metadata.get_universe_domain(request) + + assert excinfo.match(r"unauthorized") + + request.assert_called_once_with( + method="GET", + url=_metadata._METADATA_ROOT + "universe/universe_domain", + headers=_metadata._METADATA_HEADERS, + ) + + @mock.patch( "google.auth.metrics.token_request_access_token_mds", return_value=ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, diff --git a/tests/compute_engine/test_credentials.py b/tests/compute_engine/test_credentials.py index 507fea9fc..95f9b0efb 100644 --- a/tests/compute_engine/test_credentials.py +++ b/tests/compute_engine/test_credentials.py @@ -208,6 +208,26 @@ def test_token_usage_metrics(self): assert headers["authorization"] == "Bearer token" assert headers["x-goog-api-client"] == "cred-type/mds" + @mock.patch( + "google.auth.compute_engine._metadata.get_universe_domain", + return_value="fake_universe_domain", + ) + def test_universe_domain(self, get_universe_domain): + self.credentials._universe_domain_cached = False + self.credentials._universe_domain = "googleapis.com" + + # calling the universe_domain property should trigger a call to + # get_universe_domain to fetch the value. The value should be cached. + assert self.credentials.universe_domain == "fake_universe_domain" + assert self.credentials._universe_domain == "fake_universe_domain" + assert self.credentials._universe_domain_cached + get_universe_domain.assert_called_once() + + # calling the universe_domain property the second time should use the + # cached value instead of calling get_universe_domain + assert self.credentials.universe_domain == "fake_universe_domain" + get_universe_domain.assert_called_once() + class TestIDTokenCredentials(object): credentials = None