From 937b8bf166ef3cc2cb3f1c90c98ea46d23dc9863 Mon Sep 17 00:00:00 2001 From: darkskygit Date: Fri, 24 May 2024 08:00:05 +0000 Subject: [PATCH] feat: history cleanup (#7007) fix AFF-1069 --- .../migration.sql | 4 + .../server/migrations/migration_lock.toml | 2 +- packages/backend/server/schema.prisma | 15 +- .../server/src/plugins/copilot/resolver.ts | 43 +++- .../server/src/plugins/copilot/session.ts | 196 +++++++++++------- packages/backend/server/src/schema.gql | 9 + 6 files changed, 190 insertions(+), 79 deletions(-) create mode 100644 packages/backend/server/migrations/20240521100307_add_copilot_cost/migration.sql diff --git a/packages/backend/server/migrations/20240521100307_add_copilot_cost/migration.sql b/packages/backend/server/migrations/20240521100307_add_copilot_cost/migration.sql new file mode 100644 index 0000000000000..9b22ad31680f1 --- /dev/null +++ b/packages/backend/server/migrations/20240521100307_add_copilot_cost/migration.sql @@ -0,0 +1,4 @@ +-- AlterTable +ALTER TABLE "ai_sessions_metadata" ADD COLUMN "deleted_at" TIMESTAMPTZ(6), +ADD COLUMN "messageCost" INTEGER NOT NULL DEFAULT 0, +ADD COLUMN "tokenCost" INTEGER NOT NULL DEFAULT 0; diff --git a/packages/backend/server/migrations/migration_lock.toml b/packages/backend/server/migrations/migration_lock.toml index 99e4f20090794..fbffa92c2bb7c 100644 --- a/packages/backend/server/migrations/migration_lock.toml +++ b/packages/backend/server/migrations/migration_lock.toml @@ -1,3 +1,3 @@ # Please do not edit this file manually # It should be added in your version-control system (i.e. Git) -provider = "postgresql" +provider = "postgresql" \ No newline at end of file diff --git a/packages/backend/server/schema.prisma b/packages/backend/server/schema.prisma index f4293043065ac..3312f795ec81c 100644 --- a/packages/backend/server/schema.prisma +++ b/packages/backend/server/schema.prisma @@ -480,12 +480,15 @@ model AiSessionMessage { } model AiSession { - id String @id @default(uuid()) @db.VarChar(36) - userId String @map("user_id") @db.VarChar(36) - workspaceId String @map("workspace_id") @db.VarChar(36) - docId String @map("doc_id") @db.VarChar(36) - promptName String @map("prompt_name") @db.VarChar(32) - createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6) + id String @id @default(uuid()) @db.VarChar(36) + userId String @map("user_id") @db.VarChar(36) + workspaceId String @map("workspace_id") @db.VarChar(36) + docId String @map("doc_id") @db.VarChar(36) + promptName String @map("prompt_name") @db.VarChar(32) + messageCost Int @default(0) + tokenCost Int @default(0) + createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6) + deletedAt DateTime? @map("deleted_at") @db.Timestamptz(6) user User @relation(fields: [userId], references: [id], onDelete: Cascade) prompt AiPrompt @relation(fields: [promptName], references: [name], onDelete: Cascade) diff --git a/packages/backend/server/src/plugins/copilot/resolver.ts b/packages/backend/server/src/plugins/copilot/resolver.ts index e003ed45a2d4b..582080021a06c 100644 --- a/packages/backend/server/src/plugins/copilot/resolver.ts +++ b/packages/backend/server/src/plugins/copilot/resolver.ts @@ -1,6 +1,6 @@ import { createHash } from 'node:crypto'; -import { BadRequestException, Logger } from '@nestjs/common'; +import { BadRequestException, Logger, NotFoundException } from '@nestjs/common'; import { Args, Field, @@ -55,6 +55,18 @@ class CreateChatSessionInput { promptName!: string; } +@InputType() +class DeleteSessionInput { + @Field(() => String) + workspaceId!: string; + + @Field(() => String) + docId!: string; + + @Field(() => [String]) + sessionIds!: string[]; +} + @InputType() class CreateChatMessageInput implements Omit { @Field(() => String) @@ -264,6 +276,35 @@ export class CopilotResolver { return session; } + @Mutation(() => String, { + description: 'Cleanup sessions', + }) + async cleanupCopilotSession( + @CurrentUser() user: CurrentUser, + @Args({ name: 'options', type: () => DeleteSessionInput }) + options: DeleteSessionInput + ) { + await this.permissions.checkCloudPagePermission( + options.workspaceId, + options.docId, + user.id + ); + if (!options.sessionIds.length) { + return new NotFoundException('Session not found'); + } + const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${options.workspaceId}`; + await using lock = await this.mutex.lock(lockFlag); + if (!lock) { + return new TooManyRequestsException('Server is busy'); + } + + const ret = await this.chatSession.cleanup({ + ...options, + userId: user.id, + }); + return ret; + } + @Mutation(() => String, { description: 'Create a chat message', }) diff --git a/packages/backend/server/src/plugins/copilot/session.ts b/packages/backend/server/src/plugins/copilot/session.ts index 6df6b255c1db1..70722e2a2c114 100644 --- a/packages/backend/server/src/plugins/copilot/session.ts +++ b/packages/backend/server/src/plugins/copilot/session.ts @@ -186,7 +186,7 @@ export class ChatSessionService { // find existing session if session is chat session if (!state.prompt.action) { - const { id } = + const { id, deletedAt } = (await tx.aiSession.findFirst({ where: { userId: state.userId, @@ -194,8 +194,9 @@ export class ChatSessionService { docId: state.docId, prompt: { action: { equals: null } }, }, - select: { id: true }, + select: { id: true, deletedAt: true }, })) || {}; + if (deletedAt) throw new Error(`Session is deleted: ${id}`); if (id) sessionId = id; } @@ -219,6 +220,21 @@ export class ChatSessionService { sessionId, })), }); + + // only count message generated by user + const userMessages = state.messages.filter(m => m.role === 'user'); + await tx.aiSession.update({ + where: { id: sessionId }, + data: { + messageCost: { increment: userMessages.length }, + tokenCost: { + increment: this.calculateTokenSize( + userMessages, + state.prompt.model as AvailableModel + ), + }, + }, + }); } } else { await tx.aiSession.create({ @@ -242,21 +258,15 @@ export class ChatSessionService { ): Promise { return await this.db.aiSession .findUnique({ - where: { id: sessionId }, + where: { id: sessionId, deletedAt: null }, select: { id: true, userId: true, workspaceId: true, docId: true, messages: { - select: { - role: true, - content: true, - createdAt: true, - }, - orderBy: { - createdAt: 'asc', - }, + select: { role: true, content: true, createdAt: true }, + orderBy: { createdAt: 'asc' }, }, promptName: true, }, @@ -283,9 +293,18 @@ export class ChatSessionService { // after revert, we can retry the action async revertLatestMessage(sessionId: string) { await this.db.$transaction(async tx => { + const id = await tx.aiSession + .findUnique({ + where: { id: sessionId, deletedAt: null }, + select: { id: true }, + }) + .then(session => session?.id); + if (!id) { + throw new Error(`Session not found: ${sessionId}`); + } const ids = await tx.aiSessionMessage .findMany({ - where: { sessionId }, + where: { sessionId: id }, select: { id: true, role: true }, orderBy: { createdAt: 'asc' }, }) @@ -312,22 +331,14 @@ export class ChatSessionService { .reduce((total, length) => total + length, 0); } - private async countUserActions(userId: string): Promise { - return await this.db.aiSession.count({ - where: { userId, prompt: { action: { not: null } } }, - }); - } - - private async countUserChats(userId: string): Promise { - const chats = await this.db.aiSession.findMany({ - where: { userId, prompt: { action: null } }, - select: { - _count: { - select: { messages: { where: { role: AiPromptRole.user } } }, - }, - }, + private async countUserMessages(userId: string): Promise { + const sessions = await this.db.aiSession.findMany({ + where: { userId }, + select: { messageCost: true, prompt: { select: { action: true } } }, }); - return chats.reduce((prev, chat) => prev + chat._count.messages, 0); + return sessions + .map(({ messageCost, prompt: { action } }) => (action ? 1 : messageCost)) + .reduce((prev, cost) => prev + cost, 0); } async listSessions( @@ -344,6 +355,7 @@ export class ChatSessionService { prompt: { action: options?.action ? { not: null } : null, }, + deletedAt: null, }, select: { id: true }, }) @@ -367,10 +379,12 @@ export class ChatSessionService { action: options?.action ? { not: null } : null, }, id: options?.sessionId ? { equals: options.sessionId } : undefined, + deletedAt: null, }, select: { id: true, promptName: true, + tokenCost: true, createdAt: true, messages: { select: { @@ -391,50 +405,48 @@ export class ChatSessionService { }) .then(sessions => Promise.all( - sessions.map(async ({ id, promptName, messages, createdAt }) => { - try { - const ret = ChatMessageSchema.array().safeParse(messages); - if (ret.success) { - const prompt = await this.prompt.get(promptName); - if (!prompt) { - throw new Error(`Prompt not found: ${promptName}`); - } - const tokens = this.calculateTokenSize( - ret.data, - prompt.model as AvailableModel - ); - - // render system prompt - const preload = withPrompt - ? prompt - .finish(ret.data[0]?.params || {}, id) - .filter(({ role }) => role !== 'system') - : []; - - // `createdAt` is required for history sorting in frontend, let's fake the creating time of prompt messages - (preload as ChatMessage[]).forEach((msg, i) => { - msg.createdAt = new Date( - createdAt.getTime() - preload.length - i - 1 + sessions.map( + async ({ id, promptName, tokenCost, messages, createdAt }) => { + try { + const ret = ChatMessageSchema.array().safeParse(messages); + if (ret.success) { + const prompt = await this.prompt.get(promptName); + if (!prompt) { + throw new Error(`Prompt not found: ${promptName}`); + } + + // render system prompt + const preload = withPrompt + ? prompt + .finish(ret.data[0]?.params || {}, id) + .filter(({ role }) => role !== 'system') + : []; + + // `createdAt` is required for history sorting in frontend, let's fake the creating time of prompt messages + (preload as ChatMessage[]).forEach((msg, i) => { + msg.createdAt = new Date( + createdAt.getTime() - preload.length - i - 1 + ); + }); + + return { + sessionId: id, + action: prompt.action || undefined, + tokens: tokenCost, + createdAt, + messages: preload.concat(ret.data), + }; + } else { + this.logger.error( + `Unexpected message schema: ${JSON.stringify(ret.error)}` ); - }); - - return { - sessionId: id, - action: prompt.action || undefined, - tokens, - createdAt, - messages: preload.concat(ret.data), - }; - } else { - this.logger.error( - `Unexpected message schema: ${JSON.stringify(ret.error)}` - ); + } + } catch (e) { + this.logger.error('Unexpected error in listHistories', e); } - } catch (e) { - this.logger.error('Unexpected error in listHistories', e); + return undefined; } - return undefined; - }) + ) ) ) .then(histories => @@ -451,10 +463,9 @@ export class ChatSessionService { limit = quota.feature.copilotActionLimit; } - const actions = await this.countUserActions(userId); - const chats = await this.countUserChats(userId); + const used = await this.countUserMessages(userId); - return { limit, used: actions + chats }; + return { limit, used }; } async checkQuota(userId: string) { @@ -481,6 +492,49 @@ export class ChatSessionService { }); } + async cleanup( + options: Omit & { sessionIds: string[] } + ) { + return await this.db.$transaction(async tx => { + const sessions = await tx.aiSession.findMany({ + where: { + id: { in: options.sessionIds }, + userId: options.userId, + workspaceId: options.workspaceId, + docId: options.docId, + deletedAt: null, + }, + select: { id: true, promptName: true }, + }); + const sessionIds = sessions.map(({ id }) => id); + // cleanup all messages + await tx.aiSessionMessage.deleteMany({ + where: { sessionId: { in: sessionIds } }, + }); + + // only mark action session as deleted + // chat session always can be reuse + { + const actionIds = ( + await Promise.all( + sessions.map(({ id, promptName }) => + this.prompt + .get(promptName) + .then(prompt => ({ id, action: !!prompt?.action })) + ) + ) + ) + .filter(({ action }) => action) + .map(({ id }) => id); + + await tx.aiSession.updateMany({ + where: { id: { in: actionIds } }, + data: { deletedAt: new Date() }, + }); + } + }); + } + async createMessage(message: SubmittedMessage): Promise { return await this.messageCache.set(message); } diff --git a/packages/backend/server/src/schema.gql b/packages/backend/server/src/schema.gql index 303befc831dc5..cc1fc9a1a1cb8 100644 --- a/packages/backend/server/src/schema.gql +++ b/packages/backend/server/src/schema.gql @@ -76,6 +76,12 @@ type DeleteAccount { success: Boolean! } +input DeleteSessionInput { + docId: String! + sessionIds: [String!]! + workspaceId: String! +} + type DocHistoryType { id: String! timestamp: DateTime! @@ -184,6 +190,9 @@ type Mutation { changeEmail(email: String!, token: String!): UserType! changePassword(newPassword: String!, token: String!): UserType! + """Cleanup sessions""" + cleanupCopilotSession(options: DeleteSessionInput!): String! + """Create a subscription checkout link of stripe""" createCheckoutSession(input: CreateCheckoutSessionInput!): String!