From de3e6186831d351ee071ca40f97f9038a28a31dc Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Tue, 11 Jun 2024 23:38:23 +0300 Subject: [PATCH] community[major]: DeepInfra llm and chat (#5672) * Init * fix(type errors) * feat(deepinfra embeddings) * fix(default model) * fix(deepinfra): axios is removed * ref(deepinfra): remove redundant cast * format(deepinfra) * doc(deepinfra) * doc(deepinfra) * Update deepinfra.mdx * Format * feat(deepinfra): implement llm and chat. * ref(deepinfra): lint and prettier * ref(deepinfra): remove console.log * fix(chatdeepinfra): body * fix(import map): deepinfra * fix(gitignore) * revert(.gitignore) * revert(.gitignore) * Adds docs --------- Co-authored-by: Jacob Lee --- docs/core_docs/.gitignore | 2 +- .../docs/integrations/chat/deep_infra.mdx | 25 ++ .../docs/integrations/llms/deep_infra.mdx | 25 ++ .../src/models/chat/integration_deepinfra.ts | 17 ++ examples/src/models/llm/deepinfra.ts | 18 ++ libs/langchain-community/.gitignore | 8 + libs/langchain-community/langchain.config.js | 2 + libs/langchain-community/package.json | 26 +++ .../src/chat_models/deepinfra.ts | 215 ++++++++++++++++++ .../tests/chatdeepinfra.int.test.ts | 19 ++ .../langchain-community/src/llms/deepinfra.ts | 69 ++++++ .../src/llms/tests/deepinfra.int.test.ts | 8 + .../src/load/import_map.ts | 2 + 13 files changed, 435 insertions(+), 1 deletion(-) create mode 100644 docs/core_docs/docs/integrations/chat/deep_infra.mdx create mode 100644 docs/core_docs/docs/integrations/llms/deep_infra.mdx create mode 100644 examples/src/models/chat/integration_deepinfra.ts create mode 100644 examples/src/models/llm/deepinfra.ts create mode 100644 libs/langchain-community/src/chat_models/deepinfra.ts create mode 100644 libs/langchain-community/src/chat_models/tests/chatdeepinfra.int.test.ts create mode 100644 libs/langchain-community/src/llms/deepinfra.ts create mode 100644 libs/langchain-community/src/llms/tests/deepinfra.int.test.ts diff --git a/docs/core_docs/.gitignore b/docs/core_docs/.gitignore index 7e987abbe326..b88cb13fc1c6 100644 --- a/docs/core_docs/.gitignore +++ b/docs/core_docs/.gitignore @@ -176,4 +176,4 @@ docs/how_to/assign.mdx docs/how_to/agent_executor.md docs/how_to/agent_executor.mdx docs/integrations/llms/mistral.md -docs/integrations/llms/mistral.mdx \ No newline at end of file +docs/integrations/llms/mistral.mdx diff --git a/docs/core_docs/docs/integrations/chat/deep_infra.mdx b/docs/core_docs/docs/integrations/chat/deep_infra.mdx new file mode 100644 index 000000000000..5e5805c84bf9 --- /dev/null +++ b/docs/core_docs/docs/integrations/chat/deep_infra.mdx @@ -0,0 +1,25 @@ +--- +sidebar_label: Deep Infra +--- + +import CodeBlock from "@theme/CodeBlock"; + +# ChatDeepInfra + +LangChain supports chat models hosted by [Deep Infra](https://deepinfra.com/) through the `ChatDeepInfra` wrapper. +First, you'll need to install the `@langchain/community` package: + +import IntegrationInstallTooltip from "@mdx_components/integration_install_tooltip.mdx"; + + + +```bash npm2yarn +npm install @langchain/community +``` + +You'll need to obtain an API key and set it as an environment variable named `DEEPINFRA_API_TOKEN` +(or pass it into the constructor), then call the model as shown below: + +import Example from "@examples/models/chat/integration_deepinfra.ts"; + +{Example} diff --git a/docs/core_docs/docs/integrations/llms/deep_infra.mdx b/docs/core_docs/docs/integrations/llms/deep_infra.mdx new file mode 100644 index 000000000000..76e75db0e134 --- /dev/null +++ b/docs/core_docs/docs/integrations/llms/deep_infra.mdx @@ -0,0 +1,25 @@ +--- +sidebar_label: Deep Infra +--- + +import CodeBlock from "@theme/CodeBlock"; + +# DeepInfra + +LangChain supports LLMs hosted by [Deep Infra](https://deepinfra.com/) through the `DeepInfra` wrapper. +First, you'll need to install the `@langchain/community` package: + +import IntegrationInstallTooltip from "@mdx_components/integration_install_tooltip.mdx"; + + + +```bash npm2yarn +npm install @langchain/community +``` + +You'll need to obtain an API key and set it as an environment variable named `DEEPINFRA_API_TOKEN` +(or pass it into the constructor), then call the model as shown below: + +import Example from "@examples/models/llm/deepinfra.ts"; + +{Example} diff --git a/examples/src/models/chat/integration_deepinfra.ts b/examples/src/models/chat/integration_deepinfra.ts new file mode 100644 index 000000000000..def635fd676b --- /dev/null +++ b/examples/src/models/chat/integration_deepinfra.ts @@ -0,0 +1,17 @@ +import { ChatDeepInfra } from "@langchain/community/chat_models/deepinfra"; +import { HumanMessage } from "@langchain/core/messages"; + +const apiKey = process.env.DEEPINFRA_API_TOKEN; + +const model = "meta-llama/Meta-Llama-3-70B-Instruct"; + +const chat = new ChatDeepInfra({ + model, + apiKey, +}); + +const messages = [new HumanMessage("Hello")]; + +const res = await chat.invoke(messages); + +console.log(res); diff --git a/examples/src/models/llm/deepinfra.ts b/examples/src/models/llm/deepinfra.ts new file mode 100644 index 000000000000..28571d07cb04 --- /dev/null +++ b/examples/src/models/llm/deepinfra.ts @@ -0,0 +1,18 @@ +import { DeepInfraLLM } from "@langchain/community/llms/deepinfra"; + +const apiKey = process.env.DEEPINFRA_API_TOKEN; +const model = "meta-llama/Meta-Llama-3-70B-Instruct"; + +const llm = new DeepInfraLLM({ + temperature: 0.7, + maxTokens: 20, + model, + apiKey, + maxRetries: 5, +}); + +const res = await llm.invoke( + "What is the next step in the process of making a good game?" +); + +console.log({ res }); diff --git a/libs/langchain-community/.gitignore b/libs/langchain-community/.gitignore index bad275d50cf4..585ec0c0f9df 100644 --- a/libs/langchain-community/.gitignore +++ b/libs/langchain-community/.gitignore @@ -234,6 +234,10 @@ llms/cohere.cjs llms/cohere.js llms/cohere.d.ts llms/cohere.d.cts +llms/deepinfra.cjs +llms/deepinfra.js +llms/deepinfra.d.ts +llms/deepinfra.d.cts llms/fireworks.cjs llms/fireworks.js llms/fireworks.d.ts @@ -510,6 +514,10 @@ chat_models/cloudflare_workersai.cjs chat_models/cloudflare_workersai.js chat_models/cloudflare_workersai.d.ts chat_models/cloudflare_workersai.d.cts +chat_models/deepinfra.cjs +chat_models/deepinfra.js +chat_models/deepinfra.d.ts +chat_models/deepinfra.d.cts chat_models/fireworks.cjs chat_models/fireworks.js chat_models/fireworks.d.ts diff --git a/libs/langchain-community/langchain.config.js b/libs/langchain-community/langchain.config.js index d5c4c6b15eb3..53b32cfb4dd3 100644 --- a/libs/langchain-community/langchain.config.js +++ b/libs/langchain-community/langchain.config.js @@ -93,6 +93,7 @@ export const config = { "llms/bedrock/web": "llms/bedrock/web", "llms/cloudflare_workersai": "llms/cloudflare_workersai", "llms/cohere": "llms/cohere", + "llms/deepinfra": "llms/deepinfra", "llms/fireworks": "llms/fireworks", "llms/friendli": "llms/friendli", "llms/googlepalm": "llms/googlepalm", @@ -164,6 +165,7 @@ export const config = { "chat_models/bedrock": "chat_models/bedrock/index", "chat_models/bedrock/web": "chat_models/bedrock/web", "chat_models/cloudflare_workersai": "chat_models/cloudflare_workersai", + "chat_models/deepinfra": "chat_models/deepinfra", "chat_models/fireworks": "chat_models/fireworks", "chat_models/friendli": "chat_models/friendli", "chat_models/googlevertexai": "chat_models/googlevertexai/index", diff --git a/libs/langchain-community/package.json b/libs/langchain-community/package.json index b6228c721400..4192553a17a7 100644 --- a/libs/langchain-community/package.json +++ b/libs/langchain-community/package.json @@ -1231,6 +1231,15 @@ "import": "./llms/cohere.js", "require": "./llms/cohere.cjs" }, + "./llms/deepinfra": { + "types": { + "import": "./llms/deepinfra.d.ts", + "require": "./llms/deepinfra.d.cts", + "default": "./llms/deepinfra.d.ts" + }, + "import": "./llms/deepinfra.js", + "require": "./llms/deepinfra.cjs" + }, "./llms/fireworks": { "types": { "import": "./llms/fireworks.d.ts", @@ -1852,6 +1861,15 @@ "import": "./chat_models/cloudflare_workersai.js", "require": "./chat_models/cloudflare_workersai.cjs" }, + "./chat_models/deepinfra": { + "types": { + "import": "./chat_models/deepinfra.d.ts", + "require": "./chat_models/deepinfra.d.cts", + "default": "./chat_models/deepinfra.d.ts" + }, + "import": "./chat_models/deepinfra.js", + "require": "./chat_models/deepinfra.cjs" + }, "./chat_models/fireworks": { "types": { "import": "./chat_models/fireworks.d.ts", @@ -3235,6 +3253,10 @@ "llms/cohere.js", "llms/cohere.d.ts", "llms/cohere.d.cts", + "llms/deepinfra.cjs", + "llms/deepinfra.js", + "llms/deepinfra.d.ts", + "llms/deepinfra.d.cts", "llms/fireworks.cjs", "llms/fireworks.js", "llms/fireworks.d.ts", @@ -3511,6 +3533,10 @@ "chat_models/cloudflare_workersai.js", "chat_models/cloudflare_workersai.d.ts", "chat_models/cloudflare_workersai.d.cts", + "chat_models/deepinfra.cjs", + "chat_models/deepinfra.js", + "chat_models/deepinfra.d.ts", + "chat_models/deepinfra.d.cts", "chat_models/fireworks.cjs", "chat_models/fireworks.js", "chat_models/fireworks.d.ts", diff --git a/libs/langchain-community/src/chat_models/deepinfra.ts b/libs/langchain-community/src/chat_models/deepinfra.ts new file mode 100644 index 000000000000..82626ecc0a9f --- /dev/null +++ b/libs/langchain-community/src/chat_models/deepinfra.ts @@ -0,0 +1,215 @@ +import { + BaseChatModel, + type BaseChatModelParams, +} from "@langchain/core/language_models/chat_models"; +import { AIMessage, type BaseMessage } from "@langchain/core/messages"; +import { type ChatResult } from "@langchain/core/outputs"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; + +export const DEFAULT_MODEL = "meta-llama/Meta-Llama-3-70B-Instruct"; + +export type DeepInfraMessageRole = "system" | "assistant" | "user"; + +export const API_BASE_URL = + "https://api.deepinfra.com/v1/openai/chat/completions"; + +export const ENV_VARIABLE_API_KEY = "DEEPINFRA_API_TOKEN"; + +interface DeepInfraMessage { + role: DeepInfraMessageRole; + content: string; +} + +interface ChatCompletionRequest { + model: string; + messages?: DeepInfraMessage[]; + stream?: boolean; + max_tokens?: number | null; + temperature?: number | null; +} + +interface BaseResponse { + code?: string; + message?: string; +} + +interface ChoiceMessage { + role: string; + content: string; +} + +interface ResponseChoice { + index: number; + finish_reason: "stop" | "length" | "null" | null; + delta: ChoiceMessage; + message: ChoiceMessage; +} + +interface ChatCompletionResponse extends BaseResponse { + choices: ResponseChoice[]; + usage: { + completion_tokens: number; + prompt_tokens: number; + total_tokens: number; + }; + output: { + text: string; + finish_reason: "stop" | "length" | "null" | null; + }; +} + +export interface ChatDeepInfraParams { + model: string; + apiKey?: string; + temperature?: number; + maxTokens?: number; +} + +function messageToRole(message: BaseMessage): DeepInfraMessageRole { + const type = message._getType(); + switch (type) { + case "ai": + return "assistant"; + case "human": + return "user"; + case "system": + return "system"; + default: + throw new Error(`Unknown message type: ${type}`); + } +} + +export class ChatDeepInfra + extends BaseChatModel + implements ChatDeepInfraParams +{ + static lc_name() { + return "ChatDeepInfra"; + } + + get callKeys() { + return ["stop", "signal", "options"]; + } + + apiKey?: string; + + model: string; + + apiUrl: string; + + maxTokens?: number; + + temperature?: number; + + constructor(fields: Partial & BaseChatModelParams = {}) { + super(fields); + + this.apiKey = + fields?.apiKey ?? getEnvironmentVariable(ENV_VARIABLE_API_KEY); + if (!this.apiKey) { + throw new Error( + "API key is required, set `DEEPINFRA_API_TOKEN` environment variable or pass it as a parameter" + ); + } + + this.apiUrl = API_BASE_URL; + this.model = fields.model ?? DEFAULT_MODEL; + this.temperature = fields.temperature ?? 0; + this.maxTokens = fields.maxTokens; + } + + invocationParams(): Omit { + return { + model: this.model, + stream: false, + temperature: this.temperature, + max_tokens: this.maxTokens, + }; + } + + identifyingParams(): Omit { + return this.invocationParams(); + } + + async _generate( + messages: BaseMessage[], + options?: this["ParsedCallOptions"] + ): Promise { + const parameters = this.invocationParams(); + + const messagesMapped: DeepInfraMessage[] = messages.map((message) => ({ + role: messageToRole(message), + content: message.content as string, + })); + + const data = await this.completionWithRetry( + { ...parameters, messages: messagesMapped }, + false, + options?.signal + ).then((data) => { + if (data?.code) { + throw new Error(data?.message); + } + const { finish_reason, message } = data.choices[0]; + const text = message.content; + return { + ...data, + output: { text, finish_reason }, + }; + }); + + const { + prompt_tokens = 0, + completion_tokens = 0, + total_tokens = 0, + } = data.usage ?? {}; + + const { text } = data.output; + + return { + generations: [{ text, message: new AIMessage(text) }], + llmOutput: { + tokenUsage: { + promptTokens: prompt_tokens, + completionTokens: completion_tokens, + totalTokens: total_tokens, + }, + }, + }; + } + + async completionWithRetry( + request: ChatCompletionRequest, + stream: boolean, + signal?: AbortSignal + ) { + const body = { + temperature: this.temperature, + max_tokens: this.maxTokens, + ...request, + model: this.model, + }; + + const makeCompletionRequest = async () => { + const response = await fetch(this.apiUrl, { + method: "POST", + headers: { + Authorization: `Bearer ${this.apiKey}`, + "Content-Type": "application/json", + }, + body: JSON.stringify(body), + signal, + }); + + if (!stream) { + return response.json(); + } + }; + + return this.caller.call(makeCompletionRequest); + } + + _llmType(): string { + return "DeepInfra"; + } +} diff --git a/libs/langchain-community/src/chat_models/tests/chatdeepinfra.int.test.ts b/libs/langchain-community/src/chat_models/tests/chatdeepinfra.int.test.ts new file mode 100644 index 000000000000..e8b8cfbc97af --- /dev/null +++ b/libs/langchain-community/src/chat_models/tests/chatdeepinfra.int.test.ts @@ -0,0 +1,19 @@ +import { test } from "@jest/globals"; +import { HumanMessage } from "@langchain/core/messages"; +import { ChatDeepInfra } from "../deepinfra.js"; + +describe("ChatDeepInfra", () => { + test("call", async () => { + const deepInfraChat = new ChatDeepInfra({ maxTokens: 20 }); + const message = new HumanMessage("1 + 1 = "); + const res = await deepInfraChat.invoke([message]); + console.log({ res }); + }); + + test("generate", async () => { + const deepInfraChat = new ChatDeepInfra({ maxTokens: 20 }); + const message = new HumanMessage("1 + 1 = "); + const res = await deepInfraChat.generate([[message]]); + console.log(JSON.stringify(res, null, 2)); + }); +}); diff --git a/libs/langchain-community/src/llms/deepinfra.ts b/libs/langchain-community/src/llms/deepinfra.ts new file mode 100644 index 000000000000..e55c37719aff --- /dev/null +++ b/libs/langchain-community/src/llms/deepinfra.ts @@ -0,0 +1,69 @@ +import { LLM, type BaseLLMParams } from "@langchain/core/language_models/llms"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; + +export const DEEPINFRA_API_BASE = + "https://api.deepinfra.com/v1/openai/completions"; + +export const DEFAULT_MODEL_NAME = "mistralai/Mixtral-8x22B-Instruct-v0.1"; + +export const ENV_VARIABLE = "DEEPINFRA_API_TOKEN"; + +export interface DeepInfraLLMParams extends BaseLLMParams { + apiKey?: string; + model?: string; + maxTokens?: number; + temperature?: number; +} + +export class DeepInfraLLM extends LLM implements DeepInfraLLMParams { + static lc_name() { + return "DeepInfraLLM"; + } + + lc_serializable = true; + + apiKey?: string; + + model?: string; + + maxTokens?: number; + + temperature?: number; + + constructor(fields: Partial = {}) { + super(fields); + + this.apiKey = fields.apiKey ?? getEnvironmentVariable(ENV_VARIABLE); + this.model = fields.model ?? DEFAULT_MODEL_NAME; + this.maxTokens = fields.maxTokens; + this.temperature = fields.temperature; + } + + _llmType(): string { + return "DeepInfra"; + } + + async _call( + prompt: string, + options: this["ParsedCallOptions"] + ): Promise { + const body = { + temperature: this.temperature, + max_tokens: this.maxTokens, + ...options, + prompt, + model: this.model, + }; + const response = await this.caller.call(() => + fetch(DEEPINFRA_API_BASE, { + method: "POST", + headers: { + Authorization: `Bearer ${this.apiKey}`, + "Content-Type": "application/json", + }, + body: JSON.stringify(body), + }).then((res) => res.json()) + ); + return response as string; + } +} diff --git a/libs/langchain-community/src/llms/tests/deepinfra.int.test.ts b/libs/langchain-community/src/llms/tests/deepinfra.int.test.ts new file mode 100644 index 000000000000..1c8853d2782a --- /dev/null +++ b/libs/langchain-community/src/llms/tests/deepinfra.int.test.ts @@ -0,0 +1,8 @@ +import { test } from "@jest/globals"; +import { DeepInfraLLM } from "../deepinfra.js"; + +test("Test DeepInfra", async () => { + const model = new DeepInfraLLM({ maxTokens: 20 }); + const res = await model.invoke("1 + 1 ="); + console.log(res); +}, 50000); diff --git a/libs/langchain-community/src/load/import_map.ts b/libs/langchain-community/src/load/import_map.ts index bc4d5b33bd77..b8b0a18e0564 100644 --- a/libs/langchain-community/src/load/import_map.ts +++ b/libs/langchain-community/src/load/import_map.ts @@ -34,6 +34,7 @@ export * as embeddings__voyage from "../embeddings/voyage.js"; export * as llms__ai21 from "../llms/ai21.js"; export * as llms__aleph_alpha from "../llms/aleph_alpha.js"; export * as llms__cloudflare_workersai from "../llms/cloudflare_workersai.js"; +export * as llms__deepinfra from "../llms/deepinfra.js"; export * as llms__fireworks from "../llms/fireworks.js"; export * as llms__friendli from "../llms/friendli.js"; export * as llms__ollama from "../llms/ollama.js"; @@ -45,6 +46,7 @@ export * as vectorstores__vectara from "../vectorstores/vectara.js"; export * as chat_models__alibaba_tongyi from "../chat_models/alibaba_tongyi.js"; export * as chat_models__baiduwenxin from "../chat_models/baiduwenxin.js"; export * as chat_models__cloudflare_workersai from "../chat_models/cloudflare_workersai.js"; +export * as chat_models__deepinfra from "../chat_models/deepinfra.js"; export * as chat_models__fireworks from "../chat_models/fireworks.js"; export * as chat_models__friendli from "../chat_models/friendli.js"; export * as chat_models__minimax from "../chat_models/minimax.js";