Skip to content

Commit

Permalink
Merge branch 'dev-v0.3' into add-vlm-support
Browse files Browse the repository at this point in the history
  • Loading branch information
mattf committed Sep 19, 2024
2 parents 6ec6b73 + 4d66127 commit ebfceed
Show file tree
Hide file tree
Showing 13 changed files with 934 additions and 942 deletions.
86 changes: 41 additions & 45 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
Field,
PrivateAttr,
SecretStr,
root_validator,
validator,
)
from requests.models import Response
Expand All @@ -52,9 +51,10 @@ class Config:
arbitrary_types_allowed = True

## Core defaults. These probably should not be changed
_api_key_var = "NVIDIA_API_KEY"
base_url: str = Field(
...,
default_factory=lambda: os.getenv(
"NVIDIA_BASE_URL", "https://integrate.api.nvidia.com/v1"
),
description="Base URL for standard inference",
)
infer_path: str = Field(
Expand All @@ -71,11 +71,26 @@ class Config:
)
get_session_fn: Callable = Field(requests.Session)

api_key: Optional[SecretStr] = Field(description="API Key for service of choice")
api_key: Optional[SecretStr] = Field(
default_factory=lambda: SecretStr(
os.getenv("NVIDIA_API_KEY", "INTERNAL_LCNVAIE_ERROR")
)
if "NVIDIA_API_KEY" in os.environ
else None,
description="API Key for service of choice",
)

## Generation arguments
timeout: float = Field(60, ge=0, description="Timeout for waiting on response (s)")
interval: float = Field(0.02, ge=0, description="Interval for pulling response")
timeout: float = Field(
60,
ge=0,
description="The minimum amount of time (in sec) to poll after a 202 response",
)
interval: float = Field(
0.02,
ge=0,
description="Interval (in sec) between polling attempts after a 202 response",
)
last_inputs: Optional[dict] = Field(
description="Last inputs sent over to the server"
)
Expand Down Expand Up @@ -105,45 +120,23 @@ class Config:

@validator("base_url")
def _validate_base_url(cls, v: str) -> str:
## Making sure /v1 in added to the url
if v is not None:
result = urlparse(v)
expected_format = "Expected format is 'http://host:port'."
# Ensure scheme and netloc (domain name) are present
if not (result.scheme and result.netloc):
raise ValueError(f"Invalid base_url format. {expected_format} Got: {v}")
return v

@root_validator(pre=True)
def _preprocess_args(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["api_key"] = (
values.get(cls._api_key_var.lower())
or values.get("api_key")
or os.getenv(cls._api_key_var)
or None
)

## Making sure /v1 in added to the url, followed by infer_path
if "base_url" in values:
base_url = values["base_url"].strip("/")
parsed = urlparse(base_url)
expected_format = "Expected format is: http://host:port"
parsed = urlparse(v)

# Ensure scheme and netloc (domain name) are present
if not (parsed.scheme and parsed.netloc):
raise ValueError(
f"Invalid base_url format. {expected_format} Got: {base_url}"
)
expected_format = "Expected format is: http://host:port"
raise ValueError(f"Invalid base_url format. {expected_format} Got: {v}")

if base_url.endswith(
if v.strip("/").endswith(
("/embeddings", "/completions", "/rankings", "/reranking")
):
warnings.warn(f"Using {base_url}, ignoring the rest")
warnings.warn(f"Using {v}, ignoring the rest")

values["base_url"] = base_url = urlunparse(
(parsed.scheme, parsed.netloc, "v1", None, None, None)
)
values["infer_path"] = values["infer_path"].format(base_url=base_url)
v = urlunparse((parsed.scheme, parsed.netloc, "v1", None, None, None))

return values
return v

# final validation after model is constructed
# todo: when pydantic v2 is available,
Expand Down Expand Up @@ -233,7 +226,7 @@ def is_lc_serializable(cls) -> bool:

@property
def lc_secrets(self) -> Dict[str, str]:
return {"api_key": self._api_key_var}
return {"api_key": "NVIDIA_API_KEY"}

@property
def lc_attributes(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -376,9 +369,7 @@ def _wait(self, response: Response, session: requests.Session) -> Response:
start_time = time.time()
# note: the local NIM does not return a 202 status code
# (per RL 22may2024 circa 24.05)
while (
response.status_code == 202
): # todo: there are no tests that reach this point
while response.status_code == 202:
time.sleep(self.interval)
if (time.time() - start_time) > self.timeout:
raise TimeoutError(
Expand All @@ -389,10 +380,12 @@ def _wait(self, response: Response, session: requests.Session) -> Response:
"NVCF-REQID" in response.headers
), "Received 202 response with no request id to follow"
request_id = response.headers.get("NVCF-REQID")
# todo: this needs testing, missing auth header update
payload = {
"url": self.polling_url_tmpl.format(request_id=request_id),
"headers": self.headers_tmpl["call"],
}
self.last_response = response = session.get(
self.polling_url_tmpl.format(request_id=request_id),
headers=self.headers_tmpl["call"],
**self.__add_authorization(payload)
)
self._try_raise(response)
return response
Expand Down Expand Up @@ -492,7 +485,10 @@ def _aggregate_msgs(self, msg_list: Sequence[dict]) -> Tuple[dict, bool]:
usage_holder = msg.get("usage", {}) ####
if "choices" in msg:
## Tease out ['choices'][0]...['delta'/'message']
msg = msg.get("choices", [{}])[0]
# when streaming w/ usage info, we may get a response
# w/ choices: [] that includes final usage info
choices = msg.get("choices", [{}])
msg = choices[0] if choices else {}
# todo: this meeds to be fixed, the fact we only
# use the first choice breaks the interface
finish_reason_holder = msg.get("finish_reason", None)
Expand Down
10 changes: 10 additions & 0 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,11 @@ def validate_client(cls, client: str, values: dict) -> str:
model_type="chat",
client="ChatNVIDIA",
),
"abacusai/dracarys-llama-3.1-70b-instruct": Model(
id="abacusai/dracarys-llama-3.1-70b-instruct",
model_type="chat",
client="ChatNVIDIA",
),
}

QA_MODEL_TABLE = {
Expand Down Expand Up @@ -514,6 +519,11 @@ def validate_client(cls, client: str, values: dict) -> str:
model_type="embedding",
client="NVIDIAEmbeddings",
),
"nvidia/embed-qa-4": Model(
id="nvidia/embed-qa-4",
model_type="embedding",
client="NVIDIAEmbeddings",
),
}

RANKING_MODEL_TABLE = {
Expand Down
50 changes: 32 additions & 18 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
ChatResult,
Generation,
)
from langchain_core.pydantic_v1 import BaseModel, Field, PrivateAttr, root_validator
from langchain_core.pydantic_v1 import BaseModel, Field, PrivateAttr
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
Expand Down Expand Up @@ -253,8 +253,8 @@ class ChatNVIDIA(BaseChatModel):

_client: _NVIDIAClient = PrivateAttr(_NVIDIAClient)
_default_model_name: str = "meta/llama3-8b-instruct"
_default_base_url: str = "https://integrate.api.nvidia.com/v1"
base_url: str = Field(
base_url: Optional[str] = Field(
default=None,
description="Base url for model listing an invocation",
)
model: Optional[str] = Field(description="Name of the model to invoke")
Expand All @@ -266,18 +266,6 @@ class ChatNVIDIA(BaseChatModel):
seed: Optional[int] = Field(description="The seed for deterministic results")
stop: Optional[Sequence[str]] = Field(description="Stop words (cased)")

_base_url_var = "NVIDIA_BASE_URL"

@root_validator(pre=True)
def _validate_base_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["base_url"] = (
values.get(cls._base_url_var.lower())
or values.get("base_url")
or os.getenv(cls._base_url_var)
or cls._default_base_url
)
return values

def __init__(self, **kwargs: Any):
"""
Create a new NVIDIAChat chat model.
Expand Down Expand Up @@ -312,17 +300,23 @@ def __init__(self, **kwargs: Any):
)
"""
super().__init__(**kwargs)
# allow nvidia_base_url as an alternative for base_url
base_url = kwargs.pop("nvidia_base_url", self.base_url)
# allow nvidia_api_key as an alternative for api_key
api_key = kwargs.pop("nvidia_api_key", kwargs.pop("api_key", None))
self._client = _NVIDIAClient(
base_url=self.base_url,
**({"base_url": base_url} if base_url else {}), # only pass if set
model_name=self.model,
default_hosted_model_name=self._default_model_name,
api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)),
**({"api_key": api_key} if api_key else {}), # only pass if set
infer_path="{base_url}/chat/completions",
cls=self.__class__.__name__,
)
# todo: only store the model in one place
# the model may be updated to a newer name during initialization
self.model = self._client.model_name
# same for base_url
self.base_url = self._client.base_url

@property
def available_models(self) -> List[Model]:
Expand Down Expand Up @@ -382,7 +376,21 @@ def _stream(
for message in [convert_message_to_dict(message) for message in messages]
]
inputs, extra_headers = _process_for_vlm(inputs, self._client.model)
payload = self._get_payload(inputs=inputs, stop=stop, stream=True, **kwargs)
payload = self._get_payload(
inputs=inputs,
stop=stop,
stream=True,
stream_options={"include_usage": True},
**kwargs,
)
# todo: get vlm endpoints fixed and remove this
# vlm endpoints do not accept standard stream_options parameter
if (
self._client.model
and self._client.model.model_type
and self._client.model.model_type == "nv-vlm"
):
payload.pop("stream_options")
for response in self._client.get_req_stream(
payload=payload, extra_headers=extra_headers
):
Expand Down Expand Up @@ -422,6 +430,12 @@ def _custom_postprocess(
"additional_kwargs": {},
"response_metadata": {},
}
if token_usage := kw_left.pop("token_usage", None):
out_dict["usage_metadata"] = {
"input_tokens": token_usage.get("prompt_tokens", 0),
"output_tokens": token_usage.get("completion_tokens", 0),
"total_tokens": token_usage.get("total_tokens", 0),
}
# "tool_calls" is set for invoke and stream responses
if tool_calls := kw_left.pop("tool_calls", None):
assert isinstance(
Expand Down
30 changes: 11 additions & 19 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
"""Embeddings Components Derived from NVEModel/Embeddings"""

import os
import warnings
from typing import Any, Dict, List, Literal, Optional
from typing import Any, List, Literal, Optional

from langchain_core.embeddings import Embeddings
from langchain_core.outputs.llm_result import LLMResult
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
PrivateAttr,
root_validator,
validator,
)

Expand All @@ -36,8 +34,8 @@ class Config:
_client: _NVIDIAClient = PrivateAttr(_NVIDIAClient)
_default_model_name: str = "nvidia/nv-embedqa-e5-v5"
_default_max_batch_size: int = 50
_default_base_url: str = "https://integrate.api.nvidia.com/v1"
base_url: str = Field(
base_url: Optional[str] = Field(
default=None,
description="Base url for model listing an invocation",
)
model: Optional[str] = Field(description="Name of the model to invoke")
Expand All @@ -53,18 +51,6 @@ class Config:
None, description="(DEPRECATED) The type of text to be embedded."
)

_base_url_var = "NVIDIA_BASE_URL"

@root_validator(pre=True)
def _validate_base_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["base_url"] = (
values.get(cls._base_url_var.lower())
or values.get("base_url")
or os.getenv(cls._base_url_var)
or cls._default_base_url
)
return values

def __init__(self, **kwargs: Any):
"""
Create a new NVIDIAEmbeddings embedder.
Expand Down Expand Up @@ -94,17 +80,23 @@ def __init__(self, **kwargs: Any):
embedder = NVIDIAEmbeddings(base_url="http://localhost:8080/v1")
"""
super().__init__(**kwargs)
# allow nvidia_base_url as an alternative for base_url
base_url = kwargs.pop("nvidia_base_url", self.base_url)
# allow nvidia_api_key as an alternative for api_key
api_key = kwargs.pop("nvidia_api_key", kwargs.pop("api_key", None))
self._client = _NVIDIAClient(
base_url=self.base_url,
**({"base_url": base_url} if base_url else {}), # only pass if set
model_name=self.model,
default_hosted_model_name=self._default_model_name,
api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)),
**({"api_key": api_key} if api_key else {}), # only pass if set
infer_path="{base_url}/embeddings",
cls=self.__class__.__name__,
)
# todo: only store the model in one place
# the model may be updated to a newer name during initialization
self.model = self._client.model_name
# same for base_url
self.base_url = self._client.base_url

# todo: remove when nvolveqa_40k is removed from MODEL_TABLE
if "model" in kwargs and kwargs["model"] in [
Expand Down
Loading

0 comments on commit ebfceed

Please sign in to comment.