From 23c9b509e67dc15186cb102c27693310adb11ff3 Mon Sep 17 00:00:00 2001 From: forehalo Date: Thu, 17 Oct 2024 19:32:44 +0800 Subject: [PATCH] feat(doc-storage): impl doc storages --- .eslintrc.js | 2 +- Cargo.lock | 1 + Cargo.toml | 1 + .../server/src/core/doc/storage/doc.ts | 29 + .../server/src/core/doc/storage/index.ts | 4 +- .../backend/server/src/core/sync/gateway.ts | 31 +- packages/common/doc-storage/package.json | 22 + .../common/doc-storage/src/impls/cloud/doc.ts | 257 +++++++++ .../doc-storage/src/impls/cloud/index.ts | 1 + .../common/doc-storage/src/impls/idb/db.ts | 32 ++ .../common/doc-storage/src/impls/idb/doc.ts | 167 ++++++ .../common/doc-storage/src/impls/idb/index.ts | 1 + .../doc-storage/src/impls/idb/schema.ts | 120 ++++ .../doc-storage/src/impls/sqlite/doc.ts | 111 ++++ .../doc-storage/src/impls/sqlite/index.ts | 1 + packages/common/doc-storage/src/index.ts | 1 + .../doc-storage/src/storage/connection.ts | 11 + .../common/doc-storage/src/storage/doc/doc.ts | 319 +++++++++++ .../doc-storage/src/storage/doc/index.ts | 2 + .../doc-storage/src/storage/doc/types.ts | 29 + .../common/doc-storage/src/storage/index.ts | 1 + .../common/doc-storage/src/storage/lock.ts | 42 ++ packages/common/doc-storage/tsconfig.json | 9 + packages/frontend/native/Cargo.toml | 1 + packages/frontend/native/build.rs | 4 +- packages/frontend/native/index.d.ts | 36 ++ packages/frontend/native/index.js | 1 + .../native/migrations/20240929082254_init.sql | 20 + .../native/src/sqlite/doc_storage/mod.rs | 122 ++++ .../native/src/sqlite/doc_storage/storage.rs | 406 ++++++++++++++ packages/frontend/native/src/sqlite/mod.rs | 520 +----------------- packages/frontend/native/src/sqlite/v1.rs | 518 +++++++++++++++++ tsconfig.json | 6 +- yarn.lock | 13 + 34 files changed, 2300 insertions(+), 541 deletions(-) create mode 100644 packages/common/doc-storage/package.json create mode 100644 packages/common/doc-storage/src/impls/cloud/doc.ts create mode 100644 packages/common/doc-storage/src/impls/cloud/index.ts create mode 100644 packages/common/doc-storage/src/impls/idb/db.ts create mode 100644 packages/common/doc-storage/src/impls/idb/doc.ts create mode 100644 packages/common/doc-storage/src/impls/idb/index.ts create mode 100644 packages/common/doc-storage/src/impls/idb/schema.ts create mode 100644 packages/common/doc-storage/src/impls/sqlite/doc.ts create mode 100644 packages/common/doc-storage/src/impls/sqlite/index.ts create mode 100644 packages/common/doc-storage/src/index.ts create mode 100644 packages/common/doc-storage/src/storage/connection.ts create mode 100644 packages/common/doc-storage/src/storage/doc/doc.ts create mode 100644 packages/common/doc-storage/src/storage/doc/index.ts create mode 100644 packages/common/doc-storage/src/storage/doc/types.ts create mode 100644 packages/common/doc-storage/src/storage/index.ts create mode 100644 packages/common/doc-storage/src/storage/lock.ts create mode 100644 packages/common/doc-storage/tsconfig.json create mode 100644 packages/frontend/native/migrations/20240929082254_init.sql create mode 100644 packages/frontend/native/src/sqlite/doc_storage/mod.rs create mode 100644 packages/frontend/native/src/sqlite/doc_storage/storage.rs create mode 100644 packages/frontend/native/src/sqlite/v1.rs diff --git a/.eslintrc.js b/.eslintrc.js index 8506d3f8e5388..fd02a843764ee 100644 --- a/.eslintrc.js +++ b/.eslintrc.js @@ -54,7 +54,7 @@ const allPackages = [ 'packages/common/debug', 'packages/common/env', 'packages/common/infra', - 'packages/common/theme', + 'packages/common/doc-storage', 'tools/cli', ]; diff --git a/Cargo.lock b/Cargo.lock index 1cb306aa5025a..d99e392a8f6ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -25,6 +25,7 @@ dependencies = [ "anyhow", "chrono", "dotenv", + "log", "napi", "napi-build", "napi-derive", diff --git a/Cargo.toml b/Cargo.toml index 8c080a528e6a7..f0853de1e40fa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ serde = "1" serde_json = "1" sha3 = "0.10" sqlx = { version = "0.8", default-features = false, features = ["chrono", "macros", "migrate", "runtime-tokio", "sqlite", "tls-rustls"] } +log = "0.4" tiktoken-rs = "0.5" tokio = "1.37" uuid = "1.8" diff --git a/packages/backend/server/src/core/doc/storage/doc.ts b/packages/backend/server/src/core/doc/storage/doc.ts index 6010855484812..4bcbf4067be68 100644 --- a/packages/backend/server/src/core/doc/storage/doc.ts +++ b/packages/backend/server/src/core/doc/storage/doc.ts @@ -1,8 +1,10 @@ import { applyUpdate, + diffUpdate, Doc, encodeStateAsUpdate, encodeStateVector, + encodeStateVectorFromUpdate, mergeUpdates, UndoManager, } from 'yjs'; @@ -19,6 +21,12 @@ export interface DocRecord { editor?: string; } +export interface DocDiff { + missing: Uint8Array; + state: Uint8Array; + timestamp: number; +} + export interface DocUpdate { bin: Uint8Array; timestamp: number; @@ -96,6 +104,27 @@ export abstract class DocStorageAdapter extends Connection { return snapshot; } + async getDocDiff( + spaceId: string, + docId: string, + stateVector?: Uint8Array + ): Promise { + const doc = await this.getDoc(spaceId, docId); + + if (!doc) { + return null; + } + + const missing = stateVector ? diffUpdate(doc.bin, stateVector) : doc.bin; + const state = encodeStateVectorFromUpdate(doc.bin); + + return { + missing, + state, + timestamp: doc.timestamp, + }; + } + abstract pushDocUpdates( spaceId: string, docId: string, diff --git a/packages/backend/server/src/core/doc/storage/index.ts b/packages/backend/server/src/core/doc/storage/index.ts index 6ba0e23dd1119..3e56264bd1ba5 100644 --- a/packages/backend/server/src/core/doc/storage/index.ts +++ b/packages/backend/server/src/core/doc/storage/index.ts @@ -1,4 +1,6 @@ -// TODO(@forehalo): share with frontend +// This is a totally copy of definitions in [@affine/doc-storage] +// because currently importing cross workspace package from [@affine/server] is not yet supported +// should be kept updated with the original definitions in [@affine/doc-storage] import type { BlobStorageAdapter } from './blob'; import { Connection } from './connection'; import type { DocStorageAdapter } from './doc'; diff --git a/packages/backend/server/src/core/sync/gateway.ts b/packages/backend/server/src/core/sync/gateway.ts index d5686c658fd42..b922e477bb6a3 100644 --- a/packages/backend/server/src/core/sync/gateway.ts +++ b/packages/backend/server/src/core/sync/gateway.ts @@ -8,7 +8,6 @@ import { WebSocketGateway, } from '@nestjs/websockets'; import { Socket } from 'socket.io'; -import { diffUpdate, encodeStateVectorFromUpdate } from 'yjs'; import { AlreadyInSpace, @@ -233,31 +232,25 @@ export class SpaceSyncGateway @MessageBody() { spaceType, spaceId, docId, stateVector }: LoadDocMessage ): Promise< - EventResponse<{ missing: string; state?: string; timestamp: number }> + EventResponse<{ missing: string; state: string; timestamp: number }> > { const adapter = this.selectAdapter(client, spaceType); adapter.assertIn(spaceId); - const doc = await adapter.get(spaceId, docId); + const doc = await adapter.diff( + spaceId, + docId, + stateVector ? Buffer.from(stateVector, 'base64') : undefined + ); if (!doc) { throw new DocNotFound({ spaceId, docId }); } - const missing = Buffer.from( - stateVector - ? diffUpdate(doc.bin, Buffer.from(stateVector, 'base64')) - : doc.bin - ).toString('base64'); - - const state = Buffer.from(encodeStateVectorFromUpdate(doc.bin)).toString( - 'base64' - ); - return { data: { - missing, - state, + missing: Buffer.from(doc.missing).toString('base64'), + state: Buffer.from(doc.state).toString('base64'), timestamp: doc.timestamp, }, }; @@ -600,9 +593,9 @@ abstract class SyncSocketAdapter { return this.storage.pushDocUpdates(spaceId, docId, updates, editorId); } - get(spaceId: string, docId: string) { + diff(spaceId: string, docId: string, stateVector?: Uint8Array) { this.assertIn(spaceId); - return this.storage.getDoc(spaceId, docId); + return this.storage.getDocDiff(spaceId, docId, stateVector); } getTimestamps(spaceId: string, timestamp?: number) { @@ -630,9 +623,9 @@ class WorkspaceSyncAdapter extends SyncSocketAdapter { return super.push(spaceId, id.guid, updates, editorId); } - override get(spaceId: string, docId: string) { + override diff(spaceId: string, docId: string, stateVector?: Uint8Array) { const id = new DocID(docId, spaceId); - return this.storage.getDoc(spaceId, id.guid); + return this.storage.getDocDiff(spaceId, id.guid, stateVector); } async assertAccessible( diff --git a/packages/common/doc-storage/package.json b/packages/common/doc-storage/package.json new file mode 100644 index 0000000000000..1a01a64f8805b --- /dev/null +++ b/packages/common/doc-storage/package.json @@ -0,0 +1,22 @@ +{ + "name": "@affine/doc-storage", + "type": "module", + "version": "0.15.0", + "private": true, + "sideEffects": false, + "exports": { + ".": "./index.ts", + "./impls/*": "./impls/*", + "./storage": "./storage/index.ts" + }, + "dependencies": { + "@affine/native": "workspace:*", + "idb": "^8.0.0", + "lodash-es": "^4.17.21", + "socket.io-client": "^4.7.5", + "yjs": "patch:yjs@npm%3A13.6.18#~/.yarn/patches/yjs-npm-13.6.18-ad0d5f7c43.patch" + }, + "devDependencies": { + "@types/lodash-es": "^4.17.12" + } +} diff --git a/packages/common/doc-storage/src/impls/cloud/doc.ts b/packages/common/doc-storage/src/impls/cloud/doc.ts new file mode 100644 index 0000000000000..232ddb8b81cbc --- /dev/null +++ b/packages/common/doc-storage/src/impls/cloud/doc.ts @@ -0,0 +1,257 @@ +import type { Socket } from 'socket.io-client'; + +import { DocStorage, type DocStorageOptions } from '../../storage'; + +// TODO(@forehalo): use [UserFriendlyError] +interface EventError { + name: string; + message: string; +} + +type WebsocketResponse = + | { + error: EventError; + } + | { + data: T; + }; + +interface ServerEvents { + 'space:broadcast-doc-updates': { + spaceType: string; + spaceId: string; + docId: string; + updates: string[]; + timestamp: number; + }; +} + +interface ClientEvents { + 'space:join': [ + { spaceType: string; spaceId: string; clientVersion: string }, + { clientId: string }, + ]; + 'space:leave': { spaceType: string; spaceId: string }; + 'space:push-doc-updates': [ + { spaceType: string; spaceId: string; docId: string; updates: string[] }, + { timestamp: number }, + ]; + 'space:load-doc-timestamps': [ + { + spaceType: string; + spaceId: string; + timestamp?: number; + }, + Record, + ]; + 'space:load-doc': [ + { + spaceType: string; + spaceId: string; + docId: string; + stateVector?: string; + }, + { + missing: string; + state: string; + timestamp: number; + }, + ]; +} + +type ServerEventsMap = { + [Key in keyof ServerEvents]: (data: ServerEvents[Key]) => void; +}; +type ClientEventsMap = { + [Key in keyof ClientEvents]: ClientEvents[Key] extends Array + ? ( + data: ClientEvents[Key][0], + ack: (res: WebsocketResponse) => void + ) => void + : (data: ClientEvents[Key]) => void; +}; + +interface CloudDocStorageOptions extends DocStorageOptions { + socket: Socket; +} + +export class CloudDocStorage extends DocStorage { + constructor(options: CloudDocStorageOptions) { + super(options); + } + + get name() { + // @ts-expect-error we need it + return this.options.socket.io.uri; + } + + get socket() { + return this.options.socket; + } + + override async connect(): Promise { + // the event will be polled, there is no need to wait for socket to be connected + await this.clientHandShake(); + this.socket.on('space:broadcast-doc-updates', this.onServerUpdates); + } + + private async clientHandShake() { + const res = await this.socket.emitWithAck('space:join', { + spaceType: this.spaceType, + spaceId: this.spaceId, + clientVersion: BUILD_CONFIG.appVersion, + }); + + if ('error' in res) { + // TODO(@forehalo): use [UserFriendlyError] + throw new Error(res.error.message); + } + } + + override async disconnect(): Promise { + this.socket.emit('space:leave', { + spaceType: this.spaceType, + spaceId: this.spaceId, + }); + this.socket.off('space:broadcast-doc-updates', this.onServerUpdates); + } + + onServerUpdates: ServerEventsMap['space:broadcast-doc-updates'] = message => { + if ( + this.spaceType === message.spaceType && + this.spaceId === message.spaceId + ) { + this.dispatchDocUpdatesListeners( + message.docId, + message.updates.map(base64ToUint8Array), + message.timestamp + ); + } + }; + + override async getDocSnapshot(docId: string) { + const response = await this.socket.emitWithAck('space:load-doc', { + spaceType: this.spaceType, + spaceId: this.spaceId, + docId, + }); + + if ('error' in response) { + // TODO: use [UserFriendlyError] + throw new Error(response.error.message); + } + + return { + spaceId: this.spaceId, + docId, + bin: base64ToUint8Array(response.data.missing), + timestamp: response.data.timestamp, + }; + } + + override async getDocDiff(docId: string, stateVector?: Uint8Array) { + const response = await this.socket.emitWithAck('space:load-doc', { + spaceType: this.spaceType, + spaceId: this.spaceId, + docId, + stateVector: stateVector ? await uint8ArrayToBase64(stateVector) : void 0, + }); + + if ('error' in response) { + // TODO: use [UserFriendlyError] + throw new Error(response.error.message); + } + + return { + missing: base64ToUint8Array(response.data.missing), + state: base64ToUint8Array(response.data.state), + timestamp: response.data.timestamp, + }; + } + + async pushDocUpdates(docId: string, updates: Uint8Array[]): Promise { + const response = await this.socket.emitWithAck('space:push-doc-updates', { + spaceType: this.spaceType, + spaceId: this.spaceId, + docId, + updates: await Promise.all(updates.map(uint8ArrayToBase64)), + }); + + if ('error' in response) { + // TODO(@forehalo): use [UserFriendlyError] + throw new Error(response.error.message); + } + + return response.data.timestamp; + } + + async getSpaceDocTimestamps( + after?: number + ): Promise | null> { + const response = await this.socket.emitWithAck( + 'space:load-doc-timestamps', + { + spaceType: this.spaceType, + spaceId: this.spaceId, + timestamp: after, + } + ); + + if ('error' in response) { + // TODO(@forehalo): use [UserFriendlyError] + throw new Error(response.error.message); + } + + return response.data; + } + + async deleteDoc(): Promise {} + async deleteSpace(): Promise {} + protected async setDocSnapshot() { + return false; + } + protected async getDocUpdates() { + return []; + } + protected async markUpdatesMerged() { + return 0; + } + override async listDocHistories() { + return []; + } + override async getDocHistory() { + return null; + } + protected override async createDocHistory() { + return false; + } +} + +export function uint8ArrayToBase64(array: Uint8Array): Promise { + return new Promise(resolve => { + // Create a blob from the Uint8Array + const blob = new Blob([array]); + + const reader = new FileReader(); + reader.onload = function () { + const dataUrl = reader.result as string | null; + if (!dataUrl) { + resolve(''); + return; + } + // The result includes the `data:` URL prefix and the MIME type. We only want the Base64 data + const base64 = dataUrl.split(',')[1]; + resolve(base64); + }; + + reader.readAsDataURL(blob); + }); +} + +export function base64ToUint8Array(base64: string) { + const binaryString = atob(base64); + const binaryArray = binaryString.split('').map(function (char) { + return char.charCodeAt(0); + }); + return new Uint8Array(binaryArray); +} diff --git a/packages/common/doc-storage/src/impls/cloud/index.ts b/packages/common/doc-storage/src/impls/cloud/index.ts new file mode 100644 index 0000000000000..f42f6dd754a91 --- /dev/null +++ b/packages/common/doc-storage/src/impls/cloud/index.ts @@ -0,0 +1 @@ +export * from './doc'; diff --git a/packages/common/doc-storage/src/impls/idb/db.ts b/packages/common/doc-storage/src/impls/idb/db.ts new file mode 100644 index 0000000000000..7790ab320056b --- /dev/null +++ b/packages/common/doc-storage/src/impls/idb/db.ts @@ -0,0 +1,32 @@ +import { type IDBPDatabase, openDB } from 'idb'; + +import { type DocStorageSchema, latestVersion, migrate } from './schema'; + +export type SpaceIDB = IDBPDatabase; + +export class SpaceIndexedDbManager { + private static db: SpaceIDB | null = null; + + static async open(name: string) { + if (this.db) { + return this.db; + } + + const blocking = () => { + // notify user this connection is blocking other tabs to upgrade db + this.db?.close(); + }; + + const blocked = () => { + // notify user there is tab opened with old version, close it first + }; + + this.db = await openDB(name, latestVersion, { + upgrade: migrate, + blocking, + blocked, + }); + + return this.db; + } +} diff --git a/packages/common/doc-storage/src/impls/idb/doc.ts b/packages/common/doc-storage/src/impls/idb/doc.ts new file mode 100644 index 0000000000000..caeeac8093cf2 --- /dev/null +++ b/packages/common/doc-storage/src/impls/idb/doc.ts @@ -0,0 +1,167 @@ +import { type DocRecord, DocStorage, type DocUpdate } from '../../storage'; +import { type SpaceIDB, SpaceIndexedDbManager } from './db'; + +export class IndexedDBDocStorage extends DocStorage { + private db!: SpaceIDB; + + get name() { + return 'idb'; + } + + override async connect(): Promise { + this.db = await SpaceIndexedDbManager.open( + `${this.spaceType}:${this.spaceId}` + ); + } + + override async disconnect(): Promise { + this.db.close(); + } + + override async pushDocUpdates( + docId: string, + updates: Uint8Array[] + ): Promise { + if (!updates.length) { + return 0; + } + + const trx = this.db.transaction(['updates', 'clocks'], 'readwrite'); + + const timestamp = Date.now(); + await Promise.all( + updates.map(async (update, i) => { + await trx.objectStore('updates').add({ + docId, + bin: update, + createdAt: timestamp + i, + }); + }) + ); + + await trx + .objectStore('clocks') + .put({ docId, timestamp: timestamp + updates.length - 1 }); + trx.commit(); + + return updates.length; + } + + protected override async getDocSnapshot( + docId: string + ): Promise { + const trx = this.db.transaction('snapshots', 'readonly'); + const record = await trx.store.get(docId); + trx.commit(); + + if (!record) { + return null; + } + + return { + spaceId: this.spaceId, + docId, + bin: record.bin, + timestamp: record.updatedAt, + }; + } + + override async deleteDoc(docId: string): Promise { + const trx = this.db.transaction( + ['snapshots', 'updates', 'clocks'], + 'readwrite' + ); + + const idx = trx.objectStore('updates').index('docId'); + const iter = idx.iterate(IDBKeyRange.only(docId)); + + for await (const { value } of iter) { + await trx.objectStore('updates').delete([value.docId, value.createdAt]); + } + + await trx.objectStore('snapshots').delete(docId); + await trx.objectStore('clocks').delete(docId); + trx.commit(); + } + + override async deleteSpace(): Promise { + for (const name of this.db.objectStoreNames) { + await this.db.clear(name); + } + } + + override async getSpaceDocTimestamps( + after: number = 0 + ): Promise> { + const trx = this.db.transaction('clocks', 'readonly'); + const record: Record = {}; + + const iter = trx.store.iterate(IDBKeyRange.lowerBound(after)); + + for await (const { value } of iter) { + record[value.docId] = value.timestamp; + } + + trx.commit(); + return record; + } + + protected override async setDocSnapshot( + snapshot: DocRecord + ): Promise { + const trx = this.db.transaction('snapshots', 'readwrite'); + const record = await trx.store.get(snapshot.docId); + + if (record && record.updatedAt < snapshot.timestamp) { + await trx.store.put({ + docId: snapshot.docId, + bin: snapshot.bin, + createdAt: record?.createdAt ?? snapshot.timestamp, + updatedAt: snapshot.timestamp, + }); + } + + trx.commit(); + return true; + } + + protected override async getDocUpdates(docId: string): Promise { + const trx = this.db.transaction('updates', 'readonly'); + const updates = await trx.store.index('docId').getAll(docId); + + trx.commit(); + + return updates.map(update => ({ + bin: update.bin, + timestamp: update.createdAt, + })); + } + + protected override async markUpdatesMerged( + docId: string, + updates: DocUpdate[] + ): Promise { + const trx = this.db.transaction('updates', 'readwrite'); + + await Promise.all( + updates.map(update => trx.store.delete([docId, update.timestamp])) + ); + + trx.commit(); + return updates.length; + } + + // history is not supported by idb yet + override listDocHistories() { + return Promise.resolve([]); + } + override getDocHistory() { + return Promise.resolve(null); + } + override rollbackDoc() { + return Promise.resolve(); + } + protected override createDocHistory() { + return Promise.resolve(false); + } +} diff --git a/packages/common/doc-storage/src/impls/idb/index.ts b/packages/common/doc-storage/src/impls/idb/index.ts new file mode 100644 index 0000000000000..f42f6dd754a91 --- /dev/null +++ b/packages/common/doc-storage/src/impls/idb/index.ts @@ -0,0 +1 @@ +export * from './doc'; diff --git a/packages/common/doc-storage/src/impls/idb/schema.ts b/packages/common/doc-storage/src/impls/idb/schema.ts new file mode 100644 index 0000000000000..075e8018ca98a --- /dev/null +++ b/packages/common/doc-storage/src/impls/idb/schema.ts @@ -0,0 +1,120 @@ +import type { DBSchema, OpenDBCallbacks } from 'idb'; +/** +IndexedDB + > DB(workspace:${workspaceId}) + > Table(Snapshots) + > Table(Updates) + > Table(...) + +Table(Snapshots) +|docId|blob|createdAt|updatedAt| +|-----|----|---------|---------| +| str | bin| Date | Date | + +Table(Updates) +| id |docId|blob|createdAt| +|----|-----|----|---------| +|auto| str | bin| Date | + +Table(Clocks) +| docId | clock | +|-------|--------| +| str | number | + */ +export interface DocStorageSchema extends DBSchema { + snapshots: { + key: string; + value: { + docId: string; + bin: Uint8Array; + createdAt: number; + updatedAt: number; + }; + indexes: { + updatedAt: number; + }; + }; + updates: { + key: [string, number]; + value: { + docId: string; + bin: Uint8Array; + createdAt: number; + }; + indexes: { + docId: string; + }; + }; + clocks: { + key: string; + value: { + docId: string; + timestamp: number; + }; + indexes: { + timestamp: number; + }; + }; + peerClocks: { + key: [string, string]; + value: { + docId: string; + peerId: string; + clock: number; + }; + indexes: { + clock: number; + }; + }; +} + +export const migrate: OpenDBCallbacks['upgrade'] = ( + db, + oldVersion, + _newVersion, + trx +) => { + if (!oldVersion) { + oldVersion = 0; + } + + for (let i = oldVersion; i < migrations.length; i++) { + migrations[i](db, trx); + } +}; + +type MigrateParameters = Parameters< + NonNullable['upgrade']> +>; +type Migrate = (db: MigrateParameters[0], trx: MigrateParameters[3]) => void; + +// START REGION: migrations +const init: Migrate = db => { + const snapshots = db.createObjectStore('snapshots', { + keyPath: 'docId', + autoIncrement: false, + }); + + snapshots.createIndex('updatedAt', 'updatedAt', { unique: false }); + + const updates = db.createObjectStore('updates', { + keyPath: ['docId', 'createdAt'], + autoIncrement: false, + }); + + updates.createIndex('docId', 'docId', { unique: false }); + + const clocks = db.createObjectStore('clocks', { + keyPath: 'docId', + autoIncrement: false, + }); + + clocks.createIndex('timestamp', 'timestamp', { unique: false }); +}; +// END REGION + +// 1. all schema changed should be put in migrations +// 2.order matters +const migrations: Migrate[] = [init]; + +export const latestVersion = migrations.length; diff --git a/packages/common/doc-storage/src/impls/sqlite/doc.ts b/packages/common/doc-storage/src/impls/sqlite/doc.ts new file mode 100644 index 0000000000000..570f137325318 --- /dev/null +++ b/packages/common/doc-storage/src/impls/sqlite/doc.ts @@ -0,0 +1,111 @@ +import type { DocStorage as NativeDocStorage } from '@affine/native'; + +import { + type DocRecord, + DocStorage, + type DocStorageOptions, + type DocUpdate, +} from '../../storage'; + +interface SqliteDocStorageOptions extends DocStorageOptions { + db: NativeDocStorage; +} + +export class SqliteDocStorage extends DocStorage { + get name() { + return 'sqlite'; + } + + get db() { + return this.options.db; + } + + constructor(options: SqliteDocStorageOptions) { + super(options); + } + + override pushDocUpdates( + docId: string, + updates: Uint8Array[] + ): Promise { + return this.db.pushUpdates(docId, updates); + } + + override deleteDoc(docId: string): Promise { + return this.db.deleteDoc(docId); + } + + override async deleteSpace(): Promise { + await this.disconnect(); + // rm this.dbPath + } + + override async getSpaceDocTimestamps( + after?: number + ): Promise | null> { + const clocks = await this.db.getDocClocks(after); + + return clocks.reduce( + (ret, cur) => { + ret[cur.docId] = cur.timestamp.getTime(); + return ret; + }, + {} as Record + ); + } + + protected override async getDocSnapshot( + docId: string + ): Promise { + const snapshot = await this.db.getDocSnapshot(docId); + + if (!snapshot) { + return null; + } + + return { + spaceId: this.spaceId, + docId, + bin: snapshot.data, + timestamp: snapshot.timestamp.getTime(), + }; + } + + protected override setDocSnapshot(snapshot: DocRecord): Promise { + return this.db.setDocSnapshot({ + docId: snapshot.docId, + data: Buffer.from(snapshot.bin), + timestamp: new Date(snapshot.timestamp), + }); + } + + protected override async getDocUpdates(docId: string): Promise { + return this.db.getDocUpdates(docId).then(updates => + updates.map(update => ({ + bin: update.data, + timestamp: update.createdAt.getTime(), + })) + ); + } + + protected override markUpdatesMerged( + docId: string, + updates: DocUpdate[] + ): Promise { + return this.db.markUpdatesMerged( + docId, + updates.map(update => new Date(update.timestamp)) + ); + } + + override async listDocHistories() { + return []; + } + override async getDocHistory() { + return null; + } + + protected override async createDocHistory(): Promise { + return false; + } +} diff --git a/packages/common/doc-storage/src/impls/sqlite/index.ts b/packages/common/doc-storage/src/impls/sqlite/index.ts new file mode 100644 index 0000000000000..f42f6dd754a91 --- /dev/null +++ b/packages/common/doc-storage/src/impls/sqlite/index.ts @@ -0,0 +1 @@ +export * from './doc'; diff --git a/packages/common/doc-storage/src/index.ts b/packages/common/doc-storage/src/index.ts new file mode 100644 index 0000000000000..85674ee7cdb78 --- /dev/null +++ b/packages/common/doc-storage/src/index.ts @@ -0,0 +1 @@ +export * from './storage'; diff --git a/packages/common/doc-storage/src/storage/connection.ts b/packages/common/doc-storage/src/storage/connection.ts new file mode 100644 index 0000000000000..f82a72fbd3931 --- /dev/null +++ b/packages/common/doc-storage/src/storage/connection.ts @@ -0,0 +1,11 @@ +export class Connection { + protected connected: boolean = false; + connect(): Promise { + this.connected = true; + return Promise.resolve(); + } + disconnect(): Promise { + this.connected = false; + return Promise.resolve(); + } +} diff --git a/packages/common/doc-storage/src/storage/doc/doc.ts b/packages/common/doc-storage/src/storage/doc/doc.ts new file mode 100644 index 0000000000000..b6e22a92b7287 --- /dev/null +++ b/packages/common/doc-storage/src/storage/doc/doc.ts @@ -0,0 +1,319 @@ +import { + applyUpdate, + diffUpdate, + Doc, + encodeStateAsUpdate, + encodeStateVector, + encodeStateVectorFromUpdate, + mergeUpdates, + UndoManager, +} from 'yjs'; + +import { Connection } from '../connection'; +import { SingletonLocker } from '../lock'; +import type { + DocDiff, + DocRecord, + DocUpdate, + Editor, + HistoryFilter, +} from './types'; + +export type SpaceType = 'workspace' | 'userspace'; +export interface DocStorageOptions { + spaceType: string; + spaceId: string; + mergeUpdates?: (updates: Uint8Array[]) => Promise | Uint8Array; +} + +export abstract class DocStorage< + Opts extends DocStorageOptions = DocStorageOptions, +> extends Connection { + abstract get name(): string; + + public readonly options: Opts; + private readonly locker = new SingletonLocker(); + protected readonly updatesListeners = new Set< + (docId: string, updates: Uint8Array[], timestamp: number) => void + >(); + + get spaceType() { + return this.options.spaceType; + } + + get spaceId() { + return this.options.spaceId; + } + + constructor(options: Opts) { + super(); + this.options = options; + } + + // REGION: open apis + /** + * Tell a binary is empty yjs binary or not. + * + * NOTE: + * `[0, 0]` is empty yjs update binary + * `[0]` is empty yjs state vector binary + */ + isEmptyBin(bin: Uint8Array): boolean { + return ( + bin.length === 0 || + // 0x0 for state vector + (bin.length === 1 && bin[0] === 0) || + // 0x00 for update + (bin.length === 2 && bin[0] === 0 && bin[1] === 0) + ); + } + + /** + * Get a doc record with latest binary. + */ + async getDoc(docId: string): Promise { + await using _lock = await this.lockDocForUpdate(docId); + + const snapshot = await this.getDocSnapshot(docId); + const updates = await this.getDocUpdates(docId); + + if (updates.length) { + const { timestamp, bin, editor } = await this.squash( + snapshot ? [snapshot, ...updates] : updates + ); + + const newSnapshot = { + spaceId: this.spaceId, + docId, + bin, + timestamp, + editor, + }; + + const success = await this.setDocSnapshot(newSnapshot); + // if there is old snapshot, create a new history record + if (success && snapshot) { + await this.createDocHistory(snapshot); + } + + // always mark updates as merged unless throws + await this.markUpdatesMerged(docId, updates); + + return newSnapshot; + } + + return snapshot; + } + + /** + * Get a yjs binary diff with the given state vector. + */ + async getDocDiff( + docId: string, + stateVector?: Uint8Array + ): Promise { + const doc = await this.getDoc(docId); + + if (!doc) { + return null; + } + + const missing = stateVector ? diffUpdate(doc.bin, stateVector) : doc.bin; + const state = encodeStateVectorFromUpdate(doc.bin); + + return { + missing, + state, + timestamp: doc.timestamp, + }; + } + + /** + * Push updates into storage + */ + abstract pushDocUpdates( + docId: string, + updates: Uint8Array[], + editor?: string + ): Promise; + + /** + * Listen to doc updates pushed event + */ + onReceiveDocUpdates( + listener: (docId: string, updates: Uint8Array[], timestamp: number) => void + ): () => void { + this.updatesListeners.add(listener); + + return () => { + this.updatesListeners.delete(listener); + }; + } + + /** + * Delete a specific doc data with all snapshots and updates + */ + abstract deleteDoc(docId: string): Promise; + /** + * Delete the whole space data with all docs + */ + abstract deleteSpace(): Promise; + + /** + * Get all docs timestamps info. especially for useful in sync process. + */ + abstract getSpaceDocTimestamps( + after?: number + ): Promise | null>; + + /** + * Rollback the doc in a update patch way using [yjs.UndoManager]. + */ + async rollbackDoc( + docId: string, + timestamp: number, + editor?: string + ): Promise { + await using _lock = await this.lockDocForUpdate(docId); + const toSnapshot = await this.getDocHistory(docId, timestamp); + if (!toSnapshot) { + throw new Error('Can not find the version to rollback to.'); + } + + const fromSnapshot = await this.getDocSnapshot(docId); + + if (!fromSnapshot) { + throw new Error('Can not find the current version of the doc.'); + } + + const change = this.generateChangeUpdate(fromSnapshot.bin, toSnapshot.bin); + await this.pushDocUpdates(docId, [change], editor); + // force create a new history record after rollback + await this.createDocHistory(fromSnapshot, true); + } + + /** + * List all history snapshot of a doc. + */ + abstract listDocHistories( + docId: string, + query: HistoryFilter + ): Promise<{ timestamp: number; editor: Editor | null }[]>; + + /** + * Get a history snapshot of a doc. + */ + abstract getDocHistory( + docId: string, + timestamp: number + ): Promise; + + // ENDREGION + + // REGION: api for internal usage + protected dispatchDocUpdatesListeners( + docId: string, + updates: Uint8Array[], + timestamp: number + ): void { + this.updatesListeners.forEach(cb => { + cb(docId, updates, timestamp); + }); + } + + /** + * Get a doc snapshot from storage + */ + protected abstract getDocSnapshot(docId: string): Promise; + /** + * Set the doc snapshot into storage + * + * @safety + * be careful when implementing this method. + * + * It might be called with outdated snapshot when running in multi-thread environment. + * + * A common solution is update the snapshot record is DB only when the coming one's timestamp is newer. + * + * @example + * ```ts + * await using _lock = await this.lockDocForUpdate(docId); + * // set snapshot + * + * ``` + */ + protected abstract setDocSnapshot(snapshot: DocRecord): Promise; + /** + * Get all updates of a doc that haven't been merged into snapshot. + * + * Updates queue design exists for a performace concern: + * A huge amount of write time will be saved if we don't merge updates into snapshot immediately. + * Updates will be merged into snapshot when the latest doc is requested. + */ + protected abstract getDocUpdates(docId: string): Promise; + + /** + * Mark updates as merged into snapshot. + */ + protected abstract markUpdatesMerged( + docId: string, + updates: DocUpdate[] + ): Promise; + + /** + * Create a new history record for a doc. + * Will always be called after the doc snapshot is updated. + */ + protected abstract createDocHistory( + snapshot: DocRecord, + force?: boolean + ): Promise; + + /** + * Merge doc updates into a single update. + */ + protected async squash(updates: DocUpdate[]): Promise { + const merge = this.options?.mergeUpdates ?? mergeUpdates; + const lastUpdate = updates.at(-1); + if (!lastUpdate) { + throw new Error('No updates to be squashed.'); + } + + // fast return + if (updates.length === 1) { + return lastUpdate; + } + + const finalUpdate = await merge(updates.map(u => u.bin)); + + return { + bin: finalUpdate, + timestamp: lastUpdate.timestamp, + editor: lastUpdate.editor, + }; + } + + protected generateChangeUpdate(newerBin: Uint8Array, olderBin: Uint8Array) { + const newerDoc = new Doc(); + applyUpdate(newerDoc, newerBin); + const olderDoc = new Doc(); + applyUpdate(olderDoc, olderBin); + + const newerState = encodeStateVector(newerDoc); + const olderState = encodeStateVector(olderDoc); + + const diff = encodeStateAsUpdate(newerDoc, olderState); + + const undoManager = new UndoManager(Array.from(newerDoc.share.values())); + + applyUpdate(olderDoc, diff); + + undoManager.undo(); + + return encodeStateAsUpdate(olderDoc, newerState); + } + + protected async lockDocForUpdate(docId: string): Promise { + return this.locker.lock(`workspace:${this.spaceId}:update`, docId); + } +} diff --git a/packages/common/doc-storage/src/storage/doc/index.ts b/packages/common/doc-storage/src/storage/doc/index.ts new file mode 100644 index 0000000000000..1fa447519627c --- /dev/null +++ b/packages/common/doc-storage/src/storage/doc/index.ts @@ -0,0 +1,2 @@ +export * from './doc'; +export * from './types'; diff --git a/packages/common/doc-storage/src/storage/doc/types.ts b/packages/common/doc-storage/src/storage/doc/types.ts new file mode 100644 index 0000000000000..366eb280e4a25 --- /dev/null +++ b/packages/common/doc-storage/src/storage/doc/types.ts @@ -0,0 +1,29 @@ +export interface DocRecord { + spaceId: string; + docId: string; + bin: Uint8Array; + timestamp: number; + editor?: string; +} + +export interface DocDiff { + missing: Uint8Array; + state: Uint8Array; + timestamp: number; +} + +export interface DocUpdate { + bin: Uint8Array; + timestamp: number; + editor?: string; +} + +export interface HistoryFilter { + before?: number; + limit?: number; +} + +export interface Editor { + name: string; + avatarUrl: string | null; +} diff --git a/packages/common/doc-storage/src/storage/index.ts b/packages/common/doc-storage/src/storage/index.ts new file mode 100644 index 0000000000000..f42f6dd754a91 --- /dev/null +++ b/packages/common/doc-storage/src/storage/index.ts @@ -0,0 +1 @@ +export * from './doc'; diff --git a/packages/common/doc-storage/src/storage/lock.ts b/packages/common/doc-storage/src/storage/lock.ts new file mode 100644 index 0000000000000..c4fcf45f3e56d --- /dev/null +++ b/packages/common/doc-storage/src/storage/lock.ts @@ -0,0 +1,42 @@ +export interface Locker { + lock(domain: string, resource: string): Promise; +} + +export class SingletonLocker implements Locker { + lockedResource = new Map(); + constructor() {} + + async lock(domain: string, resource: string) { + let lock = this.lockedResource.get(`${domain}:${resource}`); + + if (!lock) { + lock = new Lock(); + } + + await lock.acquire(); + + return lock; + } +} + +export class Lock { + private inner: Promise = Promise.resolve(); + private release: () => void = () => {}; + + async acquire() { + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + let release: () => void = null!; + const nextLock = new Promise(resolve => { + release = resolve; + }); + + await this.inner; + this.inner = nextLock; + this.release = release; + } + + [Symbol.asyncDispose]() { + this.release(); + return Promise.resolve(); + } +} diff --git a/packages/common/doc-storage/tsconfig.json b/packages/common/doc-storage/tsconfig.json new file mode 100644 index 0000000000000..4bbd8d0b79dd9 --- /dev/null +++ b/packages/common/doc-storage/tsconfig.json @@ -0,0 +1,9 @@ +{ + "extends": "../../../tsconfig.json", + "include": ["./src"], + "compilerOptions": { + "composite": true, + "noEmit": false, + "outDir": "lib" + } +} diff --git a/packages/frontend/native/Cargo.toml b/packages/frontend/native/Cargo.toml index 204e9be6de17d..4c210c767d1a5 100644 --- a/packages/frontend/native/Cargo.toml +++ b/packages/frontend/native/Cargo.toml @@ -20,6 +20,7 @@ serde = { workspace = true } serde_json = { workspace = true } sha3 = { workspace = true } sqlx = { workspace = true, default-features = false, features = ["chrono", "macros", "migrate", "runtime-tokio", "sqlite", "tls-rustls"] } +log = { workspace = true } tokio = { workspace = true, features = ["full"] } uuid = { workspace = true, features = ["fast-rng", "serde", "v4"] } diff --git a/packages/frontend/native/build.rs b/packages/frontend/native/build.rs index 3a7a6c97b3b76..e6a55472c22c6 100644 --- a/packages/frontend/native/build.rs +++ b/packages/frontend/native/build.rs @@ -1,4 +1,4 @@ -use sqlx::sqlite::SqliteConnectOptions; +use sqlx::{migrate, sqlite::SqliteConnectOptions}; use std::fs; #[tokio::main] @@ -29,5 +29,7 @@ async fn main() -> Result<(), std::io::Error> { .execute(&pool) .await .unwrap(); + + migrate!().run(&pool).await.unwrap(); Ok(()) } diff --git a/packages/frontend/native/index.d.ts b/packages/frontend/native/index.d.ts index c3611af9bfb2e..ff6f82e16659c 100644 --- a/packages/frontend/native/index.d.ts +++ b/packages/frontend/native/index.d.ts @@ -1,5 +1,24 @@ /* auto-generated by NAPI-RS */ /* eslint-disable */ +export declare class DocStorage { + constructor(path: string) + connect(): Promise + close(): Promise + get isClosed(): Promise + pushUpdates(docId: string, updates: Array): Promise + getDocSnapshot(docId: string): Promise + setDocSnapshot(snapshot: DocRecord): Promise + getDocUpdates(docId: string): Promise> + markUpdatesMerged(docId: string, updates: Array): Promise + deleteDoc(docId: string): Promise + getDocClocks(after?: number | undefined | null): Promise> + /** + * Flush the WAL file to the database file. + * See https://www.sqlite.org/pragma.html#pragma_wal_checkpoint:~:text=PRAGMA%20schema.wal_checkpoint%3B + */ + checkpoint(): Promise +} + export declare class SqliteConnection { constructor(path: string) connect(): Promise @@ -43,6 +62,23 @@ export interface BlobRow { timestamp: Date } +export interface DocClock { + docId: string + timestamp: Date +} + +export interface DocRecord { + docId: string + data: Buffer + timestamp: Date +} + +export interface DocUpdate { + docId: string + createdAt: Date + data: Buffer +} + export interface InsertRow { docId?: string data: Uint8Array diff --git a/packages/frontend/native/index.js b/packages/frontend/native/index.js index c6c15fe67f888..69959813b9163 100644 --- a/packages/frontend/native/index.js +++ b/packages/frontend/native/index.js @@ -361,6 +361,7 @@ if (!nativeBinding) { throw new Error(`Failed to load native binding`) } +module.exports.DocStorage = nativeBinding.DocStorage module.exports.SqliteConnection = nativeBinding.SqliteConnection module.exports.mintChallengeResponse = nativeBinding.mintChallengeResponse module.exports.ValidationResult = nativeBinding.ValidationResult diff --git a/packages/frontend/native/migrations/20240929082254_init.sql b/packages/frontend/native/migrations/20240929082254_init.sql new file mode 100644 index 0000000000000..054f51b4abc04 --- /dev/null +++ b/packages/frontend/native/migrations/20240929082254_init.sql @@ -0,0 +1,20 @@ +CREATE TABLE "v2_snapshots" ( + doc_id TEXT PRIMARY KEY NOT NULL, + data BLOB NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, + updated_at TIMESTAMP NOT NULL +); +CREATE INDEX snapshots_doc_id ON v2_snapshots(doc_id); + +CREATE TABLE "v2_updates" ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + doc_id TEXT NOT NULL, + data BLOB NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL +); +CREATE INDEX updates_doc_id ON v2_updates (doc_id); + +CREATE TABLE "v2_clocks" ( + doc_id TEXT PRIMARY KEY NOT NULL, + timestamp TIMESTAMP NOT NULL +) \ No newline at end of file diff --git a/packages/frontend/native/src/sqlite/doc_storage/mod.rs b/packages/frontend/native/src/sqlite/doc_storage/mod.rs new file mode 100644 index 0000000000000..7db36e95462e0 --- /dev/null +++ b/packages/frontend/native/src/sqlite/doc_storage/mod.rs @@ -0,0 +1,122 @@ +mod storage; + +use chrono::NaiveDateTime; +use napi::bindgen_prelude::{Buffer, Uint8Array}; +use napi_derive::napi; + +fn map_err(err: sqlx::Error) -> napi::Error { + napi::Error::from(anyhow::Error::from(err)) +} + +#[napi(object)] +pub struct DocUpdate { + pub doc_id: String, + pub created_at: NaiveDateTime, + pub data: Buffer, +} + +#[napi(object)] +pub struct DocRecord { + pub doc_id: String, + pub data: Buffer, + pub timestamp: NaiveDateTime, +} + +#[napi(object)] +pub struct DocClock { + pub doc_id: String, + pub timestamp: NaiveDateTime, +} + +#[napi] +pub struct DocStorage { + storage: storage::SqliteDocStorage, +} + +#[napi] +impl DocStorage { + #[napi(constructor, async_runtime)] + pub fn new(path: String) -> napi::Result { + Ok(Self { + storage: storage::SqliteDocStorage::new(path), + }) + } + + #[napi] + pub async fn connect(&self) -> napi::Result<()> { + self.storage.connect().await.map_err(map_err) + } + + #[napi] + pub async fn close(&self) -> napi::Result<()> { + self.storage.close().await; + + Ok(()) + } + + #[napi(getter)] + pub async fn is_closed(&self) -> napi::Result { + Ok(self.storage.is_closed()) + } + + #[napi] + pub async fn push_updates(&self, doc_id: String, updates: Vec) -> napi::Result { + let updates = updates.iter().map(|u| u.as_ref()).collect::>(); + self + .storage + .push_updates(doc_id, updates) + .await + .map_err(map_err) + } + + #[napi] + pub async fn get_doc_snapshot(&self, doc_id: String) -> napi::Result> { + self.storage.get_doc_snapshot(doc_id).await.map_err(map_err) + } + + #[napi] + pub async fn set_doc_snapshot(&self, snapshot: DocRecord) -> napi::Result { + self + .storage + .set_doc_snapshot(snapshot) + .await + .map_err(map_err) + } + + #[napi] + pub async fn get_doc_updates(&self, doc_id: String) -> napi::Result> { + self.storage.get_doc_updates(doc_id).await.map_err(map_err) + } + + #[napi] + pub async fn mark_updates_merged( + &self, + doc_id: String, + updates: Vec, + ) -> napi::Result { + self + .storage + .mark_updates_merged(doc_id, updates) + .await + .map_err(map_err) + } + + #[napi] + pub async fn delete_doc(&self, doc_id: String) -> napi::Result<()> { + self.storage.delete_doc(doc_id).await.map_err(map_err) + } + + #[napi] + pub async fn get_doc_clocks(&self, after: Option) -> napi::Result> { + self.storage.get_doc_clocks(after).await.map_err(map_err) + } + + /** + * Flush the WAL file to the database file. + * See https://www.sqlite.org/pragma.html#pragma_wal_checkpoint:~:text=PRAGMA%20schema.wal_checkpoint%3B + */ + #[napi] + pub async fn checkpoint(&self) -> napi::Result<()> { + self.storage.checkpoint().await.map_err(map_err) + } +} diff --git a/packages/frontend/native/src/sqlite/doc_storage/storage.rs b/packages/frontend/native/src/sqlite/doc_storage/storage.rs new file mode 100644 index 0000000000000..c56f17af20313 --- /dev/null +++ b/packages/frontend/native/src/sqlite/doc_storage/storage.rs @@ -0,0 +1,406 @@ +use chrono::{DateTime, NaiveDateTime}; +use sqlx::{ + migrate::{MigrateDatabase, Migrator}, + sqlite::{Sqlite, SqliteConnectOptions, SqlitePoolOptions}, + ConnectOptions, Pool, QueryBuilder, Row, +}; + +use super::{DocClock, DocRecord, DocUpdate}; + +type Result = std::result::Result; + +pub struct SqliteDocStorage { + pool: Pool, + path: String, +} + +impl SqliteDocStorage { + pub fn new(path: String) -> Self { + let sqlite_options = SqliteConnectOptions::new() + .filename(&path) + .foreign_keys(false) + .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal) + .log_statements(log::LevelFilter::Trace); + + let mut pool_options = SqlitePoolOptions::new(); + + if cfg!(test) && path == ":memory:" { + pool_options = pool_options + .min_connections(1) + .max_connections(1) + .idle_timeout(None) + .max_lifetime(None); + } else { + pool_options = pool_options.max_connections(4); + } + + Self { + pool: pool_options.connect_lazy_with(sqlite_options), + path, + } + } + + pub async fn connect(&self) -> Result<()> { + if !Sqlite::database_exists(&self.path).await.unwrap_or(false) { + Sqlite::create_database(&self.path).await?; + }; + + let migrations = std::env::current_dir().unwrap().join("migrations"); + let migrator = Migrator::new(migrations).await?; + migrator.run(&self.pool).await?; + + Ok(()) + } + + pub async fn close(&self) { + self.pool.close().await + } + + pub fn is_closed(&self) -> bool { + self.pool.is_closed() + } + + pub async fn push_updates, Updates: AsRef<[Update]>>( + &self, + doc_id: String, + updates: Updates, + ) -> Result { + let mut cnt = 0; + + for chunk in updates.as_ref().chunks(10) { + self.batch_push_updates(&doc_id, chunk).await?; + cnt += chunk.len() as u32; + } + + Ok(cnt) + } + + pub async fn get_doc_snapshot(&self, doc_id: String) -> Result> { + sqlx::query_as!( + DocRecord, + "SELECT doc_id, data, updated_at as timestamp FROM v2_snapshots WHERE doc_id = ?", + doc_id + ) + .fetch_optional(&self.pool) + .await + } + + pub async fn set_doc_snapshot(&self, snapshot: DocRecord) -> Result { + let result = sqlx::query( + r#" + INSERT INTO v2_snapshots (doc_id, data, updated_at) + VALUES ($1, $2, $3) + ON CONFLICT(doc_id) + DO UPDATE SET data=$2, updated_at=$3 + WHERE updated_at <= $3;"#, + ) + .bind(snapshot.doc_id) + .bind(snapshot.data.as_ref()) + .bind(snapshot.timestamp) + .execute(&self.pool) + .await?; + + Ok(result.rows_affected() == 1) + } + + pub async fn get_doc_updates(&self, doc_id: String) -> Result> { + sqlx::query_as!( + DocUpdate, + "SELECT doc_id, created_at, data FROM v2_updates WHERE doc_id = ?", + doc_id + ) + .fetch_all(&self.pool) + .await + } + + pub async fn mark_updates_merged( + &self, + doc_id: String, + updates: Vec, + ) -> Result { + let mut qb = QueryBuilder::new("DELETE FROM v2_updates"); + + qb.push(" WHERE doc_id = "); + qb.push_bind(doc_id); + qb.push(" AND created_at IN ("); + let mut separated = qb.separated(", "); + updates.iter().for_each(|update| { + separated.push_bind(update); + }); + qb.push(");"); + + let query = qb.build(); + + let result = query.execute(&self.pool).await?; + + Ok(result.rows_affected() as u32) + } + + async fn batch_push_updates>( + &self, + doc_id: &str, + updates: &[Update], + ) -> Result<()> { + let mut timestamp = chrono::Utc::now().timestamp_micros(); + + let mut qb = QueryBuilder::new("INSERT INTO v2_updates (doc_id, data, created_at) "); + qb.push_values(updates, |mut b, update| { + timestamp += 1; + b.push_bind(doc_id).push_bind(update.as_ref()).push_bind( + DateTime::from_timestamp_millis(timestamp) + .unwrap() + .naive_utc(), + ); + }); + + let query = qb.build(); + + let mut tx = self.pool.begin().await?; + query.execute(&mut *tx).await?; + + sqlx::query( + r#" + INSERT INTO v2_clocks (doc_id, timestamp) VALUES ($1, $2) + ON CONFLICT(doc_id) + DO UPDATE SET timestamp=$2;"#, + ) + .bind(doc_id) + .bind(DateTime::from_timestamp_millis(timestamp).unwrap().to_utc()) + .execute(&mut *tx) + .await?; + + tx.commit().await + } + + pub async fn delete_doc(&self, doc_id: String) -> Result<()> { + let mut tx = self.pool.begin().await?; + + sqlx::query("DELETE FROM updates WHERE doc_id = ?;") + .bind(&doc_id) + .execute(&mut *tx) + .await?; + + sqlx::query("DELETE FROM snapshots WHERE doc_id = ?;") + .bind(&doc_id) + .execute(&mut *tx) + .await?; + + sqlx::query("DELETE FROM clocks WHERE doc_id = ?;") + .bind(&doc_id) + .execute(&mut *tx) + .await?; + + tx.commit().await + } + + pub async fn get_doc_clocks(&self, after: Option) -> Result> { + let query = if let Some(after) = after { + sqlx::query("SELECT doc_id, timestamp FROM v2_clocks WHERE timestamp > $1") + .bind(DateTime::from_timestamp_millis(after).unwrap().naive_utc()) + } else { + sqlx::query("SELECT doc_id, timestamp FROM v2_clocks") + }; + + let clocks = query.fetch_all(&self.pool).await?; + + Ok( + clocks + .iter() + .map(|row| DocClock { + doc_id: row.get("doc_id"), + timestamp: row.get("timestamp"), + }) + .collect(), + ) + } + + /** + * Flush the WAL file to the database file. + * See https://www.sqlite.org/pragma.html#pragma_wal_checkpoint:~:text=PRAGMA%20schema.wal_checkpoint%3B + */ + + pub async fn checkpoint(&self) -> Result<()> { + sqlx::query("PRAGMA wal_checkpoint(FULL);") + .execute(&self.pool) + .await?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use chrono::Utc; + use napi::bindgen_prelude::Buffer; + + use super::*; + + async fn get_storage() -> SqliteDocStorage { + let storage = SqliteDocStorage::new(":memory:".to_string()); + storage.connect().await.unwrap(); + + storage + } + + #[tokio::test] + async fn init_tables() { + let storage = get_storage().await; + + sqlx::query("INSERT INTO v2_snapshots (doc_id, data, updated_at) VALUES ($1, $2, $3);") + .bind("test") + .bind(vec![0, 0]) + .bind(Utc::now()) + .execute(&storage.pool) + .await + .unwrap(); + + sqlx::query_as!( + DocRecord, + "SELECT doc_id, data, updated_at as timestamp FROM v2_snapshots WHERE doc_id = 'test';" + ) + .fetch_one(&storage.pool) + .await + .unwrap(); + } + + #[tokio::test] + async fn push_updates() { + let storage = get_storage().await; + + let updates = vec![vec![0, 0], vec![0, 1], vec![1, 0], vec![1, 1]]; + + storage + .push_updates("test".to_string(), &updates) + .await + .unwrap(); + + let result = storage.get_doc_updates("test".to_string()).await.unwrap(); + + assert_eq!(result.len(), 4); + assert_eq!( + result.iter().map(|u| u.data.as_ref()).collect::>(), + updates + ); + } + + #[tokio::test] + async fn get_doc_snapshot() { + let storage = get_storage().await; + + let none = storage.get_doc_snapshot("test".to_string()).await.unwrap(); + + assert!(none.is_none()); + + let snapshot = DocRecord { + doc_id: "test".to_string(), + data: Buffer::from(vec![0, 0]), + timestamp: Utc::now().naive_utc(), + }; + + storage.set_doc_snapshot(snapshot).await.unwrap(); + + let result = storage.get_doc_snapshot("test".to_string()).await.unwrap(); + + assert!(result.is_some()); + assert_eq!(result.unwrap().data.as_ref(), vec![0, 0]); + } + + #[tokio::test] + async fn set_doc_snapshot() { + let storage = get_storage().await; + + let snapshot = DocRecord { + doc_id: "test".to_string(), + data: Buffer::from(vec![0, 0]), + timestamp: Utc::now().naive_utc(), + }; + + storage.set_doc_snapshot(snapshot).await.unwrap(); + + let result = storage.get_doc_snapshot("test".to_string()).await.unwrap(); + + assert!(result.is_some()); + assert_eq!(result.unwrap().data.as_ref(), vec![0, 0]); + + let snapshot = DocRecord { + doc_id: "test".to_string(), + data: Buffer::from(vec![0, 1]), + timestamp: DateTime::from_timestamp_millis(Utc::now().timestamp_millis() - 1000) + .unwrap() + .naive_utc(), + }; + + // can't update because it's tempstamp is older + storage.set_doc_snapshot(snapshot).await.unwrap(); + + let result = storage.get_doc_snapshot("test".to_string()).await.unwrap(); + + assert!(result.is_some()); + assert_eq!(result.unwrap().data.as_ref(), vec![0, 0]); + } + + #[tokio::test] + async fn get_doc_clocks() { + let storage = get_storage().await; + + let clocks = storage.get_doc_clocks(None).await.unwrap(); + + assert_eq!(clocks.len(), 0); + + // where is join_all()? + for i in 1..5u32 { + storage + .push_updates(format!("test_{i}"), vec![vec![0, 0]]) + .await + .unwrap(); + } + + let clocks = storage.get_doc_clocks(None).await.unwrap(); + + assert_eq!(clocks.len(), 4); + assert_eq!( + clocks.iter().map(|c| c.doc_id.as_str()).collect::>(), + vec!["test_1", "test_2", "test_3", "test_4"] + ); + + let clocks = storage + .get_doc_clocks(Some(Utc::now().timestamp_millis())) + .await + .unwrap(); + + assert_eq!(clocks.len(), 0); + } + + #[tokio::test] + async fn mark_updates_merged() { + let storage = get_storage().await; + + storage + .push_updates( + "test".to_string(), + vec![vec![0, 0], vec![0, 1], vec![1, 0], vec![1, 1]], + ) + .await + .unwrap(); + + let updates = storage.get_doc_updates("test".to_string()).await.unwrap(); + + let result = storage + .mark_updates_merged( + "test".to_string(), + updates + .iter() + .skip(1) + .map(|u| u.created_at) + .collect::>(), + ) + .await + .unwrap(); + + assert_eq!(result, 3); + + let updates = storage.get_doc_updates("test".to_string()).await.unwrap(); + + assert_eq!(updates.len(), 1); + } +} diff --git a/packages/frontend/native/src/sqlite/mod.rs b/packages/frontend/native/src/sqlite/mod.rs index 17293d7ba9ff9..fceb3c7894f8b 100644 --- a/packages/frontend/native/src/sqlite/mod.rs +++ b/packages/frontend/native/src/sqlite/mod.rs @@ -1,518 +1,2 @@ -use chrono::NaiveDateTime; -use napi::bindgen_prelude::{Buffer, Uint8Array}; -use napi_derive::napi; -use sqlx::{ - migrate::MigrateDatabase, - sqlite::{Sqlite, SqliteConnectOptions, SqlitePoolOptions}, - Pool, Row, -}; - -// latest version -const LATEST_VERSION: i32 = 4; - -#[napi(object)] -pub struct BlobRow { - pub key: String, - pub data: Buffer, - pub timestamp: NaiveDateTime, -} - -#[napi(object)] -pub struct UpdateRow { - pub id: i64, - pub timestamp: NaiveDateTime, - pub data: Buffer, - pub doc_id: Option, -} - -#[napi(object)] -pub struct InsertRow { - pub doc_id: Option, - pub data: Uint8Array, -} - -#[napi] -pub struct SqliteConnection { - pool: Pool, - path: String, -} - -#[napi] -pub enum ValidationResult { - MissingTables, - MissingDocIdColumn, - MissingVersionColumn, - GeneralError, - Valid, -} - -#[napi] -impl SqliteConnection { - #[napi(constructor, async_runtime)] - pub fn new(path: String) -> napi::Result { - let sqlite_options = SqliteConnectOptions::new() - .filename(&path) - .foreign_keys(false) - .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal); - let pool = SqlitePoolOptions::new() - .max_connections(4) - .connect_lazy_with(sqlite_options); - Ok(Self { pool, path }) - } - - #[napi] - pub async fn connect(&self) -> napi::Result<()> { - if !Sqlite::database_exists(&self.path).await.unwrap_or(false) { - Sqlite::create_database(&self.path) - .await - .map_err(anyhow::Error::from)?; - }; - let mut connection = self.pool.acquire().await.map_err(anyhow::Error::from)?; - sqlx::query(affine_schema::SCHEMA) - .execute(connection.as_mut()) - .await - .map_err(anyhow::Error::from)?; - self.migrate_add_doc_id().await?; - self.migrate_add_doc_id_index().await?; - connection.detach(); - Ok(()) - } - - #[napi] - pub async fn add_blob(&self, key: String, blob: Uint8Array) -> napi::Result<()> { - let blob = blob.as_ref(); - sqlx::query_as!( - BlobRow, - "INSERT INTO blobs (key, data) VALUES ($1, $2) ON CONFLICT(key) DO UPDATE SET data = excluded.data", - key, - blob, - ) - .execute(&self.pool) - .await - .map_err(anyhow::Error::from)?; - Ok(()) - } - - #[napi] - pub async fn get_blob(&self, key: String) -> Option { - sqlx::query_as!( - BlobRow, - "SELECT key, data, timestamp FROM blobs WHERE key = ?", - key - ) - .fetch_one(&self.pool) - .await - .ok() - } - - #[napi] - pub async fn delete_blob(&self, key: String) -> napi::Result<()> { - sqlx::query!("DELETE FROM blobs WHERE key = ?", key) - .execute(&self.pool) - .await - .map_err(anyhow::Error::from)?; - Ok(()) - } - - #[napi] - pub async fn get_blob_keys(&self) -> napi::Result> { - let keys = sqlx::query!("SELECT key FROM blobs") - .fetch_all(&self.pool) - .await - .map(|rows| rows.into_iter().map(|row| row.key).collect()) - .map_err(anyhow::Error::from)?; - Ok(keys) - } - - #[napi] - pub async fn get_updates(&self, doc_id: Option) -> napi::Result> { - let updates = match doc_id { - Some(doc_id) => sqlx::query_as!( - UpdateRow, - "SELECT id, timestamp, data, doc_id FROM updates WHERE doc_id = ?", - doc_id - ) - .fetch_all(&self.pool) - .await - .map_err(anyhow::Error::from)?, - None => sqlx::query_as!( - UpdateRow, - "SELECT id, timestamp, data, doc_id FROM updates WHERE doc_id is NULL", - ) - .fetch_all(&self.pool) - .await - .map_err(anyhow::Error::from)?, - }; - Ok(updates) - } - - #[napi] - pub async fn delete_updates(&self, doc_id: Option) -> napi::Result<()> { - match doc_id { - Some(doc_id) => { - sqlx::query!("DELETE FROM updates WHERE doc_id = ?", doc_id) - .execute(&self.pool) - .await - .map_err(anyhow::Error::from)?; - } - None => { - sqlx::query!("DELETE FROM updates WHERE doc_id is NULL") - .execute(&self.pool) - .await - .map_err(anyhow::Error::from)?; - } - }; - Ok(()) - } - - #[napi] - pub async fn get_updates_count(&self, doc_id: Option) -> napi::Result { - let count = match doc_id { - Some(doc_id) => { - sqlx::query!( - "SELECT COUNT(*) as count FROM updates WHERE doc_id = ?", - doc_id - ) - .fetch_one(&self.pool) - .await - .map_err(anyhow::Error::from)? - .count - } - None => { - sqlx::query!("SELECT COUNT(*) as count FROM updates WHERE doc_id is NULL") - .fetch_one(&self.pool) - .await - .map_err(anyhow::Error::from)? - .count - } - }; - Ok(count) - } - - #[napi] - pub async fn get_all_updates(&self) -> napi::Result> { - let updates = sqlx::query_as!(UpdateRow, "SELECT id, timestamp, data, doc_id FROM updates") - .fetch_all(&self.pool) - .await - .map_err(anyhow::Error::from)?; - Ok(updates) - } - - #[napi] - pub async fn insert_updates(&self, updates: Vec) -> napi::Result<()> { - let mut transaction = self.pool.begin().await.map_err(anyhow::Error::from)?; - for InsertRow { data, doc_id } in updates { - let update = data.as_ref(); - sqlx::query_as!( - UpdateRow, - "INSERT INTO updates (data, doc_id) VALUES ($1, $2)", - update, - doc_id - ) - .execute(&mut *transaction) - .await - .map_err(anyhow::Error::from)?; - } - transaction.commit().await.map_err(anyhow::Error::from)?; - Ok(()) - } - - #[napi] - pub async fn replace_updates( - &self, - doc_id: Option, - updates: Vec, - ) -> napi::Result<()> { - let mut transaction = self.pool.begin().await.map_err(anyhow::Error::from)?; - - match doc_id { - Some(doc_id) => sqlx::query!("DELETE FROM updates where doc_id = ?", doc_id) - .execute(&mut *transaction) - .await - .map_err(anyhow::Error::from)?, - None => sqlx::query!("DELETE FROM updates where doc_id is NULL",) - .execute(&mut *transaction) - .await - .map_err(anyhow::Error::from)?, - }; - - for InsertRow { data, doc_id } in updates { - let update = data.as_ref(); - sqlx::query_as!( - UpdateRow, - "INSERT INTO updates (data, doc_id) VALUES ($1, $2)", - update, - doc_id - ) - .execute(&mut *transaction) - .await - .map_err(anyhow::Error::from)?; - } - transaction.commit().await.map_err(anyhow::Error::from)?; - Ok(()) - } - - #[napi] - pub async fn get_server_clock(&self, key: String) -> Option { - sqlx::query_as!( - BlobRow, - "SELECT key, data, timestamp FROM server_clock WHERE key = ?", - key - ) - .fetch_one(&self.pool) - .await - .ok() - } - - #[napi] - pub async fn set_server_clock(&self, key: String, data: Uint8Array) -> napi::Result<()> { - let data = data.as_ref(); - sqlx::query!( - "INSERT INTO server_clock (key, data) VALUES ($1, $2) ON CONFLICT(key) DO UPDATE SET data = excluded.data", - key, - data, - ) - .execute(&self.pool) - .await - .map_err(anyhow::Error::from)?; - Ok(()) - } - - #[napi] - pub async fn get_server_clock_keys(&self) -> napi::Result> { - let keys = sqlx::query!("SELECT key FROM server_clock") - .fetch_all(&self.pool) - .await - .map(|rows| rows.into_iter().map(|row| row.key).collect()) - .map_err(anyhow::Error::from)?; - Ok(keys) - } - - #[napi] - pub async fn clear_server_clock(&self) -> napi::Result<()> { - sqlx::query!("DELETE FROM server_clock") - .execute(&self.pool) - .await - .map_err(anyhow::Error::from)?; - Ok(()) - } - - #[napi] - pub async fn del_server_clock(&self, key: String) -> napi::Result<()> { - sqlx::query!("DELETE FROM server_clock WHERE key = ?", key) - .execute(&self.pool) - .await - .map_err(anyhow::Error::from)?; - Ok(()) - } - - #[napi] - pub async fn get_sync_metadata(&self, key: String) -> Option { - sqlx::query_as!( - BlobRow, - "SELECT key, data, timestamp FROM sync_metadata WHERE key = ?", - key - ) - .fetch_one(&self.pool) - .await - .ok() - } - - #[napi] - pub async fn set_sync_metadata(&self, key: String, data: Uint8Array) -> napi::Result<()> { - let data = data.as_ref(); - sqlx::query!( - "INSERT INTO sync_metadata (key, data) VALUES ($1, $2) ON CONFLICT(key) DO UPDATE SET data = excluded.data", - key, - data, - ) - .execute(&self.pool) - .await - .map_err(anyhow::Error::from)?; - Ok(()) - } - - #[napi] - pub async fn get_sync_metadata_keys(&self) -> napi::Result> { - let keys = sqlx::query!("SELECT key FROM sync_metadata") - .fetch_all(&self.pool) - .await - .map(|rows| rows.into_iter().map(|row| row.key).collect()) - .map_err(anyhow::Error::from)?; - Ok(keys) - } - - #[napi] - pub async fn clear_sync_metadata(&self) -> napi::Result<()> { - sqlx::query!("DELETE FROM sync_metadata") - .execute(&self.pool) - .await - .map_err(anyhow::Error::from)?; - Ok(()) - } - - #[napi] - pub async fn del_sync_metadata(&self, key: String) -> napi::Result<()> { - sqlx::query!("DELETE FROM sync_metadata WHERE key = ?", key) - .execute(&self.pool) - .await - .map_err(anyhow::Error::from)?; - Ok(()) - } - - #[napi] - pub async fn init_version(&self) -> napi::Result<()> { - // create version_info table - sqlx::query!( - "CREATE TABLE IF NOT EXISTS version_info ( - version NUMBER NOT NULL, - timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL - )" - ) - .execute(&self.pool) - .await - .map_err(anyhow::Error::from)?; - // `3` is the first version that has version_info table, - // do not modify the version number. - sqlx::query!("INSERT INTO version_info (version) VALUES (3)") - .execute(&self.pool) - .await - .map_err(anyhow::Error::from)?; - Ok(()) - } - - #[napi] - pub async fn set_version(&self, version: i32) -> napi::Result<()> { - if version > LATEST_VERSION { - return Err(anyhow::Error::msg("Version is too new").into()); - } - sqlx::query!("UPDATE version_info SET version = ?", version) - .execute(&self.pool) - .await - .map_err(anyhow::Error::from)?; - Ok(()) - } - - #[napi] - pub async fn get_max_version(&self) -> napi::Result { - // 4 is the current version - let version = sqlx::query!("SELECT COALESCE(MAX(version), 4) AS max_version FROM version_info") - .fetch_one(&self.pool) - .await - .map_err(anyhow::Error::from)? - .max_version; - Ok(version) - } - - #[napi] - pub async fn close(&self) { - self.pool.close().await; - } - - #[napi(getter)] - pub fn is_close(&self) -> bool { - self.pool.is_closed() - } - - #[napi] - pub async fn validate(path: String) -> ValidationResult { - let pool = match SqlitePoolOptions::new() - .max_connections(1) - .connect(&path) - .await - { - Ok(pool) => pool, - Err(_) => return ValidationResult::GeneralError, - }; - - let tables_res = sqlx::query("SELECT name FROM sqlite_master WHERE type='table'") - .fetch_all(&pool) - .await; - - let tables_exist = match tables_res { - Ok(res) => { - let names: Vec = res.iter().map(|row| row.get(0)).collect(); - names.contains(&"updates".to_string()) && names.contains(&"blobs".to_string()) - } - Err(_) => return ValidationResult::GeneralError, - }; - - let tables_res = sqlx::query("SELECT name FROM sqlite_master WHERE type='table'") - .fetch_all(&pool) - .await; - - let version_exist = match tables_res { - Ok(res) => { - let names: Vec = res.iter().map(|row| row.get(0)).collect(); - names.contains(&"version_info".to_string()) - } - Err(_) => return ValidationResult::GeneralError, - }; - - let columns_res = sqlx::query("PRAGMA table_info(updates)") - .fetch_all(&pool) - .await; - - let doc_id_exist = match columns_res { - Ok(res) => { - let names: Vec = res.iter().map(|row| row.get(1)).collect(); - names.contains(&"doc_id".to_string()) - } - Err(_) => return ValidationResult::GeneralError, - }; - - if !tables_exist { - ValidationResult::MissingTables - } else if !doc_id_exist { - ValidationResult::MissingDocIdColumn - } else if !version_exist { - ValidationResult::MissingVersionColumn - } else { - ValidationResult::Valid - } - } - - #[napi] - pub async fn migrate_add_doc_id(&self) -> napi::Result<()> { - // ignore errors - match sqlx::query("ALTER TABLE updates ADD COLUMN doc_id TEXT") - .execute(&self.pool) - .await - { - Ok(_) => Ok(()), - Err(err) => { - if err.to_string().contains("duplicate column name") { - Ok(()) // Ignore error if it's due to duplicate column - } else { - Err(anyhow::Error::from(err).into()) // Propagate other errors - } - } - } - } - - /** - * Flush the WAL file to the database file. - * See https://www.sqlite.org/pragma.html#pragma_wal_checkpoint:~:text=PRAGMA%20schema.wal_checkpoint%3B - */ - #[napi] - pub async fn checkpoint(&self) -> napi::Result<()> { - sqlx::query("PRAGMA wal_checkpoint(FULL);") - .execute(&self.pool) - .await - .map_err(anyhow::Error::from)?; - Ok(()) - } - - pub async fn migrate_add_doc_id_index(&self) -> napi::Result<()> { - // ignore errors - match sqlx::query("CREATE INDEX IF NOT EXISTS idx_doc_id ON updates(doc_id);") - .execute(&self.pool) - .await - { - Ok(_) => Ok(()), - Err(err) => { - Err(anyhow::Error::from(err).into()) // Propagate other errors - } - } - } -} +pub mod doc_storage; +pub mod v1; diff --git a/packages/frontend/native/src/sqlite/v1.rs b/packages/frontend/native/src/sqlite/v1.rs new file mode 100644 index 0000000000000..17293d7ba9ff9 --- /dev/null +++ b/packages/frontend/native/src/sqlite/v1.rs @@ -0,0 +1,518 @@ +use chrono::NaiveDateTime; +use napi::bindgen_prelude::{Buffer, Uint8Array}; +use napi_derive::napi; +use sqlx::{ + migrate::MigrateDatabase, + sqlite::{Sqlite, SqliteConnectOptions, SqlitePoolOptions}, + Pool, Row, +}; + +// latest version +const LATEST_VERSION: i32 = 4; + +#[napi(object)] +pub struct BlobRow { + pub key: String, + pub data: Buffer, + pub timestamp: NaiveDateTime, +} + +#[napi(object)] +pub struct UpdateRow { + pub id: i64, + pub timestamp: NaiveDateTime, + pub data: Buffer, + pub doc_id: Option, +} + +#[napi(object)] +pub struct InsertRow { + pub doc_id: Option, + pub data: Uint8Array, +} + +#[napi] +pub struct SqliteConnection { + pool: Pool, + path: String, +} + +#[napi] +pub enum ValidationResult { + MissingTables, + MissingDocIdColumn, + MissingVersionColumn, + GeneralError, + Valid, +} + +#[napi] +impl SqliteConnection { + #[napi(constructor, async_runtime)] + pub fn new(path: String) -> napi::Result { + let sqlite_options = SqliteConnectOptions::new() + .filename(&path) + .foreign_keys(false) + .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal); + let pool = SqlitePoolOptions::new() + .max_connections(4) + .connect_lazy_with(sqlite_options); + Ok(Self { pool, path }) + } + + #[napi] + pub async fn connect(&self) -> napi::Result<()> { + if !Sqlite::database_exists(&self.path).await.unwrap_or(false) { + Sqlite::create_database(&self.path) + .await + .map_err(anyhow::Error::from)?; + }; + let mut connection = self.pool.acquire().await.map_err(anyhow::Error::from)?; + sqlx::query(affine_schema::SCHEMA) + .execute(connection.as_mut()) + .await + .map_err(anyhow::Error::from)?; + self.migrate_add_doc_id().await?; + self.migrate_add_doc_id_index().await?; + connection.detach(); + Ok(()) + } + + #[napi] + pub async fn add_blob(&self, key: String, blob: Uint8Array) -> napi::Result<()> { + let blob = blob.as_ref(); + sqlx::query_as!( + BlobRow, + "INSERT INTO blobs (key, data) VALUES ($1, $2) ON CONFLICT(key) DO UPDATE SET data = excluded.data", + key, + blob, + ) + .execute(&self.pool) + .await + .map_err(anyhow::Error::from)?; + Ok(()) + } + + #[napi] + pub async fn get_blob(&self, key: String) -> Option { + sqlx::query_as!( + BlobRow, + "SELECT key, data, timestamp FROM blobs WHERE key = ?", + key + ) + .fetch_one(&self.pool) + .await + .ok() + } + + #[napi] + pub async fn delete_blob(&self, key: String) -> napi::Result<()> { + sqlx::query!("DELETE FROM blobs WHERE key = ?", key) + .execute(&self.pool) + .await + .map_err(anyhow::Error::from)?; + Ok(()) + } + + #[napi] + pub async fn get_blob_keys(&self) -> napi::Result> { + let keys = sqlx::query!("SELECT key FROM blobs") + .fetch_all(&self.pool) + .await + .map(|rows| rows.into_iter().map(|row| row.key).collect()) + .map_err(anyhow::Error::from)?; + Ok(keys) + } + + #[napi] + pub async fn get_updates(&self, doc_id: Option) -> napi::Result> { + let updates = match doc_id { + Some(doc_id) => sqlx::query_as!( + UpdateRow, + "SELECT id, timestamp, data, doc_id FROM updates WHERE doc_id = ?", + doc_id + ) + .fetch_all(&self.pool) + .await + .map_err(anyhow::Error::from)?, + None => sqlx::query_as!( + UpdateRow, + "SELECT id, timestamp, data, doc_id FROM updates WHERE doc_id is NULL", + ) + .fetch_all(&self.pool) + .await + .map_err(anyhow::Error::from)?, + }; + Ok(updates) + } + + #[napi] + pub async fn delete_updates(&self, doc_id: Option) -> napi::Result<()> { + match doc_id { + Some(doc_id) => { + sqlx::query!("DELETE FROM updates WHERE doc_id = ?", doc_id) + .execute(&self.pool) + .await + .map_err(anyhow::Error::from)?; + } + None => { + sqlx::query!("DELETE FROM updates WHERE doc_id is NULL") + .execute(&self.pool) + .await + .map_err(anyhow::Error::from)?; + } + }; + Ok(()) + } + + #[napi] + pub async fn get_updates_count(&self, doc_id: Option) -> napi::Result { + let count = match doc_id { + Some(doc_id) => { + sqlx::query!( + "SELECT COUNT(*) as count FROM updates WHERE doc_id = ?", + doc_id + ) + .fetch_one(&self.pool) + .await + .map_err(anyhow::Error::from)? + .count + } + None => { + sqlx::query!("SELECT COUNT(*) as count FROM updates WHERE doc_id is NULL") + .fetch_one(&self.pool) + .await + .map_err(anyhow::Error::from)? + .count + } + }; + Ok(count) + } + + #[napi] + pub async fn get_all_updates(&self) -> napi::Result> { + let updates = sqlx::query_as!(UpdateRow, "SELECT id, timestamp, data, doc_id FROM updates") + .fetch_all(&self.pool) + .await + .map_err(anyhow::Error::from)?; + Ok(updates) + } + + #[napi] + pub async fn insert_updates(&self, updates: Vec) -> napi::Result<()> { + let mut transaction = self.pool.begin().await.map_err(anyhow::Error::from)?; + for InsertRow { data, doc_id } in updates { + let update = data.as_ref(); + sqlx::query_as!( + UpdateRow, + "INSERT INTO updates (data, doc_id) VALUES ($1, $2)", + update, + doc_id + ) + .execute(&mut *transaction) + .await + .map_err(anyhow::Error::from)?; + } + transaction.commit().await.map_err(anyhow::Error::from)?; + Ok(()) + } + + #[napi] + pub async fn replace_updates( + &self, + doc_id: Option, + updates: Vec, + ) -> napi::Result<()> { + let mut transaction = self.pool.begin().await.map_err(anyhow::Error::from)?; + + match doc_id { + Some(doc_id) => sqlx::query!("DELETE FROM updates where doc_id = ?", doc_id) + .execute(&mut *transaction) + .await + .map_err(anyhow::Error::from)?, + None => sqlx::query!("DELETE FROM updates where doc_id is NULL",) + .execute(&mut *transaction) + .await + .map_err(anyhow::Error::from)?, + }; + + for InsertRow { data, doc_id } in updates { + let update = data.as_ref(); + sqlx::query_as!( + UpdateRow, + "INSERT INTO updates (data, doc_id) VALUES ($1, $2)", + update, + doc_id + ) + .execute(&mut *transaction) + .await + .map_err(anyhow::Error::from)?; + } + transaction.commit().await.map_err(anyhow::Error::from)?; + Ok(()) + } + + #[napi] + pub async fn get_server_clock(&self, key: String) -> Option { + sqlx::query_as!( + BlobRow, + "SELECT key, data, timestamp FROM server_clock WHERE key = ?", + key + ) + .fetch_one(&self.pool) + .await + .ok() + } + + #[napi] + pub async fn set_server_clock(&self, key: String, data: Uint8Array) -> napi::Result<()> { + let data = data.as_ref(); + sqlx::query!( + "INSERT INTO server_clock (key, data) VALUES ($1, $2) ON CONFLICT(key) DO UPDATE SET data = excluded.data", + key, + data, + ) + .execute(&self.pool) + .await + .map_err(anyhow::Error::from)?; + Ok(()) + } + + #[napi] + pub async fn get_server_clock_keys(&self) -> napi::Result> { + let keys = sqlx::query!("SELECT key FROM server_clock") + .fetch_all(&self.pool) + .await + .map(|rows| rows.into_iter().map(|row| row.key).collect()) + .map_err(anyhow::Error::from)?; + Ok(keys) + } + + #[napi] + pub async fn clear_server_clock(&self) -> napi::Result<()> { + sqlx::query!("DELETE FROM server_clock") + .execute(&self.pool) + .await + .map_err(anyhow::Error::from)?; + Ok(()) + } + + #[napi] + pub async fn del_server_clock(&self, key: String) -> napi::Result<()> { + sqlx::query!("DELETE FROM server_clock WHERE key = ?", key) + .execute(&self.pool) + .await + .map_err(anyhow::Error::from)?; + Ok(()) + } + + #[napi] + pub async fn get_sync_metadata(&self, key: String) -> Option { + sqlx::query_as!( + BlobRow, + "SELECT key, data, timestamp FROM sync_metadata WHERE key = ?", + key + ) + .fetch_one(&self.pool) + .await + .ok() + } + + #[napi] + pub async fn set_sync_metadata(&self, key: String, data: Uint8Array) -> napi::Result<()> { + let data = data.as_ref(); + sqlx::query!( + "INSERT INTO sync_metadata (key, data) VALUES ($1, $2) ON CONFLICT(key) DO UPDATE SET data = excluded.data", + key, + data, + ) + .execute(&self.pool) + .await + .map_err(anyhow::Error::from)?; + Ok(()) + } + + #[napi] + pub async fn get_sync_metadata_keys(&self) -> napi::Result> { + let keys = sqlx::query!("SELECT key FROM sync_metadata") + .fetch_all(&self.pool) + .await + .map(|rows| rows.into_iter().map(|row| row.key).collect()) + .map_err(anyhow::Error::from)?; + Ok(keys) + } + + #[napi] + pub async fn clear_sync_metadata(&self) -> napi::Result<()> { + sqlx::query!("DELETE FROM sync_metadata") + .execute(&self.pool) + .await + .map_err(anyhow::Error::from)?; + Ok(()) + } + + #[napi] + pub async fn del_sync_metadata(&self, key: String) -> napi::Result<()> { + sqlx::query!("DELETE FROM sync_metadata WHERE key = ?", key) + .execute(&self.pool) + .await + .map_err(anyhow::Error::from)?; + Ok(()) + } + + #[napi] + pub async fn init_version(&self) -> napi::Result<()> { + // create version_info table + sqlx::query!( + "CREATE TABLE IF NOT EXISTS version_info ( + version NUMBER NOT NULL, + timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL + )" + ) + .execute(&self.pool) + .await + .map_err(anyhow::Error::from)?; + // `3` is the first version that has version_info table, + // do not modify the version number. + sqlx::query!("INSERT INTO version_info (version) VALUES (3)") + .execute(&self.pool) + .await + .map_err(anyhow::Error::from)?; + Ok(()) + } + + #[napi] + pub async fn set_version(&self, version: i32) -> napi::Result<()> { + if version > LATEST_VERSION { + return Err(anyhow::Error::msg("Version is too new").into()); + } + sqlx::query!("UPDATE version_info SET version = ?", version) + .execute(&self.pool) + .await + .map_err(anyhow::Error::from)?; + Ok(()) + } + + #[napi] + pub async fn get_max_version(&self) -> napi::Result { + // 4 is the current version + let version = sqlx::query!("SELECT COALESCE(MAX(version), 4) AS max_version FROM version_info") + .fetch_one(&self.pool) + .await + .map_err(anyhow::Error::from)? + .max_version; + Ok(version) + } + + #[napi] + pub async fn close(&self) { + self.pool.close().await; + } + + #[napi(getter)] + pub fn is_close(&self) -> bool { + self.pool.is_closed() + } + + #[napi] + pub async fn validate(path: String) -> ValidationResult { + let pool = match SqlitePoolOptions::new() + .max_connections(1) + .connect(&path) + .await + { + Ok(pool) => pool, + Err(_) => return ValidationResult::GeneralError, + }; + + let tables_res = sqlx::query("SELECT name FROM sqlite_master WHERE type='table'") + .fetch_all(&pool) + .await; + + let tables_exist = match tables_res { + Ok(res) => { + let names: Vec = res.iter().map(|row| row.get(0)).collect(); + names.contains(&"updates".to_string()) && names.contains(&"blobs".to_string()) + } + Err(_) => return ValidationResult::GeneralError, + }; + + let tables_res = sqlx::query("SELECT name FROM sqlite_master WHERE type='table'") + .fetch_all(&pool) + .await; + + let version_exist = match tables_res { + Ok(res) => { + let names: Vec = res.iter().map(|row| row.get(0)).collect(); + names.contains(&"version_info".to_string()) + } + Err(_) => return ValidationResult::GeneralError, + }; + + let columns_res = sqlx::query("PRAGMA table_info(updates)") + .fetch_all(&pool) + .await; + + let doc_id_exist = match columns_res { + Ok(res) => { + let names: Vec = res.iter().map(|row| row.get(1)).collect(); + names.contains(&"doc_id".to_string()) + } + Err(_) => return ValidationResult::GeneralError, + }; + + if !tables_exist { + ValidationResult::MissingTables + } else if !doc_id_exist { + ValidationResult::MissingDocIdColumn + } else if !version_exist { + ValidationResult::MissingVersionColumn + } else { + ValidationResult::Valid + } + } + + #[napi] + pub async fn migrate_add_doc_id(&self) -> napi::Result<()> { + // ignore errors + match sqlx::query("ALTER TABLE updates ADD COLUMN doc_id TEXT") + .execute(&self.pool) + .await + { + Ok(_) => Ok(()), + Err(err) => { + if err.to_string().contains("duplicate column name") { + Ok(()) // Ignore error if it's due to duplicate column + } else { + Err(anyhow::Error::from(err).into()) // Propagate other errors + } + } + } + } + + /** + * Flush the WAL file to the database file. + * See https://www.sqlite.org/pragma.html#pragma_wal_checkpoint:~:text=PRAGMA%20schema.wal_checkpoint%3B + */ + #[napi] + pub async fn checkpoint(&self) -> napi::Result<()> { + sqlx::query("PRAGMA wal_checkpoint(FULL);") + .execute(&self.pool) + .await + .map_err(anyhow::Error::from)?; + Ok(()) + } + + pub async fn migrate_add_doc_id_index(&self) -> napi::Result<()> { + // ignore errors + match sqlx::query("CREATE INDEX IF NOT EXISTS idx_doc_id ON updates(doc_id);") + .execute(&self.pool) + .await + { + Ok(_) => Ok(()), + Err(err) => { + Err(anyhow::Error::from(err).into()) // Propagate other errors + } + } + } +} diff --git a/tsconfig.json b/tsconfig.json index b99aa56c115ef..bdd3d734fe792 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -72,7 +72,8 @@ "@affine/native/*": ["./packages/frontend/native/*"], "@affine/server-native": ["./packages/backend/native/index.d.ts"], // Development only - "@affine/electron/*": ["./packages/frontend/apps/electron/src/*"] + "@affine/electron/*": ["./packages/frontend/apps/electron/src/*"], + "@affine/doc-storage": ["./packages/common/doc-storage/src"] } }, "include": [], @@ -131,6 +132,9 @@ { "path": "./packages/common/infra" }, + { + "path": "./packages/common/doc-storage" + }, // Tools { "path": "./tools/cli" diff --git a/yarn.lock b/yarn.lock index cf1963004df93..297009c8db2f6 100644 --- a/yarn.lock +++ b/yarn.lock @@ -469,6 +469,19 @@ __metadata: languageName: unknown linkType: soft +"@affine/doc-storage@workspace:packages/common/doc-storage": + version: 0.0.0-use.local + resolution: "@affine/doc-storage@workspace:packages/common/doc-storage" + dependencies: + "@affine/native": "workspace:*" + "@types/lodash-es": "npm:^4.17.12" + idb: "npm:^8.0.0" + lodash-es: "npm:^4.17.21" + socket.io-client: "npm:^4.7.5" + yjs: "patch:yjs@npm%3A13.6.18#~/.yarn/patches/yjs-npm-13.6.18-ad0d5f7c43.patch" + languageName: unknown + linkType: soft + "@affine/docs@workspace:docs/reference": version: 0.0.0-use.local resolution: "@affine/docs@workspace:docs/reference"