Skip to content

Commit

Permalink
Merge pull request #564 from janhq/feat/cli-chat
Browse files Browse the repository at this point in the history
[WIP] feat: add CLI for chat
  • Loading branch information
namchuai authored May 15, 2024
2 parents ca85eee + 177c0fb commit 14cfcfe
Show file tree
Hide file tree
Showing 7 changed files with 215 additions and 49 deletions.
4 changes: 4 additions & 0 deletions cortex-js/src/command.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import { PullCommand } from './infrastructure/commanders/pull.command';
import { InferenceCommand } from './infrastructure/commanders/inference.command';
import { ModelsCommand } from './infrastructure/commanders/models.command';
import { StartCommand } from './infrastructure/commanders/start.command';
import { ExtensionModule } from './infrastructure/repositories/extensions/extension.module';
import { ChatModule } from './usecases/chat/chat.module';

@Module({
imports: [
Expand All @@ -20,6 +22,8 @@ import { StartCommand } from './infrastructure/commanders/start.command';
DatabaseModule,
ModelsModule,
CortexModule,
ChatModule,
ExtensionModule,
],
providers: [
BasicCommand,
Expand Down
10 changes: 7 additions & 3 deletions cortex-js/src/domain/abstracts/engine.abstract.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import { Model } from '../models/model.interface';
import { Extension } from './extension.abstract';

export abstract class EngineExtension extends Extension {
abstract provider: string;
abstract inference(completion: any, req: any, res: any): void;
abstract loadModel(loadModel: any): Promise<void>;
abstract unloadModel(modelId: string): Promise<void>;

abstract inference(completion: any, req: any, stream: any, res?: any): void;

async loadModel(model: Model): Promise<void> {}

async unloadModel(modelId: string): Promise<void> {}
}
148 changes: 115 additions & 33 deletions cortex-js/src/domain/abstracts/oai.abstract.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
/* eslint-disable @typescript-eslint/no-unused-vars */
import { HttpService } from '@nestjs/axios';
import { EngineExtension } from './engine.abstract';
import { stdout } from 'process';

export type ChatStreamEvent = {
type: 'data' | 'error' | 'end';
data?: any;
error?: any;
};

export abstract class OAIEngineExtension extends EngineExtension {
abstract apiUrl: string;
Expand All @@ -9,44 +15,120 @@ export abstract class OAIEngineExtension extends EngineExtension {
super();
}

async inference(
inference(
createChatDto: any,
headers: Record<string, string>,
res: any,
writableStream: WritableStream<ChatStreamEvent>,
res?: any,
) {
if (createChatDto.stream === true) {
const response = await this.httpService
.post(this.apiUrl, createChatDto, {
headers: {
'Content-Type': headers['content-type'] ?? 'application/json',
Authorization: headers['authorization'],
},
responseType: 'stream',
})
.toPromise();

res.writeHead(200, {
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
Connection: 'keep-alive',
'Access-Control-Allow-Origin': '*',
});
if (res) {
res.writeHead(200, {
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
Connection: 'keep-alive',
'Access-Control-Allow-Origin': '*',
});
this.httpService
.post(this.apiUrl, createChatDto, {
headers: {
'Content-Type': headers['content-type'] ?? 'application/json',
Authorization: headers['authorization'],
},
responseType: 'stream',
})
.toPromise()
.then((response) => {
response?.data.pipe(res);
});
} else {
const decoder = new TextDecoder('utf-8');
const defaultWriter = writableStream.getWriter();
defaultWriter.ready.then(() => {
this.httpService
.post(this.apiUrl, createChatDto, {
headers: {
'Content-Type': headers['content-type'] ?? 'application/json',
Authorization: headers['authorization'],
},
responseType: 'stream',
})
.subscribe({
next: (response) => {
response.data.on('data', (chunk: any) => {
let content = '';
const text = decoder.decode(chunk);
const lines = text.trim().split('\n');
let cachedLines = '';
for (const line of lines) {
try {
const toParse = cachedLines + line;
if (!line.includes('data: [DONE]')) {
const data = JSON.parse(toParse.replace('data: ', ''));
content += data.choices[0]?.delta?.content ?? '';

if (content.startsWith('assistant: ')) {
content = content.replace('assistant: ', '');
}

if (content !== '') {
defaultWriter.write({
type: 'data',
data: content,
});
}
}
} catch {
cachedLines = line;
}
}
});

response?.data.pipe(res);
response.data.on('error', (error: any) => {
defaultWriter.write({
type: 'error',
error,
});
});

response.data.on('end', () => {
// stdout.write('Stream end');
defaultWriter.write({
type: 'end',
});
});
},

error: (error) => {
stdout.write('Stream error: ' + error);
},
});
});
}
} else {
const response = await this.httpService
.post(this.apiUrl, createChatDto, {
headers: {
'Content-Type': headers['content-type'] ?? 'application/json',
Authorization: headers['authorization'],
},
})
.toPromise();

res.json(response?.data);
const defaultWriter = writableStream.getWriter();
defaultWriter.ready.then(() => {
this.httpService
.post(this.apiUrl, createChatDto, {
headers: {
'Content-Type': headers['content-type'] ?? 'application/json',
Authorization: headers['authorization'],
},
})
.toPromise()
.then((response) => {
defaultWriter.write({
type: 'data',
data: response?.data,
});
})
.catch((error: any) => {
defaultWriter.write({
type: 'error',
error,
});
});
});
}
}

async loadModel(_loadModel: any): Promise<void> {}
async unloadModel(_modelId: string): Promise<void> {}
}
74 changes: 65 additions & 9 deletions cortex-js/src/infrastructure/commanders/inference.command.ts
Original file line number Diff line number Diff line change
@@ -1,25 +1,81 @@
import { ChatUsecases } from '@/usecases/chat/chat.usecases';
import { CommandRunner, SubCommand } from 'nest-commander';
import { CreateChatCompletionDto } from '../dtos/chat/create-chat-completion.dto';
import { ChatCompletionRole } from '@/domain/models/message.interface';
import { stdout } from 'process';
import * as readline from 'node:readline/promises';
import { ChatStreamEvent } from '@/domain/abstracts/oai.abstract';
import { ChatCompletionMessage } from '../dtos/chat/chat-completion-message.dto';

@SubCommand({ name: 'chat' })
export class InferenceCommand extends CommandRunner {
constructor() {
exitClause = 'exit()';
userIndicator = '>> ';
exitMessage = 'Bye!';

constructor(private readonly chatUsecases: ChatUsecases) {
super();
}

async run(_input: string[]): Promise<void> {
const lineByLine = require('readline');
const lbl = lineByLine.createInterface({
async run(): Promise<void> {
console.log(`Inorder to exit, type '${this.exitClause}'.`);
const messages: ChatCompletionMessage[] = [];

const rl = readline.createInterface({
input: process.stdin,
output: process.stdout,
prompt: this.userIndicator,
});
rl.prompt();

rl.on('close', () => {
console.log(this.exitMessage);
process.exit(0);
});
lbl.on('line', (userInput: string) => {
if (userInput.trim() === 'exit()') {
lbl.close();

rl.on('line', (userInput: string) => {
if (userInput.trim() === this.exitClause) {
rl.close();
return;
}

console.log('Result:', userInput);
console.log('Enter another equation or type "exit()" to quit.');
messages.push({
content: userInput,
role: ChatCompletionRole.User,
});

const chatDto: CreateChatCompletionDto = {
messages,
model: 'TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF',
stream: true,
max_tokens: 2048,
stop: [],
frequency_penalty: 0.7,
presence_penalty: 0.7,
temperature: 0.7,
top_p: 0.7,
};

let llmFullResponse = '';
const writableStream = new WritableStream<ChatStreamEvent>({
write(chunk) {
if (chunk.type === 'data') {
stdout.write(chunk.data ?? '');
llmFullResponse += chunk.data ?? '';
} else if (chunk.type === 'error') {
console.log('Error!!');
} else {
messages.push({
content: llmFullResponse,
role: ChatCompletionRole.Assistant,
});
llmFullResponse = '';
console.log('\n');
}
},
});

this.chatUsecases.createChatCompletions(chatDto, {}, writableStream);
});
}
}
20 changes: 19 additions & 1 deletion cortex-js/src/infrastructure/controllers/chat.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { CreateChatCompletionDto } from '@/infrastructure/dtos/chat/create-chat-
import { ChatUsecases } from '@/usecases/chat/chat.usecases';
import { Response } from 'express';
import { ApiTags } from '@nestjs/swagger';
import { ChatStreamEvent } from '@/domain/abstracts/oai.abstract';

@ApiTags('Inference')
@Controller('chat')
Expand All @@ -15,6 +16,23 @@ export class ChatController {
@Body() createChatDto: CreateChatCompletionDto,
@Res() res: Response,
) {
this.chatService.createChatCompletions(createChatDto, headers, res);
const writableStream = new WritableStream<ChatStreamEvent>({
write(chunk) {
if (chunk.type === 'data') {
res.json(chunk.data ?? {});
} else if (chunk.type === 'error') {
res.json(chunk.error ?? {});
} else {
console.log('\n');
}
},
});

this.chatService.createChatCompletions(
createChatDto,
headers,
writableStream,
res,
);
}
}
1 change: 1 addition & 0 deletions cortex-js/src/usecases/chat/chat.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ import { ExtensionModule } from '@/infrastructure/repositories/extensions/extens
imports: [DatabaseModule, ExtensionModule],
controllers: [ChatController],
providers: [ChatUsecases],
exports: [ChatUsecases],
})
export class ChatModule {}
7 changes: 4 additions & 3 deletions cortex-js/src/usecases/chat/chat.usecases.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { Inject, Injectable } from '@nestjs/common';
import { CreateChatCompletionDto } from '@/infrastructure/dtos/chat/create-chat-completion.dto';
import { Response } from 'express';
import { ExtensionRepository } from '@/domain/repositories/extension.interface';
import { Repository } from 'typeorm';
import { ModelEntity } from '@/infrastructure/entities/model.entity';
import { EngineExtension } from '@/domain/abstracts/engine.abstract';
import { ChatStreamEvent } from '@/domain/abstracts/oai.abstract';

@Injectable()
export class ChatUsecases {
Expand All @@ -17,7 +17,8 @@ export class ChatUsecases {
async createChatCompletions(
createChatDto: CreateChatCompletionDto,
headers: Record<string, string>,
res: Response,
stream: WritableStream<ChatStreamEvent>,
res?: any,
) {
const extensions = (await this.extensionRepository.findAll()) ?? [];
const model = await this.modelRepository.findOne({
Expand All @@ -26,6 +27,6 @@ export class ChatUsecases {
const engine = extensions.find((e: any) => e.provider === model?.engine) as
| EngineExtension
| undefined;
await engine?.inference(createChatDto, headers, res);
engine?.inference(createChatDto, headers, stream, res);
}
}

0 comments on commit 14cfcfe

Please sign in to comment.