From 594f3928d120dc9ef71f567837c276425abb5053 Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Thu, 19 Dec 2024 10:42:50 -0500 Subject: [PATCH] anthropic: less pydantic for client --- .../langchain_anthropic/chat_models.py | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index ba54802751f7d..6eb9dc4bca61b 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -1,6 +1,7 @@ import copy import re import warnings +from functools import cached_property from operator import itemgetter from typing import ( Any, @@ -68,11 +69,10 @@ BaseModel, ConfigDict, Field, - PrivateAttr, SecretStr, model_validator, ) -from typing_extensions import NotRequired, Self +from typing_extensions import NotRequired from langchain_anthropic.output_parsers import extract_tool_calls @@ -541,9 +541,6 @@ class Joke(BaseModel): populate_by_name=True, ) - _client: anthropic.Client = PrivateAttr(default=None) # type: ignore[assignment] - _async_client: anthropic.AsyncClient = PrivateAttr(default=None) # type: ignore[assignment] - model: str = Field(alias="model_name") """Model name to use.""" @@ -661,13 +658,11 @@ def build_extra(cls, values: Dict) -> Any: values = _build_model_kwargs(values, all_required_field_names) return values - @model_validator(mode="after") - def post_init(self) -> Self: - api_key = self.anthropic_api_key.get_secret_value() - api_url = self.anthropic_api_url + @cached_property + def _client_params(self) -> Dict[str, Any]: client_params: Dict[str, Any] = { - "api_key": api_key, - "base_url": api_url, + "api_key": self.anthropic_api_key.get_secret_value(), + "base_url": self.anthropic_api_url, "max_retries": self.max_retries, "default_headers": (self.default_headers or None), } @@ -677,9 +672,15 @@ def post_init(self) -> Self: if self.default_request_timeout is None or self.default_request_timeout > 0: client_params["timeout"] = self.default_request_timeout - self._client = anthropic.Client(**client_params) - self._async_client = anthropic.AsyncClient(**client_params) - return self + return client_params + + @cached_property + def _client(self) -> anthropic.Client: + return anthropic.Client(**self._client_params) + + @cached_property + def _async_client(self) -> anthropic.AsyncClient: + return anthropic.AsyncClient(**self._client_params) def _get_request_payload( self,