diff --git a/src/app/(theme)/client/[[...driver]]/page-client.tsx b/src/app/(theme)/client/[[...driver]]/page-client.tsx index 0017a1a2..cafaf051 100644 --- a/src/app/(theme)/client/[[...driver]]/page-client.tsx +++ b/src/app/(theme)/client/[[...driver]]/page-client.tsx @@ -10,6 +10,7 @@ import ValtownDriver from "@/drivers/valtown-driver"; import MyStudio from "@/components/my-studio"; import CloudflareD1Driver from "@/drivers/cloudflare-d1-driver"; +import StarbaseDriver from "@/drivers/starbase-driver"; export default function ClientPageBody() { const driver = useMemo(() => { @@ -27,7 +28,13 @@ export default function ClientPageBody() { "x-account-id": config.username ?? "", "x-database-id": config.database ?? "", }); + } else if (config.driver === "starbase") { + return new StarbaseDriver("/proxy/starbase", { + Authorization: "Bearer " + (config.token ?? ""), + "x-starbase-url": config.url ?? "", + }); } + return new TursoDriver(config.url, config.token as string, true); }, []); diff --git a/src/app/(theme)/client/s/[[...driver]]/page-client.tsx b/src/app/(theme)/client/s/[[...driver]]/page-client.tsx index a3b25389..233a6e60 100644 --- a/src/app/(theme)/client/s/[[...driver]]/page-client.tsx +++ b/src/app/(theme)/client/s/[[...driver]]/page-client.tsx @@ -8,6 +8,7 @@ import { useMemo } from "react"; import MyStudio from "@/components/my-studio"; import IndexdbSavedDoc from "@/drivers/saved-doc/indexdb-saved-doc"; import CloudflareD1Driver from "@/drivers/cloudflare-d1-driver"; +import StarbaseDriver from "@/drivers/starbase-driver"; export default function ClientPageBody() { const params = useSearchParams(); @@ -35,6 +36,11 @@ export default function ClientPageBody() { "x-account-id": conn.config.username ?? "", "x-database-id": conn.config.database ?? "", }); + } else if (conn.driver === "starbase") { + return new StarbaseDriver("/proxy/starbase", { + Authorization: "Bearer " + (conn.config.token ?? ""), + "x-starbase-url": conn.config.url ?? "", + }); } return new TursoDriver(conn.config.url, conn.config.token, true); diff --git a/src/app/(theme)/connect/driver-dropdown.tsx b/src/app/(theme)/connect/driver-dropdown.tsx index b2d51e1b..de6ed00c 100644 --- a/src/app/(theme)/connect/driver-dropdown.tsx +++ b/src/app/(theme)/connect/driver-dropdown.tsx @@ -84,6 +84,17 @@ export default function DriverDropdown({ + { + onSelect("starbase"); + }} + > +
+ +
StarbaseDB
+
+
+ = }, ], }, + starbase: { + name: "starbase", + displayName: "Starbase", + icon: SQLiteIcon, + disableRemote: true, + prefill: "", + fields: [ + { + name: "url", + title: "Endpoint", + required: true, + type: "text", + secret: false, + invalidate: (url: string): null | string => { + const trimmedUrl = url.trim(); + const valid = + trimmedUrl.startsWith("https://") || + trimmedUrl.startsWith("http://"); + + if (!valid) { + return "Endpoint must start with https:// or http://"; + } + + return null; + }, + }, + { + name: "token", + title: "API Token", + required: true, + type: "text", + secret: true, + }, + ], + }, "cloudflare-d1": { name: "cloudflare-d1", displayName: "Cloudflare D1", @@ -211,8 +246,10 @@ export type SupportedDriver = | "turso" | "rqlite" | "valtown" + | "starbase" | "cloudflare-d1" | "sqlite-filehandler"; + export type SavedConnectionStorage = "remote" | "local"; export type SavedConnectionLabel = "gray" | "red" | "yellow" | "green" | "blue"; diff --git a/src/app/proxy/starbase/route.ts b/src/app/proxy/starbase/route.ts new file mode 100644 index 00000000..37178df8 --- /dev/null +++ b/src/app/proxy/starbase/route.ts @@ -0,0 +1,62 @@ +import { HttpStatus } from "@/constants/http-status"; +import { headers } from "next/headers"; +import { NextRequest, NextResponse } from "next/server"; + +export const runtime = "edge"; + +export async function POST(req: NextRequest) { + // Get the account id and database id from header + const endpoint = headers().get("x-starbase-url"); + + if (!endpoint) { + return NextResponse.json( + { + error: "Please provide account id or database id", + }, + { status: HttpStatus.BAD_REQUEST } + ); + } + + const authorizationHeader = headers().get("Authorization"); + if (!authorizationHeader) { + return NextResponse.json( + { + error: "Please provide authorization header", + }, + { status: HttpStatus.BAD_REQUEST } + ); + } + + try { + const url = `${endpoint.replace(/\/$/, "")}/query/raw`; + + const response: { errors: { message: string }[] } = await ( + await fetch(url, { + method: "POST", + headers: { + Authorization: authorizationHeader, + "Content-Type": "application/json", + }, + body: JSON.stringify(await req.json()), + }) + ).json(); + + if (response.errors && response.errors.length > 0) { + return NextResponse.json( + { + error: response.errors[0].message, + }, + { status: HttpStatus.INTERNAL_SERVER_ERROR } + ); + } + + return NextResponse.json(response); + } catch (e) { + return NextResponse.json( + { + error: (e as Error).message, + }, + { status: HttpStatus.BAD_REQUEST } + ); + } +} diff --git a/src/components/gui/main-connection.tsx b/src/components/gui/main-connection.tsx index 55fc1cfc..ad84ef15 100644 --- a/src/components/gui/main-connection.tsx +++ b/src/components/gui/main-connection.tsx @@ -38,7 +38,7 @@ function MainConnectionContainer() { */ useLayoutEffect(() => { console.info("Injecting message into window object"); - window.internalPubSub = new InternalPubSub(); + if (!window.internalPubSub) window.internalPubSub = new InternalPubSub(); }, [driver]); useEffect(() => { diff --git a/src/components/gui/query-result.tsx b/src/components/gui/query-result.tsx index c778c6ea..a037955c 100644 --- a/src/components/gui/query-result.tsx +++ b/src/components/gui/query-result.tsx @@ -20,7 +20,10 @@ export default function QueryResult({ return { _tag: "EXPLAIN", value: result.result } as const; } - const state = OptimizeTableState.createFromResult(result.result); + const state = OptimizeTableState.createFromResult( + databaseDriver, + result.result + ); state.setReadOnlyMode(true); state.mismatchDetection = databaseDriver.getFlags().mismatchDetection; return { _tag: "QUERY", value: state } as const; diff --git a/src/components/gui/table-optimized/OptimizeTableState.tsx b/src/components/gui/table-optimized/OptimizeTableState.tsx index 7e374d27..edb58c4b 100644 --- a/src/components/gui/table-optimized/OptimizeTableState.tsx +++ b/src/components/gui/table-optimized/OptimizeTableState.tsx @@ -2,6 +2,7 @@ import { selectArrayFromIndexList } from "@/components/lib/export-helper"; import { OptimizeTableHeaderProps } from "."; import { LucideKey, LucideKeySquare, LucideSigma } from "lucide-react"; import { + BaseDriver, DatabaseResultSet, DatabaseTableSchema, TableColumnDataType, @@ -37,14 +38,19 @@ export default class OptimizeTableState { protected changeLogs: Record = {}; static createFromResult( + driver: BaseDriver, dataResult: DatabaseResultSet, schemaResult?: DatabaseTableSchema ) { return new OptimizeTableState( dataResult.headers.map((header) => { + const headerData = schemaResult + ? schemaResult.columns.find((c) => c.name === header.name) + : undefined; + let initialSize = 150; const headerName = header.name; - const dataType = header.type; + const dataType = header.type ?? driver.inferTypeFromHeader(headerData); if ( dataType === TableColumnDataType.INTEGER || @@ -67,10 +73,6 @@ export default class OptimizeTableState { initialSize = Math.max(150, Math.min(500, maxSize * 8)); } - const headerData = schemaResult - ? schemaResult.columns.find((c) => c.name === header.name) - : undefined; - // -------------------------------------- // Matching foreign key // -------------------------------------- diff --git a/src/components/gui/tabs/table-data-tab.tsx b/src/components/gui/tabs/table-data-tab.tsx index 9ab0a704..a625fc0c 100644 --- a/src/components/gui/tabs/table-data-tab.tsx +++ b/src/components/gui/tabs/table-data-tab.tsx @@ -89,6 +89,7 @@ export default function TableDataWindow({ }); const tableState = OptimizeTableState.createFromResult( + databaseDriver, dataResult, schemaResult ); diff --git a/src/components/gui/toolbar.tsx b/src/components/gui/toolbar.tsx index 0c010496..4ba1997d 100644 --- a/src/components/gui/toolbar.tsx +++ b/src/components/gui/toolbar.tsx @@ -48,7 +48,7 @@ export function ToolbarButton({ if (tooltip) { return ( - {buttonContent} + {buttonContent} {tooltip} ); diff --git a/src/drivers/base-driver.ts b/src/drivers/base-driver.ts index f0990885..bdcb5b72 100644 --- a/src/drivers/base-driver.ts +++ b/src/drivers/base-driver.ts @@ -207,6 +207,12 @@ export interface DriverFlags { supportModifyColumn: boolean; mismatchDetection: boolean; dialect: SupportedDialect; + + // If database supports this, we don't need + // to make a separate request to get updated + // data when update + supportInsertReturning: boolean; + supportUpdateReturning: boolean; } export interface DatabaseTableColumnChange { @@ -252,6 +258,10 @@ export abstract class BaseDriver { tableName: string ): Promise; + abstract inferTypeFromHeader( + header?: DatabaseTableColumn + ): TableColumnDataType | undefined; + abstract trigger( schemaName: string, name: string diff --git a/src/drivers/common-sql-imp.ts b/src/drivers/common-sql-imp.ts index 217dffa7..aa8c1621 100644 --- a/src/drivers/common-sql-imp.ts +++ b/src/drivers/common-sql-imp.ts @@ -36,11 +36,25 @@ export default abstract class CommonSQLImplement extends BaseDriver { const sqls = ops.map((op) => { if (op.operation === "INSERT") - return insertInto(this, schemaName, tableName, op.values); + return insertInto( + this, + schemaName, + tableName, + op.values, + this.getFlags().supportInsertReturning + ); + if (op.operation === "DELETE") return deleteFrom(this, schemaName, tableName, op.where); - return updateTable(this, schemaName, tableName, op.values, op.where); + return updateTable( + this, + schemaName, + tableName, + op.values, + op.where, + this.getFlags().supportInsertReturning + ); }); const result = await this.transaction(sqls); @@ -57,18 +71,29 @@ export default abstract class CommonSQLImplement extends BaseDriver { } if (op.operation === "UPDATE") { - const selectResult = await this.findFirst( - schemaName, - tableName, - op.where - ); + if (r.rows.length === 1) + // This is when database support RETURNING + tmp.push({ + record: r.rows[0], + }); + else { + const selectResult = await this.findFirst( + schemaName, + tableName, + op.where + ); - tmp.push({ - lastId: r.lastInsertRowid, - record: selectResult.rows[0], - }); + tmp.push({ + lastId: r.lastInsertRowid, + record: selectResult.rows[0], + }); + } } else if (op.operation === "INSERT") { - if (op.autoIncrementPkColumn) { + if (r.rows.length === 1) { + tmp.push({ + record: r.rows[0], + }); + } else if (op.autoIncrementPkColumn) { const selectResult = await this.findFirst(schemaName, tableName, { [op.autoIncrementPkColumn]: r.lastInsertRowid, }); diff --git a/src/drivers/mysql/mysql-driver.ts b/src/drivers/mysql/mysql-driver.ts index 2ab51ef5..bf01efb8 100644 --- a/src/drivers/mysql/mysql-driver.ts +++ b/src/drivers/mysql/mysql-driver.ts @@ -5,6 +5,7 @@ import { DriverFlags, DatabaseSchemaItem, DatabaseTableColumn, + TableColumnDataType, } from "../base-driver"; import CommonSQLImplement from "../common-sql-imp"; import { escapeSqlValue } from "../sqlite/sql-helper"; @@ -53,6 +54,9 @@ export default abstract class MySQLLikeDriver extends CommonSQLImplement { mismatchDetection: false, supportCreateUpdateTable: false, dialect: "mysql", + + supportInsertReturning: false, + supportUpdateReturning: false, }; } @@ -153,4 +157,8 @@ export default abstract class MySQLLikeDriver extends CommonSQLImplement { createUpdateTableSchema(): string[] { throw new Error("Not implemented"); } + + inferTypeFromHeader(): TableColumnDataType | undefined { + return undefined; + } } diff --git a/src/drivers/query-builder.ts b/src/drivers/query-builder.ts index 34db1102..e6f82eab 100644 --- a/src/drivers/query-builder.ts +++ b/src/drivers/query-builder.ts @@ -64,12 +64,14 @@ export function insertInto( dialect: BaseDriver, schema: string, table: string, - value: Record + value: Record, + supportReturning: boolean ) { return [ "INSERT INTO", `${dialect.escapeId(schema)}.${dialect.escapeId(table)}`, generateInsertValue(dialect, value), + supportReturning ? "RETURNING *" : "", ].join(" "); } @@ -78,7 +80,8 @@ export function updateTable( schema: string, table: string, value: Record, - where: Record + where: Record, + supportReturning: boolean ): string { return [ "UPDATE", @@ -86,6 +89,7 @@ export function updateTable( "SET", generateSet(dialect, value), generateWhere(dialect, where), + supportReturning ? "RETURNING *" : "", ] .filter(Boolean) .join(" "); diff --git a/src/drivers/sqlite-base-driver.ts b/src/drivers/sqlite-base-driver.ts index 84d62c09..065ea806 100644 --- a/src/drivers/sqlite-base-driver.ts +++ b/src/drivers/sqlite-base-driver.ts @@ -9,8 +9,9 @@ import type { DatabaseValue, DriverFlags, SelectFromTableOptions, + TableColumnDataType, } from "./base-driver"; -import { escapeSqlValue } from "@/drivers/sqlite/sql-helper"; +import { convertSqliteType, escapeSqlValue } from "@/drivers/sqlite/sql-helper"; import { parseCreateTableScript } from "@/drivers/sqlite/sql-parse-table"; import { parseCreateTriggerScript } from "@/drivers/sqlite/sql-parse-trigger"; @@ -32,6 +33,8 @@ export abstract class SqliteLikeBaseDriver extends CommonSQLImplement { return { supportBigInt: false, supportModifyColumn: false, + supportInsertReturning: true, + supportUpdateReturning: true, defaultSchema: "main", optionalSchema: true, mismatchDetection: false, @@ -170,6 +173,13 @@ export abstract class SqliteLikeBaseDriver extends CommonSQLImplement { // do nothing } + inferTypeFromHeader( + header?: DatabaseTableColumn + ): TableColumnDataType | undefined { + if (!header) return undefined; + return convertSqliteType(header.type); + } + async tableSchema( schemaName: string, tableName: string diff --git a/src/drivers/sqljs-driver.ts b/src/drivers/sqljs-driver.ts index 950a682b..11e721e7 100644 --- a/src/drivers/sqljs-driver.ts +++ b/src/drivers/sqljs-driver.ts @@ -84,12 +84,6 @@ export default class SqljsDriver extends SqliteLikeBaseDriver { rowsWritten: null, queryDurationMs: endTime - startTime, }, - lastInsertRowid: - headers.length > 0 - ? undefined - : (this.db.exec("select last_insert_rowid();")[0].values[0][0] as - | number - | undefined), }; } diff --git a/src/drivers/starbase-driver.ts b/src/drivers/starbase-driver.ts new file mode 100644 index 00000000..ee5ba067 --- /dev/null +++ b/src/drivers/starbase-driver.ts @@ -0,0 +1,103 @@ +import { + DatabaseHeader, + DatabaseResultSet, + DatabaseRow, + TableColumnDataType, +} from "./base-driver"; +import { SqliteLikeBaseDriver } from "./sqlite-base-driver"; + +interface StarbaseResult { + columns: string[]; + rows: unknown[][]; + meta: { + rows_read: number; + rows_written: number; + }; +} + +interface StarbaseResponse { + result: StarbaseResult | StarbaseResult[]; +} + +function transformRawResult(raw: StarbaseResult): DatabaseResultSet { + const columns = raw.columns ?? []; + const values = raw.rows; + const headerSet = new Set(); + + const headers: DatabaseHeader[] = columns.map((colName) => { + let renameColName = colName; + + for (let i = 0; i < 20; i++) { + if (!headerSet.has(renameColName)) break; + renameColName = `__${colName}_${i}`; + } + + return { + name: renameColName, + displayName: colName, + originalType: "text", + type: TableColumnDataType.TEXT, + }; + }); + + const rows = values + ? values.map((r) => + headers.reduce((a, b, idx) => { + a[b.name] = r[idx]; + return a; + }, {} as DatabaseRow) + ) + : []; + + return { + rows, + stat: { + queryDurationMs: 0, + rowsAffected: 0, + rowsRead: raw.meta.rows_read, + rowsWritten: raw.meta.rows_written, + }, + headers, + }; +} + +export default class StarbaseDriver extends SqliteLikeBaseDriver { + supportPragmaList: boolean = false; + protected headers: Record = {}; + protected url: string; + + constructor(url: string, headers: Record) { + super(); + this.headers = headers; + this.url = url; + } + + async transaction(stmts: string[]): Promise { + const r = await fetch(this.url, { + method: "POST", + headers: { ...this.headers, "Content-Type": "application/json" }, + body: JSON.stringify({ + transaction: stmts.map((s) => ({ sql: s })), + }), + }); + + const json: StarbaseResponse = await r.json(); + return (Array.isArray(json.result) ? json.result : [json.result]).map( + transformRawResult + ); + } + + async query(stmt: string): Promise { + const r = await fetch(this.url, { + method: "POST", + headers: { ...this.headers, "Content-Type": "application/json" }, + body: JSON.stringify({ sql: stmt }), + }); + + const json: StarbaseResponse = await r.json(); + + return transformRawResult( + Array.isArray(json.result) ? json.result[0] : json.result + ); + } +}