From db7b51830880ab391e3e37d8baf9dce631cd30c7 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Tue, 12 Nov 2024 10:29:43 -0500 Subject: [PATCH] update anthropic --- .../langchain_anthropic/chat_models.py | 47 ++++++++++--------- .../integration_tests/test_chat_models.py | 31 +++++++++--- 2 files changed, 48 insertions(+), 30 deletions(-) diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index 0784ff2bdeb49..310380a419968 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -1114,37 +1114,38 @@ class AnswerWithJustification(BaseModel): return llm | output_parser @beta() - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + def get_num_tokens_from_messages( + self, + messages: List[BaseMessage], + tools: Optional[ + Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]] + ] = None, + ) -> int: """Count tokens in a sequence of input messages. + Args: + messages: The message inputs to tokenize. + tools: If provided, sequence of dict, BaseModel, function, or BaseTools + to be converted to tool schemas. + .. versionchanged:: 0.2.5 Uses Anthropic's token counting API to count tokens in messages. See: - https://docs.anthropic.com/en/api/messages-count-tokens + https://docs.anthropic.com/en/docs/build-with-claude/token-counting """ - if any( - isinstance(tool, ToolMessage) - or (isinstance(tool, AIMessage) and tool.tool_calls) - for tool in messages - ): - raise NotImplementedError( - "get_num_tokens_from_messages does not yet support counting tokens " - "in tool calls." - ) formatted_system, formatted_messages = _format_messages(messages) + kwargs: Dict[str, Any] = {} if isinstance(formatted_system, str): - response = self._client.beta.messages.count_tokens( - betas=["token-counting-2024-11-01"], - model=self.model, - system=formatted_system, - messages=formatted_messages, # type: ignore[arg-type] - ) - else: - response = self._client.beta.messages.count_tokens( - betas=["token-counting-2024-11-01"], - model=self.model, - messages=formatted_messages, # type: ignore[arg-type] - ) + kwargs["system"] = formatted_system + if tools: + kwargs["tools"] = [convert_to_anthropic_tool(tool) for tool in tools] + + response = self._client.beta.messages.count_tokens( + betas=["token-counting-2024-11-01"], + model=self.model, + messages=formatted_messages, # type: ignore[arg-type] + **kwargs, + ) return response.input_tokens diff --git a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py index 47736880b2533..5c2295492632d 100644 --- a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py @@ -508,18 +508,34 @@ def test_with_structured_output() -> None: def test_get_num_tokens_from_messages() -> None: - llm = ChatAnthropic(model="claude-3-5-haiku-20241022") # type: ignore[call-arg] + llm = ChatAnthropic(model="claude-3-5-sonnet-20241022") # type: ignore[call-arg] # Test simple case messages = [ - SystemMessage(content="You are an assistant."), - HumanMessage(content="What is the weather in SF?"), + SystemMessage(content="You are a scientist"), + HumanMessage(content="Hello, Claude"), ] num_tokens = llm.get_num_tokens_from_messages(messages) assert num_tokens > 0 - # Test tool use (not yet supported) + # Test tool use + @tool(parse_docstring=True) + def get_weather(location: str) -> str: + """Get the current weather in a given location + + Args: + location: The city and state, e.g. San Francisco, CA + """ + return "Sunny" + messages = [ + HumanMessage(content="What's the weather like in San Francisco?"), + ] + num_tokens = llm.get_num_tokens_from_messages(messages, tools=[get_weather]) + assert num_tokens > 0 + + messages = [ + HumanMessage(content="What's the weather like in San Francisco?"), AIMessage( content=[ {"text": "Let's see.", "type": "text"}, @@ -538,10 +554,11 @@ def test_get_num_tokens_from_messages() -> None: "type": "tool_call", }, ], - ) + ), + ToolMessage(content="Sunny", tool_call_id="toolu_01V6d6W32QGGSmQm4BT98EKk"), ] - with pytest.raises(NotImplementedError): - num_tokens = llm.get_num_tokens_from_messages(messages) + num_tokens = llm.get_num_tokens_from_messages(messages, tools=[get_weather]) + assert num_tokens > 0 class GetWeather(BaseModel):