Skip to content

Commit

Permalink
fix stream collection
Browse files Browse the repository at this point in the history
  • Loading branch information
mattf committed Aug 28, 2024
1 parent 9f9b762 commit 72e2931
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
1 change: 1 addition & 0 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ def get_req_stream(
}

response = self.get_session_fn().post(
stream=True,
**self.__add_authorization(self.last_inputs)
)
self._try_raise(response)
Expand Down
27 changes: 27 additions & 0 deletions libs/ai-endpoints/tests/integration_tests/test_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import time
from langchain_nvidia_ai_endpoints import ChatNVIDIA

def test_ttft(chat_model: str, mode: dict) -> None:

Check failure on line 4 in libs/ai-endpoints/tests/integration_tests/test_streaming.py

View workflow job for this annotation

GitHub Actions / cd libs/ai-endpoints / make lint #3.8

Ruff (I001)

tests/integration_tests/test_streaming.py:1:1: I001 Import block is un-sorted or un-formatted
# we had an issue where streaming took a long time to start. the issue
# was all streamed results were collected before yielding them to the
# user. this test tries to detect the incorrect behavior.
#
# warning:
# - this can false positive if the model itself is slow to start
# - this can false nagative if there is a delay after the first chunk
#
# potential mitigation for false negative is to check mean & stdev and
# filter outliers.
#
# credit to Pouyan Rezakhani for finding this issue
llm = ChatNVIDIA(model=chat_model, **mode)
chunk_times = [time.time()]
print("Starting streaming test", time.time())

Check failure on line 19 in libs/ai-endpoints/tests/integration_tests/test_streaming.py

View workflow job for this annotation

GitHub Actions / cd libs/ai-endpoints / make lint #3.8

Ruff (T201)

tests/integration_tests/test_streaming.py:19:5: T201 `print` found
for chunk in llm.stream("Count to 1000 by 2s, e.g. 2 4 6 8 ...", max_tokens=512):
chunk_times.append(time.time())
ttft = chunk_times[1] - chunk_times[0]
total_time = chunk_times[-1] - chunk_times[0]
assert ttft < (total_time / 2), (
"potential streaming issue, "
"TTFT should be less than half of the total time"
)

0 comments on commit 72e2931

Please sign in to comment.