Skip to content

Commit

Permalink
Merge pull request #1 from zby909/feat/add-dalle
Browse files Browse the repository at this point in the history
Feat/add dalle
  • Loading branch information
CoreJa authored Apr 2, 2024
2 parents e38b527 + 8ea833c commit ba312bb
Show file tree
Hide file tree
Showing 12 changed files with 677 additions and 195 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
/node_modules
/.pnp
.pnp.js
.history

# testing
/coverage
Expand Down
2 changes: 2 additions & 0 deletions app/client/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ export interface ChatOptions {
messages: RequestMessage[];
config: LLMConfig;

attachImages?: string[];
isSummarizeSession?: boolean;
onUpdate?: (message: string, chunk: string) => void;
onFinish: (message: string) => void;
onError?: (err: Error) => void;
Expand Down
175 changes: 146 additions & 29 deletions app/client/platforms/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ import {
getMessageTextContent,
getMessageImages,
isVisionModel,
isDalleModel,
isDalle2Model,
isDalle3Model,
getBlobUrl2File,
base64ToFile,
createObjectURL,
} from "@/app/utils";

export interface OpenAIListModelResponse {
Expand Down Expand Up @@ -80,11 +86,34 @@ export class ChatGPTApi implements LLMApi {
}

extractMessage(res: any) {
return res.choices?.at(0)?.message?.content ?? "";
if (Array.isArray(res.data)) {
let r = "";
(res.data as any[]).forEach((data: any, index: number) => {
if (data.revised_prompt) {
r += "\n\n" + data.revised_prompt;
}
if (data.b64_json) {
const url = createObjectURL(
base64ToFile("data:image/png;base64," + data.b64_json),
);
r += `\n\n![${data.revised_prompt}](${url})`;
r += `\n\n[Download ${index + 1}](${url})`;
}
});
return r;
} else if (res.choices) {
return res.choices?.at(0)?.message?.content ?? "";
} else {
return res;
}
}

async chat(options: ChatOptions) {
const visionModel = isVisionModel(options.config.model);
const dalleModel = isDalleModel(options.config.model);
const dalle2Model = isDalle2Model(options.config.model);
const dalle3Model = isDalle3Model(options.config.model);
const isSummarizeSession = options.isSummarizeSession;
const messages = options.messages.map((v) => ({
role: v.role,
content: visionModel ? v.content : getMessageTextContent(v),
Expand All @@ -98,26 +127,61 @@ export class ChatGPTApi implements LLMApi {
},
};

const requestPayload = {
messages,
stream: options.config.stream,
model: modelConfig.model,
temperature: modelConfig.temperature,
presence_penalty: modelConfig.presence_penalty,
frequency_penalty: modelConfig.frequency_penalty,
top_p: modelConfig.top_p,
// max_tokens: Math.max(modelConfig.max_tokens, 1024),
// Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore.
};
let requestPayload = {} as any;
if (dalleModel) {
requestPayload = {
model: modelConfig.model,
n: modelConfig.n,
size: modelConfig.size,
response_format: "b64_json", //url
prompt: messages.at(-1)?.content,
quality: modelConfig.quality,
style: modelConfig.style,
};
if (dalle2Model) {
delete requestPayload.style;
delete requestPayload.quality;
if (modelConfig.dall2Mode === "CreateVariation") {
delete requestPayload.prompt;
}
if (modelConfig.dall2Mode !== "Default" && options?.attachImages?.[0]) {
try {
requestPayload.image = await getBlobUrl2File(
options?.attachImages?.[0],
);
} catch {}
}
if (modelConfig.dall2Mode === "Edit" && options?.attachImages?.[1]) {
try {
requestPayload.mask = await getBlobUrl2File(
options?.attachImages?.[1],
);
} catch {}
}
}
options.config.stream = false;
} else {
requestPayload = {
messages,
stream: options.config.stream,
model: modelConfig.model,
temperature: modelConfig.temperature,
presence_penalty: modelConfig.presence_penalty,
frequency_penalty: modelConfig.frequency_penalty,
top_p: modelConfig.top_p,
// max_tokens: Math.max(modelConfig.max_tokens, 1024),
// Please do not ask me why not send max_tokens, no reason, this param is just shit, I dont want to explain anymore.
};

// add max_tokens to vision model
if (visionModel) {
Object.defineProperty(requestPayload, "max_tokens", {
enumerable: true,
configurable: true,
writable: true,
value: modelConfig.max_tokens,
});
// add max_tokens to vision model
if (visionModel) {
Object.defineProperty(requestPayload, "max_tokens", {
enumerable: true,
configurable: true,
writable: true,
value: modelConfig.max_tokens,
});
}
}

console.log("[Request] openai payload: ", requestPayload);
Expand All @@ -127,12 +191,34 @@ export class ChatGPTApi implements LLMApi {
options.onController?.(controller);

try {
const chatPath = this.path(OpenaiPath.ChatPath);
let openaiUrl = "ChatPath" as keyof typeof OpenaiPath;
if (dalleModel) {
if (dalle3Model) {
openaiUrl = "createImgPath";
} else {
modelConfig.dall2Mode === "Default" && (openaiUrl = "createImgPath");
modelConfig.dall2Mode === "Edit" && (openaiUrl = "createEditPath");
modelConfig.dall2Mode === "CreateVariation" &&
(openaiUrl = "createVariationionsPath");
}
}
const chatPath = this.path(OpenaiPath[openaiUrl]);

const headers = getHeaders();
let body: any = JSON.stringify(requestPayload);
if (dalle2Model && modelConfig.dall2Mode !== "Default") {
delete headers["Content-Type"];
const formData = new FormData();
for (const key in requestPayload) {
formData.append(key, requestPayload[key]);
}
body = formData;
}
const chatPayload = {
method: "POST",
body: JSON.stringify(requestPayload),
body: body,
signal: controller.signal,
headers: getHeaders(),
headers: headers,
};

// make a fetch request
Expand Down Expand Up @@ -217,7 +303,9 @@ export class ChatGPTApi implements LLMApi {
responseTexts.push(extraInfo);
}

responseText = responseTexts.join("\n\n");
responseText = isSummarizeSession
? ""
: responseTexts.join("\n\n");

return finish();
}
Expand Down Expand Up @@ -255,12 +343,41 @@ export class ChatGPTApi implements LLMApi {
openWhenHidden: true,
});
} else {
const res = await fetch(chatPath, chatPayload);
clearTimeout(requestTimeoutId);
try {
const res = await fetch(chatPath, chatPayload);
clearTimeout(requestTimeoutId);
const contentType = res.headers.get("content-type");
const responseTexts = [];
if (contentType?.startsWith("text/plain")) {
const responseText = await res.clone().text();
return options.onFinish(responseText);
}
if (!res.ok || res.status !== 200) {
let extraInfo = await res.clone().text();
try {
const resJson = await res.clone().json();
extraInfo = prettyObject(resJson);
} catch {}

if (res.status === 401) {
responseTexts.push(Locale.Error.Unauthorized);
}

const resJson = await res.json();
const message = this.extractMessage(resJson);
options.onFinish(message);
if (extraInfo) {
responseTexts.push(extraInfo);
}
options.onFinish(
isSummarizeSession ? "" : responseTexts.join("\n\n"),
);
} else {
const response = await res.json();
const message = this.extractMessage(response);
options.onFinish(message);
}
} catch (e) {
clearTimeout(requestTimeoutId);
throw e;
}
}
} catch (e) {
console.log("[Request] failed to make a chat request", e);
Expand Down
62 changes: 53 additions & 9 deletions app/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ import {
getMessageImages,
isVisionModel,
compressImage,
createObjectURL,
isNotDalle2DefaultMode,
isDalle2VariationMode,
isDalle2EditMode,
} from "../utils";

import dynamic from "next/dynamic";
Expand Down Expand Up @@ -447,6 +451,8 @@ export function ChatActions(props: {

// switch model
const currentModel = chatStore.currentSession().mask.modelConfig.model;
const currentDall2Mode =
chatStore.currentSession().mask.modelConfig.dall2Mode;
const allModels = useAllModels();
const models = useMemo(
() => allModels.filter((m) => m.available),
Expand All @@ -456,7 +462,11 @@ export function ChatActions(props: {
const [showUploadImage, setShowUploadImage] = useState(false);

useEffect(() => {
const show = isVisionModel(currentModel);
const isDalle2DefaultModeV = isNotDalle2DefaultMode(
currentModel,
currentDall2Mode,
);
const show = isVisionModel(currentModel) || isDalle2DefaultModeV;
setShowUploadImage(show);
if (!show) {
props.setAttachImages([]);
Expand Down Expand Up @@ -679,6 +689,18 @@ function _Chat() {
const navigate = useNavigate();
const [attachImages, setAttachImages] = useState<string[]>([]);
const [uploading, setUploading] = useState(false);
const isDalle2VariationModeV = isDalle2VariationMode(
session.mask.modelConfig.model,
session.mask.modelConfig.dall2Mode,
);
const isDalle2EditModeV = isDalle2EditMode(
session.mask.modelConfig.model,
session.mask.modelConfig.dall2Mode,
);
const isDalle2DefaultModeV = isNotDalle2DefaultMode(
session.mask.modelConfig.model,
session.mask.modelConfig.dall2Mode,
);

// prompt hints
const promptStore = usePromptStore();
Expand Down Expand Up @@ -747,7 +769,10 @@ function _Chat() {
};

const doSubmit = (userInput: string) => {
if (userInput.trim() === "") return;
if (userInput.trim() === "" && !isDalle2VariationModeV) return;
if (isDalle2DefaultModeV) {
if (attachImages.length === 0) return;
}
const matchCommand = chatCommands.match(userInput);
if (matchCommand.matched) {
setUserInput("");
Expand Down Expand Up @@ -1102,11 +1127,13 @@ function _Chat() {
};
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);

const handlePaste = useCallback(
async (event: React.ClipboardEvent<HTMLTextAreaElement>) => {
const currentModel = chatStore.currentSession().mask.modelConfig.model;
if(!isVisionModel(currentModel)){return;}
if (!isVisionModel(currentModel)) {
return;
}
const items = (event.clipboardData || window.clipboardData).items;
for (const item of items) {
if (item.kind === "file" && item.type.startsWith("image/")) {
Expand All @@ -1119,7 +1146,13 @@ function _Chat() {
...(await new Promise<string[]>((res, rej) => {
setUploading(true);
const imagesData: string[] = [];
compressImage(file, 256 * 1024)
let promiseFn;
if (isDalle2DefaultModeV) {
promiseFn = Promise.resolve(createObjectURL(file));
} else {
promiseFn = compressImage(file, 256 * 1024);
}
promiseFn
.then((dataUrl) => {
imagesData.push(dataUrl);
setUploading(false);
Expand Down Expand Up @@ -1161,7 +1194,13 @@ function _Chat() {
const imagesData: string[] = [];
for (let i = 0; i < files.length; i++) {
const file = event.target.files[i];
compressImage(file, 256 * 1024)
let promiseFn;
if (isDalle2DefaultModeV) {
promiseFn = Promise.resolve(createObjectURL(file));
} else {
promiseFn = compressImage(file, 256 * 1024);
}
promiseFn
.then((dataUrl) => {
imagesData.push(dataUrl);
if (
Expand Down Expand Up @@ -1410,7 +1449,7 @@ function _Chat() {
<img
className={styles["chat-message-item-image"]}
src={getMessageImages(message)[0]}
alt=""
alt="img"
/>
)}
{getMessageImages(message).length > 1 && (
Expand All @@ -1430,7 +1469,7 @@ function _Chat() {
}
key={index}
src={image}
alt=""
alt="img"
/>
);
})}
Expand Down Expand Up @@ -1485,8 +1524,13 @@ function _Chat() {
<textarea
id="chat-input"
ref={inputRef}
disabled={isDalle2VariationModeV}
className={styles["chat-input"]}
placeholder={Locale.Chat.Input(submitKey)}
placeholder={Locale.Chat.Input(
submitKey,
isDalle2VariationModeV,
isDalle2EditModeV,
)}
onInput={(e) => onInput(e.currentTarget.value)}
value={userInput}
onKeyDown={onInputKeyDown}
Expand Down
1 change: 1 addition & 0 deletions app/components/markdown.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ function _MarkDownContent(props: { content: string }) {
return <a {...aProps} target={target} />;
},
}}
transformLinkUri={(uri) => uri}
>
{escapedContent}
</ReactMarkdown>
Expand Down
Loading

0 comments on commit ba312bb

Please sign in to comment.