diff --git a/packages/extensions/ai-widget/src/type.ts b/packages/extensions/ai-widget/src/type.ts
index 67d6cda..458dea8 100644
--- a/packages/extensions/ai-widget/src/type.ts
+++ b/packages/extensions/ai-widget/src/type.ts
@@ -3,16 +3,16 @@ import { EditorView } from '@codemirror/view'
export type ChatReq = {
prompt: string
refContent: string
- extra?: {}
+ extra?: any
}
export type ChatRes = {
status: 'success' | 'error'
message: string
- extra?: {}
+ extra?: any
}
-type EventType =
+export type EventType =
| 'widget.open' // {source: 'hotkey' | 'placeholder' | 'fix_sql_button' | ...}
| 'no_use_db.error'
| 'req.send' // {chatReq}
diff --git a/packages/playground/package.json b/packages/playground/package.json
index 409bfc0..7cff4f4 100644
--- a/packages/playground/package.json
+++ b/packages/playground/package.json
@@ -23,6 +23,7 @@
"@tanstack/react-query": "^5.45.1",
"@tidbcloud/codemirror-extension-ai-widget": "workspace:^",
"@tidbcloud/codemirror-extension-autocomplete": "workspace:^",
+ "@tidbcloud/codemirror-extension-cur-sql": "workspace:^",
"@tidbcloud/codemirror-extension-cur-sql-gutter": "workspace:^",
"@tidbcloud/codemirror-extension-linters": "workspace:^",
"@tidbcloud/codemirror-extension-save-helper": "workspace:^",
diff --git a/packages/playground/src/App.tsx b/packages/playground/src/App.tsx
index 1172077..57ebd42 100644
--- a/packages/playground/src/App.tsx
+++ b/packages/playground/src/App.tsx
@@ -9,6 +9,7 @@ import { FilesProvider } from '@/contexts/files-context-provider'
import { SchemaProvider } from '@/contexts/schema-context-provider'
import { EditorExample } from '@/examples/editor-example'
+import { ChatProvider } from './contexts/chat-context-provider'
const queryClient = new QueryClient()
@@ -20,7 +21,9 @@ function Full() {
-
+
+
+
diff --git a/packages/playground/src/api/tidbcloud/chat-api.ts b/packages/playground/src/api/tidbcloud/chat-api.ts
new file mode 100644
index 0000000..40cfcee
--- /dev/null
+++ b/packages/playground/src/api/tidbcloud/chat-api.ts
@@ -0,0 +1,182 @@
+import { delay } from '@/lib/delay'
+import { ChatRes } from '@tidbcloud/codemirror-extension-ai-widget'
+
+//---------------------------------
+// a simple way to cancel loop query
+
+const queryingJobs: Map = new Map()
+
+export function cancelChat(chatId: string) {
+ queryingJobs.delete(chatId)
+}
+
+//---------------------------------
+
+async function queryJobStatus(jobId: string) {
+ // res example:
+ // ---------
+ // for refine sql job:
+ // {
+ // "code": 200,
+ // "msg": "",
+ // "result": {
+ // "ended_at": 1719471325,
+ // "job_id": "f80099a82d20475d8da24b20dc817e67",
+ // "reason": "",
+ // "result": {
+ // "rewritten_sql": "SELECT * FROM `games` ORDER BY `recommendations` DESC LIMIT 10;",
+ // "solution": "To address the user feedback of limiting the results to 10 instead of 20, we can simply update the LIMIT clause in the SQL query."
+ // },
+ // "status": "done"
+ // }
+ // }
+ // ---------
+ // for chat2data job:
+ // {
+ // "code": 200,
+ // "msg": "",
+ // "result": {
+ // "ended_at": 1719461464,
+ // "job_id": "f6f46a745a264a50bf0e60163b670b9d",
+ // "reason": "",
+ // "result": {
+ // "sql": "SELECT * FROM `games` ORDER BY `recommendations` DESC LIMIT 20;",
+ // "sql_error": null
+ // // ...
+ // },
+ // "status": "done"
+ // }
+ // }
+ return fetch(`/api/jobs?job_id=${jobId}`)
+ .then((res) => {
+ if (res.status >= 400 || res.status < 200) {
+ return res.json().then((d) => {
+ throw new Error(d.msg)
+ })
+ }
+ return res
+ })
+ .then((res) => res.json())
+ .then((d) => d.result)
+}
+
+async function loopQueryJob(chatId: string, jobId: string): Promise {
+ // only try 5 times to reduce rate limit (current 100 times a day)
+ let i = 5
+ while (i > 0) {
+ i--
+
+ // check whether job is canceled
+ if (!queryingJobs.get(chatId)) {
+ return { status: 'error', message: 'chat is canceled', extra: {} }
+ }
+
+ const jobRes = await queryJobStatus(jobId)
+ if (jobRes.status === 'done') {
+ return {
+ status: 'success',
+ message: jobRes.result.rewritten_sql ?? jobRes.result.sql ?? '',
+ extra: {}
+ }
+ } else if (jobRes.status === 'failed') {
+ return { status: 'error', message: jobRes.reason, extra: {} }
+ }
+ await delay(10 * 1000)
+ }
+ throw new Error('Request timed out. Please try again.')
+}
+
+//-----------------
+
+type Chat2DataReq = {
+ database: string
+ question: string
+}
+
+export async function chat2data(
+ chatId: string,
+ params: Chat2DataReq
+): Promise {
+ queryingJobs.set(chatId, true)
+
+ try {
+ // res example:
+ // {
+ // "code": 200,
+ // "msg": "",
+ // "result": {
+ // "cluster_id": "xxx",
+ // "database": "game",
+ // "job_id": "yyy",
+ // "session_id": zzz
+ // }
+ // }
+ const res = await fetch(`/api/chat2data`, {
+ method: 'POST',
+ body: JSON.stringify(params)
+ })
+ .then((res) => {
+ if (res.status >= 400 || res.status < 200) {
+ return res.json().then((d) => {
+ throw new Error(d.msg)
+ })
+ }
+ return res
+ })
+ .then((res) => res.json())
+ .then((d) => d.result)
+
+ const jobId = res.job_id
+ const jobRes = await loopQueryJob(chatId, jobId)
+
+ return jobRes
+ } catch (error: any) {
+ return { status: 'error', message: error.message, extra: {} }
+ }
+}
+
+type RefineSqlReq = {
+ database: string
+ sql: string
+ feedback: string
+}
+
+export async function refineSql(
+ chatId: string,
+ params: RefineSqlReq
+): Promise {
+ queryingJobs.set(chatId, true)
+
+ try {
+ // res example:
+ // {
+ // "code": 200,
+ // "msg": "",
+ // "result": {
+ // "job_id": "xxx",
+ // "session_id": "yyy"
+ // }
+ // }
+ const res = await fetch(`/api/refineSql`, {
+ method: 'POST',
+ body: JSON.stringify(params)
+ })
+ .then((res) => {
+ if (res.status >= 400 || res.status < 200) {
+ return res.json().then((d) => {
+ throw new Error(d.msg)
+ })
+ }
+ return res
+ })
+ .then((res) => res.json())
+ .then((d) => d.result)
+
+ const jobId = res.job_id
+ const jobRes = await loopQueryJob(chatId, jobId)
+
+ return jobRes
+ } catch (error: any) {
+ return { status: 'error', message: error.message, extra: {} }
+ }
+}
diff --git a/packages/playground/src/api/tidbcloud/schema-api.ts b/packages/playground/src/api/tidbcloud/schema-api.ts
index 230d018..8065d72 100644
--- a/packages/playground/src/api/tidbcloud/schema-api.ts
+++ b/packages/playground/src/api/tidbcloud/schema-api.ts
@@ -2,6 +2,14 @@ import { SchemaRes } from '@/contexts/schema-context'
export async function getSchema(): Promise {
return fetch(`/api/schema`)
+ .then((res) => {
+ if (res.status >= 400 || res.status < 200) {
+ return res.json().then((d) => {
+ throw new Error(d.message)
+ })
+ }
+ return res
+ })
.then((res) => res.json())
.then((d) => d.data)
}
diff --git a/packages/playground/src/components/biz/editor-panel/editor.tsx b/packages/playground/src/components/biz/editor-panel/editor.tsx
index c4ed91e..54d20b0 100644
--- a/packages/playground/src/components/biz/editor-panel/editor.tsx
+++ b/packages/playground/src/components/biz/editor-panel/editor.tsx
@@ -1,4 +1,5 @@
-import { useMemo } from 'react'
+import { useMemo, useRef } from 'react'
+
import { EditorView, placeholder } from '@codemirror/view'
import { EditorState } from '@codemirror/state'
import { SQLConfig } from '@codemirror/lang-sql'
@@ -16,11 +17,12 @@ import {
aiWidget,
isUnifiedMergeViewActive
} from '@tidbcloud/codemirror-extension-ai-widget'
+import { getCurDatabase } from '@tidbcloud/codemirror-extension-cur-sql'
import { useFilesContext } from '@/contexts/files-context'
import { useTheme } from '@/components/darkmode-toggle/theme-provider'
import { SchemaRes, useSchemaContext } from '@/contexts/schema-context'
-import { delay } from '@/lib/delay'
+import { useChatContext } from '@/contexts/chat-context'
function convertSchemaToSQLConfig(dbList: SchemaRes): SQLConfig {
const schema: any = {}
@@ -63,6 +65,12 @@ export function Editor() {
() => convertSchemaToSQLConfig(schema ?? []),
[schema]
)
+ const getDbListRef = useRef<() => string[]>()
+ getDbListRef.current = () => {
+ return schema?.map((d) => d.name) || []
+ }
+
+ const chatCtx = useChatContext()
const activeFile = useMemo(
() => openedFiles.find((f) => f.id === activeFileId),
@@ -90,14 +98,16 @@ export function Editor() {
}),
fullWidthCharLinter(),
aiWidget({
- chat: async () => {
- await delay(2000)
- return { status: 'success', message: 'select * from test;' }
+ chat(view, chatId, req) {
+ const db = getCurDatabase(view.state)
+ req['extra']['db'] = db
+ return chatCtx.chat(chatId, { ...req })
},
- cancelChat: () => {},
- getDbList: () => {
- return ['test1', 'test2']
- }
+ cancelChat: chatCtx.cancelChat,
+ onEvent(_view, type, payload) {
+ chatCtx.onEvent(type, payload)
+ },
+ getDbList: getDbListRef.current!
})
]
}
diff --git a/packages/playground/src/contexts/chat-context-provider.tsx b/packages/playground/src/contexts/chat-context-provider.tsx
new file mode 100644
index 0000000..611f43d
--- /dev/null
+++ b/packages/playground/src/contexts/chat-context-provider.tsx
@@ -0,0 +1,41 @@
+import { useMemo } from 'react'
+import { ChatReq, EventType } from '@tidbcloud/codemirror-extension-ai-widget'
+
+import { cancelChat, chat2data, refineSql } from '@/api/tidbcloud/chat-api'
+import { ChatContext } from './chat-context'
+
+function chat(chatId: string, req: ChatReq) {
+ if (req.refContent === '') {
+ return chat2data(chatId, {
+ database: req.extra?.db ?? '',
+ question: req.prompt
+ })
+ }
+ return refineSql(chatId, {
+ database: req.extra?.db ?? '',
+ feedback: req.prompt,
+ sql: req.refContent
+ })
+}
+
+function onEvent(event: EventType, payload?: {}) {
+ console.log('event:', event)
+ console.log('payload:', payload)
+}
+
+export function ChatProvider(props: { children: React.ReactNode }) {
+ const ctxValue = useMemo(
+ () => ({
+ chat,
+ cancelChat,
+ onEvent
+ }),
+ []
+ )
+
+ return (
+
+ {props.children}
+
+ )
+}
diff --git a/packages/playground/src/contexts/chat-context.tsx b/packages/playground/src/contexts/chat-context.tsx
new file mode 100644
index 0000000..f900169
--- /dev/null
+++ b/packages/playground/src/contexts/chat-context.tsx
@@ -0,0 +1,25 @@
+import { createContext, useContext } from 'react'
+
+import {
+ ChatReq,
+ ChatRes,
+ EventType
+} from '@tidbcloud/codemirror-extension-ai-widget'
+
+type ChatCtxValue = {
+ chat: (chatId: string, req: ChatReq) => Promise
+ cancelChat: (chatId: string) => void
+ onEvent: (event: EventType, payload?: {}) => void
+}
+
+export const ChatContext = createContext(null)
+
+export const useChatContext = () => {
+ const context = useContext(ChatContext)
+
+ if (!context) {
+ throw new Error('useChatContext must be used within a provider')
+ }
+
+ return context
+}
diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml
index 3b3a938..4357ddf 100644
--- a/pnpm-lock.yaml
+++ b/pnpm-lock.yaml
@@ -426,6 +426,9 @@ importers:
'@tidbcloud/codemirror-extension-autocomplete':
specifier: workspace:^
version: link:../extensions/autocomplete
+ '@tidbcloud/codemirror-extension-cur-sql':
+ specifier: workspace:^
+ version: link:../extensions/cur-sql
'@tidbcloud/codemirror-extension-cur-sql-gutter':
specifier: workspace:^
version: link:../extensions/cur-sql-gutter