Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IMPROVEMENT: langchain anthropic nits #13835

Merged
merged 9 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions libs/anthropic/langchain_anthropic/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Anthropic chat models."""
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, cast

from langchain_core.callbacks import (
Expand Down Expand Up @@ -80,9 +81,13 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
Example:
.. code-block:: python

import anthropic
from langchain.chat_models import ChatAnthropic
model = ChatAnthropic(model="<model_name>", anthropic_api_key="my-api-key")
from langchain_anthropic import ChatAnthropic

model = ChatAnthropic(
model="claude-2",
anthropic_api_key="<my-api-key>",
max_tokens_to_sample=1024,
)
"""

class Config:
Expand All @@ -106,11 +111,13 @@ def is_lc_serializable(cls) -> bool:
return True

def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:
"""Format a list of messages into a full prompt for the Anthropic model
"""Format a list of messages into a full prompt for the Anthropic model.

Args:
messages (List[BaseMessage]): List of BaseMessage to combine.

Returns:
str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags.
String with necessary HUMAN_PROMPT and AI_PROMPT tags.
"""
prompt_params = {}
if self.HUMAN_PROMPT:
Expand All @@ -120,6 +127,14 @@ def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:
return convert_messages_to_prompt_anthropic(messages=messages, **prompt_params)

def convert_prompt(self, prompt: PromptValue) -> str:
"""Format a PromptValue into a string prompt for the Anthropic model.

Args:
prompt (PromptValue): The prompt to convert.

Returns:
String with necessary HUMAN_PROMPT and AI_PROMPT tags.
"""
return self._convert_messages_to_prompt(prompt.to_messages())

def _stream(
Expand Down Expand Up @@ -216,6 +231,7 @@ async def _agenerate(

def get_num_tokens(self, text: str) -> int:
"""Calculate number of tokens."""
if not self.count_tokens:
raise NameError("Please ensure the anthropic package is loaded")
return self.count_tokens(text)
if self.count_tokens is not None:
return self.count_tokens(text)
else:
return self.client.count_tokens(text)
72 changes: 45 additions & 27 deletions libs/anthropic/langchain_anthropic/llms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Anthropic LLMs."""
import os
import re
import warnings
Expand Down Expand Up @@ -47,20 +48,22 @@ class _AnthropicCommon(BaseLanguageModel):
top_p: Optional[float] = None
"""Total probability mass of tokens to consider at each step."""

streaming: bool = False
"""Whether to stream the results."""

default_request_timeout: Optional[float] = None
default_request_timeout: Optional[float] = Field(default=None, alias="timeout")
"""Timeout for requests to Anthropic Completion API. Default is 600 seconds."""

anthropic_api_url: Optional[str] = None
anthropic_api_url: Optional[str] = Field(default=None, alias="base_url")
"""Base API url."""

anthropic_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""Automatically inferred from env var `ANTHROPIC_API_KEY` if not provided."""

anthropic_api_key: Optional[SecretStr] = None
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Additional keyword arguments to pass in when invoking model."""

streaming: bool = False
HUMAN_PROMPT: Optional[str] = None
AI_PROMPT: Optional[str] = None
count_tokens: Optional[Callable[[str], int]] = None
model_kwargs: Dict[str, Any] = Field(default_factory=dict)

@root_validator(pre=True)
def build_extra(cls, values: Dict) -> Dict:
Expand Down Expand Up @@ -94,24 +97,35 @@ def validate_environment(cls, values: Dict) -> Dict:
import anthropic

check_package_version("anthropic", gte_version="0.3")

base_url = values["anthropic_api_url"]
api_key = cast(SecretStr, values["anthropic_api_key"]).get_secret_value()
timeout = values["default_request_timeout"]
values["client"] = anthropic.Anthropic(
base_url=values["anthropic_api_url"],
api_key=cast(SecretStr, values["anthropic_api_key"]).get_secret_value(),
timeout=values["default_request_timeout"],
base_url=base_url,
api_key=api_key,
timeout=timeout,
)
values["async_client"] = anthropic.AsyncAnthropic(
base_url=values["anthropic_api_url"],
api_key=cast(SecretStr, values["anthropic_api_key"]).get_secret_value(),
timeout=values["default_request_timeout"],
base_url=base_url,
api_key=api_key,
timeout=timeout,
)
values["HUMAN_PROMPT"] = (
values["HUMAN_PROMPT"]
if values["HUMAN_PROMPT"] is not None
else anthropic.HUMAN_PROMPT
)
values["AI_PROMPT"] = (
values["AI_PROMPT"]
if values["AI_PROMPT"] is not None
else anthropic.AI_PROMPT
)
values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT
values["AI_PROMPT"] = anthropic.AI_PROMPT
values["count_tokens"] = values["client"].count_tokens

except ImportError:
raise ImportError(
"Could not import anthropic python package. "
"Please it install it with `pip install anthropic`."
"Please it install it with `pip install -U anthropic`."
)
return values

Expand All @@ -133,7 +147,7 @@ def _default_params(self) -> Mapping[str, Any]:
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {**{}, **self._default_params}
return self._default_params

def _get_anthropic_stop(self, stop: Optional[List[str]] = None) -> List[str]:
if not self.HUMAN_PROMPT or not self.AI_PROMPT:
Expand All @@ -158,20 +172,23 @@ class Anthropic(LLM, _AnthropicCommon):
Example:
.. code-block:: python

import anthropic
from langchain.llms import Anthropic
from langchain_anthropic import Anthropic

model = Anthropic(model="<model_name>", anthropic_api_key="my-api-key")
model = Anthropic(
model="claude-2",
anthropic_api_key="<my-api-key>",
max_tokens_to_sample=1024,
)

# Simplest invocation, automatically wrapped with HUMAN_PROMPT
# and AI_PROMPT.
response = model("What are the biggest risks facing humanity?")
response = model.invoke("What are the biggest risks facing humanity?")

# Or if you want to use the chat mode, build a few-shot-prompt, or
# put words in the Assistant's mouth, use HUMAN_PROMPT and AI_PROMPT:
raw_prompt = "What are the biggest risks facing humanity?"
prompt = f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}"
response = model(prompt)
response = model.invoke(prompt)
"""

class Config:
Expand All @@ -185,7 +202,7 @@ def raise_warning(cls, values: Dict) -> Dict:
"""Raise warning that this class is deprecated."""
warnings.warn(
"This Anthropic LLM is deprecated. "
"Please use `from langchain.chat_models import ChatAnthropic` instead"
"Please use `from langchain_anthropic import ChatAnthropic` instead"
)
return values

Expand Down Expand Up @@ -351,6 +368,7 @@ async def _astream(

def get_num_tokens(self, text: str) -> int:
"""Calculate number of tokens."""
if not self.count_tokens:
raise NameError("Please ensure the anthropic package is loaded")
return self.count_tokens(text)
if self.count_tokens is not None:
return self.count_tokens(text)
else:
return self.client.count_tokens(text)
43 changes: 34 additions & 9 deletions libs/anthropic/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,37 +13,62 @@
os.environ["ANTHROPIC_API_KEY"] = "foo"


def test_anthropic_model_name_param() -> None:
def test_model_name_param() -> None:
llm = ChatAnthropic(model_name="foo")
assert llm.model == "foo"


def test_anthropic_model_param() -> None:
def test_model_param() -> None:
llm = ChatAnthropic(model="foo")
assert llm.model == "foo"


def test_anthropic_model_kwargs() -> None:
def test_model_kwargs() -> None:
llm = ChatAnthropic(model_kwargs={"foo": "bar"})
assert llm.model_kwargs == {"foo": "bar"}


def test_anthropic_invalid_model_kwargs() -> None:
def test_invalid_model_kwargs() -> None:
with pytest.raises(ValueError):
ChatAnthropic(model_kwargs={"max_tokens_to_sample": 5})


def test_anthropic_incorrect_field() -> None:
def test_incorrect_field() -> None:
with pytest.warns(match="not default parameter"):
llm = ChatAnthropic(foo="bar")
assert llm.model_kwargs == {"foo": "bar"}


def test_anthropic_initialization() -> None:
def test_initialization() -> None:
"""Test anthropic initialization."""
# Verify that chat anthropic can be initialized using a secret key provided
# as a parameter rather than an environment variable.
ChatAnthropic(model="test", anthropic_api_key="test")
# No params.
ChatAnthropic()

# All params.
ChatAnthropic(
model="test",
max_tokens_to_sample=1000,
temperature=0.2,
top_k=2,
top_p=0.9,
default_request_timeout=123,
anthropic_api_url="foobar.com",
anthropic_api_key="test",
model_kwargs={"fake_param": 2},
)

# Alias params
ChatAnthropic(
model_name="test",
timeout=123,
base_url="foobar.com",
api_key="test",
)


def test_get_num_tokens() -> None:
chat = ChatAnthropic(model="test", anthropic_api_key="test")
assert chat.get_num_tokens("Hello claude") > 0


@pytest.mark.parametrize(
Expand Down
17 changes: 11 additions & 6 deletions libs/anthropic/tests/unit_tests/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,39 @@
os.environ["ANTHROPIC_API_KEY"] = "foo"


def test_anthropic_model_name_param() -> None:
def test_model_name_param() -> None:
llm = Anthropic(model_name="foo")
assert llm.model == "foo"


def test_anthropic_model_param() -> None:
def test_model_param() -> None:
llm = Anthropic(model="foo")
assert llm.model == "foo"


def test_anthropic_model_kwargs() -> None:
def test_model_kwargs() -> None:
llm = Anthropic(model_kwargs={"foo": "bar"})
assert llm.model_kwargs == {"foo": "bar"}


def test_anthropic_invalid_model_kwargs() -> None:
def test_invalid_model_kwargs() -> None:
with pytest.raises(ValueError):
Anthropic(model_kwargs={"max_tokens_to_sample": 5})


def test_anthropic_incorrect_field() -> None:
def test_incorrect_field() -> None:
with pytest.warns(match="not default parameter"):
llm = Anthropic(foo="bar")
assert llm.model_kwargs == {"foo": "bar"}


def test_anthropic_initialization() -> None:
def test_initialization() -> None:
"""Test anthropic initialization."""
# Verify that chat anthropic can be initialized using a secret key provided
# as a parameter rather than an environment variable.
Anthropic(model="test", anthropic_api_key="test")


def test_get_num_tokens() -> None:
llm = Anthropic(model="test", anthropic_api_key="test")
assert llm.get_num_tokens("Hello claude") > 0