From 559d8a4d1385cd9e1a6b1096b5efe15a1abdb021 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Tue, 3 Sep 2024 17:41:28 -0700 Subject: [PATCH 1/5] fireworks[major]: switch to pydantic v2 --- .../langchain_fireworks/chat_models.py | 58 ++++++++++--------- .../langchain_fireworks/embeddings.py | 28 +++++---- .../fireworks/langchain_fireworks/llms.py | 49 ++++++++-------- .../integration_tests/test_chat_models.py | 2 +- .../fireworks/tests/unit_tests/test_llms.py | 2 +- 5 files changed, 75 insertions(+), 64 deletions(-) diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index 7b489cd5d0086..c5e6d9c42fd2e 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -68,12 +68,6 @@ parse_tool_call, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import ( - BaseModel, - Field, - SecretStr, - root_validator, -) from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool from langchain_core.utils import ( @@ -85,6 +79,14 @@ ) from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env +from pydantic import ( + BaseModel, + ConfigDict, + Field, + SecretStr, + model_validator, +) +from typing_extensions import Self logger = logging.getLogger(__name__) @@ -354,13 +356,13 @@ def is_lc_serializable(cls) -> bool: max_retries: Optional[int] = None """Maximum number of retries to make when generating.""" - class Config: - """Configuration for this pydantic object.""" - - allow_population_by_field_name = True + model_config = ConfigDict( + populate_by_name=True, + ) - @root_validator(pre=True) - def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: + @model_validator(mode="before") + @classmethod + def build_extra(cls, values: Dict[str, Any]) -> Any: """Build extra kwargs from additional params that were passed in.""" all_required_field_names = get_pydantic_field_names(cls) extra = values.get("model_kwargs", {}) @@ -369,32 +371,32 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: ) return values - @root_validator(pre=False, skip_on_failure=True) - def validate_environment(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_environment(self) -> Self: """Validate that api key and python package exists in environment.""" - if values["n"] < 1: + if self.n < 1: raise ValueError("n must be at least 1.") - if values["n"] > 1 and values["streaming"]: + if self.n > 1 and self.streaming: raise ValueError("n must be 1 when streaming.") client_params = { "api_key": ( - values["fireworks_api_key"].get_secret_value() - if values["fireworks_api_key"] + self.fireworks_api_key.get_secret_value() + if self.fireworks_api_key else None ), - "base_url": values["fireworks_api_base"], - "timeout": values["request_timeout"], + "base_url": self.fireworks_api_base, + "timeout": self.request_timeout, } - if not values.get("client"): - values["client"] = Fireworks(**client_params).chat.completions - if not values.get("async_client"): - values["async_client"] = AsyncFireworks(**client_params).chat.completions - if values["max_retries"]: - values["client"]._max_retries = values["max_retries"] - values["async_client"]._max_retries = values["max_retries"] - return values + if not (self.client or None): + self.client = Fireworks(**client_params).chat.completions + if not (self.async_client or None): + self.async_client = AsyncFireworks(**client_params).chat.completions + if self.max_retries: + self.client._max_retries = self.max_retries + self.async_client._max_retries = self.max_retries + return self @property def _default_params(self) -> Dict[str, Any]: diff --git a/libs/partners/fireworks/langchain_fireworks/embeddings.py b/libs/partners/fireworks/langchain_fireworks/embeddings.py index 719bb34a93540..8fd67f116cab3 100644 --- a/libs/partners/fireworks/langchain_fireworks/embeddings.py +++ b/libs/partners/fireworks/langchain_fireworks/embeddings.py @@ -1,9 +1,12 @@ -from typing import Any, Dict, List +from typing import List from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.utils import secret_from_env -from openai import OpenAI # type: ignore +from openai import OpenAI +from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator +from typing_extensions import Self + +# type: ignore class FireworksEmbeddings(BaseModel, Embeddings): @@ -65,7 +68,7 @@ class FireworksEmbeddings(BaseModel, Embeddings): [-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915] """ - _client: OpenAI = Field(default=None) + client: OpenAI = Field(default=None, exclude=True) #: :meta private: fireworks_api_key: SecretStr = Field( alias="api_key", default_factory=secret_from_env( @@ -79,20 +82,25 @@ class FireworksEmbeddings(BaseModel, Embeddings): """ model: str = "nomic-ai/nomic-embed-text-v1.5" - @root_validator(pre=False, skip_on_failure=True) - def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]: + model_config = ConfigDict( + populate_by_name=True, + arbitrary_types_allowed=True, + ) + + @model_validator(mode="after") + def validate_environment(self) -> Self: """Validate environment variables.""" - values["_client"] = OpenAI( - api_key=values["fireworks_api_key"].get_secret_value(), + self.client = OpenAI( + api_key=self.fireworks_api_key.get_secret_value(), base_url="https://api.fireworks.ai/inference/v1", ) - return values + return self def embed_documents(self, texts: List[str]) -> List[List[float]]: """Embed search docs.""" return [ i.embedding - for i in self._client.embeddings.create(input=texts, model=self.model).data + for i in self.client.embeddings.create(input=texts, model=self.model).data ] def embed_query(self, text: str) -> List[float]: diff --git a/libs/partners/fireworks/langchain_fireworks/llms.py b/libs/partners/fireworks/langchain_fireworks/llms.py index 747c59ecadd43..3189483c914a5 100644 --- a/libs/partners/fireworks/langchain_fireworks/llms.py +++ b/libs/partners/fireworks/langchain_fireworks/llms.py @@ -10,13 +10,9 @@ CallbackManagerForLLMRun, ) from langchain_core.language_models.llms import LLM -from langchain_core.pydantic_v1 import Field, SecretStr, root_validator -from langchain_core.utils import ( - convert_to_secret_str, - get_from_dict_or_env, - get_pydantic_field_names, -) -from langchain_core.utils.utils import build_extra_kwargs +from langchain_core.utils import get_pydantic_field_names +from langchain_core.utils.utils import build_extra_kwargs, secret_from_env +from pydantic import ConfigDict, Field, SecretStr, model_validator from langchain_fireworks.version import __version__ @@ -39,8 +35,21 @@ class Fireworks(LLM): base_url: str = "https://api.fireworks.ai/inference/v1/completions" """Base inference API URL.""" - fireworks_api_key: SecretStr = Field(default=None, alias="api_key") - """Fireworks AI API key. Get it here: https://fireworks.ai""" + fireworks_api_key: SecretStr = Field( + alias="api_key", + default_factory=secret_from_env( + "FIREWORKS_API_KEY", + error_message=( + "You must specify an api key. " + "You can pass it an argument as `api_key=...` or " + "set the environment variable `FIREWORKS_API_KEY`." + ), + ), + ) + """Fireworks API key. + + Automatically read from env variable `FIREWORKS_API_KEY` if not provided. + """ model: str """Model name. Available models listed here: https://readme.fireworks.ai/ @@ -74,14 +83,14 @@ class Fireworks(LLM): the response for each token generation step. """ - class Config: - """Configuration for this pydantic object.""" - - extra = "forbid" - allow_population_by_field_name = True + model_config = ConfigDict( + extra="forbid", + populate_by_name=True, + ) - @root_validator(pre=True) - def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: + @model_validator(mode="before") + @classmethod + def build_extra(cls, values: Dict[str, Any]) -> Any: """Build extra kwargs from additional params that were passed in.""" all_required_field_names = get_pydantic_field_names(cls) extra = values.get("model_kwargs", {}) @@ -90,14 +99,6 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: ) return values - @root_validator(pre=False, skip_on_failure=True) - def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key exists in environment.""" - values["fireworks_api_key"] = convert_to_secret_str( - get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY") - ) - return values - @property def _llm_type(self) -> str: """Return type of model.""" diff --git a/libs/partners/fireworks/tests/integration_tests/test_chat_models.py b/libs/partners/fireworks/tests/integration_tests/test_chat_models.py index 443f9e5f47b67..88a1cd46cfcbb 100644 --- a/libs/partners/fireworks/tests/integration_tests/test_chat_models.py +++ b/libs/partners/fireworks/tests/integration_tests/test_chat_models.py @@ -7,7 +7,7 @@ from typing import Optional from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk -from langchain_core.pydantic_v1 import BaseModel +from pydantic import BaseModel from langchain_fireworks import ChatFireworks diff --git a/libs/partners/fireworks/tests/unit_tests/test_llms.py b/libs/partners/fireworks/tests/unit_tests/test_llms.py index e2fb8a131e4b6..265df7ede83c1 100644 --- a/libs/partners/fireworks/tests/unit_tests/test_llms.py +++ b/libs/partners/fireworks/tests/unit_tests/test_llms.py @@ -2,7 +2,7 @@ from typing import cast -from langchain_core.pydantic_v1 import SecretStr +from pydantic import SecretStr from pytest import CaptureFixture, MonkeyPatch from langchain_fireworks import Fireworks From 6aac2eeab5824660a41b9be21cc005f68e262e55 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Tue, 3 Sep 2024 17:42:22 -0700 Subject: [PATCH 2/5] fmt --- libs/partners/fireworks/langchain_fireworks/chat_models.py | 6 +++--- libs/partners/fireworks/scripts/check_pydantic.sh | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index c5e6d9c42fd2e..4983f51daaba2 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -805,7 +805,7 @@ def with_structured_output( from typing import Optional from langchain_fireworks import ChatFireworks - from langchain_core.pydantic_v1 import BaseModel, Field + from pydantic import BaseModel, Field class AnswerWithJustification(BaseModel): @@ -836,7 +836,7 @@ class AnswerWithJustification(BaseModel): .. code-block:: python from langchain_fireworks import ChatFireworks - from langchain_core.pydantic_v1 import BaseModel + from pydantic import BaseModel class AnswerWithJustification(BaseModel): @@ -923,7 +923,7 @@ class AnswerWithJustification(TypedDict): .. code-block:: from langchain_fireworks import ChatFireworks - from langchain_core.pydantic_v1 import BaseModel + from pydantic import BaseModel class AnswerWithJustification(BaseModel): answer: str diff --git a/libs/partners/fireworks/scripts/check_pydantic.sh b/libs/partners/fireworks/scripts/check_pydantic.sh index 06b5bb81ae236..1317f5e53914f 100755 --- a/libs/partners/fireworks/scripts/check_pydantic.sh +++ b/libs/partners/fireworks/scripts/check_pydantic.sh @@ -20,8 +20,8 @@ result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic') if [ -n "$result" ]; then echo "ERROR: The following lines need to be updated:" echo "$result" - echo "Please replace the code with an import from langchain_core.pydantic_v1." + echo "Please replace the code with an import from pydantic." echo "For example, replace 'from pydantic import BaseModel'" - echo "with 'from langchain_core.pydantic_v1 import BaseModel'" + echo "with 'from pydantic import BaseModel'" exit 1 fi From 56163481dd692ae85d24762716dc39dea8d5b1a7 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Tue, 3 Sep 2024 17:46:41 -0700 Subject: [PATCH 3/5] fmt --- .../unit_tests/test_embeddings_standard.py | 30 +++++++++++++++++++ .../tests/unit_tests/test_standard.py | 16 +++++++++- 2 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 libs/partners/fireworks/tests/unit_tests/test_embeddings_standard.py diff --git a/libs/partners/fireworks/tests/unit_tests/test_embeddings_standard.py b/libs/partners/fireworks/tests/unit_tests/test_embeddings_standard.py new file mode 100644 index 0000000000000..ea8d16f92d0a8 --- /dev/null +++ b/libs/partners/fireworks/tests/unit_tests/test_embeddings_standard.py @@ -0,0 +1,30 @@ +"""Standard LangChain interface tests""" + +from typing import Tuple, Type + +from langchain_core.embeddings import Embeddings +from langchain_standard_tests.unit_tests.embeddings import EmbeddingsUnitTests + +from langchain_fireworks import FireworksEmbeddings + + +class TestFireworksStandard(EmbeddingsUnitTests): + @property + def embeddings_class(self) -> Type[Embeddings]: + return FireworksEmbeddings + + @property + def embeddings_params(self) -> dict: + return {"api_key": "test_api_key"} + + @property + def init_from_env_params(self) -> Tuple[dict, dict, dict]: + return ( + { + "FIREWORKS_API_KEY": "api_key", + }, + {}, + { + "fireworks_api_key": "api_key", + }, + ) diff --git a/libs/partners/fireworks/tests/unit_tests/test_standard.py b/libs/partners/fireworks/tests/unit_tests/test_standard.py index 9288aeeb9f8f2..61d0d152ba831 100644 --- a/libs/partners/fireworks/tests/unit_tests/test_standard.py +++ b/libs/partners/fireworks/tests/unit_tests/test_standard.py @@ -1,6 +1,6 @@ """Standard LangChain interface tests""" -from typing import Type +from typing import Tuple, Type from langchain_core.language_models import BaseChatModel from langchain_standard_tests.unit_tests import ( # type: ignore[import-not-found] @@ -18,3 +18,17 @@ def chat_model_class(self) -> Type[BaseChatModel]: @property def chat_model_params(self) -> dict: return {"api_key": "test_api_key"} + + @property + def init_from_env_params(self) -> Tuple[dict, dict, dict]: + return ( + { + "FIREWORKS_API_KEY": "api_key", + "FIREWORKS_API_BASE": "https://base.com", + }, + {}, + { + "fireworks_api_key": "api_key", + "fireworks_api_base": "https://base.com", + }, + ) From a91bd2737a5b81a9f6f03f3552e32b675d1205d4 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Tue, 3 Sep 2024 23:30:49 -0700 Subject: [PATCH 4/5] Update libs/partners/fireworks/langchain_fireworks/chat_models.py --- libs/partners/fireworks/langchain_fireworks/chat_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index 4983f51daaba2..5287157253008 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -389,7 +389,7 @@ def validate_environment(self) -> Self: "timeout": self.request_timeout, } - if not (self.client or None): + if not self.client: self.client = Fireworks(**client_params).chat.completions if not (self.async_client or None): self.async_client = AsyncFireworks(**client_params).chat.completions From d0cc9b022a420c12ba86a74c9d0d84d29cf56e52 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Tue, 3 Sep 2024 23:30:56 -0700 Subject: [PATCH 5/5] Update libs/partners/fireworks/langchain_fireworks/chat_models.py --- libs/partners/fireworks/langchain_fireworks/chat_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index 5287157253008..da85524f7d372 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -391,7 +391,7 @@ def validate_environment(self) -> Self: if not self.client: self.client = Fireworks(**client_params).chat.completions - if not (self.async_client or None): + if not self.async_client: self.async_client = AsyncFireworks(**client_params).chat.completions if self.max_retries: self.client._max_retries = self.max_retries