diff --git a/libs/genai/langchain_google_genai/__init__.py b/libs/genai/langchain_google_genai/__init__.py index 505e121d..187f7e3e 100644 --- a/libs/genai/langchain_google_genai/__init__.py +++ b/libs/genai/langchain_google_genai/__init__.py @@ -54,6 +54,8 @@ embeddings.embed_query("hello, world!") ``` """ # noqa: E501 + +from langchain_google_genai._enums import HarmBlockThreshold, HarmCategory from langchain_google_genai.chat_models import ChatGoogleGenerativeAI from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings from langchain_google_genai.llms import GoogleGenerativeAI @@ -62,4 +64,6 @@ "ChatGoogleGenerativeAI", "GoogleGenerativeAIEmbeddings", "GoogleGenerativeAI", + "HarmBlockThreshold", + "HarmCategory", ] diff --git a/libs/genai/langchain_google_genai/_enums.py b/libs/genai/langchain_google_genai/_enums.py new file mode 100644 index 00000000..b2a1de26 --- /dev/null +++ b/libs/genai/langchain_google_genai/_enums.py @@ -0,0 +1,6 @@ +from google.generativeai.types.safety_types import ( # type: ignore + HarmBlockThreshold, + HarmCategory, +) + +__all__ = ["HarmBlockThreshold", "HarmCategory"] diff --git a/libs/genai/langchain_google_genai/_function_utils.py b/libs/genai/langchain_google_genai/_function_utils.py new file mode 100644 index 00000000..d6719453 --- /dev/null +++ b/libs/genai/langchain_google_genai/_function_utils.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from typing import ( + Dict, + List, + Type, + Union, +) + +import google.ai.generativelanguage as glm +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.tools import BaseTool +from langchain_core.utils.json_schema import dereference_refs + +FunctionCallType = Union[BaseTool, Type[BaseModel], Dict] + +TYPE_ENUM = { + "string": glm.Type.STRING, + "number": glm.Type.NUMBER, + "integer": glm.Type.INTEGER, + "boolean": glm.Type.BOOLEAN, + "array": glm.Type.ARRAY, + "object": glm.Type.OBJECT, +} + + +def convert_to_genai_function_declarations( + function_calls: List[FunctionCallType], +) -> List[glm.Tool]: + return [ + glm.Tool( + function_declarations=[_convert_to_genai_function(fc)], + ) + for fc in function_calls + ] + + +def _convert_to_genai_function(fc: FunctionCallType) -> glm.FunctionDeclaration: + if isinstance(fc, BaseTool): + return _convert_tool_to_genai_function(fc) + elif isinstance(fc, type) and issubclass(fc, BaseModel): + return _convert_pydantic_to_genai_function(fc) + elif isinstance(fc, dict): + return glm.FunctionDeclaration( + name=fc["name"], + description=fc.get("description"), + parameters={ + "properties": { + k: { + "type_": TYPE_ENUM[v["type"]], + "description": v.get("description"), + } + for k, v in fc["parameters"]["properties"].items() + }, + "required": fc["parameters"].get("required", []), + "type_": TYPE_ENUM[fc["parameters"]["type"]], + }, + ) + else: + raise ValueError(f"Unsupported function call type {fc}") + + +def _convert_tool_to_genai_function(tool: BaseTool) -> glm.FunctionDeclaration: + if tool.args_schema: + schema = dereference_refs(tool.args_schema.schema()) + schema.pop("definitions", None) + + return glm.FunctionDeclaration( + name=tool.name or schema["title"], + description=tool.description or schema["description"], + parameters={ + "properties": { + k: { + "type_": TYPE_ENUM[v["type"]], + "description": v.get("description"), + } + for k, v in schema["properties"].items() + }, + "required": schema["required"], + "type_": TYPE_ENUM[schema["type"]], + }, + ) + else: + return glm.FunctionDeclaration( + name=tool.name, + description=tool.description, + parameters={ + "properties": { + "__arg1": {"type_": TYPE_ENUM["string"]}, + }, + "required": ["__arg1"], + "type_": TYPE_ENUM["object"], + }, + ) + + +def _convert_pydantic_to_genai_function( + pydantic_model: Type[BaseModel], +) -> glm.FunctionDeclaration: + schema = dereference_refs(pydantic_model.schema()) + schema.pop("definitions", None) + return glm.FunctionDeclaration( + name=schema["title"], + description=schema.get("description", ""), + parameters={ + "properties": { + k: { + "type_": TYPE_ENUM[v["type"]], + "description": v.get("description"), + } + for k, v in schema["properties"].items() + }, + "required": schema["required"], + "type_": TYPE_ENUM[schema["type"]], + }, + ) diff --git a/libs/genai/langchain_google_genai/chat_models.py b/libs/genai/langchain_google_genai/chat_models.py index 37cc5f34..4f0e9f0a 100644 --- a/libs/genai/langchain_google_genai/chat_models.py +++ b/libs/genai/langchain_google_genai/chat_models.py @@ -1,6 +1,7 @@ from __future__ import annotations import base64 +import json import logging import os from io import BytesIO @@ -15,16 +16,17 @@ Optional, Sequence, Tuple, - Type, Union, cast, ) from urllib.parse import urlparse +import google.ai.generativelanguage as glm import google.api_core # TODO: remove ignore once the google package is published with types import google.generativeai as genai # type: ignore[import] +import proto # type: ignore[import] import requests from langchain_core.callbacks.manager import ( AsyncCallbackManagerForLLMRun, @@ -35,10 +37,8 @@ AIMessage, AIMessageChunk, BaseMessage, - ChatMessage, - ChatMessageChunk, + FunctionMessage, HumanMessage, - HumanMessageChunk, SystemMessage, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult @@ -53,6 +53,9 @@ ) from langchain_google_genai._common import GoogleGenerativeAIError +from langchain_google_genai._function_utils import ( + convert_to_genai_function_declarations, +) from langchain_google_genai.llms import GoogleModelFamily, _BaseGoogleGenerativeAI IMAGE_TYPES: Tuple = () @@ -321,14 +324,47 @@ def _parse_chat_history( continue elif isinstance(message, AIMessage): role = "model" + raw_function_call = message.additional_kwargs.get("function_call") + if raw_function_call: + function_call = glm.FunctionCall( + { + "name": raw_function_call["name"], + "args": json.loads(raw_function_call["arguments"]), + } + ) + parts = [glm.Part(function_call=function_call)] + else: + parts = _convert_to_parts(message.content) elif isinstance(message, HumanMessage): role = "user" + parts = _convert_to_parts(message.content) + elif isinstance(message, FunctionMessage): + role = "user" + response: Any + if not isinstance(message.content, str): + response = message.content + else: + try: + response = json.loads(message.content) + except json.JSONDecodeError: + response = message.content # leave as str representation + parts = [ + glm.Part( + function_response=glm.FunctionResponse( + name=message.name, + response=( + {"output": response} + if not isinstance(response, dict) + else response + ), + ) + ) + ] else: raise ValueError( f"Unexpected message with type {type(message)} at the position {i}." ) - parts = _convert_to_parts(message.content) if raw_system_message: if role == "model": raise ValueError( @@ -341,71 +377,51 @@ def _parse_chat_history( return messages -def _parts_to_content(parts: List[genai.types.PartType]) -> Union[List[dict], str]: - """Converts a list of Gemini API Part objects into a list of LangChain messages.""" - if len(parts) == 1 and parts[0].text is not None and not parts[0].inline_data: - # Simple text response. The typical response - return parts[0].text - elif not parts: - logger.warning("Gemini produced an empty response.") - return "" - messages = [] - for part in parts: - if part.text is not None: - messages.append( - { - "type": "text", - "text": part.text, - } - ) - else: - # TODO: Handle inline_data if that's a thing? - raise ChatGoogleGenerativeAIError(f"Unexpected part type. {part}") - return messages +def _parse_response_candidate( + response_candidate: glm.Candidate, stream: bool +) -> AIMessage: + first_part = response_candidate.content.parts[0] + if first_part.function_call: + function_call = proto.Message.to_dict(first_part.function_call) + function_call["arguments"] = json.dumps(function_call.pop("args", {})) + return (AIMessageChunk if stream else AIMessage)( + content="", additional_kwargs={"function_call": function_call} + ) + else: + parts = response_candidate.content.parts + + if len(parts) == 1 and parts[0].text: + content: Union[str, List[Union[str, Dict]]] = parts[0].text + else: + content = [proto.Message.to_dict(part) for part in parts] + return (AIMessageChunk if stream else AIMessage)( + content=content, additional_kwargs={} + ) def _response_to_result( - response: genai.types.GenerateContentResponse, - ai_msg_t: Type[BaseMessage] = AIMessage, - human_msg_t: Type[BaseMessage] = HumanMessage, - chat_msg_t: Type[BaseMessage] = ChatMessage, - generation_t: Type[ChatGeneration] = ChatGeneration, + response: glm.GenerateContentResponse, + stream: bool = False, ) -> ChatResult: """Converts a PaLM API response into a LangChain ChatResult.""" - llm_output = {} - if response.prompt_feedback: - try: - prompt_feedback = type(response.prompt_feedback).to_dict( - response.prompt_feedback, use_integers_for_enums=False - ) - llm_output["prompt_feedback"] = prompt_feedback - except Exception as e: - logger.debug(f"Unable to convert prompt_feedback to dict: {e}") + llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)} generations: List[ChatGeneration] = [] - role_map = { - "model": ai_msg_t, - "user": human_msg_t, - } for candidate in response.candidates: - content = candidate.content - parts_content = _parts_to_content(content.parts) - if content.role not in role_map: - logger.warning( - f"Unrecognized role: {content.role}. Treating as a ChatMessage." - ) - msg = chat_msg_t(content=parts_content, role=content.role) - else: - msg = role_map[content.role](content=parts_content) generation_info = {} if candidate.finish_reason: generation_info["finish_reason"] = candidate.finish_reason.name - if candidate.safety_ratings: - generation_info["safety_ratings"] = [ - type(rating).to_dict(rating) for rating in candidate.safety_ratings - ] - generations.append(generation_t(message=msg, generation_info=generation_info)) + generation_info["safety_ratings"] = [ + proto.Message.to_dict(safety_rating, use_integers_for_enums=False) + for safety_rating in candidate.safety_ratings + ] + generations.append( + (ChatGenerationChunk if stream else ChatGeneration)( + message=_parse_response_candidate(candidate, stream=stream), + generation_info=generation_info, + ) + ) if not response.candidates: # Likely a "prompt feedback" violation (e.g., toxic input) # Raising an error would be different than how OpenAI handles it, @@ -414,7 +430,12 @@ def _response_to_result( "Gemini produced an empty response. Continuing with empty message\n" f"Feedback: {response.prompt_feedback}" ) - generations = [generation_t(message=ai_msg_t(content=""), generation_info={})] + generations = [ + (ChatGenerationChunk if stream else ChatGeneration)( + message=(AIMessageChunk if stream else AIMessage)(content=""), + generation_info={}, + ) + ] return ChatResult(generations=generations, llm_output=llm_output) @@ -496,6 +517,7 @@ def _identifying_params(self) -> Dict[str, Any]: "temperature": self.temperature, "top_k": self.top_k, "n": self.n, + "safety_settings": self.safety_settings, } def _prepare_params( @@ -525,7 +547,11 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - params, chat, message = self._prepare_chat(messages, stop=stop) + params, chat, message = self._prepare_chat( + messages, + stop=stop, + **kwargs, + ) response: genai.types.GenerateContentResponse = _chat_with_retry( content=message, **params, @@ -540,7 +566,11 @@ async def _agenerate( run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - params, chat, message = self._prepare_chat(messages, stop=stop) + params, chat, message = self._prepare_chat( + messages, + stop=stop, + **kwargs, + ) response: genai.types.GenerateContentResponse = await _achat_with_retry( content=message, **params, @@ -555,7 +585,11 @@ def _stream( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: - params, chat, message = self._prepare_chat(messages, stop=stop) + params, chat, message = self._prepare_chat( + messages, + stop=stop, + **kwargs, + ) response: genai.types.GenerateContentResponse = _chat_with_retry( content=message, **params, @@ -563,17 +597,12 @@ def _stream( stream=True, ) for chunk in response: - _chat_result = _response_to_result( - chunk, - ai_msg_t=AIMessageChunk, - human_msg_t=HumanMessageChunk, - chat_msg_t=ChatMessageChunk, - generation_t=ChatGenerationChunk, - ) + _chat_result = _response_to_result(chunk, stream=True) gen = cast(ChatGenerationChunk, _chat_result.generations[0]) - yield gen + if run_manager: run_manager.on_llm_new_token(gen.text) + yield gen async def _astream( self, @@ -582,24 +611,23 @@ async def _astream( run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: - params, chat, message = self._prepare_chat(messages, stop=stop) + params, chat, message = self._prepare_chat( + messages, + stop=stop, + **kwargs, + ) async for chunk in await _achat_with_retry( content=message, **params, generation_method=chat.send_message_async, stream=True, ): - _chat_result = _response_to_result( - chunk, - ai_msg_t=AIMessageChunk, - human_msg_t=HumanMessageChunk, - chat_msg_t=ChatMessageChunk, - generation_t=ChatGenerationChunk, - ) + _chat_result = _response_to_result(chunk, stream=True) gen = cast(ChatGenerationChunk, _chat_result.generations[0]) - yield gen + if run_manager: await run_manager.on_llm_new_token(gen.text) + yield gen def _prepare_chat( self, @@ -607,13 +635,24 @@ def _prepare_chat( stop: Optional[List[str]] = None, **kwargs: Any, ) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]: + client = self.client + functions = kwargs.pop("functions", None) + safety_settings = kwargs.pop("safety_settings", self.safety_settings) + if functions or safety_settings: + tools = ( + convert_to_genai_function_declarations(functions) if functions else None + ) + client = genai.GenerativeModel( + model_name=self.model, tools=tools, safety_settings=safety_settings + ) + params = self._prepare_params(stop, **kwargs) history = _parse_chat_history( messages, convert_system_message_to_human=self.convert_system_message_to_human, ) message = history.pop() - chat = self.client.start_chat(history=history) + chat = client.start_chat(history=history) return params, chat, message def get_num_tokens(self, text: str) -> int: diff --git a/libs/genai/langchain_google_genai/llms.py b/libs/genai/langchain_google_genai/llms.py index bf147410..9b483873 100644 --- a/libs/genai/langchain_google_genai/llms.py +++ b/libs/genai/langchain_google_genai/llms.py @@ -15,6 +15,11 @@ from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.utils import get_from_dict_or_env +from langchain_google_genai._enums import ( + HarmBlockThreshold, + HarmCategory, +) + class GoogleModelFamily(str, Enum): GEMINI = auto() @@ -77,7 +82,10 @@ def _completion_with_retry( try: if is_gemini: return llm.client.generate_content( - contents=prompt, stream=stream, generation_config=generation_config + contents=prompt, + stream=stream, + generation_config=generation_config, + safety_settings=kwargs.pop("safety_settings", None), ) return llm.client.generate_text(prompt=prompt, **kwargs) except google.api_core.exceptions.FailedPrecondition as exc: @@ -143,6 +151,22 @@ class _BaseGoogleGenerativeAI(BaseModel): description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].", ) + safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None + """The default safety settings to use for all generations. + + For example: + + from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory + + safety_settings = { + HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH, + HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, + } + """ # noqa: E501 + @property def lc_secrets(self) -> Dict[str, str]: return {"google_api_key": "GOOGLE_API_KEY"} @@ -151,6 +175,18 @@ def lc_secrets(self) -> Dict[str, str]: def _model_family(self) -> str: return GoogleModelFamily(self.model) + @property + def _identifying_params(self) -> Dict[str, Any]: + """Get the identifying parameters.""" + return { + "model": self.model, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "max_output_tokens": self.max_output_tokens, + "candidate_count": self.n, + } + class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM): """Google GenerativeAI models. @@ -172,6 +208,8 @@ def validate_environment(cls, values: Dict) -> Dict: ) model_name = values["model"] + safety_settings = values["safety_settings"] + if isinstance(google_api_key, SecretStr): google_api_key = google_api_key.get_secret_value() @@ -181,8 +219,15 @@ def validate_environment(cls, values: Dict) -> Dict: client_options=values.get("client_options"), ) + if safety_settings and ( + not GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI + ): + raise ValueError("Safety settings are only supported for Gemini models") + if GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI: - values["client"] = genai.GenerativeModel(model_name=model_name) + values["client"] = genai.GenerativeModel( + model_name=model_name, safety_settings=safety_settings + ) else: values["client"] = genai @@ -225,6 +270,7 @@ def _generate( is_gemini=True, run_manager=run_manager, generation_config=generation_config, + safety_settings=kwargs.pop("safety_settings", None), ) candidates = [ "".join([p.text for p in c.content.parts]) for c in res.candidates @@ -266,6 +312,7 @@ def _stream( is_gemini=True, run_manager=run_manager, generation_config=generation_config, + safety_settings=kwargs.pop("safety_settings", None), **kwargs, ): chunk = GenerationChunk(text=stream_resp.text) diff --git a/libs/genai/tests/integration_tests/test_chat_models.py b/libs/genai/tests/integration_tests/test_chat_models.py index 4551a860..26e85aa1 100644 --- a/libs/genai/tests/integration_tests/test_chat_models.py +++ b/libs/genai/tests/integration_tests/test_chat_models.py @@ -1,11 +1,16 @@ """Test ChatGoogleGenerativeAI chat model.""" + +from typing import Generator + import pytest from langchain_core.messages import AIMessage, HumanMessage, SystemMessage -from langchain_google_genai.chat_models import ( +from langchain_google_genai import ( ChatGoogleGenerativeAI, - ChatGoogleGenerativeAIError, + HarmBlockThreshold, + HarmCategory, ) +from langchain_google_genai.chat_models import ChatGoogleGenerativeAIError _MODEL = "gemini-pro" # TODO: Use nano when it's available. _VISION_MODEL = "gemini-pro-vision" @@ -102,7 +107,7 @@ def test_chat_google_genai_invoke_multimodal() -> None: # Try streaming for chunk in llm.stream(messages): - print(chunk) + print(chunk) # noqa: T201 assert isinstance(chunk.content, str) assert len(chunk.content.strip()) > 0 @@ -192,3 +197,32 @@ def test_generativeai_get_num_tokens_gemini() -> None: llm = ChatGoogleGenerativeAI(temperature=0, model="gemini-pro") output = llm.get_num_tokens("How are you?") assert output == 4 + + +def test_safety_settings_gemini() -> None: + safety_settings = { + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, + } + # test with safety filters on bind + llm = ChatGoogleGenerativeAI(temperature=0, model="gemini-pro").bind( + safety_settings=safety_settings + ) + output = llm.invoke("how to make a bomb?") + assert isinstance(output, AIMessage) + assert len(output.content) > 0 + + # test direct to stream + streamed_messages = [] + output_stream = llm.stream("how to make a bomb?", safety_settings=safety_settings) + assert isinstance(output_stream, Generator) + for message in output_stream: + streamed_messages.append(message) + assert len(streamed_messages) > 0 + + # test as init param + llm = ChatGoogleGenerativeAI( + temperature=0, model="gemini-pro", safety_settings=safety_settings + ) + out2 = llm.invoke("how to make a bomb") + assert isinstance(out2, AIMessage) + assert len(out2.content) > 0 diff --git a/libs/genai/tests/integration_tests/test_function_call.py b/libs/genai/tests/integration_tests/test_function_call.py new file mode 100644 index 00000000..57d0005a --- /dev/null +++ b/libs/genai/tests/integration_tests/test_function_call.py @@ -0,0 +1,84 @@ +"""Test ChatGoogleGenerativeAI function call.""" + +import json + +from langchain_core.messages import AIMessage +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.tools import tool + +from langchain_google_genai.chat_models import ( + ChatGoogleGenerativeAI, +) + + +def test_function_call() -> None: + functions = [ + { + "name": "get_weather", + "description": "Determine weather in my location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + } + ] + llm = ChatGoogleGenerativeAI(model="gemini-pro").bind(functions=functions) + res = llm.invoke("what weather is today in san francisco?") + assert res + assert res.additional_kwargs + assert "function_call" in res.additional_kwargs + assert "get_weather" == res.additional_kwargs["function_call"]["name"] + arguments_str = res.additional_kwargs["function_call"]["arguments"] + assert isinstance(arguments_str, str) + arguments = json.loads(arguments_str) + assert "location" in arguments + + +def test_tool_call() -> None: + @tool + def search_tool(query: str) -> str: + """Searches the web for `query` and returns the result.""" + raise NotImplementedError + + llm = ChatGoogleGenerativeAI(model="gemini-pro").bind(functions=[search_tool]) + response = llm.invoke("weather in san francisco") + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert response.content == "" + function_call = response.additional_kwargs.get("function_call") + assert function_call + assert function_call["name"] == "search_tool" + arguments_str = function_call.get("arguments") + assert arguments_str + arguments = json.loads(arguments_str) + assert "query" in arguments + + +class MyModel(BaseModel): + name: str + age: int + + +def test_pydantic_call() -> None: + llm = ChatGoogleGenerativeAI(model="gemini-pro").bind(functions=[MyModel]) + response = llm.invoke("my name is Erick and I am 27 years old") + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert response.content == "" + function_call = response.additional_kwargs.get("function_call") + assert function_call + assert function_call["name"] == "MyModel" + arguments_str = function_call.get("arguments") + assert arguments_str + arguments = json.loads(arguments_str) + assert arguments == { + "name": "Erick", + "age": 27.0, + } diff --git a/libs/genai/tests/integration_tests/test_llms.py b/libs/genai/tests/integration_tests/test_llms.py index 9bdf49dd..b761e148 100644 --- a/libs/genai/tests/integration_tests/test_llms.py +++ b/libs/genai/tests/integration_tests/test_llms.py @@ -4,10 +4,12 @@ valid API key. """ +from typing import Generator + import pytest from langchain_core.outputs import LLMResult -from langchain_google_genai.llms import GoogleGenerativeAI +from langchain_google_genai import GoogleGenerativeAI, HarmBlockThreshold, HarmCategory model_names = ["models/text-bison-001", "gemini-pro"] @@ -66,3 +68,39 @@ def test_generativeai_get_num_tokens_gemini() -> None: llm = GoogleGenerativeAI(temperature=0, model="gemini-pro") output = llm.get_num_tokens("How are you?") assert output == 4 + + +def test_safety_settings_gemini() -> None: + # test with blocked prompt + llm = GoogleGenerativeAI(temperature=0, model="gemini-pro") + output = llm.generate(prompts=["how to make a bomb?"]) + assert isinstance(output, LLMResult) + assert len(output.generations[0]) == 0 + + # safety filters + safety_settings = { + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, + } + + # test with safety filters directly to generate + output = llm.generate(["how to make a bomb?"], safety_settings=safety_settings) + assert isinstance(output, LLMResult) + assert len(output.generations[0]) > 0 + + # test with safety filters directly to stream + streamed_messages = [] + output_stream = llm.stream("how to make a bomb?", safety_settings=safety_settings) + assert isinstance(output_stream, Generator) + for message in output_stream: + streamed_messages.append(message) + assert len(streamed_messages) > 0 + + # test with safety filters on instantiation + llm = GoogleGenerativeAI( + model="gemini-pro", + safety_settings=safety_settings, + temperature=0, + ) + output = llm.generate(prompts=["how to make a bomb?"]) + assert isinstance(output, LLMResult) + assert len(output.generations[0]) > 0 diff --git a/libs/genai/tests/unit_tests/test_chat_models.py b/libs/genai/tests/unit_tests/test_chat_models.py index 93d5bb32..61990909 100644 --- a/libs/genai/tests/unit_tests/test_chat_models.py +++ b/libs/genai/tests/unit_tests/test_chat_models.py @@ -1,5 +1,14 @@ """Test chat model integration.""" -from langchain_core.messages import AIMessage, HumanMessage, SystemMessage + +from typing import Dict, List, Union + +import pytest +from langchain_core.messages import ( + AIMessage, + FunctionMessage, + HumanMessage, + SystemMessage, +) from langchain_core.pydantic_v1 import SecretStr from pytest import CaptureFixture @@ -36,7 +45,7 @@ def test_api_key_is_string() -> None: def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> None: chat = ChatGoogleGenerativeAI(model="gemini-nano", google_api_key="secret-api-key") - print(chat.google_api_key, end="") + print(chat.google_api_key, end="") # noqa: T201 captured = capsys.readouterr() assert captured.out == "**********" @@ -58,3 +67,9 @@ def test_parse_history() -> None: "parts": [{"text": system_input}, {"text": text_question1}], } assert history[1] == {"role": "model", "parts": [{"text": text_answer1}]} + + +@pytest.mark.parametrize("content", ['["a"]', '{"a":"b"}', "function output"]) +def test_parse_function_history(content: Union[str, List[Union[str, Dict]]]) -> None: + function_message = FunctionMessage(name="search_tool", content=content) + _parse_chat_history([function_message], convert_system_message_to_human=True) diff --git a/libs/genai/tests/unit_tests/test_embeddings.py b/libs/genai/tests/unit_tests/test_embeddings.py index 45acffb3..9bb57274 100644 --- a/libs/genai/tests/unit_tests/test_embeddings.py +++ b/libs/genai/tests/unit_tests/test_embeddings.py @@ -1,4 +1,5 @@ """Test embeddings model integration.""" + from langchain_core.pydantic_v1 import SecretStr from pytest import CaptureFixture @@ -31,7 +32,7 @@ def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> N model="models/embedding-001", google_api_key="secret-api-key", ) - print(embeddings.google_api_key, end="") + print(embeddings.google_api_key, end="") # noqa: T201 captured = capsys.readouterr() assert captured.out == "**********" diff --git a/libs/genai/tests/unit_tests/test_imports.py b/libs/genai/tests/unit_tests/test_imports.py index e189a9fc..8c90cb2b 100644 --- a/libs/genai/tests/unit_tests/test_imports.py +++ b/libs/genai/tests/unit_tests/test_imports.py @@ -4,6 +4,8 @@ "ChatGoogleGenerativeAI", "GoogleGenerativeAIEmbeddings", "GoogleGenerativeAI", + "HarmBlockThreshold", + "HarmCategory", ]