Skip to content

Commit

Permalink
Merge branch 'main' into brace/standard-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul authored May 31, 2024
2 parents c6addda + cc80b12 commit 40b8109
Show file tree
Hide file tree
Showing 33 changed files with 1,328 additions and 135 deletions.
Binary file not shown.
Binary file not shown.
Binary file not shown.
17 changes: 6 additions & 11 deletions docs/core_docs/docs/integrations/chat/mistral.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,16 @@ import CodeBlock from "@theme/CodeBlock";
# ChatMistralAI

[Mistral AI](https://mistral.ai/) is a research organization and hosting platform for LLMs.
They're most known for their family of 7B models ([`mistral7b` // `mistral-tiny`](https://mistral.ai/news/announcing-mistral-7b/), [`mixtral8x7b` // `mistral-small`](https://mistral.ai/news/mixtral-of-experts/)).

The LangChain implementation of Mistral's models uses their hosted generation API, making it easier to access their models without needing to run them locally.

## Models

Mistral's API offers access to two of their open source, and proprietary models:
:::tip
Want to run Mistral's models locally? Check out our [Ollama integration](/docs/integrations/chat/ollama).
:::

- `open-mistral-7b` (aka `mistral-tiny-2312`)
- `open-mixtral-8x7b` (aka `mistral-small-2312`)
- `mistral-small-latest` (aka `mistral-small-2402`) (default)
- `mistral-medium-latest` (aka `mistral-medium-2312`)
- `mistral-large-latest` (aka `mistral-large-2402`)
## Models

See [this page](https://docs.mistral.ai/guides/model-selection/) for an up to date list.
Mistral's API offers access to two of their open source, and proprietary models.
See [this page](https://docs.mistral.ai/getting-started/models/) for an up to date list.

## Setup

Expand Down
13 changes: 13 additions & 0 deletions docs/core_docs/docs/integrations/chat/openai.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,16 @@ You can also use the callbacks system:
### With `.generate()`

<CodeBlock language="typescript">{OpenAIGenerationInfo}</CodeBlock>

### Streaming tokens

OpenAI supports streaming token counts via an opt-in call option. This can be set by passing `{ stream_options: { include_usage: true } }`.
Setting this call option will cause the model to return an additional chunk at the end of the stream, containing the token usage.

import OpenAIStreamTokens from "@examples/models/chat/integration_openai_stream_tokens.ts";

<CodeBlock language="typescript">{OpenAIStreamTokens}</CodeBlock>

:::tip
See the LangSmith trace [here](https://smith.langchain.com/public/66bf7377-cc69-4676-91b6-25929a05e8b7/r)
:::
160 changes: 160 additions & 0 deletions docs/core_docs/docs/integrations/llms/mistral.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# MistralAI\n",
"\n",
"```{=mdx}\n",
":::tip\n",
"Want to run Mistral's models locally? Check out our [Ollama integration](/docs/integrations/chat/ollama).\n",
":::\n",
"```\n",
"\n",
"Here's how you can initialize an `MistralAI` LLM instance:\n",
"\n",
"```{=mdx}\n",
"import IntegrationInstallTooltip from \"@mdx_components/integration_install_tooltip.mdx\";\n",
"import Npm2Yarn from \"@theme/Npm2Yarn\";\n",
"\n",
"<IntegrationInstallTooltip></IntegrationInstallTooltip>\n",
"\n",
"<Npm2Yarn>\n",
" @langchain/mistralai\n",
"</Npm2Yarn>\n",
"```\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"console.log('hello world');\n",
"```\n",
"This will output 'hello world' to the console.\n"
]
}
],
"source": [
"import { MistralAI } from \"@langchain/mistralai\";\n",
"\n",
"const model = new MistralAI({\n",
" model: \"codestral-latest\", // Defaults to \"codestral-latest\" if no model provided.\n",
" temperature: 0,\n",
" apiKey: \"YOUR-API-KEY\", // In Node.js defaults to process.env.MISTRAL_API_KEY\n",
"});\n",
"const res = await model.invoke(\n",
" \"You can print 'hello world' to the console in javascript like this:\\n```javascript\"\n",
");\n",
"console.log(res);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Since the Mistral LLM is a completions model, they also allow you to insert a `suffix` to the prompt. Suffixes can be passed via the call options when invoking a model like so:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"console.log('hello world');\n",
"```\n"
]
}
],
"source": [
"const res = await model.invoke(\n",
" \"You can print 'hello world' to the console in javascript like this:\\n```javascript\", {\n",
" suffix: \"```\"\n",
" }\n",
");\n",
"console.log(res);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As seen in the first example, the model generated the requested `console.log('hello world')` code snippet, but also included extra unwanted text. By adding a suffix, we can constrain the model to only complete the prompt up to the suffix (in this case, three backticks). This allows us to easily parse the completion and extract only the desired response without the suffix using a custom output parser."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"console.log('hello world');\n",
"\n"
]
}
],
"source": [
"import { MistralAI } from \"@langchain/mistralai\";\n",
"\n",
"const model = new MistralAI({\n",
" model: \"codestral-latest\",\n",
" temperature: 0,\n",
" apiKey: \"YOUR-API-KEY\",\n",
"});\n",
"\n",
"const suffix = \"```\";\n",
"\n",
"const customOutputParser = (input: string) => {\n",
" if (input.includes(suffix)) {\n",
" return input.split(suffix)[0];\n",
" }\n",
" throw new Error(\"Input does not contain suffix.\")\n",
"};\n",
"\n",
"const res = await model.invoke(\n",
" \"You can print 'hello world' to the console in javascript like this:\\n```javascript\", {\n",
" suffix,\n",
" }\n",
");\n",
"\n",
"console.log(customOutputParser(res));"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "TypeScript",
"language": "typescript",
"name": "tslab"
},
"language_info": {
"codemirror_mode": {
"mode": "typescript",
"name": "javascript",
"typescript": true
},
"file_extension": ".ts",
"mimetype": "text/typescript",
"name": "typescript",
"version": "3.7.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 1 addition & 1 deletion examples/src/models/chat/chat_mistralai_tools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class CalculatorTool extends StructuredTool {

const model = new ChatMistralAI({
apiKey: process.env.MISTRAL_API_KEY,
model: "mistral-large",
model: "mistral-large-latest",
});

// Bind the tool to the model
Expand Down
2 changes: 1 addition & 1 deletion examples/src/models/chat/chat_mistralai_wsa.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const calculatorSchema = z

const model = new ChatMistralAI({
apiKey: process.env.MISTRAL_API_KEY,
model: "mistral-large",
model: "mistral-large-latest",
});

// Pass the schema and tool name to the withStructuredOutput method
Expand Down
2 changes: 1 addition & 1 deletion examples/src/models/chat/chat_mistralai_wsa_json.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ const calculatorJsonSchema = {

const model = new ChatMistralAI({
apiKey: process.env.MISTRAL_API_KEY,
model: "mistral-large",
model: "mistral-large-latest",
});

// Pass the schema and tool name to the withStructuredOutput method
Expand Down
30 changes: 30 additions & 0 deletions examples/src/models/chat/integration_openai_stream_tokens.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import { AIMessageChunk } from "@langchain/core/messages";
import { ChatOpenAI } from "@langchain/openai";

// Instantiate the model
const model = new ChatOpenAI();

const response = await model.stream("Hello, how are you?", {
// Pass the stream options
stream_options: {
include_usage: true,
},
});

// Iterate over the response, only saving the last chunk
let finalResult: AIMessageChunk | undefined;
for await (const chunk of response) {
if (finalResult) {
finalResult = finalResult.concat(chunk);
} else {
finalResult = chunk;
}
}

console.log(finalResult?.usage_metadata);

/*
{ input_tokens: 13, output_tokens: 30, total_tokens: 43 }
*/
2 changes: 1 addition & 1 deletion langchain-core/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@langchain/core",
"version": "0.2.4",
"version": "0.2.5",
"description": "Core LangChain.js abstractions and schemas",
"type": "module",
"engines": {
Expand Down
58 changes: 55 additions & 3 deletions langchain-core/src/messages/ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,25 @@ import {
export type AIMessageFields = BaseMessageFields & {
tool_calls?: ToolCall[];
invalid_tool_calls?: InvalidToolCall[];
usage_metadata?: UsageMetadata;
};

/**
* Usage metadata for a message, such as token counts.
*/
export type UsageMetadata = {
/**
* The count of input (or prompt) tokens.
*/
input_tokens: number;
/**
* The count of output (or completion) tokens
*/
output_tokens: number;
/**
* The total token count
*/
total_tokens: number;
};

/**
Expand All @@ -30,6 +49,11 @@ export class AIMessage extends BaseMessage {

invalid_tool_calls?: InvalidToolCall[] = [];

/**
* If provided, token usage information associated with the message.
*/
usage_metadata?: UsageMetadata;

get lc_aliases(): Record<string, string> {
// exclude snake case conversion to pascal case
return {
Expand Down Expand Up @@ -94,6 +118,7 @@ export class AIMessage extends BaseMessage {
this.invalid_tool_calls =
initParams.invalid_tool_calls ?? this.invalid_tool_calls;
}
this.usage_metadata = initParams.usage_metadata;
}

static lc_name() {
Expand Down Expand Up @@ -127,6 +152,11 @@ export class AIMessageChunk extends BaseMessageChunk {

tool_call_chunks?: ToolCallChunk[] = [];

/**
* If provided, token usage information associated with the message.
*/
usage_metadata?: UsageMetadata;

constructor(fields: string | AIMessageChunkFields) {
let initParams: AIMessageChunkFields;
if (typeof fields === "string") {
Expand Down Expand Up @@ -177,10 +207,11 @@ export class AIMessageChunk extends BaseMessageChunk {
// properties with initializers, so we have to check types twice.
super(initParams);
this.tool_call_chunks =
initParams?.tool_call_chunks ?? this.tool_call_chunks;
this.tool_calls = initParams?.tool_calls ?? this.tool_calls;
initParams.tool_call_chunks ?? this.tool_call_chunks;
this.tool_calls = initParams.tool_calls ?? this.tool_calls;
this.invalid_tool_calls =
initParams?.invalid_tool_calls ?? this.invalid_tool_calls;
initParams.invalid_tool_calls ?? this.invalid_tool_calls;
this.usage_metadata = initParams.usage_metadata;
}

get lc_aliases(): Record<string, string> {
Expand Down Expand Up @@ -226,6 +257,27 @@ export class AIMessageChunk extends BaseMessageChunk {
combinedFields.tool_call_chunks = rawToolCalls;
}
}
if (
this.usage_metadata !== undefined ||
chunk.usage_metadata !== undefined
) {
const left: UsageMetadata = this.usage_metadata ?? {
input_tokens: 0,
output_tokens: 0,
total_tokens: 0,
};
const right: UsageMetadata = chunk.usage_metadata ?? {
input_tokens: 0,
output_tokens: 0,
total_tokens: 0,
};
const usage_metadata: UsageMetadata = {
input_tokens: left.input_tokens + right.input_tokens,
output_tokens: left.output_tokens + right.output_tokens,
total_tokens: left.total_tokens + right.total_tokens,
};
combinedFields.usage_metadata = usage_metadata;
}
return new AIMessageChunk(combinedFields);
}
}
Loading

0 comments on commit 40b8109

Please sign in to comment.