Skip to content

Commit

Permalink
fixes after review
Browse files Browse the repository at this point in the history
  • Loading branch information
lkuligin committed Feb 27, 2024
1 parent 69a38ac commit 95ae4b3
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions libs/vertexai/langchain_google_vertexai/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,11 @@ def _parse_gemma_chat_response(response: str) -> str:
pos = response.rfind(pattern)
if pos == -1:
return response
else:
text = response[(pos + len(pattern)) :]
pos = text.find("<start_of_turn>user\n")
if pos > 0:
return text[:pos]
else:
return text
text = response[(pos + len(pattern)) :]
pos = text.find("<start_of_turn>user\n")
if pos > 0:
return text[:pos]
return text


class _GemmaBase(BaseModel):
Expand Down Expand Up @@ -159,10 +157,10 @@ async def _agenerate(
"""Top Level call"""
request = self._get_params(**kwargs)
request["prompt"] = gemma_messages_to_prompt(messages)
text = await self.async_client.predict(
output = await self.async_client.predict(
endpoint=self.endpoint_path, instances=[request]
)
text = _parse_gemma_chat_response(text)
text = output.predictions[0]
if self.parse_response or kwargs.get("parse_response"):
text = _parse_gemma_chat_response(text)
if stop:
Expand Down

0 comments on commit 95ae4b3

Please sign in to comment.