Skip to content

Commit

Permalink
remove extraneous stop params
Browse files Browse the repository at this point in the history
  • Loading branch information
mattf committed Jul 5, 2024
1 parent dfa53ad commit 9696e9d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
9 changes: 3 additions & 6 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,14 +355,12 @@ def get_req(
return self._wait(response, session)

def postprocess(
self, response: Union[str, Response], stop: Optional[Sequence[str]] = None
self, response: Union[str, Response],
) -> Tuple[dict, bool]:
"""Parses a response from the AI Foundation Model Function API.
Strongly assumes that the API will return a single response.
"""
msg_list = self._process_response(response)
msg, is_stopped = self._aggregate_msgs(msg_list)
return msg, is_stopped
return self._aggregate_msgs(self._process_response(response))

def _aggregate_msgs(self, msg_list: Sequence[dict]) -> Tuple[dict, bool]:
"""Dig out relevant details of aggregated message"""
Expand Down Expand Up @@ -402,7 +400,6 @@ def get_req_stream(
self,
payload: dict = {},
invoke_url: Optional[str] = None,
stop: Optional[Sequence[str]] = None,
) -> Iterator:
invoke_url = self._get_invoke_url(invoke_url)
if payload.get("stream", True) is False:
Expand All @@ -425,7 +422,7 @@ def out_gen() -> Generator[dict, Any, Any]:
for line in response.iter_lines():
if line and line.strip() != b"data: [DONE]":
line = line.decode("utf-8")
msg, final_line = call.postprocess(line, stop=stop)
msg, final_line = call.postprocess(line)
yield msg
if final_line:
break
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _generate(
inputs = self._custom_preprocess(messages)
payload = self._get_payload(inputs=inputs, stop=stop, stream=False, **kwargs)
response = self._client.client.get_req(payload=payload)
responses, _ = self._client.client.postprocess(response, stop=stop)
responses, _ = self._client.client.postprocess(response)
self._set_callback_out(responses, run_manager)
message = ChatMessage(**self._custom_postprocess(responses))
generation = ChatGeneration(message=message)
Expand Down

0 comments on commit 9696e9d

Please sign in to comment.