diff --git a/docs/extras/integrations/chat/google_vertex_ai_palm.ipynb b/docs/extras/integrations/chat/google_vertex_ai_palm.ipynb index 092b7cbd67d9f..61f6809476668 100644 --- a/docs/extras/integrations/chat/google_vertex_ai_palm.ipynb +++ b/docs/extras/integrations/chat/google_vertex_ai_palm.ipynb @@ -5,7 +5,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Google Cloud Platform Vertex AI PaLM \n", + "# GCP Vertex AI \n", "\n", "Note: This is seperate from the Google PaLM integration. Google has chosen to offer an enterprise version of PaLM through GCP, and this supports the models made available through there. \n", "\n", @@ -31,7 +31,7 @@ }, "outputs": [], "source": [ - "#!pip install google-cloud-aiplatform" + "#!pip install langchain google-cloud-aiplatform" ] }, { @@ -41,12 +41,7 @@ "outputs": [], "source": [ "from langchain.chat_models import ChatVertexAI\n", - "from langchain.prompts.chat import (\n", - " ChatPromptTemplate,\n", - " SystemMessagePromptTemplate,\n", - " HumanMessagePromptTemplate,\n", - ")\n", - "from langchain.schema import HumanMessage, SystemMessage" + "from langchain.prompts import ChatPromptTemplate" ] }, { @@ -60,82 +55,78 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "system = \"You are a helpful assistant who translate English to French\"\n", + "human = \"Translate this sentence from English to French. I love programming.\"\n", + "prompt = ChatPromptTemplate.from_messages(\n", + " [(\"system\", system), (\"human\", human)]\n", + ")\n", + "messages = prompt.format_messages()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "AIMessage(content='Sure, here is the translation of the sentence \"I love programming\" from English to French:\\n\\nJ\\'aime programmer.', additional_kwargs={}, example=False)" + "AIMessage(content=\" J'aime la programmation.\", additional_kwargs={}, example=False)" ] }, - "execution_count": 4, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "messages = [\n", - " SystemMessage(\n", - " content=\"You are a helpful assistant that translates English to French.\"\n", - " ),\n", - " HumanMessage(\n", - " content=\"Translate this sentence from English to French. I love programming.\"\n", - " ),\n", - "]\n", "chat(messages)" ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "You can make use of templating by using a `MessagePromptTemplate`. You can build a `ChatPromptTemplate` from one or more `MessagePromptTemplates`. You can use `ChatPromptTemplate`'s `format_prompt` -- this returns a `PromptValue`, which you can convert to a string or Message object, depending on whether you want to use the formatted value as input to an llm or chat model.\n", - "\n", - "For convenience, there is a `from_template` method exposed on the template. If you were to use this template, this is what it would look like:" + "If we want to construct a simple chain that takes user specified parameters:" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ - "template = (\n", - " \"You are a helpful assistant that translates {input_language} to {output_language}.\"\n", - ")\n", - "system_message_prompt = SystemMessagePromptTemplate.from_template(template)\n", - "human_template = \"{text}\"\n", - "human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)" + "system = \"You are a helpful assistant that translates {input_language} to {output_language}.\"\n", + "human = \"{text}\"\n", + "prompt = ChatPromptTemplate.from_messages(\n", + " [(\"system\", system), (\"human\", human)]\n", + ")" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "AIMessage(content='Sure, here is the translation of \"I love programming\" in French:\\n\\nJ\\'aime programmer.', additional_kwargs={}, example=False)" + "AIMessage(content=' 私はプログラミングが大好きです。', additional_kwargs={}, example=False)" ] }, - "execution_count": 7, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "chat_prompt = ChatPromptTemplate.from_messages(\n", - " [system_message_prompt, human_message_prompt]\n", - ")\n", - "\n", - "# get a chat completion from the formatted messages\n", - "chat(\n", - " chat_prompt.format_prompt(\n", - " input_language=\"English\", output_language=\"French\", text=\"I love programming.\"\n", - " ).to_messages()\n", + "chain = prompt | chat\n", + "chain.invoke(\n", + " {\"input_language\": \"English\", \"output_language\": \"Japanese\", \"text\": \"I love programming\"}\n", ")" ] }, @@ -153,60 +144,129 @@ "tags": [] }, "source": [ + "## Code generation chat models\n", "You can now leverage the Codey API for code chat within Vertex AI. The model name is:\n", "- codechat-bison: for code assistance" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 18, "metadata": { - "execution": { - "iopub.execute_input": "2023-06-17T21:30:43.974841Z", - "iopub.status.busy": "2023-06-17T21:30:43.974431Z", - "iopub.status.idle": "2023-06-17T21:30:44.248119Z", - "shell.execute_reply": "2023-06-17T21:30:44.247362Z", - "shell.execute_reply.started": "2023-06-17T21:30:43.974820Z" - }, "tags": [] }, "outputs": [], "source": [ - "chat = ChatVertexAI(model_name=\"codechat-bison\")" + "chat = ChatVertexAI(\n", + " model_name=\"codechat-bison\",\n", + " max_output_tokens=1000,\n", + " temperature=0.5\n", + ")" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 20, "metadata": { - "execution": { - "iopub.execute_input": "2023-06-17T21:30:45.146093Z", - "iopub.status.busy": "2023-06-17T21:30:45.145752Z", - "iopub.status.idle": "2023-06-17T21:30:47.449126Z", - "shell.execute_reply": "2023-06-17T21:30:47.448609Z", - "shell.execute_reply.started": "2023-06-17T21:30:45.146069Z" - }, "tags": [] }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " ```python\n", + "def is_prime(x): \n", + " if (x <= 1): \n", + " return False\n", + " for i in range(2, x): \n", + " if (x % i == 0): \n", + " return False\n", + " return True\n", + "```\n" + ] + } + ], + "source": [ + "# For simple string in string out usage, we can use the `predict` method:\n", + "print(chat.predict(\"Write a Python function to identify all prime numbers\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Asynchronous calls\n", + "\n", + "We can make asynchronous calls via the `agenerate` and `ainvoke` methods." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "import asyncio\n", + "# import nest_asyncio\n", + "# nest_asyncio.apply()" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "AIMessage(content='The following Python function can be used to identify all prime numbers up to a given integer:\\n\\n```\\ndef is_prime(n):\\n \"\"\"\\n Determines whether the given integer is prime.\\n\\n Args:\\n n: The integer to be tested for primality.\\n\\n Returns:\\n True if n is prime, False otherwise.\\n \"\"\"\\n\\n # Check if n is divisible by 2.\\n if n % 2 == 0:\\n return False\\n\\n # Check if n is divisible by any integer from 3 to the square root', additional_kwargs={}, example=False)" + "LLMResult(generations=[[ChatGeneration(text=\" J'aime la programmation.\", generation_info=None, message=AIMessage(content=\" J'aime la programmation.\", additional_kwargs={}, example=False))]], llm_output={}, run=[RunInfo(run_id=UUID('223599ef-38f8-4c79-ac6d-a5013060eb9d'))])" ] }, - "execution_count": 4, + "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "messages = [\n", - " HumanMessage(\n", - " content=\"How do I create a python function to identify all prime numbers?\"\n", - " )\n", - "]\n", - "chat(messages)" + "chat = ChatVertexAI(\n", + " model_name=\"chat-bison\",\n", + " max_output_tokens=1000,\n", + " temperature=0.7,\n", + " top_p=0.95,\n", + " top_k=40,\n", + ")\n", + "\n", + "asyncio.run(chat.agenerate([messages]))" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=' अहं प्रोग्रामिंग प्रेमामि', additional_kwargs={}, example=False)" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "asyncio.run(chain.ainvoke({\"input_language\": \"English\", \"output_language\": \"Sanskrit\", \"text\": \"I love programming\"}))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Streaming calls\n", + "\n", + "We can also stream outputs via the `stream` method:" ] }, { @@ -214,14 +274,51 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "import sys" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " 1. China (1,444,216,107)\n", + "2. India (1,393,409,038)\n", + "3. United States (332,403,650)\n", + "4. Indonesia (273,523,615)\n", + "5. Pakistan (220,892,340)\n", + "6. Brazil (212,559,409)\n", + "7. Nigeria (206,139,589)\n", + "8. Bangladesh (164,689,383)\n", + "9. Russia (145,934,462)\n", + "10. Mexico (128,932,488)\n", + "11. Japan (126,476,461)\n", + "12. Ethiopia (115,063,982)\n", + "13. Philippines (109,581,078)\n", + "14. Egypt (102,334,404)\n", + "15. Vietnam (97,338,589)" + ] + } + ], + "source": [ + "prompt = ChatPromptTemplate.from_messages([(\"human\", \"List out the 15 most populous countries in the world\")])\n", + "messages = prompt.format_messages()\n", + "for chunk in chat.stream(messages):\n", + " sys.stdout.write(chunk.content)\n", + " sys.stdout.flush()" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "poetry-venv", "language": "python", - "name": "python3" + "name": "poetry-venv" }, "language_info": { "codemirror_mode": { diff --git a/docs/extras/integrations/chat/index.mdx b/docs/extras/integrations/chat/index.mdx index 82c5f76169fcb..38e4a81be3332 100644 --- a/docs/extras/integrations/chat/index.mdx +++ b/docs/extras/integrations/chat/index.mdx @@ -26,7 +26,7 @@ ChatLiteLLM|✅|✅|✅|✅ ChatMLflowAIGateway|✅|❌|❌|❌ ChatOllama|✅|❌|✅|❌ ChatOpenAI|✅|✅|✅|✅ -ChatVertexAI|✅|❌|✅|❌ +ChatVertexAI|✅|✅|✅|❌ ErnieBotChat|✅|❌|❌|❌ JinaChat|✅|✅|✅|✅ MiniMaxChat|✅|✅|❌|❌ diff --git a/docs/extras/integrations/llms/google_vertex_ai_palm.ipynb b/docs/extras/integrations/llms/google_vertex_ai_palm.ipynb index 0327465f343df..40435cf37b02c 100644 --- a/docs/extras/integrations/llms/google_vertex_ai_palm.ipynb +++ b/docs/extras/integrations/llms/google_vertex_ai_palm.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Google Vertex AI PaLM \n", + "# GCP Vertex AI\n", "\n", "**Note:** This is separate from the `Google PaLM` integration, it exposes [Vertex AI PaLM API](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/overview) on `Google Cloud`. \n" ] @@ -41,12 +41,12 @@ }, "outputs": [], "source": [ - "#!pip install google-cloud-aiplatform" + "#!pip install langchain google-cloud-aiplatform" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -54,41 +54,55 @@ ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 9, "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Python is a widely used, interpreted, object-oriented, and high-level programming language with dynamic semantics, used for general-purpose programming. It is known for its readability, simplicity, and versatility. Here are some of the pros and cons of Python:\n", + "\n", + "**Pros:**\n", + "\n", + "- **Easy to learn:** Python is known for its simple and intuitive syntax, making it easy for beginners to learn. It has a relatively shallow learning curve compared to other programming languages.\n", + "\n", + "- **Versatile:** Python is a general-purpose programming language, meaning it can be used for a wide variety of tasks, including web development, data science, machine\n" + ] + } + ], "source": [ - "## Question-answering example" + "llm = VertexAI()\n", + "print(llm(\"What are some of the pros and cons of Python as a programming language?\"))" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "from langchain.prompts import PromptTemplate\nfrom langchain.chains import LLMChain" + "## Using in a chain" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ - "template = \"\"\"Question: {question}\n", - "\n", - "Answer: Let's think step by step.\"\"\"\n", - "\n", - "prompt = PromptTemplate(template=template, input_variables=[\"question\"])" + "from langchain.prompts import PromptTemplate" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ - "llm = VertexAI()" + "template = \"\"\"Question: {question}\n", + "\n", + "Answer: Let's think step by step.\"\"\"\n", + "prompt = PromptTemplate.from_template(template)" ] }, { @@ -97,29 +111,26 @@ "metadata": {}, "outputs": [], "source": [ - "llm_chain = LLMChain(prompt=prompt, llm=llm)" + "chain = prompt | llm" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "'Justin Bieber was born on March 1, 1994. The Super Bowl in 1994 was won by the San Francisco 49ers.\\nThe final answer: San Francisco 49ers.'" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + " Justin Bieber was born on March 1, 1994. Bill Clinton was the president of the United States from January 20, 1993, to January 20, 2001.\n", + "The final answer is Bill Clinton\n" + ] } ], "source": [ - "question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n", - "\n", - "llm_chain.run(question)" + "question = \"Who was the president in the year Justin Beiber was born?\"\n", + "print(chain.invoke({\"question\": question}))" ] }, { @@ -142,76 +153,198 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 15, "metadata": { - "execution": { - "iopub.execute_input": "2023-06-17T21:16:53.149438Z", - "iopub.status.busy": "2023-06-17T21:16:53.149065Z", - "iopub.status.idle": "2023-06-17T21:16:53.421824Z", - "shell.execute_reply": "2023-06-17T21:16:53.421136Z", - "shell.execute_reply.started": "2023-06-17T21:16:53.149415Z" - }, "tags": [] }, "outputs": [], "source": [ - "llm = VertexAI(model_name=\"code-bison\")" + "llm = VertexAI(model_name=\"code-bison\", max_output_tokens=1000, temperature=0.3)" ] }, { "cell_type": "code", - "execution_count": 12, - "metadata": { - "execution": { - "iopub.execute_input": "2023-06-17T21:17:11.179077Z", - "iopub.status.busy": "2023-06-17T21:17:11.178686Z", - "iopub.status.idle": "2023-06-17T21:17:11.182499Z", - "shell.execute_reply": "2023-06-17T21:17:11.181895Z", - "shell.execute_reply.started": "2023-06-17T21:17:11.179052Z" - }, - "tags": [] - }, + "execution_count": 21, + "metadata": {}, "outputs": [], "source": [ - "llm_chain = LLMChain(prompt=prompt, llm=llm)" + "question = \"Write a python function that checks if a string is a valid email address\"" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 19, "metadata": { - "execution": { - "iopub.execute_input": "2023-06-17T21:18:47.024785Z", - "iopub.status.busy": "2023-06-17T21:18:47.024230Z", - "iopub.status.idle": "2023-06-17T21:18:49.352249Z", - "shell.execute_reply": "2023-06-17T21:18:49.351695Z", - "shell.execute_reply.started": "2023-06-17T21:18:47.024762Z" - }, "tags": [] }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "```python\n", + "import re\n", + "\n", + "def is_valid_email(email):\n", + " pattern = re.compile(r\"[^@]+@[^@]+\\.[^@]+\")\n", + " return pattern.match(email)\n", + "```\n" + ] + } + ], + "source": [ + "print(llm(question))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Full generation info\n", + "\n", + "We can use the `generate` method to get back extra metadata like [safety attributes](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/responsible-ai#safety_attribute_confidence_scoring) and not just text completions" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[[GenerationChunk(text='```python\\nimport re\\n\\ndef is_valid_email(email):\\n pattern = re.compile(r\"[^@]+@[^@]+\\\\.[^@]+\")\\n return pattern.match(email)\\n```', generation_info={'is_blocked': False, 'safety_attributes': {'Health': 0.1}})]]" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result = llm.generate([question])\n", + "result.generations" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Asynchronous calls\n", + "\n", + "With `agenerate` we can make asynchronous calls" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# If running in a Jupyter notebook you'll need to install nest_asyncio\n", + "\n", + "# !pip install nest_asyncio" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "import asyncio\n", + "# import nest_asyncio\n", + "# nest_asyncio.apply()" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'```python\\ndef is_prime(n):\\n \"\"\"\\n Determines if a number is prime.\\n\\n Args:\\n n: The number to be tested.\\n\\n Returns:\\n True if the number is prime, False otherwise.\\n \"\"\"\\n\\n # Check if the number is 1.\\n if n == 1:\\n return False\\n\\n # Check if the number is 2.\\n if n == 2:\\n return True\\n\\n'" + "LLMResult(generations=[[GenerationChunk(text='```python\\nimport re\\n\\ndef is_valid_email(email):\\n pattern = re.compile(r\"[^@]+@[^@]+\\\\.[^@]+\")\\n return pattern.match(email)\\n```', generation_info={'is_blocked': False, 'safety_attributes': {'Health': 0.1}})]], llm_output=None, run=[RunInfo(run_id=UUID('caf74e91-aefb-48ac-8031-0c505fcbbcc6'))])" ] }, - "execution_count": 15, + "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "question = \"Write a python function that identifies if the number is a prime number?\"\n", + "asyncio.run(llm.agenerate([question]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Streaming calls\n", "\n", - "llm_chain.run(question)" + "With `stream` we can stream results from the model" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "import sys" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "```python\n", + "import re\n", + "\n", + "def is_valid_email(email):\n", + " \"\"\"\n", + " Checks if a string is a valid email address.\n", + "\n", + " Args:\n", + " email: The string to check.\n", + "\n", + " Returns:\n", + " True if the string is a valid email address, False otherwise.\n", + " \"\"\"\n", + "\n", + " # Check for a valid email address format.\n", + " if not re.match(r\"^[A-Za-z0-9\\.\\+_-]+@[A-Za-z0-9\\._-]+\\.[a-zA-Z]*$\", email):\n", + " return False\n", + "\n", + " # Check if the domain name exists.\n", + " try:\n", + " domain = email.split(\"@\")[1]\n", + " socket.gethostbyname(domain)\n", + " except socket.gaierror:\n", + " return False\n", + "\n", + " return True\n", + "```" + ] + } + ], + "source": [ + "for chunk in llm.stream(question):\n", + " sys.stdout.write(chunk)\n", + " sys.stdout.flush()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Using models deployed on Vertex Model Garden" + "## Vertex Model Garden" ] }, { @@ -248,7 +381,7 @@ "metadata": {}, "outputs": [], "source": [ - "llm(\"What is the meaning of life?\")" + "print(llm(\"What is the meaning of life?\"))" ] }, { @@ -264,8 +397,6 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain.prompts import PromptTemplate\n", - "\n", "prompt = PromptTemplate.from_template(\"What is the meaning of {thing}?\")" ] }, @@ -275,9 +406,8 @@ "metadata": {}, "outputs": [], "source": [ - "llm_oss_chain = prompt | llm\n", - "\n", - "llm_oss_chain.invoke({\"thing\": \"life\"})" + "chian = prompt | llm\n", + "print(chain.invoke({\"thing\": \"life\"}))" ] } ], diff --git a/docs/extras/integrations/llms/index.mdx b/docs/extras/integrations/llms/index.mdx index c76e92a0a3773..872bd8a6d1b98 100644 --- a/docs/extras/integrations/llms/index.mdx +++ b/docs/extras/integrations/llms/index.mdx @@ -83,8 +83,8 @@ TitanTakeoff|✅|❌|✅|❌|❌|❌ Tongyi|✅|❌|❌|❌|❌|❌ VLLM|✅|❌|❌|❌|✅|❌ VLLMOpenAI|✅|✅|✅|✅|✅|✅ -VertexAI|✅|✅|❌|❌|❌|❌ -VertexAIModelGarden|✅|✅|❌|❌|❌|❌ +VertexAI|✅|✅|✅|❌|✅|✅ +VertexAIModelGarden|✅|✅|❌|❌|✅|✅ Writer|✅|❌|❌|❌|❌|❌ Xinference|✅|❌|❌|❌|❌|❌ diff --git a/docs/extras/integrations/platforms/google.mdx b/docs/extras/integrations/platforms/google.mdx index a4a50e8720e7d..7f7ab56cfb6f4 100644 --- a/docs/extras/integrations/platforms/google.mdx +++ b/docs/extras/integrations/platforms/google.mdx @@ -2,6 +2,35 @@ All functionality related to Google Platform +## LLMs + +### Vertex AI + +Access PaLM LLMs like `text-bison` and `code-bison` via Google Cloud. + +```python +from langchain.llms import VertexAI +``` + +### Model Garden + +Access PaLM and hundreds of OSS models via Vertex AI Model Garden. + +```python +from langchain.llms import VertexAIModelGarden +``` + +## Chat models + +### Vertex AI + +Access PaLM chat models like `chat-bison` and `codechat-bison` via Google Cloud. + +```python +from langchain.chat_models import ChatVertexAI +``` + + ## Document Loader ### Google BigQuery diff --git a/libs/langchain/langchain/chat_models/vertexai.py b/libs/langchain/langchain/chat_models/vertexai.py index 0b0909f74d44f..a2e285030537a 100644 --- a/libs/langchain/langchain/chat_models/vertexai.py +++ b/libs/langchain/langchain/chat_models/vertexai.py @@ -1,10 +1,14 @@ """Wrapper around Google VertexAI chat-based models.""" from __future__ import annotations +import logging from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union -from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.chat_models.base import BaseChatModel, _generate_from_stream from langchain.llms.vertexai import _VertexAICommon, is_codey_model from langchain.pydantic_v1 import root_validator @@ -30,6 +34,8 @@ InputOutputTextPair, ) +logger = logging.getLogger(__name__) + @dataclass class _ChatHistory: @@ -116,7 +122,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): """`Vertex AI` Chat large language models API.""" model_name: str = "chat-bison" - streaming: bool = False + "Underlying model name." @root_validator() def validate_environment(cls, values: Dict) -> Dict: @@ -177,6 +183,42 @@ def _generate( text = self._enforce_stop_words(response.text, stop) return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))]) + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Asynchronously generate next turn in the conversation. + + Args: + messages: The history of the conversation as a list of messages. Code chat + does not support context. + stop: The list of stop words (optional). + run_manager: The CallbackManager for LLM run, it's not used at the moment. + + Returns: + The ChatResult that contains outputs generated by the model. + + Raises: + ValueError: if the last message in the list is not from human. + """ + if "stream" in kwargs: + kwargs.pop("stream") + logger.warning("ChatVertexAI does not currently support async streaming.") + question = _get_question(messages) + history = _parse_chat_history(messages[:-1]) + params = {**self._default_params, **kwargs} + examples = kwargs.get("examples", None) + if examples: + params["examples"] = _parse_examples(examples) + + chat = self._start_chat(history, params) + response = await chat.send_message_async(question.content) + text = self._enforce_stop_words(response.text, stop) + return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))]) + def _stream( self, messages: List[BaseMessage], diff --git a/libs/langchain/langchain/llms/vertexai.py b/libs/langchain/langchain/llms/vertexai.py index aaa5efbecb743..1367f9d6bbb81 100644 --- a/libs/langchain/langchain/llms/vertexai.py +++ b/libs/langchain/langchain/llms/vertexai.py @@ -1,28 +1,58 @@ from __future__ import annotations -import asyncio from concurrent.futures import Executor, ThreadPoolExecutor -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + Iterator, + List, + Optional, + Union, +) from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain.llms.base import LLM, create_base_retry_decorator +from langchain.llms.base import BaseLLM, create_base_retry_decorator from langchain.llms.utils import enforce_stop_tokens from langchain.pydantic_v1 import BaseModel, root_validator from langchain.schema import ( Generation, LLMResult, ) +from langchain.schema.output import GenerationChunk from langchain.utilities.vertexai import ( init_vertexai, raise_vertex_import_error, ) if TYPE_CHECKING: - from google.cloud.aiplatform.gapic import PredictionServiceClient - from vertexai.language_models._language_models import _LanguageModel + from google.cloud.aiplatform.gapic import ( + PredictionServiceAsyncClient, + PredictionServiceClient, + ) + from vertexai.language_models._language_models import ( + TextGenerationResponse, + _LanguageModel, + ) + + +def _response_to_generation( + response: TextGenerationResponse, +) -> GenerationChunk: + """Convert a stream response to a generation chunk.""" + try: + generation_info = { + "is_blocked": response.is_blocked, + "safety_attributes": response.safety_attributes, + } + except Exception: + generation_info = None + return GenerationChunk(text=response.text, generation_info=generation_info) def is_codey_model(model_name: str) -> bool: @@ -36,7 +66,13 @@ def is_codey_model(model_name: str) -> bool: return "code" in model_name -def _create_retry_decorator(llm: VertexAI) -> Callable[[Any], Any]: +def _create_retry_decorator( + llm: VertexAI, + *, + run_manager: Optional[ + Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] + ] = None, +) -> Callable[[Any], Any]: import google.api_core errors = [ @@ -46,14 +82,19 @@ def _create_retry_decorator(llm: VertexAI) -> Callable[[Any], Any]: google.api_core.exceptions.DeadlineExceeded, ] decorator = create_base_retry_decorator( - error_types=errors, max_retries=llm.max_retries # type: ignore + error_types=errors, max_retries=llm.max_retries, run_manager=run_manager ) return decorator -def completion_with_retry(llm: VertexAI, *args: Any, **kwargs: Any) -> Any: +def completion_with_retry( + llm: VertexAI, + *args: Any, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, +) -> Any: """Use tenacity to retry the completion call.""" - retry_decorator = _create_retry_decorator(llm) + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) @retry_decorator def _completion_with_retry(*args: Any, **kwargs: Any) -> Any: @@ -62,6 +103,38 @@ def _completion_with_retry(*args: Any, **kwargs: Any) -> Any: return _completion_with_retry(*args, **kwargs) +def stream_completion_with_retry( + llm: VertexAI, + *args: Any, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, +) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) + + @retry_decorator + def _completion_with_retry(*args: Any, **kwargs: Any) -> Any: + return llm.client.predict_streaming(*args, **kwargs) + + return _completion_with_retry(*args, **kwargs) + + +async def acompletion_with_retry( + llm: VertexAI, + *args: Any, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, +) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) + + @retry_decorator + async def _acompletion_with_retry(*args: Any, **kwargs: Any) -> Any: + return await llm.client.predict_async(*args, **kwargs) + + return await _acompletion_with_retry(*args, **kwargs) + + class _VertexAIBase(BaseModel): project: Optional[str] = None "The default GCP project to use when making Vertex API calls." @@ -110,6 +183,11 @@ class _VertexAICommon(_VertexAIBase): "The default custom credentials (google.auth.credentials.Credentials) to use " "when making API calls. If not provided, credentials will be ascertained from " "the environment." + streaming: bool = False + + @property + def _llm_type(self) -> str: + return "vertexai" @property def is_codey_model(self) -> bool: @@ -135,17 +213,6 @@ def _default_params(self) -> Dict[str, Any]: "top_p": self.top_p, } - def _predict( - self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any - ) -> str: - params = {**self._default_params, **kwargs} - res = completion_with_retry(self, prompt, **params) # type: ignore - return self._enforce_stop_words(res.text, stop) - - @property - def _llm_type(self) -> str: - return "vertexai" - @classmethod def _try_init_vertexai(cls, values: Dict) -> None: allowed_params = ["project", "location", "credentials"] @@ -154,13 +221,14 @@ def _try_init_vertexai(cls, values: Dict) -> None: return None -class VertexAI(_VertexAICommon, LLM): +class VertexAI(_VertexAICommon, BaseLLM): """Google Vertex AI large language models.""" model_name: str = "text-bison" "The name of the Vertex AI large language model." tuned_model_name: Optional[str] = None "The name of a tuned model. If provided, model_name is ignored." + streaming: bool = False @root_validator() def validate_environment(cls, values: Dict) -> Dict: @@ -191,51 +259,78 @@ def validate_environment(cls, values: Dict) -> Dict: raise_vertex_import_error() return values - def _call( + def _generate( self, - prompt: str, + prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, **kwargs: Any, - ) -> str: - """Call Vertex model to get predictions based on the prompt. - - Args: - prompt: The prompt to pass into the model. - stop: A list of stop words (optional). - run_manager: A Callbackmanager for LLM run, optional. + ) -> LLMResult: + stop_sequences = stop or self.stop + should_stream = stream if stream is not None else self.streaming - Returns: - The string generated by the model. - """ - return self._predict(prompt, stop, **kwargs) + params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs} + generations = [] + for prompt in prompts: + if should_stream: + generation = GenerationChunk(text="") + for chunk in self._stream( + prompt, stop=stop, run_manager=run_manager, **kwargs + ): + generation += chunk + generations.append([generation]) + else: + res = completion_with_retry( + self, prompt, run_manager=run_manager, **params + ) + generations.append([_response_to_generation(res)]) + return LLMResult(generations=generations) - async def _acall( + async def _agenerate( self, - prompt: str, + prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, - ) -> str: - """Call Vertex model to get predictions based on the prompt. - - Args: - prompt: The prompt to pass into the model. - stop: A list of stop words (optional). - run_manager: A callback manager for async interaction with LLMs. - - Returns: - The string generated by the model. - """ - return await asyncio.wrap_future( - self._get_task_executor().submit(self._call, prompt, stop) - ) - + ) -> LLMResult: + stop_sequences = stop or self.stop + params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs} + generations = [] + for prompt in prompts: + res = await acompletion_with_retry( + self, prompt, run_manager=run_manager, **params + ) + generations.append([_response_to_generation(res)]) + return LLMResult(generations=generations) -class VertexAIModelGarden(_VertexAIBase, LLM): + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + stop_sequences = stop or self.stop + params = {**self._default_params, "stop_sequences": stop_sequences, **kwargs} + for stream_resp in stream_completion_with_retry( + self, prompt, run_manager=run_manager, **params + ): + chunk = _response_to_generation(stream_resp) + yield chunk + if run_manager: + run_manager.on_llm_new_token( + chunk.text, + chunk=chunk, + verbose=self.verbose, + ) + + +class VertexAIModelGarden(_VertexAIBase, BaseLLM): """Large language models served from Vertex AI Model Garden.""" client: "PredictionServiceClient" = None #: :meta private: + async_client: "PredictionServiceAsyncClient" = None #: :meta private: endpoint_id: str "A name of an endpoint where the model has been deployed." allowed_model_args: Optional[List[str]] = None @@ -247,7 +342,11 @@ class VertexAIModelGarden(_VertexAIBase, LLM): def validate_environment(cls, values: Dict) -> Dict: """Validate that the python package exists in environment.""" try: - from google.cloud.aiplatform.gapic import PredictionServiceClient + from google.api_core.client_options import ClientOptions + from google.cloud.aiplatform.gapic import ( + PredictionServiceAsyncClient, + PredictionServiceClient, + ) except ImportError: raise_vertex_import_error() @@ -256,38 +355,19 @@ def validate_environment(cls, values: Dict) -> Dict: "A GCP project should be provided to run inference on Model Garden!" ) - client_options = { - "api_endpoint": f"{values['location']}-aiplatform.googleapis.com" - } + client_options = ClientOptions( + api_endpoint=f"{values['location']}-aiplatform.googleapis.com" + ) values["client"] = PredictionServiceClient(client_options=client_options) + values["async_client"] = PredictionServiceAsyncClient( + client_options=client_options + ) return values @property def _llm_type(self) -> str: return "vertexai_model_garden" - def _call( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: - """Call Vertex model to get predictions based on the prompt. - - Args: - prompt: The prompt to pass into the model. - stop: A list of stop words (optional). - run_manager: A Callbackmanager for LLM run, optional. - - Returns: - The string generated by the model. - """ - result = self._generate( - prompts=[prompt], stop=stop, run_manager=run_manager, **kwargs - ) - return result.generations[0][0].text - def _generate( self, prompts: List[str], @@ -331,23 +411,47 @@ def _generate( ) return LLMResult(generations=generations) - async def _acall( + async def _agenerate( self, - prompt: str, + prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, - ) -> str: - """Call Vertex model to get predictions based on the prompt. - - Args: - prompt: The prompt to pass into the model. - stop: A list of stop words (optional). - run_manager: A callback manager for async interaction with LLMs. - - Returns: - The string generated by the model. - """ - return await asyncio.wrap_future( - self._get_task_executor().submit(self._call, prompt, stop) + ) -> LLMResult: + """Run the LLM on the given prompt and input.""" + try: + from google.protobuf import json_format + from google.protobuf.struct_pb2 import Value + except ImportError: + raise ImportError( + "protobuf package not found, please install it with" + " `pip install protobuf`" + ) + + instances = [] + for prompt in prompts: + if self.allowed_model_args: + instance = { + k: v for k, v in kwargs.items() if k in self.allowed_model_args + } + else: + instance = {} + instance[self.prompt_arg] = prompt + instances.append(instance) + + predict_instances = [ + json_format.ParseDict(instance_dict, Value()) for instance_dict in instances + ] + + endpoint = self.async_client.endpoint_path( + project=self.project, location=self.location, endpoint=self.endpoint_id + ) + response = await self.async_client.predict( + endpoint=endpoint, instances=predict_instances ) + generations: List[List[Generation]] = [] + for result in response.predictions: + generations.append( + [Generation(text=prediction[self.result_arg]) for prediction in result] + ) + return LLMResult(generations=generations) diff --git a/libs/langchain/tests/integration_tests/chat_models/test_vertexai.py b/libs/langchain/tests/integration_tests/chat_models/test_vertexai.py index 1ae4f77a233c3..b3e3feca95114 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_vertexai.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_vertexai.py @@ -13,6 +13,7 @@ from langchain.chat_models import ChatVertexAI from langchain.chat_models.vertexai import _parse_chat_history, _parse_examples +from langchain.schema import LLMResult from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage @@ -26,10 +27,22 @@ def test_vertexai_single_call(model_name: str) -> None: response = model([message]) assert isinstance(response, AIMessage) assert isinstance(response.content, str) - assert model._llm_type == "vertexai" + assert model._llm_type == "chat-vertexai" assert model.model_name == model.client._model_id +@pytest.mark.asyncio +async def test_vertexai_agenerate() -> None: + model = ChatVertexAI(temperature=0) + message = HumanMessage(content="Hello") + response = await model.agenerate([[message]]) + assert isinstance(response, LLMResult) + assert isinstance(response.generations[0][0].message, AIMessage) # type: ignore + + sync_response = model.generate([[message]]) + assert response.generations[0][0] == sync_response.generations[0][0] + + def test_vertexai_single_call_with_context() -> None: model = ChatVertexAI() raw_context = ( diff --git a/libs/langchain/tests/integration_tests/embeddings/test_vertexai.py b/libs/langchain/tests/integration_tests/embeddings/test_vertexai.py index 5d711275b93a4..e59547af8b1ba 100644 --- a/libs/langchain/tests/integration_tests/embeddings/test_vertexai.py +++ b/libs/langchain/tests/integration_tests/embeddings/test_vertexai.py @@ -14,7 +14,6 @@ def test_embedding_documents() -> None: output = model.embed_documents(documents) assert len(output) == 1 assert len(output[0]) == 768 - assert model._llm_type == "vertexai" assert model.model_name == model.client._model_id @@ -40,5 +39,4 @@ def test_paginated_texts() -> None: output = model.embed_documents(documents) assert len(output) == 8 assert len(output[0]) == 768 - assert model._llm_type == "vertexai" assert model.model_name == model.client._model_id diff --git a/libs/langchain/tests/integration_tests/llms/test_vertexai.py b/libs/langchain/tests/integration_tests/llms/test_vertexai.py index c995147c0ee09..994f75a86d490 100644 --- a/libs/langchain/tests/integration_tests/llms/test_vertexai.py +++ b/libs/langchain/tests/integration_tests/llms/test_vertexai.py @@ -9,18 +9,49 @@ """ import os +import pytest + from langchain.llms import VertexAI, VertexAIModelGarden from langchain.schema import LLMResult def test_vertex_call() -> None: - llm = VertexAI() + llm = VertexAI(temperature=0) output = llm("Say foo:") assert isinstance(output, str) assert llm._llm_type == "vertexai" assert llm.model_name == llm.client._model_id +def test_vertex_generate() -> None: + llm = VertexAI(temperate=0) + output = llm.generate(["Please say foo:"]) + assert isinstance(output, LLMResult) + + +@pytest.mark.asyncio +async def test_vertex_agenerate() -> None: + llm = VertexAI(temperate=0) + output = await llm.agenerate(["Please say foo:"]) + assert isinstance(output, LLMResult) + + +def test_vertext_stream() -> None: + llm = VertexAI(temperate=0) + outputs = list(llm.stream("Please say foo:")) + assert isinstance(outputs[0], str) + + +@pytest.mark.asyncio +async def test_vertex_consistency() -> None: + llm = VertexAI(temperate=0) + output = llm.generate(["Please say foo:"]) + streaming_output = llm.generate(["Please say foo:"], stream=True) + async_output = await llm.agenerate(["Please say foo:"]) + assert output.generations[0][0].text == streaming_output.generations[0][0].text + assert output.generations[0][0].text == async_output.generations[0][0].text + + def test_model_garden() -> None: """In order to run this test, you should provide an endpoint name. @@ -37,7 +68,7 @@ def test_model_garden() -> None: assert llm._llm_type == "vertexai_model_garden" -def test_model_garden_batch() -> None: +def test_model_garden_generate() -> None: """In order to run this test, you should provide an endpoint name. Example: @@ -47,6 +78,16 @@ def test_model_garden_batch() -> None: endpoint_id = os.environ["ENDPOINT_ID"] project = os.environ["PROJECT"] llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project) - output = llm._generate(["What is the meaning of life?", "How much is 2+2"]) + output = llm.generate(["What is the meaning of life?", "How much is 2+2"]) + assert isinstance(output, LLMResult) + assert len(output.generations) == 2 + + +@pytest.mark.asyncio +async def test_model_garden_agenerate() -> None: + endpoint_id = os.environ["ENDPOINT_ID"] + project = os.environ["PROJECT"] + llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project) + output = await llm.agenerate(["What is the meaning of life?", "How much is 2+2"]) assert isinstance(output, LLMResult) assert len(output.generations) == 2