Skip to content

Commit

Permalink
feat: support gemini-pro-vision
Browse files Browse the repository at this point in the history
  • Loading branch information
Hk-Gosuto committed Feb 18, 2024
1 parent 5c389db commit e62b4c1
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 14 deletions.
42 changes: 36 additions & 6 deletions app/client/platforms/google.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ import {
LLMUsage,
} from "../api";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
import axios from "axios";

const getImageBase64Data = async (url: string) => {
const response = await axios.get(url, { responseType: "arraybuffer" });
const base64 = Buffer.from(response.data, "binary").toString("base64");
return base64;
};

export class GeminiProApi implements LLMApi {
toolAgentChat(options: AgentChatOptions): Promise<void> {
Expand All @@ -28,11 +35,32 @@ export class GeminiProApi implements LLMApi {
);
}
async chat(options: ChatOptions): Promise<void> {
const apiClient = this;
const messages = options.messages.map((v) => ({
role: v.role.replace("assistant", "model").replace("system", "user"),
parts: [{ text: v.content }],
}));
const messages: any[] = [];
if (options.config.model.includes("vision")) {
for (const v of options.messages) {
let message: any = {
role: v.role.replace("assistant", "model").replace("system", "user"),
parts: [{ text: v.content }],
};
if (v.image_url) {
var base64Data = await getImageBase64Data(v.image_url);
message.parts.push({
inline_data: {
mime_type: "image/jpeg",
data: base64Data,
},
});
}
messages.push(message);
}
} else {
options.messages.map((v) =>
messages.push({
role: v.role.replace("assistant", "model").replace("system", "user"),
parts: [{ text: v.content }],
}),
);
}

// google requires that role in neighboring messages must not be the same
for (let i = 0; i < messages.length - 1; ) {
Expand Down Expand Up @@ -92,7 +120,9 @@ export class GeminiProApi implements LLMApi {
const controller = new AbortController();
options.onController?.(controller);
try {
const chatPath = this.path(Google.ChatPath);
const chatPath = this.path(
Google.ChatPath.replace("{{model}}", options.config.model),
);
const chatPayload = {
method: "POST",
body: JSON.stringify(requestPayload),
Expand Down
7 changes: 3 additions & 4 deletions app/client/platforms/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,9 @@ export class ChatGPTApi implements LLMApi {
presence_penalty: modelConfig.presence_penalty,
frequency_penalty: modelConfig.frequency_penalty,
top_p: modelConfig.top_p,
max_tokens:
modelConfig.model == "gpt-4-vision-preview"
? modelConfig.max_tokens
: null,
max_tokens: modelConfig.model.includes("vision")
? modelConfig.max_tokens
: null,
// max_tokens: Math.max(modelConfig.max_tokens, 1024),
// Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore.
};
Expand Down
6 changes: 3 additions & 3 deletions app/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ export function ChatActions(props: {
}
}
};
if (currentModel === "gpt-4-vision-preview") {
if (currentModel.includes("vision")) {
window.addEventListener("paste", onPaste);
return () => {
window.removeEventListener("paste", onPaste);
Expand Down Expand Up @@ -620,7 +620,7 @@ export function ChatActions(props: {
icon={usePlugins ? <EnablePluginIcon /> : <DisablePluginIcon />}
/>
)}
{currentModel == "gpt-4-vision-preview" && (
{currentModel.includes("vision") && (
<ChatAction
onClick={selectImage}
text="选择图片"
Expand Down Expand Up @@ -1412,7 +1412,7 @@ function _Chat() {
defaultShow={i >= messages.length - 6}
/>
</div>
{!isUser && message.model == "gpt-4-vision-preview" && (
{!isUser && message.model?.includes("vision") && (
<div
className={[
styles["chat-message-actions"],
Expand Down
11 changes: 10 additions & 1 deletion app/constant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ export const Azure = {

export const Google = {
ExampleEndpoint: "https://generativelanguage.googleapis.com/",
ChatPath: "v1beta/models/gemini-pro:generateContent",
ChatPath: "v1beta/models/{{model}}:generateContent",
};

export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
Expand Down Expand Up @@ -253,6 +253,15 @@ export const DEFAULT_MODELS = [
providerType: "google",
},
},
{
name: "gemini-pro-vision",
available: true,
provider: {
id: "google",
providerName: "Google",
providerType: "google",
},
},
] as const;

export const CHAT_PAGE_SIZE = 15;
Expand Down

0 comments on commit e62b4c1

Please sign in to comment.