From d341452057f00f080d3b1923729505c40ab97a84 Mon Sep 17 00:00:00 2001 From: James Date: Tue, 21 May 2024 16:10:33 +0700 Subject: [PATCH] chore: clean up chat stream --- cortex-js/src/command.module.ts | 1 - cortex-js/src/command.ts | 2 +- .../src/domain/abstracts/engine.abstract.ts | 5 +- .../src/domain/abstracts/oai.abstract.ts | 151 ++++-------------- .../infrastructure/commanders/chat.command.ts | 13 +- .../commanders/prompt-constants.ts | 2 +- .../commanders/serve.command.ts | 8 +- .../commanders/shortcuts/run.command.ts | 9 +- .../commanders/usecases/chat.cli.usecases.ts | 58 ++++--- .../usecases/cli.usecases.module.ts | 10 +- .../controllers/chat.controller.ts | 37 ++--- .../dtos/chat/chat-completion-response.dto.ts | 26 +++ .../infrastructure/dtos/chat/choice.dto.ts | 13 ++ .../infrastructure/dtos/chat/message.dto.ts | 9 ++ .../src/infrastructure/dtos/chat/usage.dto.ts | 12 ++ cortex-js/src/usecases/chat/chat.usecases.ts | 46 +++++- .../src/usecases/models/models.usecases.ts | 2 - 17 files changed, 206 insertions(+), 198 deletions(-) create mode 100644 cortex-js/src/infrastructure/dtos/chat/chat-completion-response.dto.ts create mode 100644 cortex-js/src/infrastructure/dtos/chat/choice.dto.ts create mode 100644 cortex-js/src/infrastructure/dtos/chat/message.dto.ts create mode 100644 cortex-js/src/infrastructure/dtos/chat/usage.dto.ts diff --git a/cortex-js/src/command.module.ts b/cortex-js/src/command.module.ts index 204bf1887..4b6c92ffd 100644 --- a/cortex-js/src/command.module.ts +++ b/cortex-js/src/command.module.ts @@ -33,7 +33,6 @@ import { ModelUpdateCommand } from './infrastructure/commanders/models/model-upd DatabaseModule, ModelsModule, CortexModule, - ChatModule, ExtensionModule, HttpModule, CliUsecasesModule, diff --git a/cortex-js/src/command.ts b/cortex-js/src/command.ts index 9f3ff4494..c71a1f934 100644 --- a/cortex-js/src/command.ts +++ b/cortex-js/src/command.ts @@ -3,7 +3,7 @@ import { CommandFactory } from 'nest-commander'; import { CommandModule } from './command.module'; async function bootstrap() { - await CommandFactory.run(CommandModule); + await CommandFactory.run(CommandModule, ['warn', 'error']); } bootstrap(); diff --git a/cortex-js/src/domain/abstracts/engine.abstract.ts b/cortex-js/src/domain/abstracts/engine.abstract.ts index f21f6664b..14f334140 100644 --- a/cortex-js/src/domain/abstracts/engine.abstract.ts +++ b/cortex-js/src/domain/abstracts/engine.abstract.ts @@ -1,11 +1,14 @@ /* eslint-disable no-unused-vars, @typescript-eslint/no-unused-vars */ +import stream from 'stream'; import { Model, ModelSettingParams } from '../models/model.interface'; import { Extension } from './extension.abstract'; export abstract class EngineExtension extends Extension { abstract provider: string; - abstract inference(completion: any, req: any, stream: any, res?: any): void; + abstract inference(dto: any, headers: Record): Promise; + + abstract inferenceStream(dto: any, headers: any): Promise; async loadModel( model: Model, diff --git a/cortex-js/src/domain/abstracts/oai.abstract.ts b/cortex-js/src/domain/abstracts/oai.abstract.ts index 2923c4277..d12360f67 100644 --- a/cortex-js/src/domain/abstracts/oai.abstract.ts +++ b/cortex-js/src/domain/abstracts/oai.abstract.ts @@ -1,12 +1,6 @@ import { HttpService } from '@nestjs/axios'; import { EngineExtension } from './engine.abstract'; -import { stdout } from 'process'; - -export type ChatStreamEvent = { - type: 'data' | 'error' | 'end'; - data?: any; - error?: any; -}; +import stream from 'stream'; export abstract class OAIEngineExtension extends EngineExtension { abstract apiUrl: string; @@ -15,120 +9,43 @@ export abstract class OAIEngineExtension extends EngineExtension { super(); } - inference( + override async inferenceStream( createChatDto: any, headers: Record, - writableStream: WritableStream, - res?: any, - ) { - if (createChatDto.stream === true) { - if (res) { - res.writeHead(200, { - 'Content-Type': 'text/event-stream', - 'Cache-Control': 'no-cache', - Connection: 'keep-alive', - 'Access-Control-Allow-Origin': '*', - }); - this.httpService - .post(this.apiUrl, createChatDto, { - headers: { - 'Content-Type': headers['content-type'] ?? 'application/json', - Authorization: headers['authorization'], - }, - responseType: 'stream', - }) - .toPromise() - .then((response) => { - response?.data.pipe(res); - }); - } else { - const decoder = new TextDecoder('utf-8'); - const defaultWriter = writableStream.getWriter(); - defaultWriter.ready.then(() => { - this.httpService - .post(this.apiUrl, createChatDto, { - headers: { - 'Content-Type': headers['content-type'] ?? 'application/json', - Authorization: headers['authorization'], - }, - responseType: 'stream', - }) - .subscribe({ - next: (response) => { - response.data.on('data', (chunk: any) => { - let content = ''; - const text = decoder.decode(chunk); - const lines = text.trim().split('\n'); - let cachedLines = ''; - for (const line of lines) { - try { - const toParse = cachedLines + line; - if (!line.includes('data: [DONE]')) { - const data = JSON.parse(toParse.replace('data: ', '')); - content += data.choices[0]?.delta?.content ?? ''; - - if (content.startsWith('assistant: ')) { - content = content.replace('assistant: ', ''); - } - - if (content !== '') { - defaultWriter.write({ - type: 'data', - data: content, - }); - } - } - } catch { - cachedLines = line; - } - } - }); - - response.data.on('error', (error: any) => { - defaultWriter.write({ - type: 'error', - error, - }); - }); + ): Promise { + const response = await this.httpService + .post(this.apiUrl, createChatDto, { + headers: { + 'Content-Type': headers['content-type'] ?? 'application/json', + Authorization: headers['authorization'], + }, + responseType: 'stream', + }) + .toPromise(); + + if (!response) { + throw new Error('No response'); + } - response.data.on('end', () => { - // stdout.write('Stream end'); - defaultWriter.write({ - type: 'end', - }); - }); - }, + return response.data; + } - error: (error) => { - stdout.write('Stream error: ' + error); - }, - }); - }); - } - } else { - const defaultWriter = writableStream.getWriter(); - defaultWriter.ready.then(() => { - this.httpService - .post(this.apiUrl, createChatDto, { - headers: { - 'Content-Type': headers['content-type'] ?? 'application/json', - Authorization: headers['authorization'], - }, - }) - .toPromise() - .then((response) => { - defaultWriter.write({ - type: 'data', - data: response?.data, - }); - }) - .catch((error: any) => { - defaultWriter.write({ - type: 'error', - error, - }); - }); - }); + override async inference( + createChatDto: any, + headers: Record, + ): Promise { + const response = await this.httpService + .post(this.apiUrl, createChatDto, { + headers: { + 'Content-Type': headers['content-type'] ?? 'application/json', + Authorization: headers['authorization'], + }, + }) + .toPromise(); + if (!response) { + throw new Error('No response'); } + + return response.data; } } diff --git a/cortex-js/src/infrastructure/commanders/chat.command.ts b/cortex-js/src/infrastructure/commanders/chat.command.ts index faffa3ede..8295922e9 100644 --- a/cortex-js/src/infrastructure/commanders/chat.command.ts +++ b/cortex-js/src/infrastructure/commanders/chat.command.ts @@ -1,7 +1,5 @@ -import { ChatUsecases } from '@/usecases/chat/chat.usecases'; import { CommandRunner, SubCommand, Option } from 'nest-commander'; import { ChatCliUsecases } from './usecases/chat.cli.usecases'; -import { CortexUsecases } from '@/usecases/cortex/cortex.usecases'; import { exit } from 'node:process'; type ChatOptions = { @@ -10,10 +8,7 @@ type ChatOptions = { @SubCommand({ name: 'chat', description: 'Start a chat with a model' }) export class ChatCommand extends CommandRunner { - constructor( - private readonly chatUsecases: ChatUsecases, - private readonly cortexUsecases: CortexUsecases, - ) { + constructor(private readonly chatCliUsecases: ChatCliUsecases) { super(); } @@ -24,11 +19,7 @@ export class ChatCommand extends CommandRunner { exit(1); } - const chatCliUsecases = new ChatCliUsecases( - this.chatUsecases, - this.cortexUsecases, - ); - return chatCliUsecases.chat(modelId); + return this.chatCliUsecases.chat(modelId); } @Option({ diff --git a/cortex-js/src/infrastructure/commanders/prompt-constants.ts b/cortex-js/src/infrastructure/commanders/prompt-constants.ts index 969b24f43..585909b92 100644 --- a/cortex-js/src/infrastructure/commanders/prompt-constants.ts +++ b/cortex-js/src/infrastructure/commanders/prompt-constants.ts @@ -1,5 +1,5 @@ //// HF Chat template -export const OPEN_CHAT_3_5_JINJA = ``; +export const OPEN_CHAT_3_5_JINJA = `{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}`; export const ZEPHYR_JINJA = `{% for message in messages %} {% if message['role'] == 'user' %} diff --git a/cortex-js/src/infrastructure/commanders/serve.command.ts b/cortex-js/src/infrastructure/commanders/serve.command.ts index 6af783c76..7e49ad590 100644 --- a/cortex-js/src/infrastructure/commanders/serve.command.ts +++ b/cortex-js/src/infrastructure/commanders/serve.command.ts @@ -13,10 +13,6 @@ type ServeOptions = { description: 'Providing API endpoint for Cortex backend', }) export class ServeCommand extends CommandRunner { - constructor() { - super(); - } - async run(_input: string[], options?: ServeOptions): Promise { const host = options?.host || defaultCortexJsHost; const port = options?.port || defaultCortexJsPort; @@ -34,7 +30,7 @@ export class ServeCommand extends CommandRunner { } @Option({ - flags: '--host ', + flags: '-h, --host ', description: 'Host to serve the application', }) parseHost(value: string) { @@ -42,7 +38,7 @@ export class ServeCommand extends CommandRunner { } @Option({ - flags: '--port ', + flags: '-p, --port ', description: 'Port to serve the application', }) parsePort(value: string) { diff --git a/cortex-js/src/infrastructure/commanders/shortcuts/run.command.ts b/cortex-js/src/infrastructure/commanders/shortcuts/run.command.ts index e0f9b0152..42b68631a 100644 --- a/cortex-js/src/infrastructure/commanders/shortcuts/run.command.ts +++ b/cortex-js/src/infrastructure/commanders/shortcuts/run.command.ts @@ -2,7 +2,6 @@ import { CortexUsecases } from '@/usecases/cortex/cortex.usecases'; import { ModelsUsecases } from '@/usecases/models/models.usecases'; import { CommandRunner, SubCommand, Option } from 'nest-commander'; import { exit } from 'node:process'; -import { ChatUsecases } from '@/usecases/chat/chat.usecases'; import { ChatCliUsecases } from '../usecases/chat.cli.usecases'; import { defaultCortexCppHost, defaultCortexCppPort } from 'constant'; @@ -18,7 +17,7 @@ export class RunCommand extends CommandRunner { constructor( private readonly modelsUsecases: ModelsUsecases, private readonly cortexUsecases: CortexUsecases, - private readonly chatUsecases: ChatUsecases, + private readonly chatCliUsecases: ChatCliUsecases, ) { super(); } @@ -36,11 +35,7 @@ export class RunCommand extends CommandRunner { false, ); await this.modelsUsecases.startModel(modelId); - const chatCliUsecases = new ChatCliUsecases( - this.chatUsecases, - this.cortexUsecases, - ); - await chatCliUsecases.chat(modelId); + await this.chatCliUsecases.chat(modelId); } @Option({ diff --git a/cortex-js/src/infrastructure/commanders/usecases/chat.cli.usecases.ts b/cortex-js/src/infrastructure/commanders/usecases/chat.cli.usecases.ts index 9f7409cca..ac92ff7b9 100644 --- a/cortex-js/src/infrastructure/commanders/usecases/chat.cli.usecases.ts +++ b/cortex-js/src/infrastructure/commanders/usecases/chat.cli.usecases.ts @@ -2,12 +2,12 @@ import { ChatUsecases } from '@/usecases/chat/chat.usecases'; import { ChatCompletionRole } from '@/domain/models/message.interface'; import { exit, stdin, stdout } from 'node:process'; import * as readline from 'node:readline/promises'; -import { ChatStreamEvent } from '@/domain/abstracts/oai.abstract'; import { ChatCompletionMessage } from '@/infrastructure/dtos/chat/chat-completion-message.dto'; import { CreateChatCompletionDto } from '@/infrastructure/dtos/chat/create-chat-completion.dto'; import { CortexUsecases } from '@/usecases/cortex/cortex.usecases'; +import { Injectable } from '@nestjs/common'; -// TODO: make this class injectable +@Injectable() export class ChatCliUsecases { private exitClause = 'exit()'; private userIndicator = '>> '; @@ -59,26 +59,44 @@ export class ChatCliUsecases { top_p: 0.7, }; - let llmFullResponse = ''; - const writableStream = new WritableStream({ - write(chunk) { - if (chunk.type === 'data') { - stdout.write(chunk.data ?? ''); - llmFullResponse += chunk.data ?? ''; - } else if (chunk.type === 'error') { - console.log('Error!!'); - } else { - messages.push({ - content: llmFullResponse, - role: ChatCompletionRole.Assistant, - }); - llmFullResponse = ''; - console.log('\n'); + const decoder = new TextDecoder('utf-8'); + this.chatUsecases.inferenceStream(chatDto, {}).then((response) => { + response.on('error', (error) => { + console.error(error); + rl.prompt(); + }); + + response.on('end', () => { + console.log('\n'); + rl.prompt(); + }); + + response.on('data', (chunk) => { + let content = ''; + const text = decoder.decode(chunk); + const lines = text.trim().split('\n'); + let cachedLines = ''; + for (const line of lines) { + try { + const toParse = cachedLines + line; + if (!line.includes('data: [DONE]')) { + const data = JSON.parse(toParse.replace('data: ', '')); + content += data.choices[0]?.delta?.content ?? ''; + + if (content.startsWith('assistant: ')) { + content = content.replace('assistant: ', ''); + } + + if (content.trim().length > 0) { + stdout.write(content); + } + } + } catch { + cachedLines = line; + } } - }, + }); }); - - this.chatUsecases.createChatCompletions(chatDto, {}, writableStream); }); } } diff --git a/cortex-js/src/infrastructure/commanders/usecases/cli.usecases.module.ts b/cortex-js/src/infrastructure/commanders/usecases/cli.usecases.module.ts index a82b60dd0..bb1dc7eeb 100644 --- a/cortex-js/src/infrastructure/commanders/usecases/cli.usecases.module.ts +++ b/cortex-js/src/infrastructure/commanders/usecases/cli.usecases.module.ts @@ -3,11 +3,13 @@ import { InitCliUsecases } from './init.cli.usecases'; import { HttpModule } from '@nestjs/axios'; import { ModelsCliUsecases } from './models.cli.usecases'; import { ModelsModule } from '@/usecases/models/models.module'; +import { ChatCliUsecases } from './chat.cli.usecases'; +import { ChatModule } from '@/usecases/chat/chat.module'; +import { CortexModule } from '@/usecases/cortex/cortex.module'; @Module({ - imports: [HttpModule, ModelsModule], - controllers: [], - providers: [InitCliUsecases, ModelsCliUsecases], - exports: [InitCliUsecases, ModelsCliUsecases], + imports: [HttpModule, ModelsModule, ChatModule, CortexModule], + providers: [InitCliUsecases, ModelsCliUsecases, ChatCliUsecases], + exports: [InitCliUsecases, ModelsCliUsecases, ChatCliUsecases], }) export class CliUsecasesModule {} diff --git a/cortex-js/src/infrastructure/controllers/chat.controller.ts b/cortex-js/src/infrastructure/controllers/chat.controller.ts index e9c50591e..d9c664915 100644 --- a/cortex-js/src/infrastructure/controllers/chat.controller.ts +++ b/cortex-js/src/infrastructure/controllers/chat.controller.ts @@ -1,38 +1,35 @@ -import { Body, Controller, Post, Headers, Res } from '@nestjs/common'; +import { Body, Controller, Post, Headers, Res, HttpCode } from '@nestjs/common'; import { CreateChatCompletionDto } from '@/infrastructure/dtos/chat/create-chat-completion.dto'; import { ChatUsecases } from '@/usecases/chat/chat.usecases'; import { Response } from 'express'; -import { ApiTags } from '@nestjs/swagger'; -import { ChatStreamEvent } from '@/domain/abstracts/oai.abstract'; +import { ApiResponse, ApiTags } from '@nestjs/swagger'; +import { ChatCompletionResponseDto } from '../dtos/chat/chat-completion-response.dto'; @ApiTags('Inference') @Controller('chat') export class ChatController { constructor(private readonly chatService: ChatUsecases) {} + @HttpCode(200) + @ApiResponse({ + status: 200, + description: 'Chat completion response successfully', + type: ChatCompletionResponseDto, + }) @Post('completions') async create( @Headers() headers: Record, @Body() createChatDto: CreateChatCompletionDto, @Res() res: Response, ) { - const writableStream = new WritableStream({ - write(chunk) { - if (chunk.type === 'data') { - res.json(chunk.data ?? {}); - } else if (chunk.type === 'error') { - res.json(chunk.error ?? {}); - } else { - console.log('\n'); - } - }, - }); + const { stream } = createChatDto; - this.chatService.createChatCompletions( - createChatDto, - headers, - writableStream, - res, - ); + if (stream) { + this.chatService + .inferenceStream(createChatDto, headers) + .then((stream) => stream.pipe(res)); + } else { + res.json(await this.chatService.inference(createChatDto, headers)); + } } } diff --git a/cortex-js/src/infrastructure/dtos/chat/chat-completion-response.dto.ts b/cortex-js/src/infrastructure/dtos/chat/chat-completion-response.dto.ts new file mode 100644 index 000000000..a00fbc5a0 --- /dev/null +++ b/cortex-js/src/infrastructure/dtos/chat/chat-completion-response.dto.ts @@ -0,0 +1,26 @@ +import { ApiProperty } from '@nestjs/swagger'; +import { UsageDto } from './usage.dto'; +import { ChoiceDto } from './choice.dto'; + +export class ChatCompletionResponseDto { + @ApiProperty() + choices: ChoiceDto[]; + + @ApiProperty() + created: number; + + @ApiProperty() + id: string; + + @ApiProperty() + model: string; + + @ApiProperty() + object: string; + + @ApiProperty() + system_fingerprint: string; + + @ApiProperty() + usage: UsageDto; +} diff --git a/cortex-js/src/infrastructure/dtos/chat/choice.dto.ts b/cortex-js/src/infrastructure/dtos/chat/choice.dto.ts new file mode 100644 index 000000000..9a492dc57 --- /dev/null +++ b/cortex-js/src/infrastructure/dtos/chat/choice.dto.ts @@ -0,0 +1,13 @@ +import { ApiProperty } from '@nestjs/swagger'; +import { MessageDto } from './message.dto'; + +export class ChoiceDto { + @ApiProperty() + finish_reason: string; + + @ApiProperty() + index: number; + + @ApiProperty() + message: MessageDto; +} diff --git a/cortex-js/src/infrastructure/dtos/chat/message.dto.ts b/cortex-js/src/infrastructure/dtos/chat/message.dto.ts new file mode 100644 index 000000000..72fcaaef6 --- /dev/null +++ b/cortex-js/src/infrastructure/dtos/chat/message.dto.ts @@ -0,0 +1,9 @@ +import { ApiProperty } from '@nestjs/swagger'; + +export class MessageDto { + @ApiProperty() + content: string; + + @ApiProperty() + role: string; +} diff --git a/cortex-js/src/infrastructure/dtos/chat/usage.dto.ts b/cortex-js/src/infrastructure/dtos/chat/usage.dto.ts new file mode 100644 index 000000000..29354b785 --- /dev/null +++ b/cortex-js/src/infrastructure/dtos/chat/usage.dto.ts @@ -0,0 +1,12 @@ +import { ApiProperty } from '@nestjs/swagger'; + +export class UsageDto { + @ApiProperty() + completion_tokens: number; + + @ApiProperty() + prompt_tokens: number; + + @ApiProperty() + total_tokens: number; +} diff --git a/cortex-js/src/usecases/chat/chat.usecases.ts b/cortex-js/src/usecases/chat/chat.usecases.ts index 6386e57d8..72c44a1ed 100644 --- a/cortex-js/src/usecases/chat/chat.usecases.ts +++ b/cortex-js/src/usecases/chat/chat.usecases.ts @@ -4,7 +4,8 @@ import { ExtensionRepository } from '@/domain/repositories/extension.interface'; import { Repository } from 'typeorm'; import { ModelEntity } from '@/infrastructure/entities/model.entity'; import { EngineExtension } from '@/domain/abstracts/engine.abstract'; -import { ChatStreamEvent } from '@/domain/abstracts/oai.abstract'; +import stream from 'stream'; +import { ModelNotFoundException } from '@/infrastructure/exception/model-not-found.exception'; @Injectable() export class ChatUsecases { @@ -14,19 +15,50 @@ export class ChatUsecases { private readonly extensionRepository: ExtensionRepository, ) {} - async createChatCompletions( + async inference( createChatDto: CreateChatCompletionDto, headers: Record, - stream: WritableStream, - res?: any, - ) { + ): Promise { + const { model: modelId } = createChatDto; const extensions = (await this.extensionRepository.findAll()) ?? []; const model = await this.modelRepository.findOne({ - where: { id: createChatDto.model }, + where: { id: modelId }, }); + + if (!model) { + throw new ModelNotFoundException(modelId); + } const engine = extensions.find((e: any) => e.provider === model?.engine) as | EngineExtension | undefined; - engine?.inference(createChatDto, headers, stream, res); + + if (engine == null) { + throw new Error(`No engine found with name: ${model.engine}`); + } + return engine.inference(createChatDto, headers); + } + + async inferenceStream( + createChatDto: CreateChatCompletionDto, + headers: Record, + ): Promise { + const { model: modelId } = createChatDto; + const extensions = (await this.extensionRepository.findAll()) ?? []; + const model = await this.modelRepository.findOne({ + where: { id: modelId }, + }); + + if (!model) { + throw new ModelNotFoundException(modelId); + } + + const engine = extensions.find((e: any) => e.provider === model.engine) as + | EngineExtension + | undefined; + if (engine == null) { + throw new Error(`No engine found with name: ${model.engine}`); + } + + return engine?.inferenceStream(createChatDto, headers); } } diff --git a/cortex-js/src/usecases/models/models.usecases.ts b/cortex-js/src/usecases/models/models.usecases.ts index 3f960497a..c5257648b 100644 --- a/cortex-js/src/usecases/models/models.usecases.ts +++ b/cortex-js/src/usecases/models/models.usecases.ts @@ -193,8 +193,6 @@ export class ModelsUsecases { throw new BadRequestException('Cannot download remote model'); } - // TODO: NamH download multiple files - const downloadUrl = model.sources[0].url; if (!this.isValidUrl(downloadUrl)) { throw new BadRequestException(`Invalid download URL: ${downloadUrl}`);