Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
aviolante committed Aug 29, 2024
2 parents 8d73f3b + 1b49fc8 commit c8e0678
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 4 deletions.
8 changes: 5 additions & 3 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Tuple,
Union,
)
from urllib.parse import urlparse
from urllib.parse import urlparse, urlunparse

import requests
from langchain_core.pydantic_v1 import (
Expand Down Expand Up @@ -138,7 +138,9 @@ def _preprocess_args(cls, values: Dict[str, Any]) -> Dict[str, Any]:
):
warnings.warn(f"Using {base_url}, ignoring the rest")

values["base_url"] = base_url
values["base_url"] = base_url = urlunparse(
(parsed.scheme, parsed.netloc, "v1", None, None, None)
)
values["infer_path"] = values["infer_path"].format(base_url=base_url)

return values
Expand Down Expand Up @@ -523,7 +525,7 @@ def get_req_stream(
}

response = self.get_session_fn().post(
**self.__add_authorization(self.last_inputs)
stream=True, **self.__add_authorization(self.last_inputs)
)
self._try_raise(response)
call = self.copy()
Expand Down
2 changes: 1 addition & 1 deletion libs/ai-endpoints/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-nvidia-ai-endpoints"
version = "0.2.1"
version = "0.2.2"
description = "An integration package connecting NVIDIA AI Endpoints and LangChain"
authors = []
readme = "README.md"
Expand Down
27 changes: 27 additions & 0 deletions libs/ai-endpoints/tests/integration_tests/test_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import time

from langchain_nvidia_ai_endpoints import ChatNVIDIA


def test_ttft(chat_model: str, mode: dict) -> None:
# we had an issue where streaming took a long time to start. the issue
# was all streamed results were collected before yielding them to the
# user. this test tries to detect the incorrect behavior.
#
# warning:
# - this can false positive if the model itself is slow to start
# - this can false nagative if there is a delay after the first chunk
#
# potential mitigation for false negative is to check mean & stdev and
# filter outliers.
#
# credit to Pouyan Rezakhani for finding this issue
llm = ChatNVIDIA(model=chat_model, **mode)
chunk_times = [time.time()]
for chunk in llm.stream("Count to 1000 by 2s, e.g. 2 4 6 8 ...", max_tokens=512):
chunk_times.append(time.time())
ttft = chunk_times[1] - chunk_times[0]
total_time = chunk_times[-1] - chunk_times[0]
assert ttft < (
total_time / 2
), "potential streaming issue, TTFT should be less than half of the total time"

0 comments on commit c8e0678

Please sign in to comment.