diff --git a/packages/extensions/ai-widget/src/type.ts b/packages/extensions/ai-widget/src/type.ts index 67d6cda..458dea8 100644 --- a/packages/extensions/ai-widget/src/type.ts +++ b/packages/extensions/ai-widget/src/type.ts @@ -3,16 +3,16 @@ import { EditorView } from '@codemirror/view' export type ChatReq = { prompt: string refContent: string - extra?: {} + extra?: any } export type ChatRes = { status: 'success' | 'error' message: string - extra?: {} + extra?: any } -type EventType = +export type EventType = | 'widget.open' // {source: 'hotkey' | 'placeholder' | 'fix_sql_button' | ...} | 'no_use_db.error' | 'req.send' // {chatReq} diff --git a/packages/playground/package.json b/packages/playground/package.json index 409bfc0..7cff4f4 100644 --- a/packages/playground/package.json +++ b/packages/playground/package.json @@ -23,6 +23,7 @@ "@tanstack/react-query": "^5.45.1", "@tidbcloud/codemirror-extension-ai-widget": "workspace:^", "@tidbcloud/codemirror-extension-autocomplete": "workspace:^", + "@tidbcloud/codemirror-extension-cur-sql": "workspace:^", "@tidbcloud/codemirror-extension-cur-sql-gutter": "workspace:^", "@tidbcloud/codemirror-extension-linters": "workspace:^", "@tidbcloud/codemirror-extension-save-helper": "workspace:^", diff --git a/packages/playground/src/App.tsx b/packages/playground/src/App.tsx index 1172077..57ebd42 100644 --- a/packages/playground/src/App.tsx +++ b/packages/playground/src/App.tsx @@ -9,6 +9,7 @@ import { FilesProvider } from '@/contexts/files-context-provider' import { SchemaProvider } from '@/contexts/schema-context-provider' import { EditorExample } from '@/examples/editor-example' +import { ChatProvider } from './contexts/chat-context-provider' const queryClient = new QueryClient() @@ -20,7 +21,9 @@ function Full() { - + + + diff --git a/packages/playground/src/api/tidbcloud/chat-api.ts b/packages/playground/src/api/tidbcloud/chat-api.ts new file mode 100644 index 0000000..40cfcee --- /dev/null +++ b/packages/playground/src/api/tidbcloud/chat-api.ts @@ -0,0 +1,182 @@ +import { delay } from '@/lib/delay' +import { ChatRes } from '@tidbcloud/codemirror-extension-ai-widget' + +//--------------------------------- +// a simple way to cancel loop query + +const queryingJobs: Map = new Map() + +export function cancelChat(chatId: string) { + queryingJobs.delete(chatId) +} + +//--------------------------------- + +async function queryJobStatus(jobId: string) { + // res example: + // --------- + // for refine sql job: + // { + // "code": 200, + // "msg": "", + // "result": { + // "ended_at": 1719471325, + // "job_id": "f80099a82d20475d8da24b20dc817e67", + // "reason": "", + // "result": { + // "rewritten_sql": "SELECT * FROM `games` ORDER BY `recommendations` DESC LIMIT 10;", + // "solution": "To address the user feedback of limiting the results to 10 instead of 20, we can simply update the LIMIT clause in the SQL query." + // }, + // "status": "done" + // } + // } + // --------- + // for chat2data job: + // { + // "code": 200, + // "msg": "", + // "result": { + // "ended_at": 1719461464, + // "job_id": "f6f46a745a264a50bf0e60163b670b9d", + // "reason": "", + // "result": { + // "sql": "SELECT * FROM `games` ORDER BY `recommendations` DESC LIMIT 20;", + // "sql_error": null + // // ... + // }, + // "status": "done" + // } + // } + return fetch(`/api/jobs?job_id=${jobId}`) + .then((res) => { + if (res.status >= 400 || res.status < 200) { + return res.json().then((d) => { + throw new Error(d.msg) + }) + } + return res + }) + .then((res) => res.json()) + .then((d) => d.result) +} + +async function loopQueryJob(chatId: string, jobId: string): Promise { + // only try 5 times to reduce rate limit (current 100 times a day) + let i = 5 + while (i > 0) { + i-- + + // check whether job is canceled + if (!queryingJobs.get(chatId)) { + return { status: 'error', message: 'chat is canceled', extra: {} } + } + + const jobRes = await queryJobStatus(jobId) + if (jobRes.status === 'done') { + return { + status: 'success', + message: jobRes.result.rewritten_sql ?? jobRes.result.sql ?? '', + extra: {} + } + } else if (jobRes.status === 'failed') { + return { status: 'error', message: jobRes.reason, extra: {} } + } + await delay(10 * 1000) + } + throw new Error('Request timed out. Please try again.') +} + +//----------------- + +type Chat2DataReq = { + database: string + question: string +} + +export async function chat2data( + chatId: string, + params: Chat2DataReq +): Promise { + queryingJobs.set(chatId, true) + + try { + // res example: + // { + // "code": 200, + // "msg": "", + // "result": { + // "cluster_id": "xxx", + // "database": "game", + // "job_id": "yyy", + // "session_id": zzz + // } + // } + const res = await fetch(`/api/chat2data`, { + method: 'POST', + body: JSON.stringify(params) + }) + .then((res) => { + if (res.status >= 400 || res.status < 200) { + return res.json().then((d) => { + throw new Error(d.msg) + }) + } + return res + }) + .then((res) => res.json()) + .then((d) => d.result) + + const jobId = res.job_id + const jobRes = await loopQueryJob(chatId, jobId) + + return jobRes + } catch (error: any) { + return { status: 'error', message: error.message, extra: {} } + } +} + +type RefineSqlReq = { + database: string + sql: string + feedback: string +} + +export async function refineSql( + chatId: string, + params: RefineSqlReq +): Promise { + queryingJobs.set(chatId, true) + + try { + // res example: + // { + // "code": 200, + // "msg": "", + // "result": { + // "job_id": "xxx", + // "session_id": "yyy" + // } + // } + const res = await fetch(`/api/refineSql`, { + method: 'POST', + body: JSON.stringify(params) + }) + .then((res) => { + if (res.status >= 400 || res.status < 200) { + return res.json().then((d) => { + throw new Error(d.msg) + }) + } + return res + }) + .then((res) => res.json()) + .then((d) => d.result) + + const jobId = res.job_id + const jobRes = await loopQueryJob(chatId, jobId) + + return jobRes + } catch (error: any) { + return { status: 'error', message: error.message, extra: {} } + } +} diff --git a/packages/playground/src/api/tidbcloud/schema-api.ts b/packages/playground/src/api/tidbcloud/schema-api.ts index 230d018..8065d72 100644 --- a/packages/playground/src/api/tidbcloud/schema-api.ts +++ b/packages/playground/src/api/tidbcloud/schema-api.ts @@ -2,6 +2,14 @@ import { SchemaRes } from '@/contexts/schema-context' export async function getSchema(): Promise { return fetch(`/api/schema`) + .then((res) => { + if (res.status >= 400 || res.status < 200) { + return res.json().then((d) => { + throw new Error(d.message) + }) + } + return res + }) .then((res) => res.json()) .then((d) => d.data) } diff --git a/packages/playground/src/components/biz/editor-panel/editor.tsx b/packages/playground/src/components/biz/editor-panel/editor.tsx index c4ed91e..54d20b0 100644 --- a/packages/playground/src/components/biz/editor-panel/editor.tsx +++ b/packages/playground/src/components/biz/editor-panel/editor.tsx @@ -1,4 +1,5 @@ -import { useMemo } from 'react' +import { useMemo, useRef } from 'react' + import { EditorView, placeholder } from '@codemirror/view' import { EditorState } from '@codemirror/state' import { SQLConfig } from '@codemirror/lang-sql' @@ -16,11 +17,12 @@ import { aiWidget, isUnifiedMergeViewActive } from '@tidbcloud/codemirror-extension-ai-widget' +import { getCurDatabase } from '@tidbcloud/codemirror-extension-cur-sql' import { useFilesContext } from '@/contexts/files-context' import { useTheme } from '@/components/darkmode-toggle/theme-provider' import { SchemaRes, useSchemaContext } from '@/contexts/schema-context' -import { delay } from '@/lib/delay' +import { useChatContext } from '@/contexts/chat-context' function convertSchemaToSQLConfig(dbList: SchemaRes): SQLConfig { const schema: any = {} @@ -63,6 +65,12 @@ export function Editor() { () => convertSchemaToSQLConfig(schema ?? []), [schema] ) + const getDbListRef = useRef<() => string[]>() + getDbListRef.current = () => { + return schema?.map((d) => d.name) || [] + } + + const chatCtx = useChatContext() const activeFile = useMemo( () => openedFiles.find((f) => f.id === activeFileId), @@ -90,14 +98,16 @@ export function Editor() { }), fullWidthCharLinter(), aiWidget({ - chat: async () => { - await delay(2000) - return { status: 'success', message: 'select * from test;' } + chat(view, chatId, req) { + const db = getCurDatabase(view.state) + req['extra']['db'] = db + return chatCtx.chat(chatId, { ...req }) }, - cancelChat: () => {}, - getDbList: () => { - return ['test1', 'test2'] - } + cancelChat: chatCtx.cancelChat, + onEvent(_view, type, payload) { + chatCtx.onEvent(type, payload) + }, + getDbList: getDbListRef.current! }) ] } diff --git a/packages/playground/src/contexts/chat-context-provider.tsx b/packages/playground/src/contexts/chat-context-provider.tsx new file mode 100644 index 0000000..611f43d --- /dev/null +++ b/packages/playground/src/contexts/chat-context-provider.tsx @@ -0,0 +1,41 @@ +import { useMemo } from 'react' +import { ChatReq, EventType } from '@tidbcloud/codemirror-extension-ai-widget' + +import { cancelChat, chat2data, refineSql } from '@/api/tidbcloud/chat-api' +import { ChatContext } from './chat-context' + +function chat(chatId: string, req: ChatReq) { + if (req.refContent === '') { + return chat2data(chatId, { + database: req.extra?.db ?? '', + question: req.prompt + }) + } + return refineSql(chatId, { + database: req.extra?.db ?? '', + feedback: req.prompt, + sql: req.refContent + }) +} + +function onEvent(event: EventType, payload?: {}) { + console.log('event:', event) + console.log('payload:', payload) +} + +export function ChatProvider(props: { children: React.ReactNode }) { + const ctxValue = useMemo( + () => ({ + chat, + cancelChat, + onEvent + }), + [] + ) + + return ( + + {props.children} + + ) +} diff --git a/packages/playground/src/contexts/chat-context.tsx b/packages/playground/src/contexts/chat-context.tsx new file mode 100644 index 0000000..f900169 --- /dev/null +++ b/packages/playground/src/contexts/chat-context.tsx @@ -0,0 +1,25 @@ +import { createContext, useContext } from 'react' + +import { + ChatReq, + ChatRes, + EventType +} from '@tidbcloud/codemirror-extension-ai-widget' + +type ChatCtxValue = { + chat: (chatId: string, req: ChatReq) => Promise + cancelChat: (chatId: string) => void + onEvent: (event: EventType, payload?: {}) => void +} + +export const ChatContext = createContext(null) + +export const useChatContext = () => { + const context = useContext(ChatContext) + + if (!context) { + throw new Error('useChatContext must be used within a provider') + } + + return context +} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 3b3a938..4357ddf 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -426,6 +426,9 @@ importers: '@tidbcloud/codemirror-extension-autocomplete': specifier: workspace:^ version: link:../extensions/autocomplete + '@tidbcloud/codemirror-extension-cur-sql': + specifier: workspace:^ + version: link:../extensions/cur-sql '@tidbcloud/codemirror-extension-cur-sql-gutter': specifier: workspace:^ version: link:../extensions/cur-sql-gutter