Skip to content

Commit

Permalink
feat: history cleanup (#7007)
Browse files Browse the repository at this point in the history
  • Loading branch information
darkskygit committed May 24, 2024
1 parent 02564a8 commit 937b8bf
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 79 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- AlterTable
ALTER TABLE "ai_sessions_metadata" ADD COLUMN "deleted_at" TIMESTAMPTZ(6),
ADD COLUMN "messageCost" INTEGER NOT NULL DEFAULT 0,
ADD COLUMN "tokenCost" INTEGER NOT NULL DEFAULT 0;
2 changes: 1 addition & 1 deletion packages/backend/server/migrations/migration_lock.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Please do not edit this file manually
# It should be added in your version-control system (i.e. Git)
provider = "postgresql"
provider = "postgresql"
15 changes: 9 additions & 6 deletions packages/backend/server/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -480,12 +480,15 @@ model AiSessionMessage {
}

model AiSession {
id String @id @default(uuid()) @db.VarChar(36)
userId String @map("user_id") @db.VarChar(36)
workspaceId String @map("workspace_id") @db.VarChar(36)
docId String @map("doc_id") @db.VarChar(36)
promptName String @map("prompt_name") @db.VarChar(32)
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
id String @id @default(uuid()) @db.VarChar(36)
userId String @map("user_id") @db.VarChar(36)
workspaceId String @map("workspace_id") @db.VarChar(36)
docId String @map("doc_id") @db.VarChar(36)
promptName String @map("prompt_name") @db.VarChar(32)
messageCost Int @default(0)
tokenCost Int @default(0)
createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6)
deletedAt DateTime? @map("deleted_at") @db.Timestamptz(6)
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
prompt AiPrompt @relation(fields: [promptName], references: [name], onDelete: Cascade)
Expand Down
43 changes: 42 additions & 1 deletion packages/backend/server/src/plugins/copilot/resolver.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { createHash } from 'node:crypto';

import { BadRequestException, Logger } from '@nestjs/common';
import { BadRequestException, Logger, NotFoundException } from '@nestjs/common';
import {
Args,
Field,
Expand Down Expand Up @@ -55,6 +55,18 @@ class CreateChatSessionInput {
promptName!: string;
}

@InputType()
class DeleteSessionInput {
@Field(() => String)
workspaceId!: string;

@Field(() => String)
docId!: string;

@Field(() => [String])
sessionIds!: string[];
}

@InputType()
class CreateChatMessageInput implements Omit<SubmittedMessage, 'content'> {
@Field(() => String)
Expand Down Expand Up @@ -264,6 +276,35 @@ export class CopilotResolver {
return session;
}

@Mutation(() => String, {
description: 'Cleanup sessions',
})
async cleanupCopilotSession(
@CurrentUser() user: CurrentUser,
@Args({ name: 'options', type: () => DeleteSessionInput })
options: DeleteSessionInput
) {
await this.permissions.checkCloudPagePermission(
options.workspaceId,
options.docId,
user.id
);
if (!options.sessionIds.length) {
return new NotFoundException('Session not found');
}
const lockFlag = `${COPILOT_LOCKER}:session:${user.id}:${options.workspaceId}`;
await using lock = await this.mutex.lock(lockFlag);
if (!lock) {
return new TooManyRequestsException('Server is busy');
}

const ret = await this.chatSession.cleanup({
...options,
userId: user.id,
});
return ret;
}

@Mutation(() => String, {
description: 'Create a chat message',
})
Expand Down
196 changes: 125 additions & 71 deletions packages/backend/server/src/plugins/copilot/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -186,16 +186,17 @@ export class ChatSessionService {

// find existing session if session is chat session
if (!state.prompt.action) {
const { id } =
const { id, deletedAt } =
(await tx.aiSession.findFirst({
where: {
userId: state.userId,
workspaceId: state.workspaceId,
docId: state.docId,
prompt: { action: { equals: null } },
},
select: { id: true },
select: { id: true, deletedAt: true },
})) || {};
if (deletedAt) throw new Error(`Session is deleted: ${id}`);
if (id) sessionId = id;
}

Expand All @@ -219,6 +220,21 @@ export class ChatSessionService {
sessionId,
})),
});

// only count message generated by user
const userMessages = state.messages.filter(m => m.role === 'user');
await tx.aiSession.update({
where: { id: sessionId },
data: {
messageCost: { increment: userMessages.length },
tokenCost: {
increment: this.calculateTokenSize(
userMessages,
state.prompt.model as AvailableModel
),
},
},
});
}
} else {
await tx.aiSession.create({
Expand All @@ -242,21 +258,15 @@ export class ChatSessionService {
): Promise<ChatSessionState | undefined> {
return await this.db.aiSession
.findUnique({
where: { id: sessionId },
where: { id: sessionId, deletedAt: null },
select: {
id: true,
userId: true,
workspaceId: true,
docId: true,
messages: {
select: {
role: true,
content: true,
createdAt: true,
},
orderBy: {
createdAt: 'asc',
},
select: { role: true, content: true, createdAt: true },
orderBy: { createdAt: 'asc' },
},
promptName: true,
},
Expand All @@ -283,9 +293,18 @@ export class ChatSessionService {
// after revert, we can retry the action
async revertLatestMessage(sessionId: string) {
await this.db.$transaction(async tx => {
const id = await tx.aiSession
.findUnique({
where: { id: sessionId, deletedAt: null },
select: { id: true },
})
.then(session => session?.id);
if (!id) {
throw new Error(`Session not found: ${sessionId}`);
}
const ids = await tx.aiSessionMessage
.findMany({
where: { sessionId },
where: { sessionId: id },
select: { id: true, role: true },
orderBy: { createdAt: 'asc' },
})
Expand All @@ -312,22 +331,14 @@ export class ChatSessionService {
.reduce((total, length) => total + length, 0);
}

private async countUserActions(userId: string): Promise<number> {
return await this.db.aiSession.count({
where: { userId, prompt: { action: { not: null } } },
});
}

private async countUserChats(userId: string): Promise<number> {
const chats = await this.db.aiSession.findMany({
where: { userId, prompt: { action: null } },
select: {
_count: {
select: { messages: { where: { role: AiPromptRole.user } } },
},
},
private async countUserMessages(userId: string): Promise<number> {
const sessions = await this.db.aiSession.findMany({
where: { userId },
select: { messageCost: true, prompt: { select: { action: true } } },
});
return chats.reduce((prev, chat) => prev + chat._count.messages, 0);
return sessions
.map(({ messageCost, prompt: { action } }) => (action ? 1 : messageCost))
.reduce((prev, cost) => prev + cost, 0);
}

async listSessions(
Expand All @@ -344,6 +355,7 @@ export class ChatSessionService {
prompt: {
action: options?.action ? { not: null } : null,
},
deletedAt: null,
},
select: { id: true },
})
Expand All @@ -367,10 +379,12 @@ export class ChatSessionService {
action: options?.action ? { not: null } : null,
},
id: options?.sessionId ? { equals: options.sessionId } : undefined,
deletedAt: null,
},
select: {
id: true,
promptName: true,
tokenCost: true,
createdAt: true,
messages: {
select: {
Expand All @@ -391,50 +405,48 @@ export class ChatSessionService {
})
.then(sessions =>
Promise.all(
sessions.map(async ({ id, promptName, messages, createdAt }) => {
try {
const ret = ChatMessageSchema.array().safeParse(messages);
if (ret.success) {
const prompt = await this.prompt.get(promptName);
if (!prompt) {
throw new Error(`Prompt not found: ${promptName}`);
}
const tokens = this.calculateTokenSize(
ret.data,
prompt.model as AvailableModel
);

// render system prompt
const preload = withPrompt
? prompt
.finish(ret.data[0]?.params || {}, id)
.filter(({ role }) => role !== 'system')
: [];

// `createdAt` is required for history sorting in frontend, let's fake the creating time of prompt messages
(preload as ChatMessage[]).forEach((msg, i) => {
msg.createdAt = new Date(
createdAt.getTime() - preload.length - i - 1
sessions.map(
async ({ id, promptName, tokenCost, messages, createdAt }) => {
try {
const ret = ChatMessageSchema.array().safeParse(messages);
if (ret.success) {
const prompt = await this.prompt.get(promptName);
if (!prompt) {
throw new Error(`Prompt not found: ${promptName}`);
}

// render system prompt
const preload = withPrompt
? prompt
.finish(ret.data[0]?.params || {}, id)
.filter(({ role }) => role !== 'system')
: [];

// `createdAt` is required for history sorting in frontend, let's fake the creating time of prompt messages
(preload as ChatMessage[]).forEach((msg, i) => {
msg.createdAt = new Date(
createdAt.getTime() - preload.length - i - 1
);
});

return {
sessionId: id,
action: prompt.action || undefined,
tokens: tokenCost,
createdAt,
messages: preload.concat(ret.data),
};
} else {
this.logger.error(
`Unexpected message schema: ${JSON.stringify(ret.error)}`
);
});

return {
sessionId: id,
action: prompt.action || undefined,
tokens,
createdAt,
messages: preload.concat(ret.data),
};
} else {
this.logger.error(
`Unexpected message schema: ${JSON.stringify(ret.error)}`
);
}
} catch (e) {
this.logger.error('Unexpected error in listHistories', e);
}
} catch (e) {
this.logger.error('Unexpected error in listHistories', e);
return undefined;
}
return undefined;
})
)
)
)
.then(histories =>
Expand All @@ -451,10 +463,9 @@ export class ChatSessionService {
limit = quota.feature.copilotActionLimit;
}

const actions = await this.countUserActions(userId);
const chats = await this.countUserChats(userId);
const used = await this.countUserMessages(userId);

return { limit, used: actions + chats };
return { limit, used };
}

async checkQuota(userId: string) {
Expand All @@ -481,6 +492,49 @@ export class ChatSessionService {
});
}

async cleanup(
options: Omit<ChatSessionOptions, 'promptName'> & { sessionIds: string[] }
) {
return await this.db.$transaction(async tx => {
const sessions = await tx.aiSession.findMany({
where: {
id: { in: options.sessionIds },
userId: options.userId,
workspaceId: options.workspaceId,
docId: options.docId,
deletedAt: null,
},
select: { id: true, promptName: true },
});
const sessionIds = sessions.map(({ id }) => id);
// cleanup all messages
await tx.aiSessionMessage.deleteMany({
where: { sessionId: { in: sessionIds } },
});

// only mark action session as deleted
// chat session always can be reuse
{
const actionIds = (
await Promise.all(
sessions.map(({ id, promptName }) =>
this.prompt
.get(promptName)
.then(prompt => ({ id, action: !!prompt?.action }))
)
)
)
.filter(({ action }) => action)
.map(({ id }) => id);

await tx.aiSession.updateMany({
where: { id: { in: actionIds } },
data: { deletedAt: new Date() },
});
}
});
}

async createMessage(message: SubmittedMessage): Promise<string | undefined> {
return await this.messageCache.set(message);
}
Expand Down
Loading

0 comments on commit 937b8bf

Please sign in to comment.