Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow custom model and provider #16

Merged
merged 2 commits into from
Jul 4, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion index.ts
Original file line number Diff line number Diff line change
@@ -6,5 +6,4 @@ export * from "./src/error";
export * from "./src/types";
export * from "./src/models";
export { MODEL_NAME_WITH_PROVIDER_SPLITTER } from "./src/constants";
export { type Model, UpstashLLMClient, UpstashLLMClientConfig } from "./src/upstash-llm-client";
export { ChatOpenAI } from "@langchain/openai";
16 changes: 4 additions & 12 deletions src/config.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import type { BaseLanguageModelInterface } from "@langchain/core/language_models/base";
import { ChatOpenAI } from "@langchain/openai";
import type { Ratelimit } from "@upstash/ratelimit";
import { Redis } from "@upstash/redis";
import { Index } from "@upstash/vector";
import { DEFAULT_PROMPT } from "./constants";
import { openaiModel, upstashModel } from "./models";
import type { CustomPrompt } from "./rag-chat-base";
import type { RAGChatConfig } from "./types";
import { UpstashLLMClient } from "./upstash-llm-client";
import { DEFAULT_PROMPT } from "./constants";

export class Config {
public readonly vector?: Index;
@@ -49,17 +48,10 @@ const initializeModel = () => {
const openAIToken = process.env.OPENAI_API_KEY;

if (qstashToken)
return new UpstashLLMClient({
model: "meta-llama/Meta-Llama-3-8B-Instruct",
apiKey: qstashToken,
});
return upstashModel("meta-llama/Meta-Llama-3-8B-Instruct", { apiKey: qstashToken });

if (openAIToken) {
return new ChatOpenAI({
modelName: "gpt-4o",
verbose: false,
apiKey: openAIToken,
});
return openaiModel("gpt-4o", { apiKey: openAIToken });
}

throw new Error(
15 changes: 7 additions & 8 deletions src/upstash-llm-client.ts → src/custom-llm-client.ts
Original file line number Diff line number Diff line change
@@ -2,10 +2,8 @@ import { ChatOpenAI } from "@langchain/openai";
import { type BaseMessage } from "@langchain/core/messages";
import { type ChatGeneration } from "@langchain/core/outputs";

export type Model = "mistralai/Mistral-7B-Instruct-v0.2" | "meta-llama/Meta-Llama-3-8B-Instruct";

export type UpstashLLMClientConfig = {
model: Model;
export type LLMClientConfig = {
model: string;
apiKey: string;
maxTokens?: number;
stop?: string[];
@@ -17,10 +15,11 @@ export type UpstashLLMClientConfig = {
logitBias?: Record<string, number>;
logProbs?: number;
topLogprobs?: number;
baseUrl: string;
};

export class UpstashLLMClient extends ChatOpenAI {
modelName: Model;
export class LLMClient extends ChatOpenAI {
modelName: string;
apiKey: string;
maxTokens?: number;
stop?: string[];
@@ -33,7 +32,7 @@ export class UpstashLLMClient extends ChatOpenAI {
logProbs?: number;
topLogprobs?: number;

constructor(config: UpstashLLMClientConfig) {
constructor(config: LLMClientConfig) {
super(
{
modelName: config.model,
@@ -49,7 +48,7 @@ export class UpstashLLMClient extends ChatOpenAI {
stop: config.stop,
},
{
baseURL: "https://qstash.upstash.io/llm/v1",
baseURL: config.baseUrl,
}
);

31 changes: 19 additions & 12 deletions src/models.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { ChatOpenAI } from "@langchain/openai";
import type { UpstashLLMClientConfig } from "./upstash-llm-client";
import { UpstashLLMClient } from "./upstash-llm-client";
import type { LLMClientConfig } from "./custom-llm-client";
import { LLMClient } from "./custom-llm-client";

export type OpenAIChatModel =
| "gpt-4-turbo"
@@ -10,6 +10,7 @@ export type OpenAIChatModel =
| "gpt-4-1106-preview"
| "gpt-4-vision-preview"
| "gpt-4"
| "gpt-4o"
| "gpt-4-0314"
| "gpt-4-0613"
| "gpt-4-32k"
@@ -27,21 +28,27 @@ export type UpstashChatModel =
| "mistralai/Mistral-7B-Instruct-v0.2"
| "meta-llama/Meta-Llama-3-8B-Instruct";

export const upstashModel = (
model: UpstashChatModel,
options?: Omit<UpstashLLMClientConfig, "model">
) => {
return new UpstashLLMClient({
type ModelOptions = Omit<LLMClientConfig, "model">;

export const upstashModel = (model: UpstashChatModel, options?: Omit<ModelOptions, "baseUrl">) => {
return new LLMClient({
model,
baseUrl: "https://qstash.upstash.io/llm/v1",
apiKey: options?.apiKey ?? "",
...options,
});
};

export const customModel = (model: string, options?: ModelOptions) => {
if (!options?.baseUrl) throw new Error("baseUrl cannot be empty or undefined.");

return new LLMClient({
model,
apiKey: process.env.QSTASH_TOKEN ?? options?.apiKey ?? "",
...options,
});
};

export const openaiModel = (
model: OpenAIChatModel,
options?: Omit<UpstashLLMClientConfig, "model">
) => {
export const openaiModel = (model: OpenAIChatModel, options?: Omit<ModelOptions, "baseUrl">) => {
return new ChatOpenAI({
modelName: model,
temperature: 0,
41 changes: 41 additions & 0 deletions src/rag-chat.test.ts
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@ import { afterAll, beforeAll, describe, expect, test } from "bun:test";
import { RatelimitUpstashError } from "./error/ratelimit";
import { RAGChat } from "./rag-chat";
import { awaitUntilIndexed } from "./test-utils";
import { customModel } from "./models";

async function checkStream(
stream: ReadableStream<string>,
@@ -462,3 +463,43 @@ describe("RAGChat init without model", () => {
{ timeout: 30_000 }
);
});

describe("RAGChat init with custom model - todo", () => {
const namespace = "japan";
const vector = new Index({
token: process.env.UPSTASH_VECTOR_REST_TOKEN!,
url: process.env.UPSTASH_VECTOR_REST_URL!,
});

const ragChat = new RAGChat({
vector,
model: customModel("meta-llama/Llama-3-8b-chat-hf", {
apiKey: "be4f7601b3fe999ccd4fced312c812f77db2121a1954646dda75097c14e3de7e",
baseUrl: "https://api.together.xyz",
}),
});

afterAll(async () => {
await vector.reset({ namespace });
});

test(
"should be able to insert data into a namespace and query it",
async () => {
await ragChat.context.add({
type: "text",
data: "Tokyo is the Capital of Japan.",
options: { namespace },
});
await awaitUntilIndexed(vector);

const result = await ragChat.chat("Where is the capital of Japan?", {
metadataKey: "text",
namespace,
});

expect(result.output).toContain("Tokyo");
},
{ timeout: 30_000 }
);
});
4 changes: 2 additions & 2 deletions src/types.ts
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@ import type { Ratelimit } from "@upstash/ratelimit";
import type { Redis } from "@upstash/redis";
import type { Index } from "@upstash/vector";
import type { CustomPrompt } from "./rag-chat-base";
import type { UpstashLLMClient } from "./upstash-llm-client";
import type { LLMClient } from "./custom-llm-client";

declare const __brand: unique symbol;
type Brand<B> = { [__brand]: B };
@@ -70,7 +70,7 @@ type RAGChatConfigCommon = {
apiKey,
})
*/
model?: UpstashLLMClient | ChatOpenAI;
model?: LLMClient | ChatOpenAI;
/**
* If no Index name or instance is provided, falls back to the default.
* @default
123 changes: 0 additions & 123 deletions src/upstash-llm-client.test.ts

This file was deleted.