Skip to content

Commit

Permalink
feat: add clients initialization in validator
Browse files Browse the repository at this point in the history
  • Loading branch information
yanomaly committed Nov 13, 2024
1 parent 6532517 commit 8897656
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 11 deletions.
27 changes: 24 additions & 3 deletions libs/community/langchain_community/chat_models/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable
from langchain_core.utils import get_from_dict_or_env
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator

logger = logging.getLogger(__name__)

Expand All @@ -66,8 +67,10 @@ class ChatWriter(BaseChatModel):
)
"""

client: Any = Field(exclude=True) #: :meta private:
async_client: Any = Field(exclude=True) #: :meta private:
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
writer_api_key: Optional[SecretStr] = Field(default=None)
"""Writer API key."""
model_name: str = Field(default="palmyra-x-004", alias="model")
"""Model name to use."""
temperature: float = 0.7
Expand Down Expand Up @@ -106,6 +109,24 @@ def _default_params(self) -> Dict[str, Any]:
**self.model_kwargs,
}

@model_validator(mode="before")
def validate_environment(self, values: Dict) -> Any:
"""Validates that api key is passed and creates Writer clients."""
try:
from writerai import AsyncClient, Client
except ImportError as e:
raise ImportError(
"Could not import writerai python package. "
"Please install it with `pip install writerai`."
) from e

if not (values["client"] and values["async_client"]):
api_key = get_from_dict_or_env(values, "api_key", "WRITER_API_KEY")
values["client"] = Client(api_key=api_key)
values["async_client"] = AsyncClient(api_key=api_key)

return values

def _create_chat_result(self, response: Any) -> ChatResult:
generations = []
for choice in response.choices:
Expand Down
10 changes: 2 additions & 8 deletions libs/community/tests/unit_tests/chat_models/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(
self.choices = choices


@pytest.mark.requires("writer-sdk")
class TestChatWriterCustom:
"""Test case for ChatWriter"""

Expand All @@ -114,24 +115,16 @@ def test_writer_model_param(self) -> None:
test_cases: List[dict] = [
{
"model_name": "palmyra-x-004",
"client": MagicMock(),
"async_client": AsyncMock(),
},
{
"model": "palmyra-x-004",
"client": MagicMock(),
"async_client": AsyncMock(),
},
{
"model_name": "palmyra-x-004",
"client": MagicMock(),
"async_client": AsyncMock(),
},
{
"model": "palmyra-x-004",
"temperature": 0.5,
"client": MagicMock(),
"async_client": AsyncMock(),
},
]

Expand Down Expand Up @@ -423,6 +416,7 @@ class GetWeather(BaseModel):
assert response.tool_calls[0]["args"]["location"] == "London"


@pytest.mark.requires("writer-sdk")
class TestChatWriterStandart(ChatModelUnitTests):
"""Test case for ChatWriter that inherits from standard LangChain tests."""

Expand Down

0 comments on commit 8897656

Please sign in to comment.