Skip to content

Commit

Permalink
added post-processing for local gemma
Browse files Browse the repository at this point in the history
  • Loading branch information
lkuligin committed Feb 27, 2024
1 parent 960ac13 commit 69a38ac
Showing 1 changed file with 51 additions and 10 deletions.
61 changes: 51 additions & 10 deletions libs/vertexai/langchain_google_vertexai/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,21 @@ def gemma_messages_to_prompt(history: List[BaseMessage]) -> str:
return "".join(messages)


def _parse_gemma_chat_response(response: str) -> str:
"""Removes chat history from the response."""
pattern = "<start_of_turn>model\n"
pos = response.rfind(pattern)
if pos == -1:
return response
else:
text = response[(pos + len(pattern)) :]
pos = text.find("<start_of_turn>user\n")
if pos > 0:
return text[:pos]
else:
return text


class _GemmaBase(BaseModel):
max_tokens: Optional[int] = None
"""The maximum number of tokens to generate."""
Expand Down Expand Up @@ -98,6 +113,9 @@ class GemmaChatVertexAIModelGarden(_GemmaBase, _BaseVertexAIModelGarden, BaseCha
"top_k",
"max_tokens",
]
parse_response: bool = False
"""Whether to post-process the chat response and clean repeations """
"""or multi-turn statements."""

@property
def _llm_type(self) -> str:
Expand All @@ -120,6 +138,8 @@ def _generate(
request["prompt"] = gemma_messages_to_prompt(messages)
output = self.client.predict(endpoint=self.endpoint_path, instances=[request])
text = output.predictions[0]
if self.parse_response or kwargs.get("parse_response"):
text = _parse_gemma_chat_response(text)
if stop:
text = enforce_stop_tokens(text, stop)
generations = [
Expand All @@ -139,10 +159,12 @@ async def _agenerate(
"""Top Level call"""
request = self._get_params(**kwargs)
request["prompt"] = gemma_messages_to_prompt(messages)
output = await self.async_client.predict(
text = await self.async_client.predict(
endpoint=self.endpoint_path, instances=[request]
)
text = output.predictions[0]
text = _parse_gemma_chat_response(text)
if self.parse_response or kwargs.get("parse_response"):
text = _parse_gemma_chat_response(text)
if stop:
text = enforce_stop_tokens(text, stop)
generations = [
Expand Down Expand Up @@ -183,6 +205,11 @@ def _default_params(self) -> Dict[str, Any]:
params = {"max_length": self.max_tokens}
return {k: v for k, v in params.items() if v is not None}

def _get_params(self, **kwargs) -> Dict[str, Any]:
mapping = {"max_tokens": "max_length"}
params = {mapping[k]: v for k, v in kwargs.items() if k in mapping}
return {**self._default_params, **params}


class GemmaLocalKaggle(_GemmaLocalKaggleBase, BaseLLM):
"""Local gemma chat model loaded from Kaggle."""
Expand All @@ -195,7 +222,7 @@ def _generate(
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
params = {"max_length": self.max_tokens} if self.max_tokens else {}
params = self._get_params(**kwargs)
results = self.client.generate(prompts, **params)
results = [results] if isinstance(results, str) else results
if stop:
Expand All @@ -209,16 +236,22 @@ def _llm_type(self) -> str:


class GemmaChatLocalKaggle(_GemmaLocalKaggleBase, BaseChatModel):
parse_response: bool = False
"""Whether to post-process the chat response and clean repeations """
"""or multi-turn statements."""

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = {"max_length": self.max_tokens} if self.max_tokens else {}
params = self._get_params(**kwargs)
prompt = gemma_messages_to_prompt(messages)
text = self.client.generate(prompt, **params)
if self.parse_response or kwargs.get("parse_response"):
text = _parse_gemma_chat_response(text)
if stop:
text = enforce_stop_tokens(text, stop)
generation = ChatGeneration(message=AIMessage(content=text))
Expand Down Expand Up @@ -268,9 +301,15 @@ def _default_params(self) -> Dict[str, Any]:
params = {"max_length": self.max_tokens}
return {k: v for k, v in params.items() if v is not None}

def _get_params(self, **kwargs) -> Dict[str, Any]:
mapping = {"max_tokens": "max_length"}
params = {mapping[k]: v for k, v in kwargs.items() if k in mapping}
return {**self._default_params, **params}

def _run(self, prompt: str, **kwargs: Any) -> str:
inputs = self.tokenizer(prompt, return_tensors="pt")
generate_ids = self.client.generate(inputs.input_ids, **kwargs)
params = self._get_params(**kwargs)
generate_ids = self.client.generate(inputs.input_ids, **params)
return self.tokenizer.batch_decode(
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
Expand All @@ -287,8 +326,7 @@ def _generate(
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
params = {"max_length": self.max_tokens} if self.max_tokens else {}
results = [self._run(prompt, **params) for prompt in prompts]
results = [self._run(prompt, **kwargs) for prompt in prompts]
if stop:
results = [enforce_stop_tokens(text, stop) for text in results]
return LLMResult(generations=[[Generation(text=text)] for text in results])
Expand All @@ -300,7 +338,9 @@ def _llm_type(self) -> str:


class GemmaChatLocalHF(_GemmaLocalHFBase, BaseChatModel):
"""Local gemma chat model loaded from HuggingFace."""
parse_response: bool = False
"""Whether to post-process the chat response and clean repeations """
"""or multi-turn statements."""

def _generate(
self,
Expand All @@ -309,9 +349,10 @@ def _generate(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = {"max_length": self.max_tokens} if self.max_tokens else {}
prompt = gemma_messages_to_prompt(messages)
text = self._run(prompt, **params)
text = self._run(prompt, **kwargs)
if self.parse_response or kwargs.get("parse_response"):
text = _parse_gemma_chat_response(text)
if stop:
text = enforce_stop_tokens(text, stop)
generation = ChatGeneration(message=AIMessage(content=text))
Expand Down

0 comments on commit 69a38ac

Please sign in to comment.