Skip to content

Commit

Permalink
chore: auto load model on /chat/completions request (#900)
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-jan authored Jul 22, 2024
1 parent 9857448 commit 749cf0c
Show file tree
Hide file tree
Showing 14 changed files with 154 additions and 30 deletions.
32 changes: 31 additions & 1 deletion cortex-js/src/domain/abstracts/engine.abstract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ import stream from 'stream';
import { Model, ModelSettingParams } from '../../domain/models/model.interface';
import { Extension } from './extension.abstract';

/**
* This class should be extended by any class that represents an engine extension.
* It provides methods for loading and unloading models, and for making inference requests.
*/
export abstract class EngineExtension extends Extension {
abstract onLoad(): void;

Expand All @@ -12,16 +16,42 @@ export abstract class EngineExtension extends Extension {

initalized: boolean = false;

/**
* Makes an inference request to the engine.
* @param dto
* @param headers
*/
abstract inference(
dto: any,
headers: Record<string, string>,
): Promise<stream.Readable | any>;

/**
* Checks if a model is running by the engine
* This method should check run-time status of the model
* Since the model can be corrupted during the run-time
* This method should return false if the model is not running
* @param modelId
*/
async isModelRunning(modelId: string): Promise<boolean> {
return true;
}

/**
* Loads a model into the engine.
* There are model settings such as `ngl` and `ctx_len` that can be passed to the engine.
* Applicable for local engines only
* @param model
* @param settingParams
*/
async loadModel(
model: Model,
settingParams?: ModelSettingParams,
): Promise<void> {}

/**
* Unloads a model from the engine.
* @param modelId
*/
async unloadModel(modelId: string): Promise<void> {}

}
9 changes: 9 additions & 0 deletions cortex-js/src/domain/models/model.interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ export interface ModelSettingParams {
* The number of layers to load onto the GPU for acceleration.
*/
ngl?: number;

/**
* Support embedding or not (legacy)
*/
embedding?: boolean;

/**
Expand Down Expand Up @@ -117,6 +121,11 @@ export interface ModelSettingParams {
* To enable mmap, default is true
*/
use_mmap?: boolean;

/**
* Model type we want to use: llm or embedding, default value is llm (latest llama.cpp update)
*/
model_type?: string;
}

/**
Expand Down
4 changes: 3 additions & 1 deletion cortex-js/src/infrastructure/constants/cortex.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ export const defaultCortexJsPort = 1337;

export const defaultCortexCppHost = '127.0.0.1';
export const defaultCortexCppPort = 3929;

export const defaultEmbeddingModel = 'nomic-embed-text-v1';
// CORTEX CPP
export const CORTEX_CPP_EMBEDDINGS_URL = (
host: string = defaultCortexCppHost,
Expand Down Expand Up @@ -50,4 +52,4 @@ export const CUDA_DOWNLOAD_URL =

export const telemetryServerUrl = 'https://telemetry.jan.ai';

export const MIN_CUDA_VERSION = '12.3';
export const MIN_CUDA_VERSION = '12.3';
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { DownloadManagerModule } from '@/infrastructure/services/download-manage
import { EventEmitterModule } from '@nestjs/event-emitter';
import { TelemetryModule } from '@/usecases/telemetry/telemetry.module';
import { FileManagerModule } from '../services/file-manager/file-manager.module';
import { ModelsModule } from '@/usecases/models/models.module';

describe('ChatController', () => {
let controller: ChatController;
Expand All @@ -25,6 +26,7 @@ describe('ChatController', () => {
EventEmitterModule.forRoot(),
TelemetryModule,
FileManagerModule,
ModelsModule,
],
controllers: [ChatController],
providers: [ChatUsecases],
Expand Down
32 changes: 13 additions & 19 deletions cortex-js/src/infrastructure/controllers/chat.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,27 +35,21 @@ export class ChatController {
) {
const { stream } = createChatDto;

if (stream) {
this.chatService
.inference(createChatDto, extractCommonHeaders(headers))
.then((stream) => {
this.chatService
.inference(createChatDto, extractCommonHeaders(headers))
.then((response) => {
if (stream) {
res.header('Content-Type', 'text/event-stream');
stream.pipe(res);
})
.catch((error) =>
res.status(error.statusCode ?? 400).send(error.message),
);
} else {
res.header('Content-Type', 'application/json');
this.chatService
.inference(createChatDto, extractCommonHeaders(headers))
.then((response) => {
response.pipe(res);
} else {
res.header('Content-Type', 'application/json');
res.json(response);
})
.catch((error) =>
res.status(error.statusCode ?? 400).send(error.message),
);
}
}
})
.catch((error) =>
res.status(error.statusCode ?? 400).send(error.message),
);

this.telemetryUsecases.addEventToQueue({
name: EventName.CHAT,
modelId: createChatDto.model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { DownloadManagerModule } from '@/infrastructure/services/download-manage
import { EventEmitterModule } from '@nestjs/event-emitter';
import { TelemetryModule } from '@/usecases/telemetry/telemetry.module';
import { FileManagerModule } from '../services/file-manager/file-manager.module';
import { ModelsModule } from '@/usecases/models/models.module';

describe('EmbeddingsController', () => {
let controller: EmbeddingsController;
Expand All @@ -25,6 +26,7 @@ describe('EmbeddingsController', () => {
EventEmitterModule.forRoot(),
TelemetryModule,
FileManagerModule,
ModelsModule,
],
controllers: [EmbeddingsController],
providers: [ChatUsecases],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Body, Controller, Post, HttpCode } from '@nestjs/common';
import { Body, Controller, Post, HttpCode, Res } from '@nestjs/common';
import { ChatUsecases } from '@/usecases/chat/chat.usecases';
import { ApiOperation, ApiTags, ApiResponse } from '@nestjs/swagger';
import { CreateEmbeddingsDto } from '../dtos/embeddings/embeddings-request.dto';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ export class CreateMessageDto implements Partial<Message> {
example: 'user',
description: 'The sources of the messages.',
})
@IsString()
role: 'user' | 'assistant';

@ApiProperty({
Expand Down
35 changes: 33 additions & 2 deletions cortex-js/src/infrastructure/providers/cortex/cortex.provider.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import { Injectable } from '@nestjs/common';
import { HttpStatus, Injectable } from '@nestjs/common';
import { OAIEngineExtension } from '@/domain/abstracts/oai.abstract';
import { PromptTemplate } from '@/domain/models/prompt-template.interface';
import { join } from 'path';
import { Model, ModelSettingParams } from '@/domain/models/model.interface';
import { HttpService } from '@nestjs/axios';
import {
CORTEX_CPP_MODELS_URL,
defaultCortexCppHost,
defaultCortexCppPort,
} from '@/infrastructure/constants/cortex';
Expand All @@ -13,6 +14,11 @@ import { normalizeModelId } from '@/utils/normalize-model-id';
import { firstValueFrom } from 'rxjs';
import { FileManagerService } from '@/infrastructure/services/file-manager/file-manager.service';

export interface ModelStatResponse {
object: string;
data: any;
}

@Injectable()
export default class CortexProvider extends OAIEngineExtension {
apiUrl = `http://${defaultCortexCppHost}:${defaultCortexCppPort}/inferences/server/chat_completion`;
Expand All @@ -28,11 +34,12 @@ export default class CortexProvider extends OAIEngineExtension {

constructor(
protected readonly httpService: HttpService,
private readonly fileManagerService: FileManagerService,
protected readonly fileManagerService: FileManagerService,
) {
super(httpService);
}

// Override the inference method to make an inference request to the engine
override async loadModel(
model: Model,
settings?: ModelSettingParams,
Expand Down Expand Up @@ -92,6 +99,30 @@ export default class CortexProvider extends OAIEngineExtension {
).then(); // pipe error or void instead of throwing
}

// Override the isModelRunning method to check if the model is running
override async isModelRunning(modelId: string): Promise<boolean> {
const configs = await this.fileManagerService.getConfig();

return firstValueFrom(
this.httpService.get(
CORTEX_CPP_MODELS_URL(configs.cortexCppHost, configs.cortexCppPort),
),
)
.then((res) => {
const data = res.data as ModelStatResponse;
if (
res.status === HttpStatus.OK &&
data &&
Array.isArray(data.data) &&
data.data.length > 0
) {
return data.data.find((e) => e.id === modelId);
}
return false;
})
.catch(() => false);
}

private readonly promptTemplateConverter = (
promptTemplate: string,
): PromptTemplate => {
Expand Down
2 changes: 2 additions & 0 deletions cortex-js/src/usecases/chat/chat.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { ModelRepositoryModule } from '@/infrastructure/repositories/models/mode
import { HttpModule } from '@nestjs/axios';
import { TelemetryModule } from '../telemetry/telemetry.module';
import { FileManagerModule } from '@/infrastructure/services/file-manager/file-manager.module';
import { ModelsModule } from '../models/models.module';

@Module({
imports: [
Expand All @@ -15,6 +16,7 @@ import { FileManagerModule } from '@/infrastructure/services/file-manager/file-m
HttpModule,
TelemetryModule,
FileManagerModule,
ModelsModule,
],
controllers: [],
providers: [ChatUsecases],
Expand Down
2 changes: 2 additions & 0 deletions cortex-js/src/usecases/chat/chat.usecases.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { HttpModule } from '@nestjs/axios';
import { DownloadManagerModule } from '@/infrastructure/services/download-manager/download-manager.module';
import { EventEmitterModule } from '@nestjs/event-emitter';
import { FileManagerModule } from '@/infrastructure/services/file-manager/file-manager.module';
import { ModelsModule } from '../models/models.module';

describe('ChatService', () => {
let service: ChatUsecases;
Expand All @@ -24,6 +25,7 @@ describe('ChatService', () => {
DownloadManagerModule,
EventEmitterModule.forRoot(),
FileManagerModule,
ModelsModule,
],
providers: [ChatUsecases],
exports: [ChatUsecases],
Expand Down
38 changes: 34 additions & 4 deletions cortex-js/src/usecases/chat/chat.usecases.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,24 @@ import { EngineExtension } from '@/domain/abstracts/engine.abstract';
import { ModelNotFoundException } from '@/infrastructure/exception/model-not-found.exception';
import { TelemetryUsecases } from '../telemetry/telemetry.usecases';
import { TelemetrySource } from '@/domain/telemetry/telemetry.interface';
import { ModelRepository } from '@/domain/repositories/model.interface';
import { ExtensionRepository } from '@/domain/repositories/extension.interface';
import { firstValueFrom } from 'rxjs';
import { HttpService } from '@nestjs/axios';
import { CORTEX_CPP_EMBEDDINGS_URL } from '@/infrastructure/constants/cortex';
import {
CORTEX_CPP_EMBEDDINGS_URL,
defaultEmbeddingModel,
} from '@/infrastructure/constants/cortex';
import { CreateEmbeddingsDto } from '@/infrastructure/dtos/embeddings/embeddings-request.dto';
import { FileManagerService } from '@/infrastructure/services/file-manager/file-manager.service';
import { Engines } from '@/infrastructure/commanders/types/engine.interface';
import { ModelsUsecases } from '../models/models.usecases';

@Injectable()
export class ChatUsecases {
constructor(
private readonly modelRepository: ModelRepository,
private readonly extensionRepository: ExtensionRepository,
private readonly telemetryUseCases: TelemetryUsecases,
private readonly modelsUsescases: ModelsUsecases,
private readonly httpService: HttpService,
private readonly fileService: FileManagerService,
) {}
Expand All @@ -28,10 +31,18 @@ export class ChatUsecases {
headers: Record<string, string>,
): Promise<any> {
const { model: modelId } = createChatDto;
const model = await this.modelRepository.findOne(modelId);
const model = await this.modelsUsescases.findOne(modelId);
if (!model) {
throw new ModelNotFoundException(modelId);
}

const isModelRunning = await this.modelsUsescases.isModelRunning(modelId);
// If model is not running
// Start the model
if (!isModelRunning) {
await this.modelsUsescases.startModel(modelId);
}

const engine = (await this.extensionRepository.findOne(
model!.engine ?? Engines.llamaCPP,
)) as EngineExtension | undefined;
Expand Down Expand Up @@ -59,7 +70,26 @@ export class ChatUsecases {
* @returns Embedding vector.
*/
async embeddings(dto: CreateEmbeddingsDto) {
const modelId = dto.model ?? defaultEmbeddingModel;

if (modelId !== dto.model) dto = { ...dto, model: modelId };

if (!(await this.modelsUsescases.findOne(modelId))) {
await this.modelsUsescases.pullModel(modelId);
}

const isModelRunning = await this.modelsUsescases.isModelRunning(modelId);
// If model is not running
// Start the model
if (!isModelRunning) {
await this.modelsUsescases.startModel(modelId, {
embedding: true,
model_type: 'embedding',
});
}

const configs = await this.fileService.getConfig();

return firstValueFrom(
this.httpService.post(
CORTEX_CPP_EMBEDDINGS_URL(configs.cortexCppHost, configs.cortexCppPort),
Expand Down
21 changes: 19 additions & 2 deletions cortex-js/src/usecases/models/models.usecases.ts
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,11 @@ export class ModelsUsecases {
await promises.mkdir(modelFolder, { recursive: true }).catch(() => {});

let files = (await fetchJanRepoData(originModelId)).siblings;

// HuggingFace GGUF Repo - Only one file is downloaded
if (originModelId.includes('/') && selection && files.length) {
try {
files = [await selection(files)];
files = [await selection(files)];
} catch (e) {
const modelEvent: ModelEvent = {
model: modelId,
Expand Down Expand Up @@ -498,6 +498,23 @@ export class ModelsUsecases {
return this.activeModelStatuses;
}

/**
* Check whether the model is running in the Cortex C++ server
*/
async isModelRunning(modelId: string): Promise<boolean> {
const model = await this.getModelOrThrow(modelId).catch((e) => undefined);

if (!model) return false;

const engine = (await this.extensionRepository.findOne(
model.engine ?? Engines.llamaCPP,
)) as EngineExtension | undefined;

if (!engine) return false;

return engine.isModelRunning(modelId);
}

/**
* Check whether the download file is valid or not
* @param file
Expand Down
Loading

0 comments on commit 749cf0c

Please sign in to comment.