Skip to content

Commit

Permalink
Merge branch 'main' into vertexai-clientinfo
Browse files Browse the repository at this point in the history
  • Loading branch information
holtskinner authored Feb 26, 2024
2 parents 26c2487 + fc2250e commit 4a393d5
Show file tree
Hide file tree
Showing 9 changed files with 141 additions and 11 deletions.
6 changes: 5 additions & 1 deletion libs/vertexai/langchain_google_vertexai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,32 @@
from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory
from langchain_google_vertexai.chains import create_structured_runnable
from langchain_google_vertexai.chat_models import ChatVertexAI
from langchain_google_vertexai.embeddings import VertexAIEmbeddings
from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser
from langchain_google_vertexai.gemma import (
GemmaChatLocalKaggle,
GemmaChatVertexAIModelGarden,
GemmaLocalHF,
GemmaLocalKaggle,
GemmaVertexAIModelGarden,
)
from langchain_google_vertexai.llms import VertexAI
from langchain_google_vertexai.model_garden import VertexAIModelGarden
from langchain_google_vertexai.vectorstores.vectorstores import VectorSearchVectorStore

__all__ = [
"ChatVertexAI",
"GemmaVertexAIModelGarden",
"GemmaChatVertexAIModelGarden",
"GemmaLocalKaggle",
"GemmaChatLocalKaggle",
"GemmaLocalHF",
"GemmaChatLocalHF",
"VertexAIEmbeddings",
"VertexAI",
"VertexAIModelGarden",
"HarmBlockThreshold",
"HarmCategory",
"PydanticFunctionsOutputParser",
"create_structured_runnable",
"VectorSearchVectorStore",
]
6 changes: 6 additions & 0 deletions libs/vertexai/langchain_google_vertexai/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utilities to init Vertex AI."""

import dataclasses
import re
from importlib import metadata
from typing import Any, Callable, Dict, Optional, Tuple, Union

Expand Down Expand Up @@ -178,3 +179,8 @@ def get_generation_info(
info.pop("is_blocked")

return info


def enforce_stop_tokens(text: str, stop: List[str]) -> str:

Check failure on line 184 in libs/vertexai/langchain_google_vertexai/_utils.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.8

Ruff (F821)

langchain_google_vertexai/_utils.py:184:42: F821 Undefined name `List`

Check failure on line 184 in libs/vertexai/langchain_google_vertexai/_utils.py

View workflow job for this annotation

GitHub Actions / cd libs/vertexai / - / make lint #3.11

Ruff (F821)

langchain_google_vertexai/_utils.py:184:42: F821 Undefined name `List`
"""Cut off the text as soon as any stop words occur."""
return re.split("|".join(stop), text, maxsplit=1)[0]
118 changes: 112 additions & 6 deletions libs/vertexai/langchain_google_vertexai/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from langchain_core.pydantic_v1 import BaseModel, root_validator

from langchain_google_vertexai._base import _BaseVertexAIModelGarden
from langchain_google_vertexai._utils import enforce_stop_tokens
from langchain_google_vertexai.model_garden import VertexAIModelGarden

USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn>\n"
Expand Down Expand Up @@ -118,9 +119,12 @@ def _generate(
request = self._get_params(**kwargs)
request["prompt"] = gemma_messages_to_prompt(messages)
output = self.client.predict(endpoint=self.endpoint_path, instances=[request])
text = output.predictions[0]
if stop:
text = enforce_stop_tokens(text, stop)
generations = [
ChatGeneration(
message=AIMessage(content=output.predictions[0]),
message=AIMessage(content=text),
)
]
return ChatResult(generations=generations)
Expand All @@ -135,19 +139,22 @@ async def _agenerate(
"""Top Level call"""
request = self._get_params(**kwargs)
request["prompt"] = gemma_messages_to_prompt(messages)
output = await self.async_client.predict_(
output = await self.async_client.predict(
endpoint=self.endpoint_path, instances=[request]
)
text = output.predictions[0]
if stop:
text = enforce_stop_tokens(text, stop)
generations = [
ChatGeneration(
message=AIMessage(content=output.predictions[0]),
message=AIMessage(content=text),
)
]
return ChatResult(generations=generations)


class _GemmaLocalKaggleBase(_GemmaBase):
"""Local gemma model."""
"""Local gemma model loaded from Kaggle."""

client: Any = None #: :meta private:
keras_backend: str = "jax"
Expand Down Expand Up @@ -178,6 +185,8 @@ def _default_params(self) -> Dict[str, Any]:


class GemmaLocalKaggle(_GemmaLocalKaggleBase, BaseLLM):
"""Local gemma chat model loaded from Kaggle."""

def _generate(
self,
prompts: List[str],
Expand All @@ -189,6 +198,8 @@ def _generate(
params = {"max_length": self.max_tokens} if self.max_tokens else {}
results = self.client.generate(prompts, **params)
results = results if isinstance(results, str) else [results]
if stop:
results = [enforce_stop_tokens(text, stop) for text in results]
return LLMResult(generations=[[Generation(text=result)] for result in results])

@property
Expand All @@ -207,11 +218,106 @@ def _generate(
) -> ChatResult:
params = {"max_length": self.max_tokens} if self.max_tokens else {}
prompt = gemma_messages_to_prompt(messages)
output = self.client.generate(prompt, **params)
generation = ChatGeneration(message=AIMessage(content=output))
text = self.client.generate(prompt, **params)
if stop:
text = enforce_stop_tokens(text, stop)
generation = ChatGeneration(message=AIMessage(content=text))
return ChatResult(generations=[generation])

@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "gemma_local_chat_kaggle"


class _GemmaLocalHFBase(_GemmaBase):
"""Local gemma model loaded from HuggingFace."""

tokenizer: Any = None #: :meta private:
client: Any = None #: :meta private:
hf_access_token: str
cache_dir: Optional[str] = None
model_name: str = "gemma_2b_en"
"""Gemma model name."""

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that llama-cpp-python library is installed."""
try:
from transformers import AutoTokenizer, GemmaForCausalLM # type: ignore
except ImportError:
raise ImportError(
"Could not import GemmaForCausalLM library. "
"Please install the GemmaForCausalLM library to "
"use this model: pip install transformers>=4.38.1"
)

values["tokenizer"] = AutoTokenizer.from_pretrained(
values["model_name"], token=values["hf_access_token"]
)
values["client"] = GemmaForCausalLM.from_pretrained(
values["model_name"],
token=values["hf_access_token"],
cache_dir=values["cache_dir"],
)
return values

@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling gemma."""
params = {"max_length": self.max_tokens}
return {k: v for k, v in params.items() if v is not None}

def _run(self, prompt: str, kwargs: Any) -> str:
inputs = self.tokenizer(prompt, return_tensors="pt")
generate_ids = self.client.generate(inputs.input_ids, **kwargs)
return self.tokenizer.batch_decode(
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]


class GemmaLocalHF(_GemmaLocalHFBase, BaseLLM):
"""Local gemma model loaded from HuggingFace."""

def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
params = {"max_length": self.max_tokens} if self.max_tokens else {}
results = [self._run(prompt, **params) for prompt in prompts]
if stop:
results = [enforce_stop_tokens(text, stop) for text in results]
return LLMResult(generations=[[Generation(text=text)] for text in results])

@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "gemma_local_hf"


class GemmaChatLocalHF(_GemmaLocalHFBase, BaseChatModel):
"""Local gemma chat model loaded from HuggingFace."""

def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = {"max_length": self.max_tokens} if self.max_tokens else {}
prompt = gemma_messages_to_prompt(messages)
text = self._run(prompt, **params)
if stop:
text = enforce_stop_tokens(text, stop)
generation = ChatGeneration(message=AIMessage(content=text))
return ChatResult(generations=[generation])

@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "gemma_local_chat_hf"
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,11 @@ def _get_default_embeddings(cls) -> Embeddings:
)

# TODO: Change to vertexai embbedingss
from langchain_community import (
embeddings, # type: ignore[import-not-found, unused-ignore]
from langchain_community.embeddings import ( # type: ignore[import-not-found, unused-ignore]
TensorflowHubEmbeddings,
)

return embeddings.TensorflowHubEmbeddings()
return TensorflowHubEmbeddings()

def _generate_unique_ids(self, number: int) -> List[str]:
"""Generates a list of unique ids of length `number`
Expand Down
1 change: 0 additions & 1 deletion libs/vertexai/tests/integration_tests/test_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
)


@pytest.mark.skip("CI testing not set up")
@pytest.mark.skip("CI testing not set up")
def test_gemma_model_garden() -> None:
"""In order to run this test, you should provide endpoint names.
Expand Down
3 changes: 3 additions & 0 deletions libs/vertexai/tests/integration_tests/test_llms_safety.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from langchain_core.outputs import LLMResult

from langchain_google_vertexai import HarmBlockThreshold, HarmCategory, VertexAI
Expand Down Expand Up @@ -40,6 +41,7 @@
"""


@pytest.mark.skip("CI testing not set up")
def test_gemini_safety_settings_generate() -> None:
llm = VertexAI(model_name="gemini-pro", safety_settings=SAFETY_SETTINGS)
output = llm.generate(["What do you think about child abuse:"])
Expand Down Expand Up @@ -68,6 +70,7 @@ def test_gemini_safety_settings_generate() -> None:
assert not generation_info.get("is_blocked")


@pytest.mark.skip("CI testing not set up")
async def test_gemini_safety_settings_agenerate() -> None:
llm = VertexAI(model_name="gemini-pro", safety_settings=SAFETY_SETTINGS)
output = await llm.agenerate(["What do you think about child abuse:"])
Expand Down
3 changes: 3 additions & 0 deletions libs/vertexai/tests/integration_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
from typing import Any, List, Union

import pytest
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
from langchain_core.messages import AIMessageChunk
from langchain_core.output_parsers import BaseOutputParser
Expand Down Expand Up @@ -43,6 +44,7 @@ def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
raise ValueError("Can only parse messages")


@pytest.mark.skip("CI testing not set up")
def test_tools() -> None:
from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad import (
Expand Down Expand Up @@ -91,6 +93,7 @@ def test_tools() -> None:
assert round(float(just_numbers), 2) == 2.16


@pytest.mark.skip("CI testing not set up")
def test_stream() -> None:
from langchain.chains import LLMMathChain

Expand Down
6 changes: 6 additions & 0 deletions libs/vertexai/tests/integration_tests/test_vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def sdk_manager() -> VectorSearchSDKManager:
return sdk_manager


@pytest.mark.skip("CI testing not set up")
def test_vector_search_sdk_manager(sdk_manager: VectorSearchSDKManager):
gcs_client = sdk_manager.get_gcs_client()
assert isinstance(gcs_client, storage.Client)
Expand All @@ -55,6 +56,7 @@ def test_vector_search_sdk_manager(sdk_manager: VectorSearchSDKManager):
assert isinstance(endpoint, MatchingEngineIndexEndpoint)


@pytest.mark.skip("CI testing not set up")
def test_gcs_document_storage(sdk_manager: VectorSearchSDKManager):
bucket = sdk_manager.get_gcs_bucket(os.environ["GCS_BUCKET_NAME"])
prefix = "integration-test"
Expand All @@ -78,6 +80,7 @@ def test_gcs_document_storage(sdk_manager: VectorSearchSDKManager):
assert original_text == retrieved_text


@pytest.mark.skip("CI testing not set up")
def test_datastore_document_storage(sdk_manager: VectorSearchSDKManager):
ds_client = sdk_manager.get_datastore_client(namespace="Foo")

Expand All @@ -100,6 +103,7 @@ def test_datastore_document_storage(sdk_manager: VectorSearchSDKManager):
assert original_text == retrieved_text


@pytest.mark.skip("CI testing not set up")
def test_public_endpoint_vector_searcher(sdk_manager: VectorSearchSDKManager):
index = sdk_manager.get_index(os.environ["INDEX_ID"])
endpoint = sdk_manager.get_endpoint(os.environ["ENDPOINT_ID"])
Expand All @@ -116,6 +120,7 @@ def test_public_endpoint_vector_searcher(sdk_manager: VectorSearchSDKManager):
assert len(matching_neighbors_list) == 2


@pytest.mark.skip("CI testing not set up")
def test_vector_store():
embeddings = VertexAIEmbeddings(model_name="textembedding-gecko-default")

Expand Down Expand Up @@ -143,6 +148,7 @@ def test_vector_store():
assert isinstance(doc, Document)


@pytest.mark.skip("CI testing not set up")
def test_vector_store_update_index():
embeddings = VertexAIEmbeddings(model_name="textembedding-gecko-default")

Expand Down
3 changes: 3 additions & 0 deletions libs/vertexai/tests/unit_tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
"GemmaChatVertexAIModelGarden",
"GemmaLocalKaggle",
"GemmaChatLocalKaggle",
"GemmaChatLocalHF",
"GemmaLocalHF",
"VertexAIEmbeddings",
"VertexAI",
"VertexAIModelGarden",
"HarmBlockThreshold",
"HarmCategory",
"PydanticFunctionsOutputParser",
"create_structured_runnable",
"VectorSearchVectorStore",
]


Expand Down

0 comments on commit 4a393d5

Please sign in to comment.