Skip to content

Commit

Permalink
also patch async functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Steve Bunting committed Jul 9, 2024
1 parent 5dd1234 commit f827447
Showing 1 changed file with 62 additions and 0 deletions.
62 changes: 62 additions & 0 deletions src/supergood/vendors/httpx.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def patch(cache_request, cache_response):

_original_handle_async_request = httpx.AsyncHTTPTransport.handle_async_request
_original_response_aread = httpx.Response.aread
_original_response_aiter_bytes = httpx.Response.aiter_bytes
_original_response_aiter_lines = httpx.Response.aiter_lines

def _wrap_handle_request(
httpTransport: httpx.HTTPTransport, request: httpx.Request
Expand Down Expand Up @@ -96,6 +98,25 @@ def _wrap_iter_lines(response: httpx.Response):
status_text,
)

async def _wrap_aiter_lines(response: httpx.Response):
request_id = getattr(response, REQUEST_ID_KEY)
status_text = response.extensions.get("reason_phrase", None)
if status_text:
status_text = status_text.decode("utf-8")
response_parts = []
async for line in _original_response_aiter_lines(response):
if line:
response_parts.append(line)
yield line
response_body = "\n".join(response_parts)
cache_response(
request_id,
response_body,
response.headers,
response.status_code,
status_text,
)

def _parse_sse(chunk: str):
data = []
event = None
Expand Down Expand Up @@ -170,10 +191,51 @@ def _wrap_iter_bytes(response: httpx.Response, chunk_size: Optional[int] = None)
status_text,
)

async def _wrap_aiter_bytes(
response: httpx.Response, chunk_size: Optional[int] = None
):
request_id = getattr(response, REQUEST_ID_KEY)
status_text = response.extensions.get("reason_phrase", None)
if status_text:
status_text = status_text.decode("utf-8")
response_chunks = []
data = b""
async for chunk in _original_response_aiter_bytes(response, chunk_size):
for line in chunk.splitlines(keepends=True):
data += line
if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
# assume it's an SSE response
decoded = data.decode("utf-8")
try:
sse = _parse_sse(decoded)
if sse:
response_chunks.append(sse)
except Exception:
# failing SSE parsing, just return as string
response_chunks.append(decoded)
data = b""
yield chunk
if len(data):
# must have been an invalid chunk, append it anyway
response_chunks.append(data)
try:
response_body = json.dumps(response_chunks, cls=DataclassesJSONEncoder)
except Exception:
response_body = str(response_chunks)
cache_response(
request_id,
response_body,
response.headers,
response.status_code,
status_text,
)

httpx.HTTPTransport.handle_request = _wrap_handle_request
httpx.Response.read = _wrap_response_read
httpx.Response.iter_lines = _wrap_iter_lines
httpx.Response.iter_bytes = _wrap_iter_bytes

httpx.AsyncHTTPTransport.handle_async_request = _wrap_handle_async_request
httpx.Response.aread = _wrap_response_aread
httpx.Response.aiter_lines = _wrap_aiter_lines
httpx.Response.aiter_bytes = _wrap_aiter_bytes

0 comments on commit f827447

Please sign in to comment.