Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jasonozuzu/pts 7374 langchain chat v2 (stream added) #103

Open
wants to merge 46 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
d0017ae
Add v2 request construction
jasonozuzu-cohere Oct 9, 2024
2fd16be
Add v2 get_generation_info
jasonozuzu-cohere Oct 9, 2024
bc0f3f2
Refactor cohere tool call conversion and usage metadata fetching for v2
jasonozuzu-cohere Oct 9, 2024
f074ca8
Add calls to refactored cohere tool call/usage metadata conversion to…
jasonozuzu-cohere Oct 9, 2024
220b570
Add default model logic
jasonozuzu-cohere Oct 9, 2024
812b9cf
Tweak v2 document request construction
jasonozuzu-cohere Oct 10, 2024
883981c
Refactoring for integration tests
jasonozuzu-cohere Oct 11, 2024
7d45d99
Refactor test_chat_models unit tests for new v2 tool formatting logic
jasonozuzu-cohere Oct 11, 2024
93ef9d5
Merge branch 'main' into jasonozuzu/pts-7374-langchain-chat-v2-NOSTREAM
jasonozuzu-cohere Oct 16, 2024
6ab2842
Bump ver. num, and add documents to generation_info
jasonozuzu-cohere Oct 31, 2024
aa9e566
Add support for v2 streaming
jasonozuzu-cohere Oct 31, 2024
8953389
Add deprecation warning for connectors
jasonozuzu-cohere Oct 31, 2024
e862669
Add documents to get_stream_info
jasonozuzu-cohere Oct 31, 2024
d15b94a
Minor cleanup
jasonozuzu-cohere Nov 11, 2024
046715a
Fix minor issues
jasonozuzu-cohere Nov 14, 2024
f65e0a6
Add missing cassettes, and fix pytest collection bug
jasonozuzu-cohere Nov 15, 2024
58ad738
Add test for tool call buffering in astream
jasonozuzu-cohere Nov 15, 2024
6bb7b3b
test debugging + coverage improvements, and replacing api key with pe…
jasonozuzu-cohere Nov 15, 2024
4d3bb05
Merge branch 'main' into jasonozuzu/pts-7374-langchain-chat-v2-NOSTREAM
jasonozuzu-cohere Nov 15, 2024
88b3396
Fix typo
jasonozuzu-cohere Nov 18, 2024
87bbb97
Merge branch 'jasonozuzu/pts-7374-langchain-chat-v2-NOSTREAM' of gith…
jasonozuzu-cohere Nov 18, 2024
f2ff354
Remove release candidate flag
jasonozuzu-cohere Nov 18, 2024
23daad7
Fix lint errors
jasonozuzu-cohere Nov 19, 2024
7a0b390
Fix failing tests, and refactor to remove v2 instance vars
jasonozuzu-cohere Nov 19, 2024
1e13faf
Reformat for linting
jasonozuzu-cohere Nov 20, 2024
fdbc1ea
Update raw_prompting comment
jasonozuzu-cohere Nov 20, 2024
fa31242
Remove cohere_tools_agent, and add tests for format_to_cohere_tools_v2
jasonozuzu-cohere Nov 20, 2024
b85d1f7
Remove debugging print statements
jasonozuzu-cohere Nov 20, 2024
2bf2f8f
Fix more lint errors
jasonozuzu-cohere Nov 20, 2024
01e0c95
Fix linting
jasonozuzu-cohere Nov 20, 2024
b0bef74
Merge branch 'main' into jasonozuzu/pts-7374-langchain-chat-v2-NOSTREAM
jasonozuzu-cohere Nov 20, 2024
d723047
Merge branch 'main' into jasonozuzu/pts-7374-langchain-chat-v2-NOSTREAM
jasonozuzu-cohere Nov 20, 2024
e6be1d3
Merge branch 'main' into jasonozuzu/pts-7374-langchain-chat-v2-NOSTREAM
jasonozuzu-cohere Nov 20, 2024
9d3b4b2
Update dependencies, and change release version
jasonozuzu-cohere Nov 21, 2024
9ce79fc
Trigger CI
jasonozuzu-cohere Nov 21, 2024
81c767d
Update release ver to 0.4.0, and improve formatting on tests
jasonozuzu-cohere Nov 25, 2024
c4c2554
Merge branch 'main' into jasonozuzu/pts-7374-langchain-chat-v2-NOSTREAM
jasonozuzu-cohere Nov 26, 2024
e84899f
Ensure release ver is correct post-merge
jasonozuzu-cohere Nov 26, 2024
ed7e9e7
Merge branch 'main' into jasonozuzu/pts-7374-langchain-chat-v2-NOSTREAM
jasonozuzu-cohere Nov 27, 2024
4b9f163
Fix linting issues
jasonozuzu-cohere Nov 27, 2024
b2b326f
Improve typing and add autospec to default_model mock
jasonozuzu-cohere Nov 28, 2024
72636f0
Refactor to improve typing
jasonozuzu-cohere Nov 28, 2024
2c6407c
Improve type hinting
jasonozuzu-cohere Nov 29, 2024
7de5ef1
Fix lint errors
jasonozuzu-cohere Nov 29, 2024
1504b6f
Ensure if statement is valid
jasonozuzu-cohere Nov 29, 2024
7ef8058
Ensure tool plan is passed to AIMessage in generate
jasonozuzu-cohere Nov 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
515 changes: 460 additions & 55 deletions libs/cohere/langchain_cohere/chat_models.py

Large diffs are not rendered by default.

99 changes: 99 additions & 0 deletions libs/cohere/langchain_cohere/cohere_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from cohere.types import (
Tool,
ToolV2,
ToolV2Function,
ToolCall,
ToolParameterDefinitionsValue,
ToolResult,
Expand Down Expand Up @@ -69,6 +71,12 @@ def _format_to_cohere_tools(
return [_convert_to_cohere_tool(tool) for tool in tools]


def _format_to_cohere_tools_v2(
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
) -> List[Dict[str, Any]]:
return [_convert_to_cohere_tool_v2(tool) for tool in tools]


def _format_to_cohere_tools_messages(
intermediate_steps: Sequence[Tuple[AgentAction, str]],
) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -176,6 +184,97 @@ def _convert_to_cohere_tool(
)


def _convert_to_cohere_tool_v2(
tool: Union[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
) -> Dict[str, Any]:
"""
Convert a BaseTool instance, JSON schema dict, or BaseModel type to a V2 Cohere tool.
"""
if isinstance(tool, dict):
if not all(k in tool for k in ("title", "description", "properties")):
raise ValueError(
"Unsupported dict type. Tool must be passed in as a BaseTool instance, JSON schema dict, or BaseModel type." # noqa: E501
)
return ToolV2(
type="function",
function=ToolV2Function(
name=tool.get("title"),
description=tool.get("description"),
parameters={
"type": "object",
"properties": {
param_name: {
"description": param_definition.get("description"),
"type": JSON_TO_PYTHON_TYPES.get(
param_definition.get("type"), param_definition.get("type")
),
}
for param_name, param_definition in tool.get("properties", {}).items()
},
"required": [param_name
for param_name, param_definition
in tool.get("properties", {}).items()
if "default" not in param_definition],
},
)
).dict()
elif (
(isinstance(tool, type) and issubclass(tool, BaseModel))
or callable(tool)
or isinstance(tool, BaseTool)
):
as_json_schema_function = convert_to_openai_function(tool)
parameters = as_json_schema_function.get("parameters", {})
properties = parameters.get("properties", {})
parameter_definitions = {}
required_params = []
for param_name, param_definition in properties.items():
if "type" in param_definition:
_type_str = param_definition.get("type")
_type = JSON_TO_PYTHON_TYPES.get(_type_str)
elif "anyOf" in param_definition:
_type_str = next(
(
t.get("type")
for t in param_definition.get("anyOf", [])
if t.get("type") != "null"
),
param_definition.get("type"),
)
_type = JSON_TO_PYTHON_TYPES.get(_type_str)
else:
_type = None
tool_definition = {
"type": _type,
"description": param_definition.get("description"),
}
parameter_definitions[param_name] = tool_definition
if param_name in parameters.get("required", []):
required_params.append(param_name)
return ToolV2(
type="function",
function=ToolV2Function(
name=as_json_schema_function.get("name"),
description=as_json_schema_function.get(
# The Cohere API requires the description field.
"description",
as_json_schema_function.get("name"),
),
parameters={
"type": "object",
"properties": {
**parameter_definitions,
},
"required": required_params,
},
)
).dict()
else:
raise ValueError(
f"Unsupported tool type {type(tool)}. Tool must be passed in as a BaseTool instance, JSON schema dict, or BaseModel type." # noqa: E501
)


class _CohereToolsAgentOutputParser(
BaseOutputParser[Union[List[AgentAction], AgentFinish]]
):
Expand Down
36 changes: 35 additions & 1 deletion libs/cohere/langchain_cohere/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import re
from typing import Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional

import cohere
from langchain_core.callbacks import (
Expand Down Expand Up @@ -30,6 +30,9 @@ def enforce_stop_tokens(text: str, stop: List[str]) -> str:

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from cohere.types import ListModelsResponse # noqa: F401


def completion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
Expand Down Expand Up @@ -58,6 +61,19 @@ class BaseCohere(Serializable):

client: Any = None #: :meta private:
async_client: Any = None #: :meta private:

chat_v2: Optional[Any] = None
"Cohere chat v2."

async_chat_v2: Optional[Any] = None
"Cohere async chat v2."

chat_stream_v2: Optional[Any] = None
"Cohere chat stream v2."

async_chat_stream_v2: Optional[Any] = None
"Cohere async chat stream v2."

model: Optional[str] = Field(default=None)
"""Model name to use."""

Expand All @@ -83,6 +99,15 @@ class BaseCohere(Serializable):
base_url: Optional[str] = None
"""Override the default Cohere API URL."""

def _get_default_model(self) -> str:
"""Fetches the current default model name."""
response = self.client.models.list(default_only=True, endpoint="chat") # type: "ListModelsResponse"
if not response.models:
raise Exception("invalid cohere list models response")
if not response.models[0].name:
raise Exception("invalid cohere list models response")
return response.models[0].name

@model_validator(mode="after")
def validate_environment(self) -> Self: # type: ignore[valid-type]
"""Validate that api key and python package exists in environment."""
Expand All @@ -98,12 +123,21 @@ def validate_environment(self) -> Self: # type: ignore[valid-type]
client_name=client_name,
base_url=self.base_url,
)
self.chat_v2 = self.client.v2.chat
self.chat_stream_v2 = self.client.v2.chat_stream

self.async_client = cohere.AsyncClient(
api_key=cohere_api_key,
client_name=client_name,
timeout=timeout_seconds,
base_url=self.base_url,
)
self.async_chat_v2 = self.async_client.v2.chat
self.async_chat_stream_v2 = self.async_client.v2.chat_stream

if not self.model:
self.model = self._get_default_model()

return self


Expand Down
2 changes: 0 additions & 2 deletions libs/cohere/langchain_cohere/rag_retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ def _get_docs(response: Any) -> List[Document]:
metadata={
"type": "model_response",
"citations": response.generation_info["citations"],
"search_results": response.generation_info["search_results"],
"search_queries": response.generation_info["search_queries"],
"token_count": response.generation_info["token_count"],
},
)
Expand Down
2 changes: 1 addition & 1 deletion libs/cohere/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-cohere"
version = "0.3.1"
version = "1.0.0-rc1"
description = "An integration package connecting Cohere and LangChain"
authors = []
readme = "README.md"
Expand Down
18 changes: 18 additions & 0 deletions libs/cohere/tests/clear_cassettes.py
jasonozuzu-cohere marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os
import shutil

def delete_cassettes_directories(root_dir):
for dirpath, dirnames, filenames in os.walk(root_dir):
for dirname in dirnames:
if dirname == "cassettes":
dir_to_delete = os.path.join(dirpath, dirname)
print(f"Deleting directory: {dir_to_delete}")
shutil.rmtree(dir_to_delete)

if __name__ == "__main__":
# Clear all cassettes directories in the integration_tests directory
# run using: python clear_cassettes.py
directory_to_clear = os.path.join(os.getcwd(), "integration_tests")
if not os.path.isdir(directory_to_clear):
raise Exception("integration_tests directory not found in current working directory")
delete_cassettes_directories(directory_to_clear)
11 changes: 10 additions & 1 deletion libs/cohere/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Dict
from typing import Dict, Generator, Optional

import pytest
from unittest.mock import MagicMock, patch
from langchain_cohere.llms import BaseCohere


@pytest.fixture(scope="module")
Expand All @@ -10,3 +12,10 @@ def vcr_config() -> Dict:
"filter_headers": [("Authorization", None)],
"ignore_hosts": ["storage.googleapis.com"],
}

@pytest.fixture
def patch_base_cohere_get_default_model() -> Generator[Optional[MagicMock], None, None]:
with patch.object(
BaseCohere, "_get_default_model", return_value="command-r-plus"
) as mock_get_default_model:
yield mock_get_default_model
Loading