diff --git a/docs/docs/integrations/chat/writer.ipynb b/docs/docs/integrations/chat/writer.ipynb index a4a875b927cc9e..e0d73dff5f8157 100644 --- a/docs/docs/integrations/chat/writer.ipynb +++ b/docs/docs/integrations/chat/writer.ipynb @@ -78,9 +78,7 @@ "cell_type": "code", "id": "2113471c-75d7-45df-b784-d78da4ef7aba", "metadata": {}, - "source": [ - "%pip install -qU langchain-community writer-sdk" - ], + "source": "%pip install -qU langchain-community writer-sdk", "outputs": [], "execution_count": null }, @@ -102,13 +100,14 @@ }, "source": [ "from langchain_community.chat_models.writer import ChatWriter\n", + "from writerai import AsyncWriter, Writer\n", "\n", "llm = ChatWriter(\n", + " client=Writer(),\n", + " async_client=AsyncWriter(),\n", " model=\"palmyra-x-004\",\n", " temperature=0.7,\n", " max_tokens=1000,\n", - " # api_key=\"...\", # if you prefer to pass api key in directly instaed of using env vars\n", - " # base_url=\"...\",\n", " # other params...\n", ")" ], @@ -152,6 +151,31 @@ "outputs": [], "execution_count": null }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "## Streaming", + "id": "35b3a5b3dabef65" + }, + { + "metadata": {}, + "cell_type": "code", + "source": "ai_stream = llm.stream(messages)", + "id": "2725770182bf96dc", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "for chunk in ai_stream:\n", + " print(chunk.content, end=\"\")" + ], + "id": "a48410d9488162e3", + "outputs": [], + "execution_count": null + }, { "cell_type": "markdown", "id": "778f912a-66ea-4a5d-b3de-6c7db4baba26", diff --git a/libs/community/langchain_community/chat_models/writer.py b/libs/community/langchain_community/chat_models/writer.py index 7e0040f6aef593..cf38744505ff39 100644 --- a/libs/community/langchain_community/chat_models/writer.py +++ b/libs/community/langchain_community/chat_models/writer.py @@ -57,12 +57,20 @@ class ChatWriter(BaseChatModel): .. code-block:: python from langchain_community.chat_models import ChatWriter + from writerai import Writer, AsyncWriter - chat = ChatWriter(model="palmyra-x-004") + client = Writer() + async_client = AsyncWriter() + + chat = ChatWriter( + client=client, + async_client=async_client, + model="palmyra-x-004" + ) """ - client: Any = Field(default=None, exclude=True) #: :meta private: - async_client: Any = Field(default=None, exclude=True) #: :meta private: + client: Any = Field(exclude=True) #: :meta private: + async_client: Any = Field(exclude=True) #: :meta private: model_name: str = Field(default="palmyra-x-004", alias="model") """Model name to use.""" temperature: float = 0.7 diff --git a/libs/community/tests/unit_tests/chat_models/test_writer.py b/libs/community/tests/unit_tests/chat_models/test_writer.py index 55b92b3ccaf456..b9d93041700eea 100644 --- a/libs/community/tests/unit_tests/chat_models/test_writer.py +++ b/libs/community/tests/unit_tests/chat_models/test_writer.py @@ -1,8 +1,6 @@ -"""Unit tests for Writer chat model integration.""" - import json from typing import Any, Dict, List, Optional -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock import pytest from langchain_core.callbacks.manager import CallbackManager @@ -11,6 +9,8 @@ from langchain_community.chat_models.writer import ChatWriter from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler +"""Classes for mocking Writer responses.""" + class ChoiceDelta: def __init__(self, content: str): @@ -104,16 +104,33 @@ def __init__( self.choices = choices +"""Unit tests for Writer chat model integration.""" + + class TestChatWriter: def test_writer_model_param(self) -> None: """Test different ways to initialize the chat model.""" test_cases: List[dict] = [ - {"model_name": "palmyra-x-004"}, - {"model": "palmyra-x-004"}, - {"model_name": "palmyra-x-004"}, + { + "model_name": "palmyra-x-004", + "client": MagicMock(), + "async_client": AsyncMock(), + }, + { + "model": "palmyra-x-004", + "client": MagicMock(), + "async_client": AsyncMock(), + }, + { + "model_name": "palmyra-x-004", + "client": MagicMock(), + "async_client": AsyncMock(), + }, { "model": "palmyra-x-004", "temperature": 0.5, + "client": MagicMock(), + "async_client": AsyncMock(), }, ] @@ -183,7 +200,6 @@ def test_convert_writer_to_langchain_with_tool_calls(self) -> None: @pytest.fixture(autouse=True) def mock_unstreaming_completion(self) -> Chat: """Fixture providing a mock API response.""" - return Chat( id="chat-12345", object="chat.completion", @@ -270,29 +286,29 @@ def test_sync_completion( self, mock_unstreaming_completion: List[ChatCompletionChunk] ) -> None: """Test basic chat completion with mocked response.""" - chat = ChatWriter() mock_client = MagicMock() mock_client.chat.chat.return_value = mock_unstreaming_completion - with patch.object(chat, "client", mock_client): - message = HumanMessage(content="Hi there!") - response = chat.invoke([message]) - assert isinstance(response, AIMessage) - assert response.content == "Hello! How can I help you?" + chat = ChatWriter(client=mock_client, async_client=AsyncMock()) + + message = HumanMessage(content="Hi there!") + response = chat.invoke([message]) + assert isinstance(response, AIMessage) + assert response.content == "Hello! How can I help you?" async def test_async_completion( self, mock_unstreaming_completion: List[ChatCompletionChunk] ) -> None: """Test async chat completion with mocked response.""" - chat = ChatWriter() mock_client = AsyncMock() mock_client.chat.chat.return_value = mock_unstreaming_completion - with patch.object(chat, "async_client", mock_client): - message = HumanMessage(content="Hi there!") - response = await chat.ainvoke([message]) - assert isinstance(response, AIMessage) - assert response.content == "Hello! How can I help you?" + chat = ChatWriter(client=MagicMock(), async_client=mock_client) + + message = HumanMessage(content="Hi there!") + response = await chat.ainvoke([message]) + assert isinstance(response, AIMessage) + assert response.content == "Hello! How can I help you?" def test_sync_streaming( self, mock_streaming_chunks: List[ChatCompletionChunk] @@ -301,27 +317,25 @@ def test_sync_streaming( callback_handler = FakeCallbackHandler() callback_manager = CallbackManager([callback_handler]) - chat = ChatWriter( - callback_manager=callback_manager, - max_tokens=10, - ) - mock_client = MagicMock() mock_response = MagicMock() mock_response.__iter__.return_value = mock_streaming_chunks mock_client.chat.chat.return_value = mock_response - with patch.object(chat, "client", mock_client): - message = HumanMessage(content="Hi") - response = chat.stream([message]) - - response_message = "" - - for chunk in response: - response_message += str(chunk.content) + chat = ChatWriter( + client=mock_client, + async_client=AsyncMock(), + callback_manager=callback_manager, + max_tokens=10, + ) - assert callback_handler.llm_streams > 0 - assert response_message == "Hello! How can I help you?" + message = HumanMessage(content="Hi") + response = chat.stream([message]) + response_message = "" + for chunk in response: + response_message += str(chunk.content) + assert callback_handler.llm_streams > 0 + assert response_message == "Hello! How can I help you?" async def test_async_streaming( self, mock_streaming_chunks: List[ChatCompletionChunk] @@ -330,27 +344,25 @@ async def test_async_streaming( callback_handler = FakeCallbackHandler() callback_manager = CallbackManager([callback_handler]) - chat = ChatWriter( - callback_manager=callback_manager, - max_tokens=10, - ) - mock_client = AsyncMock() mock_response = AsyncMock() mock_response.__aiter__.return_value = mock_streaming_chunks mock_client.chat.chat.return_value = mock_response - with patch.object(chat, "async_client", mock_client): - message = HumanMessage(content="Hi") - response = chat.astream([message]) - - response_message = "" - - async for chunk in response: - response_message += str(chunk.content) + chat = ChatWriter( + client=MagicMock(), + async_client=mock_client, + callback_manager=callback_manager, + max_tokens=10, + ) - assert callback_handler.llm_streams > 0 - assert response_message == "Hello! How can I help you?" + message = HumanMessage(content="Hi") + response = chat.astream([message]) + response_message = "" + async for chunk in response: + response_message += str(chunk.content) + assert callback_handler.llm_streams > 0 + assert response_message == "Hello! How can I help you?" def test_sync_tool_calling( self, mock_tool_call_choice_response: Dict[str, Any] @@ -366,7 +378,7 @@ class GetWeather(BaseModel): mock_client = MagicMock() mock_client.chat.chat.return_value = mock_tool_call_choice_response - chat = ChatWriter(client=mock_client) + chat = ChatWriter(client=mock_client, async_client=AsyncMock()) chat_with_tools = chat.bind_tools( tools=[GetWeather], @@ -393,7 +405,7 @@ class GetWeather(BaseModel): mock_client = AsyncMock() mock_client.chat.chat.return_value = mock_tool_call_choice_response - chat = ChatWriter(async_client=mock_client) + chat = ChatWriter(client=MagicMock(), async_client=mock_client) chat_with_tools = chat.bind_tools( tools=[GetWeather],