Skip to content

Commit

Permalink
remove broken stop word detection, rely on server side implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
mattf committed Jul 2, 2024
1 parent 0338598 commit e17aa91
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 34 deletions.
13 changes: 0 additions & 13 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,6 @@ def postprocess(
"""
msg_list = self._process_response(response)
msg, is_stopped = self._aggregate_msgs(msg_list)
msg, is_stopped = self._early_stop_msg(msg, is_stopped, stop=stop)
return msg, is_stopped

def _aggregate_msgs(self, msg_list: Sequence[dict]) -> Tuple[dict, bool]:
Expand Down Expand Up @@ -424,18 +423,6 @@ def _aggregate_msgs(self, msg_list: Sequence[dict]) -> Tuple[dict, bool]:
content_holder.update(token_usage=usage_holder) ####
return content_holder, is_stopped

def _early_stop_msg(
self, msg: dict, is_stopped: bool, stop: Optional[Sequence[str]] = None
) -> Tuple[dict, bool]:
"""Try to early-terminate streaming or generation by iterating over stop list"""
content = msg.get("content", "")
if content and stop:
for stop_str in stop:
if stop_str and stop_str in content:
msg["content"] = content[: content.find(stop_str) + 1]
is_stopped = True
return msg, is_stopped

####################################################################################
## Streaming interface to allow you to iterate through progressive generations

Expand Down
79 changes: 58 additions & 21 deletions libs/ai-endpoints/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,27 +202,6 @@ def test_ai_endpoints_invoke(chat_model: str, mode: dict) -> None:
assert isinstance(result.content, str)


# todo: test that stop is cased and works with multiple words
@pytest.mark.xfail(reason="stop is not consistently implemented")
def test_invoke_stop(chat_model: str, mode: dict) -> None:
"""Test invoke's stop words."""
llm = ChatNVIDIA(model=chat_model, **mode, stop=["10"])
result = llm.invoke("please count to 20 by 1s, e.g. 1 2 3 4")
assert isinstance(result.content, str)
assert "10" not in result.content


@pytest.mark.xfail(reason="stop is not consistently implemented")
def test_stream_stop(chat_model: str, mode: dict) -> None:
"""Test stream's stop words."""
llm = ChatNVIDIA(model=chat_model, **mode, stop=["10"])
result = ""
for token in llm.stream("please count to 20 by 1s, e.g. 1 2 3 4"):
assert isinstance(token.content, str)
result += f"{token.content}|"
assert "10" not in result


# todo: max_tokens test for ainvoke, batch, abatch, stream, astream


Expand Down Expand Up @@ -383,3 +362,61 @@ def test_serialize_chatnvidia(chat_model: str, mode: dict) -> None:
model = loads(dumps(llm), valid_namespaces=["langchain_nvidia_ai_endpoints"])
result = model.invoke("What is there if there is nothing?")
assert isinstance(result.content, str)


# todo: test that stop is cased and works with multiple words


@pytest.mark.parametrize(
"prop",
[
False,
True,
],
ids=["no_prop", "prop"],
)
@pytest.mark.parametrize(
"param",
[
False,
True,
],
ids=["no_param", "param"],
)
@pytest.mark.parametrize(
"targets",
[["5"], ["6", "100"], ["100", "7"]],
ids=["5", "6,100", "100,7"],
)
@pytest.mark.parametrize(
"func",
[
"invoke",
"stream",
],
)
@pytest.mark.xfail(reason="stop is not consistently implemented")
def test_stop(
chat_model: str, mode: dict, func: str, prop: bool, param: bool, targets: List[str]
) -> None:
if not prop and not param:
pytest.skip("Skipping test, no stop parameter")
llm = ChatNVIDIA(
model=chat_model, stop=targets if prop else None, max_tokens=512, **mode
)
result = ""
if func == "invoke":
response = llm.invoke(
"please count to 20 by 1s, e.g. 1 2 3 4",
stop=targets if param else None,
) # invoke returns Union[str, List[Union[str, Dict[Any, Any]]]]
assert isinstance(response.content, str)
result = response.content
elif func == "stream":
for token in llm.stream(
"please count to 20 by 1s, e.g. 1 2 3 4",
stop=targets if param else None,
):
assert isinstance(token.content, str)
result += f"{token.content}|"
assert all(target not in result for target in targets)

0 comments on commit e17aa91

Please sign in to comment.