diff --git a/cortex-js/src/command.module.ts b/cortex-js/src/command.module.ts index fe23ca16c..f3e4efd42 100644 --- a/cortex-js/src/command.module.ts +++ b/cortex-js/src/command.module.ts @@ -9,6 +9,8 @@ import { PullCommand } from './infrastructure/commanders/pull.command'; import { InferenceCommand } from './infrastructure/commanders/inference.command'; import { ModelsCommand } from './infrastructure/commanders/models.command'; import { StartCommand } from './infrastructure/commanders/start.command'; +import { ExtensionModule } from './infrastructure/repositories/extensions/extension.module'; +import { ChatModule } from './usecases/chat/chat.module'; @Module({ imports: [ @@ -20,6 +22,8 @@ import { StartCommand } from './infrastructure/commanders/start.command'; DatabaseModule, ModelsModule, CortexModule, + ChatModule, + ExtensionModule, ], providers: [ BasicCommand, diff --git a/cortex-js/src/domain/abstracts/engine.abstract.ts b/cortex-js/src/domain/abstracts/engine.abstract.ts index 564faa2a1..596f5eadf 100644 --- a/cortex-js/src/domain/abstracts/engine.abstract.ts +++ b/cortex-js/src/domain/abstracts/engine.abstract.ts @@ -1,8 +1,12 @@ +import { Model } from '../models/model.interface'; import { Extension } from './extension.abstract'; export abstract class EngineExtension extends Extension { abstract provider: string; - abstract inference(completion: any, req: any, res: any): void; - abstract loadModel(loadModel: any): Promise; - abstract unloadModel(modelId: string): Promise; + + abstract inference(completion: any, req: any, stream: any, res?: any): void; + + async loadModel(model: Model): Promise {} + + async unloadModel(modelId: string): Promise {} } diff --git a/cortex-js/src/domain/abstracts/oai.abstract.ts b/cortex-js/src/domain/abstracts/oai.abstract.ts index 96748449d..2923c4277 100644 --- a/cortex-js/src/domain/abstracts/oai.abstract.ts +++ b/cortex-js/src/domain/abstracts/oai.abstract.ts @@ -1,6 +1,12 @@ -/* eslint-disable @typescript-eslint/no-unused-vars */ import { HttpService } from '@nestjs/axios'; import { EngineExtension } from './engine.abstract'; +import { stdout } from 'process'; + +export type ChatStreamEvent = { + type: 'data' | 'error' | 'end'; + data?: any; + error?: any; +}; export abstract class OAIEngineExtension extends EngineExtension { abstract apiUrl: string; @@ -9,44 +15,120 @@ export abstract class OAIEngineExtension extends EngineExtension { super(); } - async inference( + inference( createChatDto: any, headers: Record, - res: any, + writableStream: WritableStream, + res?: any, ) { if (createChatDto.stream === true) { - const response = await this.httpService - .post(this.apiUrl, createChatDto, { - headers: { - 'Content-Type': headers['content-type'] ?? 'application/json', - Authorization: headers['authorization'], - }, - responseType: 'stream', - }) - .toPromise(); - - res.writeHead(200, { - 'Content-Type': 'text/event-stream', - 'Cache-Control': 'no-cache', - Connection: 'keep-alive', - 'Access-Control-Allow-Origin': '*', - }); + if (res) { + res.writeHead(200, { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache', + Connection: 'keep-alive', + 'Access-Control-Allow-Origin': '*', + }); + this.httpService + .post(this.apiUrl, createChatDto, { + headers: { + 'Content-Type': headers['content-type'] ?? 'application/json', + Authorization: headers['authorization'], + }, + responseType: 'stream', + }) + .toPromise() + .then((response) => { + response?.data.pipe(res); + }); + } else { + const decoder = new TextDecoder('utf-8'); + const defaultWriter = writableStream.getWriter(); + defaultWriter.ready.then(() => { + this.httpService + .post(this.apiUrl, createChatDto, { + headers: { + 'Content-Type': headers['content-type'] ?? 'application/json', + Authorization: headers['authorization'], + }, + responseType: 'stream', + }) + .subscribe({ + next: (response) => { + response.data.on('data', (chunk: any) => { + let content = ''; + const text = decoder.decode(chunk); + const lines = text.trim().split('\n'); + let cachedLines = ''; + for (const line of lines) { + try { + const toParse = cachedLines + line; + if (!line.includes('data: [DONE]')) { + const data = JSON.parse(toParse.replace('data: ', '')); + content += data.choices[0]?.delta?.content ?? ''; + + if (content.startsWith('assistant: ')) { + content = content.replace('assistant: ', ''); + } + + if (content !== '') { + defaultWriter.write({ + type: 'data', + data: content, + }); + } + } + } catch { + cachedLines = line; + } + } + }); - response?.data.pipe(res); + response.data.on('error', (error: any) => { + defaultWriter.write({ + type: 'error', + error, + }); + }); + + response.data.on('end', () => { + // stdout.write('Stream end'); + defaultWriter.write({ + type: 'end', + }); + }); + }, + + error: (error) => { + stdout.write('Stream error: ' + error); + }, + }); + }); + } } else { - const response = await this.httpService - .post(this.apiUrl, createChatDto, { - headers: { - 'Content-Type': headers['content-type'] ?? 'application/json', - Authorization: headers['authorization'], - }, - }) - .toPromise(); - - res.json(response?.data); + const defaultWriter = writableStream.getWriter(); + defaultWriter.ready.then(() => { + this.httpService + .post(this.apiUrl, createChatDto, { + headers: { + 'Content-Type': headers['content-type'] ?? 'application/json', + Authorization: headers['authorization'], + }, + }) + .toPromise() + .then((response) => { + defaultWriter.write({ + type: 'data', + data: response?.data, + }); + }) + .catch((error: any) => { + defaultWriter.write({ + type: 'error', + error, + }); + }); + }); } } - - async loadModel(_loadModel: any): Promise {} - async unloadModel(_modelId: string): Promise {} } diff --git a/cortex-js/src/infrastructure/commanders/inference.command.ts b/cortex-js/src/infrastructure/commanders/inference.command.ts index b5eba3988..fc94058df 100644 --- a/cortex-js/src/infrastructure/commanders/inference.command.ts +++ b/cortex-js/src/infrastructure/commanders/inference.command.ts @@ -1,25 +1,81 @@ +import { ChatUsecases } from '@/usecases/chat/chat.usecases'; import { CommandRunner, SubCommand } from 'nest-commander'; +import { CreateChatCompletionDto } from '../dtos/chat/create-chat-completion.dto'; +import { ChatCompletionRole } from '@/domain/models/message.interface'; +import { stdout } from 'process'; +import * as readline from 'node:readline/promises'; +import { ChatStreamEvent } from '@/domain/abstracts/oai.abstract'; +import { ChatCompletionMessage } from '../dtos/chat/chat-completion-message.dto'; @SubCommand({ name: 'chat' }) export class InferenceCommand extends CommandRunner { - constructor() { + exitClause = 'exit()'; + userIndicator = '>> '; + exitMessage = 'Bye!'; + + constructor(private readonly chatUsecases: ChatUsecases) { super(); } - async run(_input: string[]): Promise { - const lineByLine = require('readline'); - const lbl = lineByLine.createInterface({ + async run(): Promise { + console.log(`Inorder to exit, type '${this.exitClause}'.`); + const messages: ChatCompletionMessage[] = []; + + const rl = readline.createInterface({ input: process.stdin, output: process.stdout, + prompt: this.userIndicator, + }); + rl.prompt(); + + rl.on('close', () => { + console.log(this.exitMessage); + process.exit(0); }); - lbl.on('line', (userInput: string) => { - if (userInput.trim() === 'exit()') { - lbl.close(); + + rl.on('line', (userInput: string) => { + if (userInput.trim() === this.exitClause) { + rl.close(); return; } - console.log('Result:', userInput); - console.log('Enter another equation or type "exit()" to quit.'); + messages.push({ + content: userInput, + role: ChatCompletionRole.User, + }); + + const chatDto: CreateChatCompletionDto = { + messages, + model: 'TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF', + stream: true, + max_tokens: 2048, + stop: [], + frequency_penalty: 0.7, + presence_penalty: 0.7, + temperature: 0.7, + top_p: 0.7, + }; + + let llmFullResponse = ''; + const writableStream = new WritableStream({ + write(chunk) { + if (chunk.type === 'data') { + stdout.write(chunk.data ?? ''); + llmFullResponse += chunk.data ?? ''; + } else if (chunk.type === 'error') { + console.log('Error!!'); + } else { + messages.push({ + content: llmFullResponse, + role: ChatCompletionRole.Assistant, + }); + llmFullResponse = ''; + console.log('\n'); + } + }, + }); + + this.chatUsecases.createChatCompletions(chatDto, {}, writableStream); }); } } diff --git a/cortex-js/src/infrastructure/controllers/chat.controller.ts b/cortex-js/src/infrastructure/controllers/chat.controller.ts index dc9f7abda..e9c50591e 100644 --- a/cortex-js/src/infrastructure/controllers/chat.controller.ts +++ b/cortex-js/src/infrastructure/controllers/chat.controller.ts @@ -3,6 +3,7 @@ import { CreateChatCompletionDto } from '@/infrastructure/dtos/chat/create-chat- import { ChatUsecases } from '@/usecases/chat/chat.usecases'; import { Response } from 'express'; import { ApiTags } from '@nestjs/swagger'; +import { ChatStreamEvent } from '@/domain/abstracts/oai.abstract'; @ApiTags('Inference') @Controller('chat') @@ -15,6 +16,23 @@ export class ChatController { @Body() createChatDto: CreateChatCompletionDto, @Res() res: Response, ) { - this.chatService.createChatCompletions(createChatDto, headers, res); + const writableStream = new WritableStream({ + write(chunk) { + if (chunk.type === 'data') { + res.json(chunk.data ?? {}); + } else if (chunk.type === 'error') { + res.json(chunk.error ?? {}); + } else { + console.log('\n'); + } + }, + }); + + this.chatService.createChatCompletions( + createChatDto, + headers, + writableStream, + res, + ); } } diff --git a/cortex-js/src/usecases/chat/chat.module.ts b/cortex-js/src/usecases/chat/chat.module.ts index 1f7c70090..e69b10b73 100644 --- a/cortex-js/src/usecases/chat/chat.module.ts +++ b/cortex-js/src/usecases/chat/chat.module.ts @@ -8,5 +8,6 @@ import { ExtensionModule } from '@/infrastructure/repositories/extensions/extens imports: [DatabaseModule, ExtensionModule], controllers: [ChatController], providers: [ChatUsecases], + exports: [ChatUsecases], }) export class ChatModule {} diff --git a/cortex-js/src/usecases/chat/chat.usecases.ts b/cortex-js/src/usecases/chat/chat.usecases.ts index f4c338b0a..6386e57d8 100644 --- a/cortex-js/src/usecases/chat/chat.usecases.ts +++ b/cortex-js/src/usecases/chat/chat.usecases.ts @@ -1,10 +1,10 @@ import { Inject, Injectable } from '@nestjs/common'; import { CreateChatCompletionDto } from '@/infrastructure/dtos/chat/create-chat-completion.dto'; -import { Response } from 'express'; import { ExtensionRepository } from '@/domain/repositories/extension.interface'; import { Repository } from 'typeorm'; import { ModelEntity } from '@/infrastructure/entities/model.entity'; import { EngineExtension } from '@/domain/abstracts/engine.abstract'; +import { ChatStreamEvent } from '@/domain/abstracts/oai.abstract'; @Injectable() export class ChatUsecases { @@ -17,7 +17,8 @@ export class ChatUsecases { async createChatCompletions( createChatDto: CreateChatCompletionDto, headers: Record, - res: Response, + stream: WritableStream, + res?: any, ) { const extensions = (await this.extensionRepository.findAll()) ?? []; const model = await this.modelRepository.findOne({ @@ -26,6 +27,6 @@ export class ChatUsecases { const engine = extensions.find((e: any) => e.provider === model?.engine) as | EngineExtension | undefined; - await engine?.inference(createChatDto, headers, res); + engine?.inference(createChatDto, headers, stream, res); } }