From 085b292e280fe153be95f8161bcdfba9eb49ef77 Mon Sep 17 00:00:00 2001 From: Leonid Kuligin Date: Tue, 27 Feb 2024 13:07:49 +0100 Subject: [PATCH] fixed gemma_hf (#33) * fix gemma_hf --- .../langchain_google_vertexai/__init__.py | 13 ++++++ .../langchain_google_vertexai/gemma.py | 41 ++++++++++++++----- libs/vertexai/pyproject.toml | 2 +- .../vertexai/tests/unit_tests/test_imports.py | 5 +++ 4 files changed, 49 insertions(+), 12 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/__init__.py b/libs/vertexai/langchain_google_vertexai/__init__.py index 858d0f43..9adfcd22 100644 --- a/libs/vertexai/langchain_google_vertexai/__init__.py +++ b/libs/vertexai/langchain_google_vertexai/__init__.py @@ -3,6 +3,7 @@ from langchain_google_vertexai.chat_models import ChatVertexAI from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser from langchain_google_vertexai.gemma import ( + GemmaChatLocalHF, GemmaChatLocalKaggle, GemmaChatVertexAIModelGarden, GemmaLocalHF, @@ -12,6 +13,13 @@ from langchain_google_vertexai.llms import VertexAI from langchain_google_vertexai.model_garden import VertexAIModelGarden from langchain_google_vertexai.vectorstores.vectorstores import VectorSearchVectorStore +from langchain_google_vertexai.vision_models import ( + VertexAIImageCaptioning, + VertexAIImageCaptioningChat, + VertexAIImageEditorChat, + VertexAIImageGeneratorChat, + VertexAIVisualQnAChat, +) __all__ = [ "ChatVertexAI", @@ -29,4 +37,9 @@ "PydanticFunctionsOutputParser", "create_structured_runnable", "VectorSearchVectorStore", + "VertexAIImageCaptioning", + "VertexAIImageCaptioningChat", + "VertexAIImageEditorChat", + "VertexAIImageGeneratorChat", + "VertexAIVisualQnAChat", ] diff --git a/libs/vertexai/langchain_google_vertexai/gemma.py b/libs/vertexai/langchain_google_vertexai/gemma.py index 30280de0..0d1ff989 100644 --- a/libs/vertexai/langchain_google_vertexai/gemma.py +++ b/libs/vertexai/langchain_google_vertexai/gemma.py @@ -53,6 +53,16 @@ def gemma_messages_to_prompt(history: List[BaseMessage]) -> str: return "".join(messages) +def _parse_gemma_chat_response(response: str) -> str: + """Removes chat history from the response.""" + pattern = "model\n" + pos = response.rfind(pattern) + if pos == -1: + return response + else: + return response[(pos + len(pattern)) :] + + class _GemmaBase(BaseModel): max_tokens: Optional[int] = None """The maximum number of tokens to generate.""" @@ -119,7 +129,7 @@ def _generate( request = self._get_params(**kwargs) request["prompt"] = gemma_messages_to_prompt(messages) output = self.client.predict(endpoint=self.endpoint_path, instances=[request]) - text = output.predictions[0] + text = _parse_gemma_chat_response(output.predictions[0]) if stop: text = enforce_stop_tokens(text, stop) generations = [ @@ -142,7 +152,7 @@ async def _agenerate( output = await self.async_client.predict( endpoint=self.endpoint_path, instances=[request] ) - text = output.predictions[0] + text = _parse_gemma_chat_response(output.predictions[0]) if stop: text = enforce_stop_tokens(text, stop) generations = [ @@ -183,6 +193,11 @@ def _default_params(self) -> Dict[str, Any]: params = {"max_length": self.max_tokens} return {k: v for k, v in params.items() if v is not None} + def _get_params(self, **kwargs) -> Dict[str, Any]: + mapping = {"max_tokens": "max_length"} + params = {mapping[k]: v for k, v in kwargs.items() if k in mapping} + return {**self._default_params, **params} + class GemmaLocalKaggle(_GemmaLocalKaggleBase, BaseLLM): """Local gemma chat model loaded from Kaggle.""" @@ -195,9 +210,9 @@ def _generate( **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" - params = {"max_length": self.max_tokens} if self.max_tokens else {} + params = self._get_params(**kwargs) results = self.client.generate(prompts, **params) - results = results if isinstance(results, str) else [results] + results = [results] if isinstance(results, str) else results if stop: results = [enforce_stop_tokens(text, stop) for text in results] return LLMResult(generations=[[Generation(text=result)] for result in results]) @@ -218,7 +233,7 @@ def _generate( ) -> ChatResult: params = {"max_length": self.max_tokens} if self.max_tokens else {} prompt = gemma_messages_to_prompt(messages) - text = self.client.generate(prompt, **params) + text = _parse_gemma_chat_response(self.client.generate(prompt, **params)) if stop: text = enforce_stop_tokens(text, stop) generation = ChatGeneration(message=AIMessage(content=text)) @@ -268,9 +283,15 @@ def _default_params(self) -> Dict[str, Any]: params = {"max_length": self.max_tokens} return {k: v for k, v in params.items() if v is not None} - def _run(self, prompt: str, kwargs: Any) -> str: + def _get_params(self, **kwargs) -> Dict[str, Any]: + mapping = {"max_tokens": "max_length"} + params = {mapping[k]: v for k, v in kwargs.items() if k in mapping} + return {**self._default_params, **params} + + def _run(self, prompt: str, **kwargs: Any) -> str: inputs = self.tokenizer(prompt, return_tensors="pt") - generate_ids = self.client.generate(inputs.input_ids, **kwargs) + params = self._get_params(**kwargs) + generate_ids = self.client.generate(inputs.input_ids, **params) return self.tokenizer.batch_decode( generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] @@ -287,8 +308,7 @@ def _generate( **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" - params = {"max_length": self.max_tokens} if self.max_tokens else {} - results = [self._run(prompt, **params) for prompt in prompts] + results = [self._run(prompt, **kwargs) for prompt in prompts] if stop: results = [enforce_stop_tokens(text, stop) for text in results] return LLMResult(generations=[[Generation(text=text)] for text in results]) @@ -309,9 +329,8 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - params = {"max_length": self.max_tokens} if self.max_tokens else {} prompt = gemma_messages_to_prompt(messages) - text = self._run(prompt, **params) + text = _parse_gemma_chat_response(self._run(prompt, **kwargs)) if stop: text = enforce_stop_tokens(text, stop) generation = ChatGeneration(message=AIMessage(content=text)) diff --git a/libs/vertexai/pyproject.toml b/libs/vertexai/pyproject.toml index ff0f6976..4dbc328e 100644 --- a/libs/vertexai/pyproject.toml +++ b/libs/vertexai/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-google-vertexai" -version = "0.0.6" +version = "0.0.7" description = "An integration package connecting GoogleVertexAI and LangChain" authors = [] readme = "README.md" diff --git a/libs/vertexai/tests/unit_tests/test_imports.py b/libs/vertexai/tests/unit_tests/test_imports.py index ab722808..02f721fe 100644 --- a/libs/vertexai/tests/unit_tests/test_imports.py +++ b/libs/vertexai/tests/unit_tests/test_imports.py @@ -16,6 +16,11 @@ "PydanticFunctionsOutputParser", "create_structured_runnable", "VectorSearchVectorStore", + "VertexAIImageCaptioning", + "VertexAIImageCaptioningChat", + "VertexAIImageEditorChat", + "VertexAIImageGeneratorChat", + "VertexAIVisualQnAChat", ]