Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: auto load model on /chat/completions request #900

Merged
merged 1 commit into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion cortex-js/src/domain/abstracts/engine.abstract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ import stream from 'stream';
import { Model, ModelSettingParams } from '../../domain/models/model.interface';
import { Extension } from './extension.abstract';

/**
* This class should be extended by any class that represents an engine extension.
* It provides methods for loading and unloading models, and for making inference requests.
*/
export abstract class EngineExtension extends Extension {
abstract onLoad(): void;

Expand All @@ -12,16 +16,42 @@ export abstract class EngineExtension extends Extension {

initalized: boolean = false;

/**
* Makes an inference request to the engine.
* @param dto
* @param headers
*/
abstract inference(
dto: any,
headers: Record<string, string>,
): Promise<stream.Readable | any>;

/**
* Checks if a model is running by the engine
* This method should check run-time status of the model
* Since the model can be corrupted during the run-time
* This method should return false if the model is not running
* @param modelId
*/
async isModelRunning(modelId: string): Promise<boolean> {
return true;
}

/**
* Loads a model into the engine.
* There are model settings such as `ngl` and `ctx_len` that can be passed to the engine.
* Applicable for local engines only
* @param model
* @param settingParams
*/
async loadModel(
model: Model,
settingParams?: ModelSettingParams,
): Promise<void> {}

/**
* Unloads a model from the engine.
* @param modelId
*/
async unloadModel(modelId: string): Promise<void> {}

}
9 changes: 9 additions & 0 deletions cortex-js/src/domain/models/model.interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ export interface ModelSettingParams {
* The number of layers to load onto the GPU for acceleration.
*/
ngl?: number;

/**
* Support embedding or not (legacy)
*/
embedding?: boolean;

/**
Expand Down Expand Up @@ -117,6 +121,11 @@ export interface ModelSettingParams {
* To enable mmap, default is true
*/
use_mmap?: boolean;

/**
* Model type we want to use: llm or embedding, default value is llm (latest llama.cpp update)
*/
model_type?: string;
}

/**
Expand Down
4 changes: 3 additions & 1 deletion cortex-js/src/infrastructure/constants/cortex.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ export const defaultCortexJsPort = 1337;

export const defaultCortexCppHost = '127.0.0.1';
export const defaultCortexCppPort = 3929;

export const defaultEmbeddingModel = 'nomic-embed-text-v1';
// CORTEX CPP
export const CORTEX_CPP_EMBEDDINGS_URL = (
host: string = defaultCortexCppHost,
Expand Down Expand Up @@ -50,4 +52,4 @@ export const CUDA_DOWNLOAD_URL =

export const telemetryServerUrl = 'https://telemetry.jan.ai';

export const MIN_CUDA_VERSION = '12.3';
export const MIN_CUDA_VERSION = '12.3';
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { DownloadManagerModule } from '@/infrastructure/services/download-manage
import { EventEmitterModule } from '@nestjs/event-emitter';
import { TelemetryModule } from '@/usecases/telemetry/telemetry.module';
import { FileManagerModule } from '../services/file-manager/file-manager.module';
import { ModelsModule } from '@/usecases/models/models.module';

describe('ChatController', () => {
let controller: ChatController;
Expand All @@ -25,6 +26,7 @@ describe('ChatController', () => {
EventEmitterModule.forRoot(),
TelemetryModule,
FileManagerModule,
ModelsModule,
],
controllers: [ChatController],
providers: [ChatUsecases],
Expand Down
32 changes: 13 additions & 19 deletions cortex-js/src/infrastructure/controllers/chat.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,27 +35,21 @@ export class ChatController {
) {
const { stream } = createChatDto;

if (stream) {
this.chatService
.inference(createChatDto, extractCommonHeaders(headers))
.then((stream) => {
this.chatService
.inference(createChatDto, extractCommonHeaders(headers))
.then((response) => {
if (stream) {
res.header('Content-Type', 'text/event-stream');
stream.pipe(res);
})
.catch((error) =>
res.status(error.statusCode ?? 400).send(error.message),
);
} else {
res.header('Content-Type', 'application/json');
this.chatService
.inference(createChatDto, extractCommonHeaders(headers))
.then((response) => {
response.pipe(res);
} else {
res.header('Content-Type', 'application/json');
res.json(response);
})
.catch((error) =>
res.status(error.statusCode ?? 400).send(error.message),
);
}
}
})
.catch((error) =>
res.status(error.statusCode ?? 400).send(error.message),
);

this.telemetryUsecases.addEventToQueue({
name: EventName.CHAT,
modelId: createChatDto.model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { DownloadManagerModule } from '@/infrastructure/services/download-manage
import { EventEmitterModule } from '@nestjs/event-emitter';
import { TelemetryModule } from '@/usecases/telemetry/telemetry.module';
import { FileManagerModule } from '../services/file-manager/file-manager.module';
import { ModelsModule } from '@/usecases/models/models.module';

describe('EmbeddingsController', () => {
let controller: EmbeddingsController;
Expand All @@ -25,6 +26,7 @@ describe('EmbeddingsController', () => {
EventEmitterModule.forRoot(),
TelemetryModule,
FileManagerModule,
ModelsModule,
],
controllers: [EmbeddingsController],
providers: [ChatUsecases],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Body, Controller, Post, HttpCode } from '@nestjs/common';
import { Body, Controller, Post, HttpCode, Res } from '@nestjs/common';
import { ChatUsecases } from '@/usecases/chat/chat.usecases';
import { ApiOperation, ApiTags, ApiResponse } from '@nestjs/swagger';
import { CreateEmbeddingsDto } from '../dtos/embeddings/embeddings-request.dto';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ export class CreateMessageDto implements Partial<Message> {
example: 'user',
description: 'The sources of the messages.',
})
@IsString()
role: 'user' | 'assistant';

@ApiProperty({
Expand Down
35 changes: 33 additions & 2 deletions cortex-js/src/infrastructure/providers/cortex/cortex.provider.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import { Injectable } from '@nestjs/common';
import { HttpStatus, Injectable } from '@nestjs/common';
import { OAIEngineExtension } from '@/domain/abstracts/oai.abstract';
import { PromptTemplate } from '@/domain/models/prompt-template.interface';
import { join } from 'path';
import { Model, ModelSettingParams } from '@/domain/models/model.interface';
import { HttpService } from '@nestjs/axios';
import {
CORTEX_CPP_MODELS_URL,
defaultCortexCppHost,
defaultCortexCppPort,
} from '@/infrastructure/constants/cortex';
Expand All @@ -13,6 +14,11 @@ import { normalizeModelId } from '@/utils/normalize-model-id';
import { firstValueFrom } from 'rxjs';
import { FileManagerService } from '@/infrastructure/services/file-manager/file-manager.service';

export interface ModelStatResponse {
object: string;
data: any;
}

@Injectable()
export default class CortexProvider extends OAIEngineExtension {
apiUrl = `http://${defaultCortexCppHost}:${defaultCortexCppPort}/inferences/server/chat_completion`;
Expand All @@ -28,11 +34,12 @@ export default class CortexProvider extends OAIEngineExtension {

constructor(
protected readonly httpService: HttpService,
private readonly fileManagerService: FileManagerService,
protected readonly fileManagerService: FileManagerService,
) {
super(httpService);
}

// Override the inference method to make an inference request to the engine
override async loadModel(
model: Model,
settings?: ModelSettingParams,
Expand Down Expand Up @@ -92,6 +99,30 @@ export default class CortexProvider extends OAIEngineExtension {
).then(); // pipe error or void instead of throwing
}

// Override the isModelRunning method to check if the model is running
override async isModelRunning(modelId: string): Promise<boolean> {
const configs = await this.fileManagerService.getConfig();

return firstValueFrom(
this.httpService.get(
CORTEX_CPP_MODELS_URL(configs.cortexCppHost, configs.cortexCppPort),
),
)
.then((res) => {
const data = res.data as ModelStatResponse;
if (
res.status === HttpStatus.OK &&
data &&
Array.isArray(data.data) &&
data.data.length > 0
) {
return data.data.find((e) => e.id === modelId);
}
return false;
})
.catch(() => false);
}

private readonly promptTemplateConverter = (
promptTemplate: string,
): PromptTemplate => {
Expand Down
2 changes: 2 additions & 0 deletions cortex-js/src/usecases/chat/chat.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { ModelRepositoryModule } from '@/infrastructure/repositories/models/mode
import { HttpModule } from '@nestjs/axios';
import { TelemetryModule } from '../telemetry/telemetry.module';
import { FileManagerModule } from '@/infrastructure/services/file-manager/file-manager.module';
import { ModelsModule } from '../models/models.module';

@Module({
imports: [
Expand All @@ -15,6 +16,7 @@ import { FileManagerModule } from '@/infrastructure/services/file-manager/file-m
HttpModule,
TelemetryModule,
FileManagerModule,
ModelsModule,
],
controllers: [],
providers: [ChatUsecases],
Expand Down
2 changes: 2 additions & 0 deletions cortex-js/src/usecases/chat/chat.usecases.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { HttpModule } from '@nestjs/axios';
import { DownloadManagerModule } from '@/infrastructure/services/download-manager/download-manager.module';
import { EventEmitterModule } from '@nestjs/event-emitter';
import { FileManagerModule } from '@/infrastructure/services/file-manager/file-manager.module';
import { ModelsModule } from '../models/models.module';

describe('ChatService', () => {
let service: ChatUsecases;
Expand All @@ -24,6 +25,7 @@ describe('ChatService', () => {
DownloadManagerModule,
EventEmitterModule.forRoot(),
FileManagerModule,
ModelsModule,
],
providers: [ChatUsecases],
exports: [ChatUsecases],
Expand Down
38 changes: 34 additions & 4 deletions cortex-js/src/usecases/chat/chat.usecases.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,24 @@ import { EngineExtension } from '@/domain/abstracts/engine.abstract';
import { ModelNotFoundException } from '@/infrastructure/exception/model-not-found.exception';
import { TelemetryUsecases } from '../telemetry/telemetry.usecases';
import { TelemetrySource } from '@/domain/telemetry/telemetry.interface';
import { ModelRepository } from '@/domain/repositories/model.interface';
import { ExtensionRepository } from '@/domain/repositories/extension.interface';
import { firstValueFrom } from 'rxjs';
import { HttpService } from '@nestjs/axios';
import { CORTEX_CPP_EMBEDDINGS_URL } from '@/infrastructure/constants/cortex';
import {
CORTEX_CPP_EMBEDDINGS_URL,
defaultEmbeddingModel,
} from '@/infrastructure/constants/cortex';
import { CreateEmbeddingsDto } from '@/infrastructure/dtos/embeddings/embeddings-request.dto';
import { FileManagerService } from '@/infrastructure/services/file-manager/file-manager.service';
import { Engines } from '@/infrastructure/commanders/types/engine.interface';
import { ModelsUsecases } from '../models/models.usecases';

@Injectable()
export class ChatUsecases {
constructor(
private readonly modelRepository: ModelRepository,
private readonly extensionRepository: ExtensionRepository,
private readonly telemetryUseCases: TelemetryUsecases,
private readonly modelsUsescases: ModelsUsecases,
private readonly httpService: HttpService,
private readonly fileService: FileManagerService,
) {}
Expand All @@ -28,10 +31,18 @@ export class ChatUsecases {
headers: Record<string, string>,
): Promise<any> {
const { model: modelId } = createChatDto;
const model = await this.modelRepository.findOne(modelId);
const model = await this.modelsUsescases.findOne(modelId);
if (!model) {
throw new ModelNotFoundException(modelId);
}

const isModelRunning = await this.modelsUsescases.isModelRunning(modelId);
// If model is not running
// Start the model
if (!isModelRunning) {
await this.modelsUsescases.startModel(modelId);
}

const engine = (await this.extensionRepository.findOne(
model!.engine ?? Engines.llamaCPP,
)) as EngineExtension | undefined;
Expand Down Expand Up @@ -59,7 +70,26 @@ export class ChatUsecases {
* @returns Embedding vector.
*/
async embeddings(dto: CreateEmbeddingsDto) {
const modelId = dto.model ?? defaultEmbeddingModel;

if (modelId !== dto.model) dto = { ...dto, model: modelId };

if (!(await this.modelsUsescases.findOne(modelId))) {
await this.modelsUsescases.pullModel(modelId);
}

const isModelRunning = await this.modelsUsescases.isModelRunning(modelId);
// If model is not running
// Start the model
if (!isModelRunning) {
await this.modelsUsescases.startModel(modelId, {
embedding: true,
model_type: 'embedding',
});
}

const configs = await this.fileService.getConfig();

return firstValueFrom(
this.httpService.post(
CORTEX_CPP_EMBEDDINGS_URL(configs.cortexCppHost, configs.cortexCppPort),
Expand Down
21 changes: 19 additions & 2 deletions cortex-js/src/usecases/models/models.usecases.ts
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,11 @@ export class ModelsUsecases {
await promises.mkdir(modelFolder, { recursive: true }).catch(() => {});

let files = (await fetchJanRepoData(originModelId)).siblings;

// HuggingFace GGUF Repo - Only one file is downloaded
if (originModelId.includes('/') && selection && files.length) {
try {
files = [await selection(files)];
files = [await selection(files)];
} catch (e) {
const modelEvent: ModelEvent = {
model: modelId,
Expand Down Expand Up @@ -498,6 +498,23 @@ export class ModelsUsecases {
return this.activeModelStatuses;
}

/**
* Check whether the model is running in the Cortex C++ server
*/
async isModelRunning(modelId: string): Promise<boolean> {
const model = await this.getModelOrThrow(modelId).catch((e) => undefined);

if (!model) return false;

const engine = (await this.extensionRepository.findOne(
model.engine ?? Engines.llamaCPP,
)) as EngineExtension | undefined;

if (!engine) return false;

return engine.isModelRunning(modelId);
}

/**
* Check whether the download file is valid or not
* @param file
Expand Down
Loading
Loading