From af11fbfbf6ae3fae9a2fd0cf6e51b4e8f38c4886 Mon Sep 17 00:00:00 2001 From: "Friso H. Kingma" Date: Wed, 4 Sep 2024 15:26:48 +0200 Subject: [PATCH] langchain_openai: Make sure the response from the async client in the astream method of ChatOpenAI is properly awaited in case of "include_response_headers=True" (#26031) - **Description:** This is a **one line change**. the `self.async_client.with_raw_response.create(**payload)` call is not properly awaited within the `_astream` method. In `_agenerate` this is done already, but likely forgotten in the other method. - **Issue:** Not applicable - **Dependencies:** No dependencies required. (If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.) --------- Co-authored-by: Chester Curme --- .../langchain_openai/chat_models/base.py | 2 +- .../chat_models/test_base.py | 36 +++++++++++++++++-- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 66b7e75edecf6..b67bcafe60d53 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -757,7 +757,7 @@ async def _astream( ) return if self.include_response_headers: - raw_response = self.async_client.with_raw_response.create(**payload) + raw_response = await self.async_client.with_raw_response.create(**payload) response = raw_response.parse() base_generation_info = {"headers": dict(raw_response.headers)} else: diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 2e235421f9348..551588976c127 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -686,15 +686,47 @@ def test_openai_proxy() -> None: assert proxy.port == 8080 -def test_openai_response_headers_invoke() -> None: +def test_openai_response_headers() -> None: """Test ChatOpenAI response headers.""" chat_openai = ChatOpenAI(include_response_headers=True) - result = chat_openai.invoke("I'm Pickle Rick") + query = "I'm Pickle Rick" + result = chat_openai.invoke(query, max_tokens=10) headers = result.response_metadata["headers"] assert headers assert isinstance(headers, dict) assert "content-type" in headers + # Stream + full: Optional[BaseMessageChunk] = None + for chunk in chat_openai.stream(query, max_tokens=10): + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessage) + headers = full.response_metadata["headers"] + assert headers + assert isinstance(headers, dict) + assert "content-type" in headers + + +async def test_openai_response_headers_async() -> None: + """Test ChatOpenAI response headers.""" + chat_openai = ChatOpenAI(include_response_headers=True) + query = "I'm Pickle Rick" + result = await chat_openai.ainvoke(query, max_tokens=10) + headers = result.response_metadata["headers"] + assert headers + assert isinstance(headers, dict) + assert "content-type" in headers + + # Stream + full: Optional[BaseMessageChunk] = None + async for chunk in chat_openai.astream(query, max_tokens=10): + full = chunk if full is None else full + chunk + assert isinstance(full, AIMessage) + headers = full.response_metadata["headers"] + assert headers + assert isinstance(headers, dict) + assert "content-type" in headers + def test_image_token_counting_jpeg() -> None: model = ChatOpenAI(model="gpt-4o", temperature=0)