Skip to content

Commit

Permalink
feat: Add REST Interceptors to support reading metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
parthea committed Dec 17, 2024
1 parent 1fb1c76 commit fbaff28
Show file tree
Hide file tree
Showing 13 changed files with 996 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,20 @@ class {{ async_method_name_prefix }}{{ service.name }}RestInterceptor:
"""
return response

{% if not method.server_streaming %}
{{ async_prefix }}def post_{{ method.name|snake_case }}_with_metadata(self, response: {{method.output.ident}}, {{ client_method_metadata_argument() }}) -> Tuple[{{method.output.ident}}, {{ client_method_metadata_type() }}]:
{% else %}
{{ async_prefix }}def post_{{ method.name|snake_case }}_with_metadata(self, response: rest_streaming{{ async_suffix }}.{{ async_method_name_prefix }}ResponseIterator, {{ client_method_metadata_argument() }}) -> Tuple[rest_streaming{{ async_suffix }}.{{ async_method_name_prefix }}ResponseIterator, {{ client_method_metadata_type() }}]:
{% endif %}
"""Post-rpc interceptor for {{ method.name|snake_case }}

Override in a subclass to either manipulate or read, either the response
or metadata after it is returned by the {{ service.name }} server but before
it is returned to user code.
"""
return response, metadata

{% endif %}{# not method.void #}
{% endfor %}

{% for name, signature in api.mixin_api_signatures.items() %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ class {{service.name}}RestTransport(_Base{{ service.name }}RestTransport):
{% endif %}{# method.lro #}
{#- TODO(https://github.com/googleapis/gapic-generator-python/issues/2274): Add debug log before intercepting a request #}
resp = self._interceptor.post_{{ method.name|snake_case }}(resp)
response_metadata = [(k, str(v)) for k, v in response.headers.items()]
resp, _ = self._interceptor.post_{{ method.name|snake_case }}_with_metadata(resp, response_metadata)
{# TODO(https://github.com/googleapis/gapic-generator-python/issues/2279): Add logging support for rest streaming. #}
{% if not method.server_streaming %}
if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor(logging.DEBUG): # pragma: NO COVER
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ class Async{{service.name}}RestTransport(_Base{{ service.name }}RestTransport):
json_format.Parse(content, pb_resp, ignore_unknown_fields=True)
{% endif %}{# if method.server_streaming #}
resp = await self._interceptor.post_{{ method.name|snake_case }}(resp)
response_metadata = [(k, str(v)) for k, v in response.headers.items()]
resp, _ = await self._interceptor.post_{{ method.name|snake_case }}_with_metadata(resp, response_metadata)
{# TODO(https://github.com/googleapis/gapic-generator-python/issues/2279): Add logging support for rest streaming. #}
{% if not method.server_streaming %}
if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor(logging.DEBUG): # pragma: NO COVER
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2218,11 +2218,13 @@ def test_initialize_client_w_{{transport_name}}():
{% endif %}
{% if not method.void %}
mock.patch.object(transports.{{async_method_prefix}}{{ service.name }}RestInterceptor, "post_{{method.name|snake_case}}") as post, \
mock.patch.object(transports.{{async_method_prefix}}{{ service.name }}RestInterceptor, "post_{{method.name|snake_case}}_with_metadata") as post_with_metadata, \
{% endif %}
mock.patch.object(transports.{{async_method_prefix}}{{ service.name }}RestInterceptor, "pre_{{ method.name|snake_case }}") as pre:
pre.assert_not_called()
{% if not method.void %}
post.assert_not_called()
post_with_metadata.assert_not_called()
{% endif %}
{% if method.input.ident.is_proto_plus_type %}
pb_message = {{ method.input.ident }}.pb({{ method.input.ident }}())
Expand Down Expand Up @@ -2265,13 +2267,15 @@ def test_initialize_client_w_{{transport_name}}():
pre.return_value = request, metadata
{% if not method.void %}
post.return_value = {{ method.output.ident }}()
post_with_metadata.return_value = {{ method.output.ident }}(), metadata
{% endif %}

{{await_prefix}}client.{{ method_name }}(request, metadata=[("key", "val"), ("cephalopod", "squid"),])

pre.assert_called_once()
{% if not method.void %}
post.assert_called_once()
post_with_metadata.assert_called_once()
{% endif %}
{% endif %}{# end 'grpc' in transport #}
{% endmacro%}{# inteceptor_class_test #}

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,15 @@ def post_generate_access_token(self, response: common.GenerateAccessTokenRespons
"""
return response

def post_generate_access_token_with_metadata(self, response: common.GenerateAccessTokenResponse, metadata: Sequence[Tuple[str, Union[str, bytes]]]) -> Tuple[common.GenerateAccessTokenResponse, Sequence[Tuple[str, Union[str, bytes]]]]:
"""Post-rpc interceptor for generate_access_token
Override in a subclass to either manipulate or read, either the response
or metadata after it is returned by the IAMCredentials server but before
it is returned to user code.
"""
return response, metadata

def pre_generate_id_token(self, request: common.GenerateIdTokenRequest, metadata: Sequence[Tuple[str, Union[str, bytes]]]) -> Tuple[common.GenerateIdTokenRequest, Sequence[Tuple[str, Union[str, bytes]]]]:
"""Pre-rpc interceptor for generate_id_token
Expand All @@ -144,6 +153,15 @@ def post_generate_id_token(self, response: common.GenerateIdTokenResponse) -> co
"""
return response

def post_generate_id_token_with_metadata(self, response: common.GenerateIdTokenResponse, metadata: Sequence[Tuple[str, Union[str, bytes]]]) -> Tuple[common.GenerateIdTokenResponse, Sequence[Tuple[str, Union[str, bytes]]]]:
"""Post-rpc interceptor for generate_id_token
Override in a subclass to either manipulate or read, either the response
or metadata after it is returned by the IAMCredentials server but before
it is returned to user code.
"""
return response, metadata

def pre_sign_blob(self, request: common.SignBlobRequest, metadata: Sequence[Tuple[str, Union[str, bytes]]]) -> Tuple[common.SignBlobRequest, Sequence[Tuple[str, Union[str, bytes]]]]:
"""Pre-rpc interceptor for sign_blob
Expand All @@ -161,6 +179,15 @@ def post_sign_blob(self, response: common.SignBlobResponse) -> common.SignBlobRe
"""
return response

def post_sign_blob_with_metadata(self, response: common.SignBlobResponse, metadata: Sequence[Tuple[str, Union[str, bytes]]]) -> Tuple[common.SignBlobResponse, Sequence[Tuple[str, Union[str, bytes]]]]:
"""Post-rpc interceptor for sign_blob
Override in a subclass to either manipulate or read, either the response
or metadata after it is returned by the IAMCredentials server but before
it is returned to user code.
"""
return response, metadata

def pre_sign_jwt(self, request: common.SignJwtRequest, metadata: Sequence[Tuple[str, Union[str, bytes]]]) -> Tuple[common.SignJwtRequest, Sequence[Tuple[str, Union[str, bytes]]]]:
"""Pre-rpc interceptor for sign_jwt
Expand All @@ -178,6 +205,15 @@ def post_sign_jwt(self, response: common.SignJwtResponse) -> common.SignJwtRespo
"""
return response

def post_sign_jwt_with_metadata(self, response: common.SignJwtResponse, metadata: Sequence[Tuple[str, Union[str, bytes]]]) -> Tuple[common.SignJwtResponse, Sequence[Tuple[str, Union[str, bytes]]]]:
"""Post-rpc interceptor for sign_jwt
Override in a subclass to either manipulate or read, either the response
or metadata after it is returned by the IAMCredentials server but before
it is returned to user code.
"""
return response, metadata


@dataclasses.dataclass
class IAMCredentialsRestStub:
Expand Down Expand Up @@ -375,6 +411,8 @@ def __call__(self,
json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True)

resp = self._interceptor.post_generate_access_token(resp)
response_metadata = [(k, str(v)) for k, v in response.headers.items()]
resp, _ = self._interceptor.post_generate_access_token_with_metadata(resp, response_metadata)
if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor(logging.DEBUG): # pragma: NO COVER
try:
response_payload = common.GenerateAccessTokenResponse.to_json(response)
Expand Down Expand Up @@ -495,6 +533,8 @@ def __call__(self,
json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True)

resp = self._interceptor.post_generate_id_token(resp)
response_metadata = [(k, str(v)) for k, v in response.headers.items()]
resp, _ = self._interceptor.post_generate_id_token_with_metadata(resp, response_metadata)
if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor(logging.DEBUG): # pragma: NO COVER
try:
response_payload = common.GenerateIdTokenResponse.to_json(response)
Expand Down Expand Up @@ -615,6 +655,8 @@ def __call__(self,
json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True)

resp = self._interceptor.post_sign_blob(resp)
response_metadata = [(k, str(v)) for k, v in response.headers.items()]
resp, _ = self._interceptor.post_sign_blob_with_metadata(resp, response_metadata)
if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor(logging.DEBUG): # pragma: NO COVER
try:
response_payload = common.SignBlobResponse.to_json(response)
Expand Down Expand Up @@ -735,6 +777,8 @@ def __call__(self,
json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True)

resp = self._interceptor.post_sign_jwt(resp)
response_metadata = [(k, str(v)) for k, v in response.headers.items()]
resp, _ = self._interceptor.post_sign_jwt_with_metadata(resp, response_metadata)
if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor(logging.DEBUG): # pragma: NO COVER
try:
response_payload = common.SignJwtResponse.to_json(response)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3162,9 +3162,11 @@ def test_generate_access_token_rest_interceptors(null_interceptor):
with mock.patch.object(type(client.transport._session), "request") as req, \
mock.patch.object(path_template, "transcode") as transcode, \
mock.patch.object(transports.IAMCredentialsRestInterceptor, "post_generate_access_token") as post, \
mock.patch.object(transports.IAMCredentialsRestInterceptor, "post_generate_access_token_with_metadata") as post_with_metadata, \
mock.patch.object(transports.IAMCredentialsRestInterceptor, "pre_generate_access_token") as pre:
pre.assert_not_called()
post.assert_not_called()
post_with_metadata.assert_not_called()
pb_message = common.GenerateAccessTokenRequest.pb(common.GenerateAccessTokenRequest())
transcode.return_value = {
"method": "post",
Expand All @@ -3186,11 +3188,13 @@ def test_generate_access_token_rest_interceptors(null_interceptor):
]
pre.return_value = request, metadata
post.return_value = common.GenerateAccessTokenResponse()
post_with_metadata.return_value = common.GenerateAccessTokenResponse(), metadata

client.generate_access_token(request, metadata=[("key", "val"), ("cephalopod", "squid"),])

pre.assert_called_once()
post.assert_called_once()
post_with_metadata.assert_called_once()


def test_generate_id_token_rest_bad_request(request_type=common.GenerateIdTokenRequest):
Expand Down Expand Up @@ -3264,9 +3268,11 @@ def test_generate_id_token_rest_interceptors(null_interceptor):
with mock.patch.object(type(client.transport._session), "request") as req, \
mock.patch.object(path_template, "transcode") as transcode, \
mock.patch.object(transports.IAMCredentialsRestInterceptor, "post_generate_id_token") as post, \
mock.patch.object(transports.IAMCredentialsRestInterceptor, "post_generate_id_token_with_metadata") as post_with_metadata, \
mock.patch.object(transports.IAMCredentialsRestInterceptor, "pre_generate_id_token") as pre:
pre.assert_not_called()
post.assert_not_called()
post_with_metadata.assert_not_called()
pb_message = common.GenerateIdTokenRequest.pb(common.GenerateIdTokenRequest())
transcode.return_value = {
"method": "post",
Expand All @@ -3288,11 +3294,13 @@ def test_generate_id_token_rest_interceptors(null_interceptor):
]
pre.return_value = request, metadata
post.return_value = common.GenerateIdTokenResponse()
post_with_metadata.return_value = common.GenerateIdTokenResponse(), metadata

client.generate_id_token(request, metadata=[("key", "val"), ("cephalopod", "squid"),])

pre.assert_called_once()
post.assert_called_once()
post_with_metadata.assert_called_once()


def test_sign_blob_rest_bad_request(request_type=common.SignBlobRequest):
Expand Down Expand Up @@ -3368,9 +3376,11 @@ def test_sign_blob_rest_interceptors(null_interceptor):
with mock.patch.object(type(client.transport._session), "request") as req, \
mock.patch.object(path_template, "transcode") as transcode, \
mock.patch.object(transports.IAMCredentialsRestInterceptor, "post_sign_blob") as post, \
mock.patch.object(transports.IAMCredentialsRestInterceptor, "post_sign_blob_with_metadata") as post_with_metadata, \
mock.patch.object(transports.IAMCredentialsRestInterceptor, "pre_sign_blob") as pre:
pre.assert_not_called()
post.assert_not_called()
post_with_metadata.assert_not_called()
pb_message = common.SignBlobRequest.pb(common.SignBlobRequest())
transcode.return_value = {
"method": "post",
Expand All @@ -3392,11 +3402,13 @@ def test_sign_blob_rest_interceptors(null_interceptor):
]
pre.return_value = request, metadata
post.return_value = common.SignBlobResponse()
post_with_metadata.return_value = common.SignBlobResponse(), metadata

client.sign_blob(request, metadata=[("key", "val"), ("cephalopod", "squid"),])

pre.assert_called_once()
post.assert_called_once()
post_with_metadata.assert_called_once()


def test_sign_jwt_rest_bad_request(request_type=common.SignJwtRequest):
Expand Down Expand Up @@ -3472,9 +3484,11 @@ def test_sign_jwt_rest_interceptors(null_interceptor):
with mock.patch.object(type(client.transport._session), "request") as req, \
mock.patch.object(path_template, "transcode") as transcode, \
mock.patch.object(transports.IAMCredentialsRestInterceptor, "post_sign_jwt") as post, \
mock.patch.object(transports.IAMCredentialsRestInterceptor, "post_sign_jwt_with_metadata") as post_with_metadata, \
mock.patch.object(transports.IAMCredentialsRestInterceptor, "pre_sign_jwt") as pre:
pre.assert_not_called()
post.assert_not_called()
post_with_metadata.assert_not_called()
pb_message = common.SignJwtRequest.pb(common.SignJwtRequest())
transcode.return_value = {
"method": "post",
Expand All @@ -3496,11 +3510,13 @@ def test_sign_jwt_rest_interceptors(null_interceptor):
]
pre.return_value = request, metadata
post.return_value = common.SignJwtResponse()
post_with_metadata.return_value = common.SignJwtResponse(), metadata

client.sign_jwt(request, metadata=[("key", "val"), ("cephalopod", "squid"),])

pre.assert_called_once()
post.assert_called_once()
post_with_metadata.assert_called_once()

def test_initialize_client_w_rest():
client = IAMCredentialsClient(
Expand Down
Loading

0 comments on commit fbaff28

Please sign in to comment.