From 9696e9db3bb1d849fd47e92af57f70b50fa57b52 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 5 Jul 2024 11:17:51 -0400 Subject: [PATCH] remove extraneous stop params --- .../langchain_nvidia_ai_endpoints/_common.py | 9 +++------ .../langchain_nvidia_ai_endpoints/chat_models.py | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py index aed48dd0..22124e52 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py @@ -355,14 +355,12 @@ def get_req( return self._wait(response, session) def postprocess( - self, response: Union[str, Response], stop: Optional[Sequence[str]] = None + self, response: Union[str, Response], ) -> Tuple[dict, bool]: """Parses a response from the AI Foundation Model Function API. Strongly assumes that the API will return a single response. """ - msg_list = self._process_response(response) - msg, is_stopped = self._aggregate_msgs(msg_list) - return msg, is_stopped + return self._aggregate_msgs(self._process_response(response)) def _aggregate_msgs(self, msg_list: Sequence[dict]) -> Tuple[dict, bool]: """Dig out relevant details of aggregated message""" @@ -402,7 +400,6 @@ def get_req_stream( self, payload: dict = {}, invoke_url: Optional[str] = None, - stop: Optional[Sequence[str]] = None, ) -> Iterator: invoke_url = self._get_invoke_url(invoke_url) if payload.get("stream", True) is False: @@ -425,7 +422,7 @@ def out_gen() -> Generator[dict, Any, Any]: for line in response.iter_lines(): if line and line.strip() != b"data: [DONE]": line = line.decode("utf-8") - msg, final_line = call.postprocess(line, stop=stop) + msg, final_line = call.postprocess(line) yield msg if final_line: break diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py index 1808bddc..7cce280d 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py @@ -212,7 +212,7 @@ def _generate( inputs = self._custom_preprocess(messages) payload = self._get_payload(inputs=inputs, stop=stop, stream=False, **kwargs) response = self._client.client.get_req(payload=payload) - responses, _ = self._client.client.postprocess(response, stop=stop) + responses, _ = self._client.client.postprocess(response) self._set_callback_out(responses, run_manager) message = ChatMessage(**self._custom_postprocess(responses)) generation = ChatGeneration(message=message)