From 69a38ac7fbd7abca8fa9f1a7bd56f251da5531bf Mon Sep 17 00:00:00 2001 From: Leonid Kuligin Date: Tue, 27 Feb 2024 18:10:54 +0100 Subject: [PATCH 1/2] added post-processing for local gemma --- .../langchain_google_vertexai/gemma.py | 61 ++++++++++++++++--- 1 file changed, 51 insertions(+), 10 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/gemma.py b/libs/vertexai/langchain_google_vertexai/gemma.py index c24cb0b1..2f4e6f94 100644 --- a/libs/vertexai/langchain_google_vertexai/gemma.py +++ b/libs/vertexai/langchain_google_vertexai/gemma.py @@ -53,6 +53,21 @@ def gemma_messages_to_prompt(history: List[BaseMessage]) -> str: return "".join(messages) +def _parse_gemma_chat_response(response: str) -> str: + """Removes chat history from the response.""" + pattern = "model\n" + pos = response.rfind(pattern) + if pos == -1: + return response + else: + text = response[(pos + len(pattern)) :] + pos = text.find("user\n") + if pos > 0: + return text[:pos] + else: + return text + + class _GemmaBase(BaseModel): max_tokens: Optional[int] = None """The maximum number of tokens to generate.""" @@ -98,6 +113,9 @@ class GemmaChatVertexAIModelGarden(_GemmaBase, _BaseVertexAIModelGarden, BaseCha "top_k", "max_tokens", ] + parse_response: bool = False + """Whether to post-process the chat response and clean repeations """ + """or multi-turn statements.""" @property def _llm_type(self) -> str: @@ -120,6 +138,8 @@ def _generate( request["prompt"] = gemma_messages_to_prompt(messages) output = self.client.predict(endpoint=self.endpoint_path, instances=[request]) text = output.predictions[0] + if self.parse_response or kwargs.get("parse_response"): + text = _parse_gemma_chat_response(text) if stop: text = enforce_stop_tokens(text, stop) generations = [ @@ -139,10 +159,12 @@ async def _agenerate( """Top Level call""" request = self._get_params(**kwargs) request["prompt"] = gemma_messages_to_prompt(messages) - output = await self.async_client.predict( + text = await self.async_client.predict( endpoint=self.endpoint_path, instances=[request] ) - text = output.predictions[0] + text = _parse_gemma_chat_response(text) + if self.parse_response or kwargs.get("parse_response"): + text = _parse_gemma_chat_response(text) if stop: text = enforce_stop_tokens(text, stop) generations = [ @@ -183,6 +205,11 @@ def _default_params(self) -> Dict[str, Any]: params = {"max_length": self.max_tokens} return {k: v for k, v in params.items() if v is not None} + def _get_params(self, **kwargs) -> Dict[str, Any]: + mapping = {"max_tokens": "max_length"} + params = {mapping[k]: v for k, v in kwargs.items() if k in mapping} + return {**self._default_params, **params} + class GemmaLocalKaggle(_GemmaLocalKaggleBase, BaseLLM): """Local gemma chat model loaded from Kaggle.""" @@ -195,7 +222,7 @@ def _generate( **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" - params = {"max_length": self.max_tokens} if self.max_tokens else {} + params = self._get_params(**kwargs) results = self.client.generate(prompts, **params) results = [results] if isinstance(results, str) else results if stop: @@ -209,6 +236,10 @@ def _llm_type(self) -> str: class GemmaChatLocalKaggle(_GemmaLocalKaggleBase, BaseChatModel): + parse_response: bool = False + """Whether to post-process the chat response and clean repeations """ + """or multi-turn statements.""" + def _generate( self, messages: List[BaseMessage], @@ -216,9 +247,11 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - params = {"max_length": self.max_tokens} if self.max_tokens else {} + params = self._get_params(**kwargs) prompt = gemma_messages_to_prompt(messages) text = self.client.generate(prompt, **params) + if self.parse_response or kwargs.get("parse_response"): + text = _parse_gemma_chat_response(text) if stop: text = enforce_stop_tokens(text, stop) generation = ChatGeneration(message=AIMessage(content=text)) @@ -268,9 +301,15 @@ def _default_params(self) -> Dict[str, Any]: params = {"max_length": self.max_tokens} return {k: v for k, v in params.items() if v is not None} + def _get_params(self, **kwargs) -> Dict[str, Any]: + mapping = {"max_tokens": "max_length"} + params = {mapping[k]: v for k, v in kwargs.items() if k in mapping} + return {**self._default_params, **params} + def _run(self, prompt: str, **kwargs: Any) -> str: inputs = self.tokenizer(prompt, return_tensors="pt") - generate_ids = self.client.generate(inputs.input_ids, **kwargs) + params = self._get_params(**kwargs) + generate_ids = self.client.generate(inputs.input_ids, **params) return self.tokenizer.batch_decode( generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] @@ -287,8 +326,7 @@ def _generate( **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" - params = {"max_length": self.max_tokens} if self.max_tokens else {} - results = [self._run(prompt, **params) for prompt in prompts] + results = [self._run(prompt, **kwargs) for prompt in prompts] if stop: results = [enforce_stop_tokens(text, stop) for text in results] return LLMResult(generations=[[Generation(text=text)] for text in results]) @@ -300,7 +338,9 @@ def _llm_type(self) -> str: class GemmaChatLocalHF(_GemmaLocalHFBase, BaseChatModel): - """Local gemma chat model loaded from HuggingFace.""" + parse_response: bool = False + """Whether to post-process the chat response and clean repeations """ + """or multi-turn statements.""" def _generate( self, @@ -309,9 +349,10 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - params = {"max_length": self.max_tokens} if self.max_tokens else {} prompt = gemma_messages_to_prompt(messages) - text = self._run(prompt, **params) + text = self._run(prompt, **kwargs) + if self.parse_response or kwargs.get("parse_response"): + text = _parse_gemma_chat_response(text) if stop: text = enforce_stop_tokens(text, stop) generation = ChatGeneration(message=AIMessage(content=text)) From 95ae4b3973bc4fd18579c343ab491c37e9887e43 Mon Sep 17 00:00:00 2001 From: Leonid Kuligin Date: Tue, 27 Feb 2024 18:28:31 +0100 Subject: [PATCH 2/2] fixes after review --- libs/vertexai/langchain_google_vertexai/gemma.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/gemma.py b/libs/vertexai/langchain_google_vertexai/gemma.py index 2f4e6f94..4a453d9a 100644 --- a/libs/vertexai/langchain_google_vertexai/gemma.py +++ b/libs/vertexai/langchain_google_vertexai/gemma.py @@ -59,13 +59,11 @@ def _parse_gemma_chat_response(response: str) -> str: pos = response.rfind(pattern) if pos == -1: return response - else: - text = response[(pos + len(pattern)) :] - pos = text.find("user\n") - if pos > 0: - return text[:pos] - else: - return text + text = response[(pos + len(pattern)) :] + pos = text.find("user\n") + if pos > 0: + return text[:pos] + return text class _GemmaBase(BaseModel): @@ -159,10 +157,10 @@ async def _agenerate( """Top Level call""" request = self._get_params(**kwargs) request["prompt"] = gemma_messages_to_prompt(messages) - text = await self.async_client.predict( + output = await self.async_client.predict( endpoint=self.endpoint_path, instances=[request] ) - text = _parse_gemma_chat_response(text) + text = output.predictions[0] if self.parse_response or kwargs.get("parse_response"): text = _parse_gemma_chat_response(text) if stop: