Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fireworks[major]: switch to pydantic v2 #26004

Merged
merged 7 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 33 additions & 31 deletions libs/partners/fireworks/langchain_fireworks/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,6 @@
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
SecretStr,
root_validator,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils import (
Expand All @@ -85,6 +79,14 @@
)
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env
from pydantic import (
BaseModel,
ConfigDict,
Field,
SecretStr,
model_validator,
)
from typing_extensions import Self

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -354,13 +356,13 @@ def is_lc_serializable(cls) -> bool:
max_retries: Optional[int] = None
"""Maximum number of retries to make when generating."""

class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True
model_config = ConfigDict(
populate_by_name=True,
)

@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
Expand All @@ -369,32 +371,32 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
)
return values

@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if values["n"] < 1:
if self.n < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
if self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")

client_params = {
"api_key": (
values["fireworks_api_key"].get_secret_value()
if values["fireworks_api_key"]
self.fireworks_api_key.get_secret_value()
if self.fireworks_api_key
else None
),
"base_url": values["fireworks_api_base"],
"timeout": values["request_timeout"],
"base_url": self.fireworks_api_base,
"timeout": self.request_timeout,
}

if not values.get("client"):
values["client"] = Fireworks(**client_params).chat.completions
if not values.get("async_client"):
values["async_client"] = AsyncFireworks(**client_params).chat.completions
if values["max_retries"]:
values["client"]._max_retries = values["max_retries"]
values["async_client"]._max_retries = values["max_retries"]
return values
if not (self.client or None):
baskaryan marked this conversation as resolved.
Show resolved Hide resolved
self.client = Fireworks(**client_params).chat.completions
if not (self.async_client or None):
baskaryan marked this conversation as resolved.
Show resolved Hide resolved
self.async_client = AsyncFireworks(**client_params).chat.completions
if self.max_retries:
self.client._max_retries = self.max_retries
self.async_client._max_retries = self.max_retries
return self

@property
def _default_params(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -803,7 +805,7 @@ def with_structured_output(
from typing import Optional

from langchain_fireworks import ChatFireworks
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field


class AnswerWithJustification(BaseModel):
Expand Down Expand Up @@ -834,7 +836,7 @@ class AnswerWithJustification(BaseModel):
.. code-block:: python

from langchain_fireworks import ChatFireworks
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel


class AnswerWithJustification(BaseModel):
Expand Down Expand Up @@ -921,7 +923,7 @@ class AnswerWithJustification(TypedDict):
.. code-block::

from langchain_fireworks import ChatFireworks
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel

class AnswerWithJustification(BaseModel):
answer: str
Expand Down
28 changes: 18 additions & 10 deletions libs/partners/fireworks/langchain_fireworks/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Any, Dict, List
from typing import List

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import secret_from_env
from openai import OpenAI # type: ignore
from openai import OpenAI
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self

# type: ignore


class FireworksEmbeddings(BaseModel, Embeddings):
Expand Down Expand Up @@ -65,7 +68,7 @@ class FireworksEmbeddings(BaseModel, Embeddings):
[-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915]
"""

_client: OpenAI = Field(default=None)
client: OpenAI = Field(default=None, exclude=True) #: :meta private:
fireworks_api_key: SecretStr = Field(
alias="api_key",
default_factory=secret_from_env(
Expand All @@ -79,20 +82,25 @@ class FireworksEmbeddings(BaseModel, Embeddings):
"""
model: str = "nomic-ai/nomic-embed-text-v1.5"

@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
model_config = ConfigDict(
populate_by_name=True,
arbitrary_types_allowed=True,
)

@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate environment variables."""
values["_client"] = OpenAI(
api_key=values["fireworks_api_key"].get_secret_value(),
self.client = OpenAI(
api_key=self.fireworks_api_key.get_secret_value(),
base_url="https://api.fireworks.ai/inference/v1",
)
return values
return self

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs."""
return [
i.embedding
for i in self._client.embeddings.create(input=texts, model=self.model).data
for i in self.client.embeddings.create(input=texts, model=self.model).data
]

def embed_query(self, text: str) -> List[float]:
Expand Down
49 changes: 25 additions & 24 deletions libs/partners/fireworks/langchain_fireworks/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,9 @@
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
)
from langchain_core.utils.utils import build_extra_kwargs
from langchain_core.utils import get_pydantic_field_names
from langchain_core.utils.utils import build_extra_kwargs, secret_from_env
from pydantic import ConfigDict, Field, SecretStr, model_validator

from langchain_fireworks.version import __version__

Expand All @@ -39,8 +35,21 @@ class Fireworks(LLM):

base_url: str = "https://api.fireworks.ai/inference/v1/completions"
"""Base inference API URL."""
fireworks_api_key: SecretStr = Field(default=None, alias="api_key")
"""Fireworks AI API key. Get it here: https://fireworks.ai"""
fireworks_api_key: SecretStr = Field(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alias="api_key",
default_factory=secret_from_env(
"FIREWORKS_API_KEY",
error_message=(
"You must specify an api key. "
"You can pass it an argument as `api_key=...` or "
"set the environment variable `FIREWORKS_API_KEY`."
),
),
)
"""Fireworks API key.

Automatically read from env variable `FIREWORKS_API_KEY` if not provided.
"""
model: str
"""Model name. Available models listed here:
https://readme.fireworks.ai/
Expand Down Expand Up @@ -74,14 +83,14 @@ class Fireworks(LLM):
the response for each token generation step.
"""

class Config:
"""Configuration for this pydantic object."""

extra = "forbid"
allow_population_by_field_name = True
model_config = ConfigDict(
extra="forbid",
populate_by_name=True,
)

@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
Expand All @@ -90,14 +99,6 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
)
return values

@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment."""
values["fireworks_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY")
)
return values

@property
def _llm_type(self) -> str:
"""Return type of model."""
Expand Down
4 changes: 2 additions & 2 deletions libs/partners/fireworks/scripts/check_pydantic.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
if [ -n "$result" ]; then
echo "ERROR: The following lines need to be updated:"
echo "$result"
echo "Please replace the code with an import from langchain_core.pydantic_v1."
echo "Please replace the code with an import from pydantic."
echo "For example, replace 'from pydantic import BaseModel'"
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
echo "with 'from pydantic import BaseModel'"
exit 1
fi
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Optional

from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel

from langchain_fireworks import ChatFireworks

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Standard LangChain interface tests"""

from typing import Tuple, Type

from langchain_core.embeddings import Embeddings
from langchain_standard_tests.unit_tests.embeddings import EmbeddingsUnitTests

from langchain_fireworks import FireworksEmbeddings


class TestFireworksStandard(EmbeddingsUnitTests):
@property
def embeddings_class(self) -> Type[Embeddings]:
return FireworksEmbeddings

@property
def embeddings_params(self) -> dict:
return {"api_key": "test_api_key"}

@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
return (
{
"FIREWORKS_API_KEY": "api_key",
},
{},
{
"fireworks_api_key": "api_key",
},
)
2 changes: 1 addition & 1 deletion libs/partners/fireworks/tests/unit_tests/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import cast

from langchain_core.pydantic_v1 import SecretStr
from pydantic import SecretStr
from pytest import CaptureFixture, MonkeyPatch

from langchain_fireworks import Fireworks
Expand Down
16 changes: 15 additions & 1 deletion libs/partners/fireworks/tests/unit_tests/test_standard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Standard LangChain interface tests"""

from typing import Type
from typing import Tuple, Type

from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.unit_tests import ( # type: ignore[import-not-found]
Expand All @@ -18,3 +18,17 @@ def chat_model_class(self) -> Type[BaseChatModel]:
@property
def chat_model_params(self) -> dict:
return {"api_key": "test_api_key"}

@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
return (
{
"FIREWORKS_API_KEY": "api_key",
"FIREWORKS_API_BASE": "https://base.com",
},
{},
{
"fireworks_api_key": "api_key",
"fireworks_api_base": "https://base.com",
},
)
Loading