From 05fd6a16a9802fc6e56d0ec65499ea4db608383d Mon Sep 17 00:00:00 2001 From: Akshata Date: Thu, 7 Nov 2024 14:34:24 -0600 Subject: [PATCH] Add ChatModels wrapper for Cloudflare Workers AI (#27645) Thank you for contributing to LangChain! - [x] **PR title**: "community: chat models wrapper for Cloudflare Workers AI" - [x] **PR message**: - **Description:** Add chat models wrapper for Cloudflare Workers AI. Enables Langgraph intergration via ChatModel for tool usage, agentic usage. - [x] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --------- Co-authored-by: Erick Friis Co-authored-by: Chester Curme --- .../chat/cloudflare_workersai.ipynb | 264 ++++++++++++++++++ docs/scripts/check_templates.py | 2 +- .../chat_models/cloudflare_workersai.py | 245 ++++++++++++++++ .../chat_models/test_cloudflare_workersai.py | 78 ++++++ 4 files changed, 588 insertions(+), 1 deletion(-) create mode 100644 docs/docs/integrations/chat/cloudflare_workersai.ipynb create mode 100644 libs/community/langchain_community/chat_models/cloudflare_workersai.py create mode 100644 libs/community/tests/unit_tests/chat_models/test_cloudflare_workersai.py diff --git a/docs/docs/integrations/chat/cloudflare_workersai.ipynb b/docs/docs/integrations/chat/cloudflare_workersai.ipynb new file mode 100644 index 0000000000000..df7c2a1cb667b --- /dev/null +++ b/docs/docs/integrations/chat/cloudflare_workersai.ipynb @@ -0,0 +1,264 @@ +{ + "cells": [ + { + "cell_type": "raw", + "id": "30373ae2-f326-4e96-a1f7-062f57396886", + "metadata": {}, + "source": [ + "---\n", + "sidebar_label: Cloudflare Workers AI\n", + "---" + ] + }, + { + "cell_type": "markdown", + "id": "f679592d", + "metadata": {}, + "source": [ + "# ChatCloudflareWorkersAI\n", + "\n", + "This will help you getting started with CloudflareWorkersAI [chat models](/docs/concepts/#chat-models). For detailed documentation of all available Cloudflare WorkersAI models head to the [API reference](https://developers.cloudflare.com/workers-ai/).\n", + "\n", + "\n", + "## Overview\n", + "### Integration details\n", + "\n", + "| Class | Package | Local | Serializable | [JS support](https://js.langchain.com/docs/integrations/chat/cloudflare_workersai) | Package downloads | Package latest |\n", + "| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n", + "| ChatCloudflareWorkersAI | langchain-community| ❌ | ❌ | ✅ | ❌ | ❌ |\n", + "\n", + "### Model features\n", + "| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n", + "| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n", + "| ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | \n", + "\n", + "## Setup\n", + "\n", + "- To access Cloudflare Workers AI models you'll need to create a Cloudflare account, get an account number and API key, and install the `langchain-community` package.\n", + "\n", + "\n", + "### Credentials\n", + "\n", + "\n", + "Head to [this document](https://developers.cloudflare.com/workers-ai/get-started/rest-api/) to sign up to Cloudflare Workers AI and generate an API key." + ] + }, + { + "cell_type": "markdown", + "id": "4a524cff", + "metadata": {}, + "source": [ + "If you want to get automated tracing of your model calls you can also set your [LangSmith](https://docs.smith.langchain.com/) API key by uncommenting below:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "71b53c25", + "metadata": {}, + "outputs": [], + "source": [ + "# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n", + "# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")" + ] + }, + { + "cell_type": "markdown", + "id": "777a8526", + "metadata": {}, + "source": [ + "### Installation\n", + "\n", + "The LangChain ChatCloudflareWorkersAI integration lives in the `langchain-community` package:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "54990998", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -qU langchain-community" + ] + }, + { + "cell_type": "markdown", + "id": "629ba46f", + "metadata": {}, + "source": [ + "## Instantiation\n", + "\n", + "Now we can instantiate our model object and generate chat completions:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec13c2d9", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.chat_models.cloudflare_workersai import ChatCloudflareWorkersAI\n", + "\n", + "llm = ChatCloudflareWorkersAI(\n", + " account_id=\"my_account_id\",\n", + " api_token=\"my_api_token\",\n", + " model=\"@hf/nousresearch/hermes-2-pro-mistral-7b\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "119b6732", + "metadata": {}, + "source": [ + "## Invocation" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "2438a906", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-11-07 15:55:14 - INFO - Sending prompt to Cloudflare Workers AI: {'prompt': 'role: system, content: You are a helpful assistant that translates English to French. Translate the user sentence.\\nrole: user, content: I love programming.', 'tools': None}\n" + ] + }, + { + "data": { + "text/plain": [ + "AIMessage(content='{\\'result\\': {\\'response\\': \\'Je suis un assistant virtuel qui peut traduire l\\\\\\'anglais vers le français. La phrase que vous avez dite est : \"J\\\\\\'aime programmer.\" En français, cela se traduit par : \"J\\\\\\'adore programmer.\"\\'}, \\'success\\': True, \\'errors\\': [], \\'messages\\': []}', additional_kwargs={}, response_metadata={}, id='run-838fd398-8594-4ca5-9055-03c72993caf6-0')" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "messages = [\n", + " (\n", + " \"system\",\n", + " \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n", + " ),\n", + " (\"human\", \"I love programming.\"),\n", + "]\n", + "ai_msg = llm.invoke(messages)\n", + "ai_msg" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "1b4911bd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'result': {'response': 'Je suis un assistant virtuel qui peut traduire l\\'anglais vers le français. La phrase que vous avez dite est : \"J\\'aime programmer.\" En français, cela se traduit par : \"J\\'adore programmer.\"'}, 'success': True, 'errors': [], 'messages': []}\n" + ] + } + ], + "source": [ + "print(ai_msg.content)" + ] + }, + { + "cell_type": "markdown", + "id": "111aa5d4", + "metadata": {}, + "source": [ + "## Chaining\n", + "\n", + "We can [chain](/docs/how_to/sequence/) our model with a prompt template like so:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "b2a14282", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-11-07 15:55:24 - INFO - Sending prompt to Cloudflare Workers AI: {'prompt': 'role: system, content: You are a helpful assistant that translates English to German.\\nrole: user, content: I love programming.', 'tools': None}\n" + ] + }, + { + "data": { + "text/plain": [ + "AIMessage(content=\"{'result': {'response': 'role: system, content: Das ist sehr nett zu hören! Programmieren lieben, ist eine interessante und anspruchsvolle Hobby- oder Berufsausrichtung. Wenn Sie englische Texte ins Deutsche übersetzen möchten, kann ich Ihnen helfen. Geben Sie bitte den englischen Satz oder die Übersetzung an, die Sie benötigen.'}, 'success': True, 'errors': [], 'messages': []}\", additional_kwargs={}, response_metadata={}, id='run-0d3be9a6-3d74-4dde-b49a-4479d6af00ef-0')" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain_core.prompts import ChatPromptTemplate\n", + "\n", + "prompt = ChatPromptTemplate.from_messages(\n", + " [\n", + " (\n", + " \"system\",\n", + " \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n", + " ),\n", + " (\"human\", \"{input}\"),\n", + " ]\n", + ")\n", + "\n", + "chain = prompt | llm\n", + "chain.invoke(\n", + " {\n", + " \"input_language\": \"English\",\n", + " \"output_language\": \"German\",\n", + " \"input\": \"I love programming.\",\n", + " }\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "e1f311bd", + "metadata": {}, + "source": [ + "## API reference\n", + "\n", + "For detailed documentation on `ChatCloudflareWorkersAI` features and configuration options, please refer to the [API reference](https://python.langchain.com/api_reference/community/chat_models/langchain_community.chat_models.cloudflare_workersai.html)." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/scripts/check_templates.py b/docs/scripts/check_templates.py index 2023cf00fdb64..4a029009552c3 100644 --- a/docs/scripts/check_templates.py +++ b/docs/scripts/check_templates.py @@ -44,7 +44,7 @@ def _get_headers(doc_dir: str) -> Iterable[str]: for cell in nb["cells"]: if cell["cell_type"] == "markdown": for line in cell["source"]: - if not line.startswith("##") or "TODO" in line: + if not line.startswith("## ") or "TODO" in line: continue header = line.strip() headers.append(header) diff --git a/libs/community/langchain_community/chat_models/cloudflare_workersai.py b/libs/community/langchain_community/chat_models/cloudflare_workersai.py new file mode 100644 index 0000000000000..ff14da9e55adf --- /dev/null +++ b/libs/community/langchain_community/chat_models/cloudflare_workersai.py @@ -0,0 +1,245 @@ +import logging +from operator import itemgetter +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Sequence, + Type, + Union, + cast, +) +from uuid import uuid4 + +import requests +from langchain.schema import AIMessage, ChatGeneration, ChatResult, HumanMessage +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import LanguageModelInput +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import ( + AIMessageChunk, + BaseMessage, + SystemMessage, + ToolCall, + ToolMessage, +) +from langchain_core.messages.tool import tool_call +from langchain_core.output_parsers import ( + JsonOutputParser, + PydanticOutputParser, +) +from langchain_core.output_parsers.base import OutputParserLike +from langchain_core.output_parsers.openai_tools import ( + JsonOutputKeyToolsParser, + PydanticToolsParser, +) +from langchain_core.runnables import Runnable, RunnablePassthrough +from langchain_core.runnables.base import RunnableMap +from langchain_core.tools import BaseTool +from langchain_core.utils.function_calling import convert_to_openai_tool +from langchain_core.utils.pydantic import is_basemodel_subclass +from pydantic import BaseModel, Field + +# Initialize logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +_logger = logging.getLogger(__name__) + + +def _is_pydantic_class(obj: Any) -> bool: + return isinstance(obj, type) and is_basemodel_subclass(obj) + + +def _convert_messages_to_cloudflare_messages( + messages: List[BaseMessage], +) -> List[Dict[str, Any]]: + """Convert LangChain messages to Cloudflare Workers AI format.""" + cloudflare_messages = [] + msg: Dict[str, Any] + for message in messages: + # Base structure for each message + msg = { + "role": "", + "content": message.content if isinstance(message.content, str) else "", + } + + # Determine role and additional fields based on message type + if isinstance(message, HumanMessage): + msg["role"] = "user" + elif isinstance(message, AIMessage): + msg["role"] = "assistant" + # If the AIMessage includes tool calls, format them as needed + if message.tool_calls: + tool_calls = [ + {"name": tool_call["name"], "arguments": tool_call["args"]} + for tool_call in message.tool_calls + ] + msg["tool_calls"] = tool_calls + elif isinstance(message, SystemMessage): + msg["role"] = "system" + elif isinstance(message, ToolMessage): + msg["role"] = "tool" + msg["tool_call_id"] = ( + message.tool_call_id + ) # Use tool_call_id if it's a ToolMessage + + # Add the formatted message to the list + cloudflare_messages.append(msg) + + return cloudflare_messages + + +def _get_tool_calls_from_response(response: requests.Response) -> List[ToolCall]: + """Get tool calls from ollama response.""" + tool_calls = [] + if "tool_calls" in response.json()["result"]: + for tc in response.json()["result"]["tool_calls"]: + tool_calls.append( + tool_call( + id=str(uuid4()), + name=tc["name"], + args=tc["arguments"], + ) + ) + return tool_calls + + +class ChatCloudflareWorkersAI(BaseChatModel): + """Custom chat model for Cloudflare Workers AI""" + + account_id: str = Field(...) + api_token: str = Field(...) + model: str = Field(...) + ai_gateway: str = "" + url: str = "" + base_url: str = "https://api.cloudflare.com/client/v4/accounts" + gateway_url: str = "https://gateway.ai.cloudflare.com/v1" + + def __init__(self, **kwargs: Any) -> None: + """Initialize with necessary credentials.""" + super().__init__(**kwargs) + if self.ai_gateway: + self.url = ( + f"{self.gateway_url}/{self.account_id}/" + f"{self.ai_gateway}/workers-ai/run/{self.model}" + ) + else: + self.url = f"{self.base_url}/{self.account_id}/ai/run/{self.model}" + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Generate a response based on the messages provided.""" + formatted_messages = _convert_messages_to_cloudflare_messages(messages) + + headers = {"Authorization": f"Bearer {self.api_token}"} + prompt = "\n".join( + f"role: {msg['role']}, content: {msg['content']}" + + (f", tools: {msg['tool_calls']}" if "tool_calls" in msg else "") + + ( + f", tool_call_id: {msg['tool_call_id']}" + if "tool_call_id" in msg + else "" + ) + for msg in formatted_messages + ) + + # Initialize `data` with `prompt` + data = { + "prompt": prompt, + "tools": kwargs["tools"] if "tools" in kwargs else None, + **{key: value for key, value in kwargs.items() if key not in ["tools"]}, + } + + # Ensure `tools` is a list if it's included in `kwargs` + if data["tools"] is not None and not isinstance(data["tools"], list): + data["tools"] = [data["tools"]] + + _logger.info(f"Sending prompt to Cloudflare Workers AI: {data}") + + response = requests.post(self.url, headers=headers, json=data) + tool_calls = _get_tool_calls_from_response(response) + ai_message = AIMessage( + content=str(response.json()), tool_calls=cast(AIMessageChunk, tool_calls) + ) + chat_generation = ChatGeneration(message=ai_message) + return ChatResult(generations=[chat_generation]) + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type, Callable[..., Any], BaseTool]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tools for use in model generation.""" + formatted_tools = [convert_to_openai_tool(tool) for tool in tools] + return super().bind(tools=formatted_tools, **kwargs) + + def with_structured_output( + self, + schema: Union[Dict, Type[BaseModel]], + *, + include_raw: bool = False, + method: Optional[Literal["json_mode", "function_calling"]] = "function_calling", + **kwargs: Any, + ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: + """Model wrapper that returns outputs formatted to match the given schema.""" + + if kwargs: + raise ValueError(f"Received unsupported arguments {kwargs}") + is_pydantic_schema = _is_pydantic_class(schema) + if method == "function_calling": + if schema is None: + raise ValueError( + "schema must be specified when method is 'function_calling'. " + "Received None." + ) + tool_name = convert_to_openai_tool(schema)["function"]["name"] + llm = self.bind_tools([schema], tool_choice=tool_name) + if is_pydantic_schema: + output_parser: OutputParserLike = PydanticToolsParser( + tools=[schema], # type: ignore[list-item] + first_tool_only=True, # type: ignore[list-item] + ) + else: + output_parser = JsonOutputKeyToolsParser( + key_name=tool_name, first_tool_only=True + ) + elif method == "json_mode": + llm = self.bind(response_format={"type": "json_object"}) + output_parser = ( + PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type] + if is_pydantic_schema + else JsonOutputParser() + ) + else: + raise ValueError( + f"Unrecognized method argument. Expected one of 'function_calling' or " + f"'json_mode'. Received: '{method}'" + ) + + if include_raw: + parser_assign = RunnablePassthrough.assign( + parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None + ) + parser_none = RunnablePassthrough.assign(parsed=lambda _: None) + parser_with_fallback = parser_assign.with_fallbacks( + [parser_none], exception_key="parsing_error" + ) + return RunnableMap(raw=llm) | parser_with_fallback + else: + return llm | output_parser + + @property + def _llm_type(self) -> str: + """Return the type of the LLM (for Langchain compatibility).""" + return "cloudflare-workers-ai" diff --git a/libs/community/tests/unit_tests/chat_models/test_cloudflare_workersai.py b/libs/community/tests/unit_tests/chat_models/test_cloudflare_workersai.py new file mode 100644 index 0000000000000..940616048de02 --- /dev/null +++ b/libs/community/tests/unit_tests/chat_models/test_cloudflare_workersai.py @@ -0,0 +1,78 @@ +"""Test CloudflareWorkersAI Chat API wrapper.""" + +from typing import Any, Dict, List, Type + +import pytest +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) +from langchain_standard_tests.unit_tests import ChatModelUnitTests + +from langchain_community.chat_models.cloudflare_workersai import ( + ChatCloudflareWorkersAI, + _convert_messages_to_cloudflare_messages, +) + + +class TestChatCloudflareWorkersAI(ChatModelUnitTests): + @property + def chat_model_class(self) -> Type[BaseChatModel]: + return ChatCloudflareWorkersAI + + @property + def chat_model_params(self) -> dict: + return { + "account_id": "my_account_id", + "api_token": "my_api_token", + "model": "@hf/nousresearch/hermes-2-pro-mistral-7b", + } + + +@pytest.mark.parametrize( + ("messages", "expected"), + [ + # Test case with a single HumanMessage + ( + [HumanMessage(content="Hello, AI!")], + [{"role": "user", "content": "Hello, AI!"}], + ), + # Test case with SystemMessage, HumanMessage, and AIMessage without tool calls + ( + [ + SystemMessage(content="System initialized."), + HumanMessage(content="Hello, AI!"), + AIMessage(content="Response from AI"), + ], + [ + {"role": "system", "content": "System initialized."}, + {"role": "user", "content": "Hello, AI!"}, + {"role": "assistant", "content": "Response from AI"}, + ], + ), + # Test case with ToolMessage and tool_call_id + ( + [ + ToolMessage( + content="Tool message content", tool_call_id="tool_call_123" + ), + ], + [ + { + "role": "tool", + "content": "Tool message content", + "tool_call_id": "tool_call_123", + } + ], + ), + ], +) +def test_convert_messages_to_cloudflare_format( + messages: List[BaseMessage], expected: List[Dict[str, Any]] +) -> None: + result = _convert_messages_to_cloudflare_messages(messages) + assert result == expected