Skip to content

Commit

Permalink
fixed gemma_hf (#33)
Browse files Browse the repository at this point in the history
* fix gemma_hf
  • Loading branch information
lkuligin committed Feb 27, 2024
1 parent 6421180 commit 085b292
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 12 deletions.
13 changes: 13 additions & 0 deletions libs/vertexai/langchain_google_vertexai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from langchain_google_vertexai.chat_models import ChatVertexAI
from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser
from langchain_google_vertexai.gemma import (
GemmaChatLocalHF,
GemmaChatLocalKaggle,
GemmaChatVertexAIModelGarden,
GemmaLocalHF,
Expand All @@ -12,6 +13,13 @@
from langchain_google_vertexai.llms import VertexAI
from langchain_google_vertexai.model_garden import VertexAIModelGarden
from langchain_google_vertexai.vectorstores.vectorstores import VectorSearchVectorStore
from langchain_google_vertexai.vision_models import (
VertexAIImageCaptioning,
VertexAIImageCaptioningChat,
VertexAIImageEditorChat,
VertexAIImageGeneratorChat,
VertexAIVisualQnAChat,
)

__all__ = [
"ChatVertexAI",
Expand All @@ -29,4 +37,9 @@
"PydanticFunctionsOutputParser",
"create_structured_runnable",
"VectorSearchVectorStore",
"VertexAIImageCaptioning",
"VertexAIImageCaptioningChat",
"VertexAIImageEditorChat",
"VertexAIImageGeneratorChat",
"VertexAIVisualQnAChat",
]
41 changes: 30 additions & 11 deletions libs/vertexai/langchain_google_vertexai/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ def gemma_messages_to_prompt(history: List[BaseMessage]) -> str:
return "".join(messages)


def _parse_gemma_chat_response(response: str) -> str:
"""Removes chat history from the response."""
pattern = "<start_of_turn>model\n"
pos = response.rfind(pattern)
if pos == -1:
return response
else:
return response[(pos + len(pattern)) :]


class _GemmaBase(BaseModel):
max_tokens: Optional[int] = None
"""The maximum number of tokens to generate."""
Expand Down Expand Up @@ -119,7 +129,7 @@ def _generate(
request = self._get_params(**kwargs)
request["prompt"] = gemma_messages_to_prompt(messages)
output = self.client.predict(endpoint=self.endpoint_path, instances=[request])
text = output.predictions[0]
text = _parse_gemma_chat_response(output.predictions[0])
if stop:
text = enforce_stop_tokens(text, stop)
generations = [
Expand All @@ -142,7 +152,7 @@ async def _agenerate(
output = await self.async_client.predict(
endpoint=self.endpoint_path, instances=[request]
)
text = output.predictions[0]
text = _parse_gemma_chat_response(output.predictions[0])
if stop:
text = enforce_stop_tokens(text, stop)
generations = [
Expand Down Expand Up @@ -183,6 +193,11 @@ def _default_params(self) -> Dict[str, Any]:
params = {"max_length": self.max_tokens}
return {k: v for k, v in params.items() if v is not None}

def _get_params(self, **kwargs) -> Dict[str, Any]:
mapping = {"max_tokens": "max_length"}
params = {mapping[k]: v for k, v in kwargs.items() if k in mapping}
return {**self._default_params, **params}


class GemmaLocalKaggle(_GemmaLocalKaggleBase, BaseLLM):
"""Local gemma chat model loaded from Kaggle."""
Expand All @@ -195,9 +210,9 @@ def _generate(
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
params = {"max_length": self.max_tokens} if self.max_tokens else {}
params = self._get_params(**kwargs)
results = self.client.generate(prompts, **params)
results = results if isinstance(results, str) else [results]
results = [results] if isinstance(results, str) else results
if stop:
results = [enforce_stop_tokens(text, stop) for text in results]
return LLMResult(generations=[[Generation(text=result)] for result in results])
Expand All @@ -218,7 +233,7 @@ def _generate(
) -> ChatResult:
params = {"max_length": self.max_tokens} if self.max_tokens else {}
prompt = gemma_messages_to_prompt(messages)
text = self.client.generate(prompt, **params)
text = _parse_gemma_chat_response(self.client.generate(prompt, **params))
if stop:
text = enforce_stop_tokens(text, stop)
generation = ChatGeneration(message=AIMessage(content=text))
Expand Down Expand Up @@ -268,9 +283,15 @@ def _default_params(self) -> Dict[str, Any]:
params = {"max_length": self.max_tokens}
return {k: v for k, v in params.items() if v is not None}

def _run(self, prompt: str, kwargs: Any) -> str:
def _get_params(self, **kwargs) -> Dict[str, Any]:
mapping = {"max_tokens": "max_length"}
params = {mapping[k]: v for k, v in kwargs.items() if k in mapping}
return {**self._default_params, **params}

def _run(self, prompt: str, **kwargs: Any) -> str:
inputs = self.tokenizer(prompt, return_tensors="pt")
generate_ids = self.client.generate(inputs.input_ids, **kwargs)
params = self._get_params(**kwargs)
generate_ids = self.client.generate(inputs.input_ids, **params)
return self.tokenizer.batch_decode(
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
Expand All @@ -287,8 +308,7 @@ def _generate(
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
params = {"max_length": self.max_tokens} if self.max_tokens else {}
results = [self._run(prompt, **params) for prompt in prompts]
results = [self._run(prompt, **kwargs) for prompt in prompts]
if stop:
results = [enforce_stop_tokens(text, stop) for text in results]
return LLMResult(generations=[[Generation(text=text)] for text in results])
Expand All @@ -309,9 +329,8 @@ def _generate(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = {"max_length": self.max_tokens} if self.max_tokens else {}
prompt = gemma_messages_to_prompt(messages)
text = self._run(prompt, **params)
text = _parse_gemma_chat_response(self._run(prompt, **kwargs))
if stop:
text = enforce_stop_tokens(text, stop)
generation = ChatGeneration(message=AIMessage(content=text))
Expand Down
2 changes: 1 addition & 1 deletion libs/vertexai/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-google-vertexai"
version = "0.0.6"
version = "0.0.7"
description = "An integration package connecting GoogleVertexAI and LangChain"
authors = []
readme = "README.md"
Expand Down
5 changes: 5 additions & 0 deletions libs/vertexai/tests/unit_tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
"PydanticFunctionsOutputParser",
"create_structured_runnable",
"VectorSearchVectorStore",
"VertexAIImageCaptioning",
"VertexAIImageCaptioningChat",
"VertexAIImageEditorChat",
"VertexAIImageGeneratorChat",
"VertexAIVisualQnAChat",
]


Expand Down

0 comments on commit 085b292

Please sign in to comment.