Skip to content

Commit

Permalink
Merge branch 'dev-v0.3' into add-vlm-support
Browse files Browse the repository at this point in the history
  • Loading branch information
mattf committed Sep 19, 2024
2 parents ebfceed + ed8dfd0 commit 1098fdd
Show file tree
Hide file tree
Showing 22 changed files with 331 additions and 545 deletions.
1 change: 0 additions & 1 deletion libs/ai-endpoints/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test

lint lint_diff lint_package lint_tests:
./scripts/check_pydantic.sh .
./scripts/lint_imports.sh
poetry run ruff .
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
Expand Down
282 changes: 26 additions & 256 deletions libs/ai-endpoints/docs/chat/nvidia_ai_endpoints.ipynb

Large diffs are not rendered by default.

56 changes: 31 additions & 25 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,40 +20,46 @@
from urllib.parse import urlparse, urlunparse

import requests
from langchain_core.pydantic_v1 import (
from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
SecretStr,
validator,
field_validator,
)
from requests.models import Response

from langchain_nvidia_ai_endpoints._statics import MODEL_TABLE, Model, determine_model

logger = logging.getLogger(__name__)

_API_KEY_VAR = "NVIDIA_API_KEY"
_BASE_URL_VAR = "NVIDIA_BASE_URL"


class _NVIDIAClient(BaseModel):
"""
Low level client library interface to NIM endpoints.
"""

default_hosted_model_name: str = Field(..., description="Default model name to use")
model_name: Optional[str] = Field(..., description="Name of the model to invoke")
# "mdl_name" because "model_" is a protected namespace in pydantic
mdl_name: Optional[str] = Field(..., description="Name of the model to invoke")
model: Optional[Model] = Field(None, description="The model to invoke")
is_hosted: bool = Field(True)
cls: str = Field(..., description="Class Name")

# todo: add a validator for requests.Response (last_response attribute) and
# remove arbitrary_types_allowed=True
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)

## Core defaults. These probably should not be changed
base_url: str = Field(
default_factory=lambda: os.getenv(
"NVIDIA_BASE_URL", "https://integrate.api.nvidia.com/v1"
_BASE_URL_VAR, "https://integrate.api.nvidia.com/v1"
),
description="Base URL for standard inference",
)
Expand All @@ -73,9 +79,9 @@ class Config:

api_key: Optional[SecretStr] = Field(
default_factory=lambda: SecretStr(
os.getenv("NVIDIA_API_KEY", "INTERNAL_LCNVAIE_ERROR")
os.getenv(_API_KEY_VAR, "INTERNAL_LCNVAIE_ERROR")
)
if "NVIDIA_API_KEY" in os.environ
if _API_KEY_VAR in os.environ
else None,
description="API Key for service of choice",
)
Expand All @@ -92,7 +98,7 @@ class Config:
description="Interval (in sec) between polling attempts after a 202 response",
)
last_inputs: Optional[dict] = Field(
description="Last inputs sent over to the server"
default={}, description="Last inputs sent over to the server"
)
last_response: Response = Field(
None, description="Last response sent from the server"
Expand All @@ -118,7 +124,7 @@ class Config:
###################################################################################
################### Validation and Initialization #################################

@validator("base_url")
@field_validator("base_url")
def _validate_base_url(cls, v: str) -> str:
## Making sure /v1 in added to the url
if v is not None:
Expand Down Expand Up @@ -158,10 +164,10 @@ def __init__(self, **kwargs: Any):
)

# set default model for hosted endpoint
if not self.model_name:
self.model_name = self.default_hosted_model_name
if not self.mdl_name:
self.mdl_name = self.default_hosted_model_name

if model := determine_model(self.model_name):
if model := determine_model(self.mdl_name):
if not model.client:
warnings.warn(f"Unable to determine validity of {model.id}")
elif model.client != self.cls:
Expand All @@ -179,37 +185,37 @@ def __init__(self, **kwargs: Any):
candidates = [
model
for model in self.available_models
if model.id == self.model_name
if model.id == self.mdl_name
]
assert len(candidates) <= 1, (
f"Multiple candidates for {self.model_name} "
f"Multiple candidates for {self.mdl_name} "
f"in `available_models`: {candidates}"
)
if candidates:
model = candidates[0]
warnings.warn(
f"Found {self.model_name} in available_models, but type is "
f"Found {self.mdl_name} in available_models, but type is "
"unknown and inference may fail."
)
else:
raise ValueError(
f"Model {self.model_name} is unknown, check `available_models`"
f"Model {self.mdl_name} is unknown, check `available_models`"
)
self.model = model
self.model_name = self.model.id # name may change because of aliasing
self.mdl_name = self.model.id # name may change because of aliasing
else:
# set default model
if not self.model_name:
if not self.mdl_name:
valid_models = [
model
for model in self.available_models
if not model.base_model or model.base_model == model.id
]
self.model = next(iter(valid_models), None)
if self.model:
self.model_name = self.model.id
self.mdl_name = self.model.id
warnings.warn(
f"Default model is set as: {self.model_name}. \n"
f"Default model is set as: {self.mdl_name}. \n"
"Set model using model parameter. \n"
"To get available models use available_models property.",
UserWarning,
Expand All @@ -226,15 +232,15 @@ def is_lc_serializable(cls) -> bool:

@property
def lc_secrets(self) -> Dict[str, str]:
return {"api_key": "NVIDIA_API_KEY"}
return {"api_key": _API_KEY_VAR}

@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}
attributes["base_url"] = self.base_url

if self.model_name:
attributes["model"] = self.model_name
if self.mdl_name:
attributes["model"] = self.mdl_name

return attributes

Expand Down Expand Up @@ -535,7 +541,7 @@ def get_req_stream(
stream=True, **self.__add_authorization(self.last_inputs)
)
self._try_raise(response)
call = self.copy()
call: _NVIDIAClient = self.model_copy()

def out_gen() -> Generator[dict, Any, Any]:
## Good for client, since it allows self.last_inputs
Expand Down
16 changes: 8 additions & 8 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from typing import Literal, Optional

from langchain_core.pydantic_v1 import BaseModel, validator
from pydantic import BaseModel, model_validator


class Model(BaseModel):
Expand Down Expand Up @@ -37,21 +37,21 @@ class Model(BaseModel):
def __hash__(self) -> int:
return hash(self.id)

@validator("client", always=True)
def validate_client(cls, client: str, values: dict) -> str:
if client:
@model_validator(mode="after")
def validate_client(self) -> "Model":
if self.client:
supported = {
"ChatNVIDIA": ("chat", "vlm", "nv-vlm", "qa"),
"NVIDIAEmbeddings": ("embedding",),
"NVIDIARerank": ("ranking",),
"NVIDIA": ("completions",),
}
model_type = values.get("model_type")
if model_type not in supported[client]:
if self.model_type not in supported.get(self.client, ()):
raise ValueError(
f"Model type '{model_type}' not supported by client '{client}'"
f"Model type '{self.model_type}' not supported "
f"by client '{self.client}'"
)
return client
return self


CHAT_MODEL_TABLE = {
Expand Down
40 changes: 23 additions & 17 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,17 @@
ChatResult,
Generation,
)
from langchain_core.pydantic_v1 import BaseModel, Field, PrivateAttr
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from pydantic import BaseModel, Field, PrivateAttr

from langchain_nvidia_ai_endpoints._common import _NVIDIAClient
from langchain_nvidia_ai_endpoints._statics import Model
from langchain_nvidia_ai_endpoints._utils import convert_message_to_dict

_CallbackManager = Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
_DictOrPydanticOrEnumClass = Union[Dict[str, Any], Type[BaseModel], Type[enum.Enum]]
_DictOrPydanticOrEnum = Union[Dict, BaseModel, enum.Enum]

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -238,6 +236,9 @@ def _process_for_vlm(
return inputs, extra_headers


_DEFAULT_MODEL_NAME: str = "meta/llama3-8b-instruct"


class ChatNVIDIA(BaseChatModel):
"""NVIDIA chat model.
Expand All @@ -252,19 +253,20 @@ class ChatNVIDIA(BaseChatModel):
"""

_client: _NVIDIAClient = PrivateAttr(_NVIDIAClient)
_default_model_name: str = "meta/llama3-8b-instruct"
base_url: Optional[str] = Field(
default=None,
description="Base url for model listing an invocation",
)
model: Optional[str] = Field(description="Name of the model to invoke")
temperature: Optional[float] = Field(description="Sampling temperature in [0, 1]")
model: Optional[str] = Field(None, description="Name of the model to invoke")
temperature: Optional[float] = Field(
None, description="Sampling temperature in [0, 1]"
)
max_tokens: Optional[int] = Field(
1024, description="Maximum # of tokens to generate"
)
top_p: Optional[float] = Field(description="Top-p for distribution sampling")
seed: Optional[int] = Field(description="The seed for deterministic results")
stop: Optional[Sequence[str]] = Field(description="Stop words (cased)")
top_p: Optional[float] = Field(None, description="Top-p for distribution sampling")
seed: Optional[int] = Field(None, description="The seed for deterministic results")
stop: Optional[Sequence[str]] = Field(None, description="Stop words (cased)")

def __init__(self, **kwargs: Any):
"""
Expand Down Expand Up @@ -306,15 +308,15 @@ def __init__(self, **kwargs: Any):
api_key = kwargs.pop("nvidia_api_key", kwargs.pop("api_key", None))
self._client = _NVIDIAClient(
**({"base_url": base_url} if base_url else {}), # only pass if set
model_name=self.model,
default_hosted_model_name=self._default_model_name,
mdl_name=self.model,
default_hosted_model_name=_DEFAULT_MODEL_NAME,
**({"api_key": api_key} if api_key else {}), # only pass if set
infer_path="{base_url}/chat/completions",
cls=self.__class__.__name__,
)
# todo: only store the model in one place
# the model may be updated to a newer name during initialization
self.model = self._client.model_name
self.model = self._client.mdl_name
# same for base_url
self.base_url = self._client.base_url

Expand Down Expand Up @@ -529,7 +531,7 @@ def _get_payload(

def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
*,
tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "any", "required"], bool]
Expand Down Expand Up @@ -615,11 +617,11 @@ def bind_functions(
# as a result need to type ignore for the schema parameter and return type.
def with_structured_output( # type: ignore
self,
schema: _DictOrPydanticOrEnumClass,
schema: Union[Dict, Type],
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _DictOrPydanticOrEnum]:
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""
Bind a structured output schema to the model.
Expand Down Expand Up @@ -656,7 +658,7 @@ def with_structured_output( # type: ignore
1. If a Pydantic schema is provided, the model will return a Pydantic object.
Example:
```
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field
class Joke(BaseModel):
setup: str = Field(description="The setup of the joke")
punchline: str = Field(description="The punchline to the joke")
Expand Down Expand Up @@ -814,7 +816,11 @@ def parse_result(
return None

output_parser = ForgivingPydanticOutputParser(pydantic_object=schema)
nvext_param = {"guided_json": schema.schema()}
if hasattr(schema, "model_json_schema"):
json_schema = schema.model_json_schema()
else:
json_schema = schema.schema()
nvext_param = {"guided_json": json_schema}

else:
raise ValueError(
Expand Down
Loading

0 comments on commit 1098fdd

Please sign in to comment.