-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'langchain-ai:main' into main
- Loading branch information
Showing
3 changed files
with
33 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
27 changes: 27 additions & 0 deletions
27
libs/ai-endpoints/tests/integration_tests/test_streaming.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import time | ||
|
||
from langchain_nvidia_ai_endpoints import ChatNVIDIA | ||
|
||
|
||
def test_ttft(chat_model: str, mode: dict) -> None: | ||
# we had an issue where streaming took a long time to start. the issue | ||
# was all streamed results were collected before yielding them to the | ||
# user. this test tries to detect the incorrect behavior. | ||
# | ||
# warning: | ||
# - this can false positive if the model itself is slow to start | ||
# - this can false nagative if there is a delay after the first chunk | ||
# | ||
# potential mitigation for false negative is to check mean & stdev and | ||
# filter outliers. | ||
# | ||
# credit to Pouyan Rezakhani for finding this issue | ||
llm = ChatNVIDIA(model=chat_model, **mode) | ||
chunk_times = [time.time()] | ||
for chunk in llm.stream("Count to 1000 by 2s, e.g. 2 4 6 8 ...", max_tokens=512): | ||
chunk_times.append(time.time()) | ||
ttft = chunk_times[1] - chunk_times[0] | ||
total_time = chunk_times[-1] - chunk_times[0] | ||
assert ttft < ( | ||
total_time / 2 | ||
), "potential streaming issue, TTFT should be less than half of the total time" |