Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cohere[minor]: Add support for tool calling cohere #5810

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/core_docs/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,4 @@ docs/how_to/assign.mdx
docs/how_to/agent_executor.md
docs/how_to/agent_executor.mdx
docs/integrations/llms/mistral.md
docs/integrations/llms/mistral.mdx
docs/integrations/llms/mistral.mdx
8 changes: 8 additions & 0 deletions docs/core_docs/docs/integrations/chat/cohere.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ import StatefulChatExample from "@examples/models/chat/cohere/stateful_conversat
You can see the LangSmith traces from this example [here](https://smith.langchain.com/public/8e67b05a-4e63-414e-ac91-a91acf21b262/r) and [here](https://smith.langchain.com/public/50fabc25-46fe-4727-a59c-7e4eb0de8e70/r)
:::

### Tools

The Cohere API supports tool calling, along with multi-hop-tool calling. The following example demonstrates how to call tools:

import CohereToolCaling from "@examples/models/chat/cohere/tool_calling.ts";

<CodeBlock language="typescript">{CohereToolCaling}</CodeBlock>

### RAG

Cohere also comes out of the box with RAG support.
Expand Down
1 change: 0 additions & 1 deletion examples/src/models/chat/cohere/chat_cohere.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import { ChatPromptTemplate } from "@langchain/core/prompts";

const model = new ChatCohere({
apiKey: process.env.COHERE_API_KEY, // Default
model: "command", // Default
});
const prompt = ChatPromptTemplate.fromMessages([
["ai", "You are a helpful assistant"],
Expand Down
1 change: 0 additions & 1 deletion examples/src/models/chat/cohere/chat_stream_cohere.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import { StringOutputParser } from "@langchain/core/output_parsers";

const model = new ChatCohere({
apiKey: process.env.COHERE_API_KEY, // Default
model: "command", // Default
});
const prompt = ChatPromptTemplate.fromMessages([
["ai", "You are a helpful assistant"],
Expand Down
1 change: 0 additions & 1 deletion examples/src/models/chat/cohere/connectors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import { HumanMessage } from "@langchain/core/messages";

const model = new ChatCohere({
apiKey: process.env.COHERE_API_KEY, // Default
model: "command", // Default
});

const response = await model.invoke(
Expand Down
1 change: 0 additions & 1 deletion examples/src/models/chat/cohere/rag.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import { HumanMessage } from "@langchain/core/messages";

const model = new ChatCohere({
apiKey: process.env.COHERE_API_KEY, // Default
model: "command", // Default
});

const documents = [
Expand Down
1 change: 0 additions & 1 deletion examples/src/models/chat/cohere/stateful_conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import { HumanMessage } from "@langchain/core/messages";

const model = new ChatCohere({
apiKey: process.env.COHERE_API_KEY, // Default
model: "command", // Default
});

const conversationId = `demo_test_id-${Math.random()}`;
Expand Down
79 changes: 79 additions & 0 deletions examples/src/models/chat/cohere/tool_calling.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import { ChatCohere } from "@langchain/cohere";
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there! 👋 This is a friendly flag to highlight that the recent code change explicitly accesses an environment variable using process.env. This is an important point for maintainers to review. Keep up the great work! 🚀

import { HumanMessage, ToolMessage } from "@langchain/core/messages";
import { convertToCohereTool } from "@langchain/core/utils/function_calling";
import { z } from "zod";
import { DynamicStructuredTool } from "@langchain/core/tools";

const model = new ChatCohere({
apiKey: process.env.COHERE_API_KEY, // Default
});

const magicFunctionTool = new DynamicStructuredTool({
name: "magic_function",
description: "Apply a magic function to the input number",
schema: z.object({
num: z.number().describe("The number to apply the magic function for"),
}),
func: async ({ num }) => {
return `The magic function of ${num} is ${num + 5}`;
},
});

const tools = [magicFunctionTool];
const modelWithTools = model.bind({
tools: tools.map(convertToCohereTool),
});

let messages = [new HumanMessage("What is the magic function of number 5?")];
const response = await modelWithTools.invoke(
messages,
);
/**
response: AIMessage {
lc_serializable: true,
lc_kwargs: {
content: 'I will use the magic_function tool to answer this question.',
additional_kwargs: {
response_id: 'd0b189e5-3dbf-493c-93f8-99ed4b01d96d',
generationId: '8982a68f-c64c-48f8-bf12-0b4bea0018b6',
chatHistory: [Array],
finishReason: 'COMPLETE',
meta: [Object],
toolCalls: [Array]
},
tool_calls: [ [Object] ],
usage_metadata: { input_tokens: 920, output_tokens: 54, total_tokens: 974 },
invalid_tool_calls: [],
response_metadata: {}
},
lc_namespace: [ 'langchain_core', 'messages' ],
content: 'I will use the magic_function tool to answer this question.',
name: undefined,
additional_kwargs: {
response_id: 'd0b189e5-3dbf-493c-93f8-99ed4b01d96d',
generationId: '8982a68f-c64c-48f8-bf12-0b4bea0018b6',
chatHistory: [ [Object], [Object] ],
finishReason: 'COMPLETE',
meta: { apiVersion: [Object], billedUnits: [Object], tokens: [Object] },
toolCalls: [ [Object] ]
},
response_metadata: {
estimatedTokenUsage: { completionTokens: 54, promptTokens: 920, totalTokens: 974 },
response_id: 'd0b189e5-3dbf-493c-93f8-99ed4b01d96d',
generationId: '8982a68f-c64c-48f8-bf12-0b4bea0018b6',
chatHistory: [ [Object], [Object] ],
finishReason: 'COMPLETE',
meta: { apiVersion: [Object], billedUnits: [Object], tokens: [Object] },
toolCalls: [ [Object] ]
},
tool_calls: [
{
name: 'magic_function',
args: [Object],
id: '4ec98550-ba9a-4043-adfe-566230e5'
}
],
invalid_tool_calls: [],
usage_metadata: { input_tokens: 920, output_tokens: 54, total_tokens: 974 }
}
*/
33 changes: 33 additions & 0 deletions langchain-core/src/utils/function_calling.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,39 @@ export function convertToOpenAIFunction(
};
}

/**
* Formats a `StructuredTool` instance into a format that is compatible
* with OpenAI function calling. It uses the `zodToJsonSchema`
* function to convert the schema of the `StructuredTool` into a JSON
* schema, which is then used as the parameters for the Cohere function.
*/
export function convertToCohereTool(
tool: StructuredToolInterface
): {name: string, description: string, parameterDefinitions: Record<string, any>} { /* eslint-disable-line @typescript-eslint/no-explicit-any */
const parameterDefinitionsFromZod = zodToJsonSchema(tool.schema);
const parameterDefinitionsProperties = "properties" in parameterDefinitionsFromZod ? parameterDefinitionsFromZod.properties : {};
let parameterDefinitionsRequired = "required" in parameterDefinitionsFromZod ? parameterDefinitionsFromZod.required : [];

const parameterDefinitionsFinal: Record<string, any> = {}; /* eslint-disable-line @typescript-eslint/no-explicit-any */

// Iterate through all properties
Object.keys(parameterDefinitionsProperties).forEach(propertyName => {
// Create the property in the new object
parameterDefinitionsFinal[propertyName] = parameterDefinitionsProperties[propertyName];
// Set the required property based on the 'required' array
if (parameterDefinitionsRequired === undefined) {
parameterDefinitionsRequired = [];
}
parameterDefinitionsFinal[propertyName].required = parameterDefinitionsRequired.includes(propertyName);
});

return {
name: tool.name,
description: tool.description,
parameterDefinitions: parameterDefinitionsFinal,
};
}

/**
* Formats a `StructuredTool` instance into a format that is compatible
* with OpenAI tool calling. It uses the `zodToJsonSchema`
Expand Down
2 changes: 1 addition & 1 deletion langchain-core/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"ES2022.Object",
"DOM"
],
"module": "ES2020",
"module": "NodeNext",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you revert these? Should not be in this PR.

"moduleResolution": "nodenext",
"esModuleInterop": true,
"declaration": true,
Expand Down
2 changes: 1 addition & 1 deletion langchain/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"ES2022.Object",
"DOM"
],
"module": "ES2020",
"module": "NodeNext",
"moduleResolution": "nodenext",
"esModuleInterop": true,
"declaration": true,
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain-anthropic/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"ES2022.Object",
"DOM"
],
"module": "ES2020",
"module": "NodeNext",
"moduleResolution": "nodenext",
"esModuleInterop": true,
"declaration": true,
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain-azure-openai/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"ES2022.Object",
"DOM"
],
"module": "ES2020",
"module": "NodeNext",
"moduleResolution": "nodenext",
"esModuleInterop": true,
"declaration": true,
Expand Down
6 changes: 4 additions & 2 deletions libs/langchain-cohere/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"test": "NODE_OPTIONS=--experimental-vm-modules jest --testPathIgnorePatterns=\\.int\\.test.ts --testTimeout 30000 --maxWorkers=50%",
"test:watch": "NODE_OPTIONS=--experimental-vm-modules jest --watch --testPathIgnorePatterns=\\.int\\.test.ts",
"test:single": "NODE_OPTIONS=--experimental-vm-modules yarn run jest --config jest.config.cjs --testTimeout 100000",
"test:int": "NODE_OPTIONS=--experimental-vm-modules jest --testPathPattern=\\.int\\.test.ts --testTimeout 100000 --maxWorkers=50%",
"test:int": "NODE_OPTIONS=--experimental-vm-modules jest --testPathPattern=\\chat_models.int\\.test.ts --testTimeout 100000 --maxWorkers=50%",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert

"test:standard:unit": "NODE_OPTIONS=--experimental-vm-modules jest --testPathPattern=\\.standard\\.test.ts --testTimeout 100000 --maxWorkers=50%",
"test:standard:int": "NODE_OPTIONS=--experimental-vm-modules jest --testPathPattern=\\.standard\\.int\\.test.ts --testTimeout 100000 --maxWorkers=50%",
"test:standard": "yarn test:standard:unit && yarn test:standard:int",
Expand All @@ -36,7 +36,9 @@
"license": "MIT",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there! I noticed that a new dependency "zod" has been added to the "dependencies" section, which is a change in the project's dependencies. This comment is to flag the change for maintainers to review. Great work!

"dependencies": {
"@langchain/core": ">=0.2.5 <0.3.0",
"cohere-ai": "^7.10.5"
"cohere-ai": "^7.10.5",
"uuid": "^10.0.0",
"zod": "^3.23.8"
},
"devDependencies": {
"@jest/globals": "^29.5.0",
Expand Down
Loading