Skip to content

Commit

Permalink
update anthropic
Browse files Browse the repository at this point in the history
  • Loading branch information
ccurme committed Nov 12, 2024
1 parent e4bfc84 commit db7b518
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 30 deletions.
47 changes: 24 additions & 23 deletions libs/partners/anthropic/langchain_anthropic/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,37 +1114,38 @@ class AnswerWithJustification(BaseModel):
return llm | output_parser

@beta()
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
def get_num_tokens_from_messages(
self,
messages: List[BaseMessage],
tools: Optional[
Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]]
] = None,
) -> int:
"""Count tokens in a sequence of input messages.
Args:
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
.. versionchanged:: 0.2.5
Uses Anthropic's token counting API to count tokens in messages. See:
https://docs.anthropic.com/en/api/messages-count-tokens
https://docs.anthropic.com/en/docs/build-with-claude/token-counting
"""
if any(
isinstance(tool, ToolMessage)
or (isinstance(tool, AIMessage) and tool.tool_calls)
for tool in messages
):
raise NotImplementedError(
"get_num_tokens_from_messages does not yet support counting tokens "
"in tool calls."
)
formatted_system, formatted_messages = _format_messages(messages)
kwargs: Dict[str, Any] = {}
if isinstance(formatted_system, str):
response = self._client.beta.messages.count_tokens(
betas=["token-counting-2024-11-01"],
model=self.model,
system=formatted_system,
messages=formatted_messages, # type: ignore[arg-type]
)
else:
response = self._client.beta.messages.count_tokens(
betas=["token-counting-2024-11-01"],
model=self.model,
messages=formatted_messages, # type: ignore[arg-type]
)
kwargs["system"] = formatted_system
if tools:
kwargs["tools"] = [convert_to_anthropic_tool(tool) for tool in tools]

response = self._client.beta.messages.count_tokens(
betas=["token-counting-2024-11-01"],
model=self.model,
messages=formatted_messages, # type: ignore[arg-type]
**kwargs,
)
return response.input_tokens


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -508,18 +508,34 @@ def test_with_structured_output() -> None:


def test_get_num_tokens_from_messages() -> None:
llm = ChatAnthropic(model="claude-3-5-haiku-20241022") # type: ignore[call-arg]
llm = ChatAnthropic(model="claude-3-5-sonnet-20241022") # type: ignore[call-arg]

# Test simple case
messages = [
SystemMessage(content="You are an assistant."),
HumanMessage(content="What is the weather in SF?"),
SystemMessage(content="You are a scientist"),
HumanMessage(content="Hello, Claude"),
]
num_tokens = llm.get_num_tokens_from_messages(messages)
assert num_tokens > 0

# Test tool use (not yet supported)
# Test tool use
@tool(parse_docstring=True)
def get_weather(location: str) -> str:
"""Get the current weather in a given location
Args:
location: The city and state, e.g. San Francisco, CA
"""
return "Sunny"

messages = [
HumanMessage(content="What's the weather like in San Francisco?"),
]
num_tokens = llm.get_num_tokens_from_messages(messages, tools=[get_weather])
assert num_tokens > 0

messages = [
HumanMessage(content="What's the weather like in San Francisco?"),
AIMessage(
content=[
{"text": "Let's see.", "type": "text"},
Expand All @@ -538,10 +554,11 @@ def test_get_num_tokens_from_messages() -> None:
"type": "tool_call",
},
],
)
),
ToolMessage(content="Sunny", tool_call_id="toolu_01V6d6W32QGGSmQm4BT98EKk"),
]
with pytest.raises(NotImplementedError):
num_tokens = llm.get_num_tokens_from_messages(messages)
num_tokens = llm.get_num_tokens_from_messages(messages, tools=[get_weather])
assert num_tokens > 0


class GetWeather(BaseModel):
Expand Down

0 comments on commit db7b518

Please sign in to comment.