Skip to content

Commit

Permalink
tests: add showcase tests for reading grpc/rest call metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
parthea committed Dec 17, 2024
1 parent fbaff28 commit a48fc54
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 13 deletions.
140 changes: 131 additions & 9 deletions tests/system/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import collections

import grpc
from unittest import mock
import os
import pytest

from typing import Sequence, Tuple

from google.api_core.client_options import ClientOptions # type: ignore
from google.showcase_v1beta1.services.echo.transports import EchoRestInterceptor

try:
from google.auth.aio import credentials as ga_credentials_async
Expand All @@ -38,13 +41,20 @@
import asyncio
from google.showcase import EchoAsyncClient
from google.showcase import IdentityAsyncClient

try:
from google.showcase_v1beta1.services.echo.transports import AsyncEchoRestTransport
from google.showcase_v1beta1.services.echo.transports import (
AsyncEchoRestTransport,
)

HAS_ASYNC_REST_ECHO_TRANSPORT = True
except:
HAS_ASYNC_REST_ECHO_TRANSPORT = False
try:
from google.showcase_v1beta1.services.identity.transports import AsyncIdentityRestTransport
from google.showcase_v1beta1.services.identity.transports import (
AsyncIdentityRestTransport,
)

HAS_ASYNC_REST_IDENTITY_TRANSPORT = True
except:
HAS_ASYNC_REST_IDENTITY_TRANSPORT = False
Expand Down Expand Up @@ -77,7 +87,9 @@ def async_echo(use_mtls, request, event_loop):
EchoAsyncClient,
use_mtls,
transport_name=transport,
channel_creator=aio.insecure_channel if request.param == "grpc_asyncio" else None,
channel_creator=aio.insecure_channel
if request.param == "grpc_asyncio"
else None,
credentials=async_anonymous_credentials(),
)

Expand All @@ -90,7 +102,9 @@ def async_identity(use_mtls, request, event_loop):
IdentityAsyncClient,
use_mtls,
transport_name=transport,
channel_creator=aio.insecure_channel if request.param == "grpc_asyncio" else None,
channel_creator=aio.insecure_channel
if request.param == "grpc_asyncio"
else None,
credentials=async_anonymous_credentials(),
)

Expand Down Expand Up @@ -237,7 +251,7 @@ def messaging(use_mtls, request):
return construct_client(MessagingClient, use_mtls, transport_name=request.param)


class MetadataClientInterceptor(
class MetadataClienGrpcInterceptor(
grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor,
Expand All @@ -246,14 +260,19 @@ class MetadataClientInterceptor(
def __init__(self, key, value):
self._key = key
self._value = value
self.request_metadata = []
self.response_metadata = []

def _add_metadata(self, client_call_details):
if client_call_details.metadata is not None:
client_call_details.metadata.append((self._key, self._value))
self.request_metadata = client_call_details.metadata

def intercept_unary_unary(self, continuation, client_call_details, request):
self._add_metadata(client_call_details)
response = continuation(client_call_details, request)
metadata = [(k, str(v)) for k, v in response.trailing_metadata()]
self.response_metadata = metadata
return response

def intercept_unary_stream(self, continuation, client_call_details, request):
Expand All @@ -276,12 +295,80 @@ def intercept_stream_stream(
return response_it


class MetadataClientGrpcAsyncInterceptor(
grpc.aio.UnaryUnaryClientInterceptor,
grpc.aio.UnaryStreamClientInterceptor,
grpc.aio.StreamUnaryClientInterceptor,
grpc.aio.StreamStreamClientInterceptor,
):
def __init__(self, key, value):
self._key = key
self._value = value
self.request_metadata = []
self.response_metadata = []

async def _add_metadata(self, client_call_details):
if client_call_details.metadata is not None:
client_call_details.metadata.append((self._key, self._value))
self.request_metadata = client_call_details.metadata

async def intercept_unary_unary(self, continuation, client_call_details, request):
await self._add_metadata(client_call_details)
response = await continuation(client_call_details, request)
metadata = [(k, str(v)) for k, v in await response.trailing_metadata()]
self.response_metadata = metadata
return response

async def intercept_unary_stream(self, continuation, client_call_details, request):
self._add_metadata(client_call_details)
response_it = continuation(client_call_details, request)
return response_it

async def intercept_stream_unary(
self, continuation, client_call_details, request_iterator
):
self._add_metadata(client_call_details)
response = continuation(client_call_details, request_iterator)
return response

async def intercept_stream_stream(
self, continuation, client_call_details, request_iterator
):
self._add_metadata(client_call_details)
response_it = continuation(client_call_details, request_iterator)
return response_it


class MetadataClienRestInterceptor(EchoRestInterceptor):
request_metadata: Sequence[Tuple[str, str]] = []
response_metadata: Sequence[Tuple[str, str]] = []

def pre_echo(self, request, metadata):
self.request_metadata = metadata
return request, metadata

def post_echo_with_metadata(self, request, metadata):
self.response_metadata = metadata
return request, metadata

def pre_expand(self, request, metadata):
self.request_metadata = metadata
return request, metadata

def post_expand_with_metadata(self, request, metadata):
self.response_metadata = metadata
return request, metadata


@pytest.fixture
def intercepted_echo(use_mtls):
def intercepted_echo_grpc(use_mtls):
# The interceptor adds 'showcase-trailer' client metadata. Showcase server
# echos any metadata with key 'showcase-trailer', so the same metadata
# should appear as trailing metadata in the response.
interceptor = MetadataClientInterceptor("showcase-trailer", "intercepted")
interceptor = MetadataClienGrpcInterceptor(
"showcase-trailer",
"intercepted",
)
host = "localhost:7469"
channel = (
grpc.secure_channel(host, ssl_credentials)
Expand All @@ -293,4 +380,39 @@ def intercepted_echo(use_mtls):
credentials=ga_credentials.AnonymousCredentials(),
channel=intercept_channel,
)
return EchoClient(transport=transport)
return EchoClient(transport=transport), interceptor


@pytest.fixture
def intercepted_echo_rest():
transport_name = "rest"
transport_cls = EchoClient.get_transport_class(transport_name)
interceptor = MetadataClienRestInterceptor()

# The custom host explicitly bypasses https.
transport = transport_cls(
credentials=ga_credentials.AnonymousCredentials(),
host="localhost:7469",
url_scheme="http",
interceptor=interceptor,
)
return EchoClient(transport=transport), interceptor


@pytest.fixture
def intercepted_echo_grpc_async():
# The interceptor adds 'showcase-trailer' client metadata. Showcase server
# echos any metadata with key 'showcase-trailer', so the same metadata
# should appear as trailing metadata in the response.
interceptor = MetadataClientGrpcAsyncInterceptor(
"showcase-trailer",
"intercepted",
)
host = "localhost:7469"
channel = grpc.aio.insecure_channel(host, interceptors=[interceptor])
# intercept_channel = grpc.aio.intercept_channel(channel, interceptor)
transport = EchoAsyncClient.get_transport_class("grpc_asyncio")(
credentials=ga_credentials.AnonymousCredentials(),
channel=channel,
)
return EchoAsyncClient(transport=transport), interceptor
12 changes: 8 additions & 4 deletions tests/system/test_grpc_interceptor_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
intercepted_metadata = (('showcase-trailer', 'intercepted'),)


def test_unary_stream(intercepted_echo):
def test_unary_stream(intercepted_echo_grpc):
client, interceptor = intercepted_echo_grpc
content = 'The hail in Wales falls mainly on the snails.'
responses = intercepted_echo.expand({
responses = client.expand({
'content': content,
})

Expand All @@ -36,13 +37,15 @@ def test_unary_stream(intercepted_echo):
for metadata in responses.trailing_metadata()
]
assert intercepted_metadata[0] in response_metadata
interceptor.response_metadata = response_metadata


def test_stream_stream(intercepted_echo):
def test_stream_stream(intercepted_echo_grpc):
client, interceptor = intercepted_echo_grpc
requests = []
requests.append(showcase.EchoRequest(content="hello"))
requests.append(showcase.EchoRequest(content="world!"))
responses = intercepted_echo.chat(iter(requests))
responses = client.chat(iter(requests))

contents = [response.content for response in responses]
assert contents == ['hello', 'world!']
Expand All @@ -52,3 +55,4 @@ def test_stream_stream(intercepted_echo):
for metadata in responses.trailing_metadata()
]
assert intercepted_metadata[0] in response_metadata
interceptor.response_metadata = response_metadata
77 changes: 77 additions & 0 deletions tests/system/test_response_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import pytest

from google import showcase


@pytest.mark.parametrize(
"transport,response_metadata",
[
("grpc", ("something1", "something_value1")),
("rest", ("X-Showcase-Request-Something1", "something_value1")),
],
)
def test_metadata_response_unary(
intercepted_echo_rest, intercepted_echo_grpc, transport, response_metadata
):
request_content = "The hail in Wales falls mainly on the snails."
request_metadata = ("something1", "something_value1")
if transport == "grpc":
client, interceptor = intercepted_echo_grpc
else:
client, interceptor = intercepted_echo_rest
response = client.echo(
request=showcase.EchoRequest(content=request_content),
metadata=(request_metadata,),
)
assert response.content == request_content
assert request_metadata in interceptor.request_metadata
assert response_metadata in interceptor.response_metadata


def test_metadata_response_rest_streams(intercepted_echo_rest):
request_content = "The hail in Wales falls mainly on the snails."
request_metadata = ("something2", "something_value2")
response_metadata = ("X-Showcase-Request-Something2", "something_value2")
client, interceptor = intercepted_echo_rest
client.expand(
{
"content": request_content,
},
metadata=(request_metadata,),
)

assert request_metadata in interceptor.request_metadata
assert response_metadata in interceptor.response_metadata


if os.environ.get("GAPIC_PYTHON_ASYNC", "true") == "true":

@pytest.mark.asyncio
async def test_metadata_response_grpc_unary_async(intercepted_echo_grpc_async):
request_content = "The hail in Wales falls mainly on the snails."
request_metadata = ("something3", "something_value3")
response_metadata = ("something3", "something_value3")

client, interceptor = intercepted_echo_grpc_async
response = await client.echo(
request=showcase.EchoRequest(content=request_content),
metadata=(("something3", "something_value3"),),
)
assert response.content == request_content
assert request_metadata in interceptor.request_metadata
assert response_metadata in interceptor.response_metadata

0 comments on commit a48fc54

Please sign in to comment.