Skip to content

Commit

Permalink
feat: add engine init endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-jan committed Jul 17, 2024
1 parent b562ecb commit 1890da8
Show file tree
Hide file tree
Showing 12 changed files with 144 additions and 125 deletions.
2 changes: 2 additions & 0 deletions cortex-js/src/domain/abstracts/engine.abstract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ export abstract class EngineExtension extends Extension {

transformResponse?: Function;

initalized: boolean = false;

abstract inference(
dto: any,
headers: Record<string, string>,
Expand Down
3 changes: 3 additions & 0 deletions cortex-js/src/domain/abstracts/extension.abstract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ export abstract class Extension {
/** @type {string} Extension's version. */
version?: string;

/** @type {boolean} Whether the extension is initialized or not. */
initalized: boolean;

/**
* Called when the extension is loaded.
* Any initialization logic for the extension should be put here.
Expand Down
40 changes: 22 additions & 18 deletions cortex-js/src/extensions/anthropic.engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export default class AnthropicEngineExtension extends OAIEngineExtension {
productName = 'Anthropic Inference Engine';
description = 'This extension enables Anthropic chat completion API calls';
version = '0.0.1';
initalized = true;
apiKey?: string;

constructor(
Expand All @@ -39,42 +40,45 @@ export default class AnthropicEngineExtension extends OAIEngineExtension {
this.apiKey = configs?.apiKey;
}

override async inference(dto: any, headers: Record<string, string>): Promise<stream.Readable | any> {
headers['x-api-key'] = this.apiKey as string
headers['Content-Type'] = 'application/json'
headers['anthropic-version'] = '2023-06-01'
return super.inference(dto, headers)
override async inference(
dto: any,
headers: Record<string, string>,
): Promise<stream.Readable | any> {
headers['x-api-key'] = this.apiKey as string;
headers['Content-Type'] = 'application/json';
headers['anthropic-version'] = '2023-06-01';
return super.inference(dto, headers);
}

transformPayload = (data: any): any => {
return _.pick(data, ['messages', 'model', 'stream', 'max_tokens']);
}
};

transformResponse = (data: any): string => {
// handling stream response
if (typeof data === 'string' && data.trim().length === 0) {
return '';
return '';
}
if (typeof data === 'string' && data.startsWith('event: ')) {
return ''
return '';
}
if (typeof data === 'string' && data.startsWith('data: ')) {
data = data.replace('data: ', '');
const parsedData = JSON.parse(data);
if (parsedData.type !== 'content_block_delta') {
return ''
return '';
}
const text = parsedData.delta?.text;
//convert to have this format data.choices[0]?.delta?.content
return JSON.stringify({
choices: [
{
delta: {
content: text
}
}
]
})
{
delta: {
content: text,
},
},
],
});
}
// non-stream response
if (data.content && data.content.length > 0 && data.content[0].text) {
Expand All @@ -88,8 +92,8 @@ export default class AnthropicEngineExtension extends OAIEngineExtension {
],
});
}

console.error('Invalid response format:', data);
return '';
}
};
}
1 change: 1 addition & 0 deletions cortex-js/src/extensions/groq.engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ export default class GroqEngineExtension extends OAIEngineExtension {
productName = 'Groq Inference Engine';
description = 'This extension enables fast Groq chat completion API calls';
version = '0.0.1';
initalized = true;
apiKey?: string;

constructor(
Expand Down
1 change: 1 addition & 0 deletions cortex-js/src/extensions/mistral.engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ export default class MistralEngineExtension extends OAIEngineExtension {
productName = 'Mistral Inference Engine';
description = 'This extension enables Mistral chat completion API calls';
version = '0.0.1';
initalized = true;
apiKey?: string;

constructor(
Expand Down
1 change: 1 addition & 0 deletions cortex-js/src/extensions/openai.engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ export default class OpenAIEngineExtension extends OAIEngineExtension {
productName = 'OpenAI Inference Engine';
description = 'This extension enables OpenAI chat completion API calls';
version = '0.0.1';
initalized = true;
apiKey?: string;

constructor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import { FileManagerModule } from '@/infrastructure/services/file-manager/file-m
import { PSCliUsecases } from './ps.cli.usecases';
import { BenchmarkCliUsecases } from './benchmark.cli.usecases';
import { TelemetryModule } from '@/usecases/telemetry/telemetry.module';
import { DownloadManagerModule } from '@/infrastructure/services/download-manager/download-manager.module';

@Module({
imports: [
Expand All @@ -25,6 +26,7 @@ import { TelemetryModule } from '@/usecases/telemetry/telemetry.module';
MessagesModule,
FileManagerModule,
TelemetryModule,
DownloadManagerModule,
],
providers: [
InitCliUsecases,
Expand Down
143 changes: 45 additions & 98 deletions cortex-js/src/infrastructure/commanders/usecases/init.cli.usecases.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@ import { Engines } from '../types/engine.interface';

import { cpuInfo } from 'cpu-instructions';
import ora from 'ora';
import { DownloadManagerService } from '@/infrastructure/services/download-manager/download-manager.service';
import { DownloadType } from '@/domain/models/download.interface';

@Injectable()
export class InitCliUsecases {
constructor(
private readonly httpService: HttpService,
private readonly fileManagerService: FileManagerService,
private readonly downloadManagerService: DownloadManagerService,
) {}

/**
Expand Down Expand Up @@ -135,53 +138,25 @@ export class InitCliUsecases {
const destination = join(dataFolderPath, 'cuda-toolkit.tar.gz');

console.log('Downloading CUDA Toolkit dependency...');
const download = await firstValueFrom(
this.httpService.get(url, {
responseType: 'stream',
}),
);

if (!download) {
console.log('Failed to download dependency');
process.exit(1);
}

await new Promise((resolve, reject) => {
const writer = createWriteStream(destination);
let receivedBytes = 0;
const totalBytes = download.headers['content-length'];

writer.on('finish', () => {
bar.stop();
resolve(true);
});

writer.on('error', (error) => {
bar.stop();
reject(error);
});

const bar = new SingleBar({}, Presets.shades_classic);
bar.start(100, 0);

download.data.on('data', (chunk: any) => {
receivedBytes += chunk.length;
bar.update(Math.floor((receivedBytes / totalBytes) * 100));
});

download.data.pipe(writer);
});

try {
await decompress(
destination,
await this.fileManagerService.getCortexCppEnginePath(),
);
} catch (e) {
console.log(e);
exit(1);
}
await rm(destination, { force: true });
await this.downloadManagerService.submitDownloadRequest(
url,
'Cuda Toolkit Dependencies',
DownloadType.Miscelanous,
{ [url]: destination },
async () => {
try {
await decompress(
destination,
await this.fileManagerService.getCortexCppEnginePath(),
);
} catch (e) {
console.log(e);
exit(1);
}
await rm(destination, { force: true });
},
);
};

private detectInstructions = (): Promise<
Expand Down Expand Up @@ -247,59 +222,31 @@ export class InitCliUsecases {

if (!existsSync(engineDir)) mkdirSync(engineDir, { recursive: true });

const download = await firstValueFrom(
this.httpService.get(toDownloadAsset.browser_download_url, {
responseType: 'stream',
}),
);
if (!download) {
console.log('Failed to download model');
process.exit(1);
}

const destination = join(engineDir, toDownloadAsset.name);

await new Promise((resolve, reject) => {
const writer = createWriteStream(destination);
let receivedBytes = 0;
const totalBytes = download.headers['content-length'];

writer.on('finish', () => {
bar.stop();
resolve(true);
});

writer.on('error', (error) => {
bar.stop();
reject(error);
});

const bar = new SingleBar({}, Presets.shades_classic);
bar.start(100, 0);

download.data.on('data', (chunk: any) => {
receivedBytes += chunk.length;
bar.update(Math.floor((receivedBytes / totalBytes) * 100));
});

download.data.pipe(writer);
});

const decompressIndicator = ora('Decompressing engine...').start();
try {
await decompress(destination, engineDir);
} catch (e) {
console.error('Error decompressing file', e);
exit(1);
}
await rm(destination, { force: true });

// Copy the additional files to the cortex-cpp directory
for (const file of readdirSync(join(engineDir, engine))) {
if (!file.includes('engine')) {
await cpSync(join(engineDir, engine, file), join(engineDir, file));
}
}
decompressIndicator.succeed('Engine decompressed');
await this.downloadManagerService.submitDownloadRequest(
toDownloadAsset.browser_download_url,
engine,
DownloadType.Miscelanous,
{ [toDownloadAsset.browser_download_url]: destination },
async () => {
const decompressIndicator = ora('Decompressing engine...').start();
try {
await decompress(destination, engineDir);
} catch (e) {
console.error('Error decompressing file', e);
exit(1);
}
await rm(destination, { force: true });

// Copy the additional files to the cortex-cpp directory
for (const file of readdirSync(join(engineDir, engine))) {
if (!file.includes('engine')) {
await cpSync(join(engineDir, engine, file), join(engineDir, file));
}
}
decompressIndicator.succeed('Engine decompressed');
},
);
}
}
32 changes: 31 additions & 1 deletion cortex-js/src/infrastructure/controllers/engines.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,23 @@ import {
Param,
HttpCode,
UseInterceptors,
Post,
} from '@nestjs/common';
import { ApiOperation, ApiParam, ApiTags, ApiResponse } from '@nestjs/swagger';
import { TransformInterceptor } from '../interceptors/transform.interceptor';
import { EnginesUsecases } from '@/usecases/engines/engines.usecase';
import { EngineDto } from '../dtos/engines/engines.dto';
import { InitCliUsecases } from '../commanders/usecases/init.cli.usecases';
import { CommonResponseDto } from '../dtos/common/common-response.dto';

@ApiTags('Engines')
@Controller('engines')
@UseInterceptors(TransformInterceptor)
export class EnginesController {
constructor(private readonly enginesUsecases: EnginesUsecases) {}
constructor(
private readonly enginesUsecases: EnginesUsecases,
private readonly initUsescases: InitCliUsecases,
) {}

@HttpCode(200)
@ApiResponse({
Expand Down Expand Up @@ -52,4 +58,28 @@ export class EnginesController {
findOne(@Param('name') name: string) {
return this.enginesUsecases.getEngine(name);
}

@HttpCode(200)
@ApiResponse({
status: 200,
description: 'Ok',
type: CommonResponseDto,
})
@ApiOperation({
summary: 'Initialize an engine',
description:
'Initializes an engine instance with the given name. It will download the engine if it is not available locally.',
})
@ApiParam({
name: 'name',
required: true,
description: 'The unique identifier of the engine.',
})
@Post(':name(*)/init')
initialize(@Param('name') name: string) {
this.initUsescases.installEngine(undefined, 'latest', name, true);
return {
message: 'Engine initialization started successfully.',
};
}
}
13 changes: 10 additions & 3 deletions cortex-js/src/infrastructure/dtos/engines/engines.dto.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import { Extension } from '@/domain/abstracts/extension.abstract';
import { ApiProperty } from '@nestjs/swagger';
import { IsOptional, IsString } from 'class-validator';
import { IsBoolean, IsOptional, IsString } from 'class-validator';

export class EngineDto implements Partial<Extension> {
@ApiProperty({
type: String,
example: 'cortex.llamacpp',
description:
'The name of the engine that you want to retrieve.',
description: 'The name of the engine that you want to retrieve.',
})
@IsString()
name: string;
Expand Down Expand Up @@ -39,4 +38,12 @@ export class EngineDto implements Partial<Extension> {
@IsString()
@IsOptional()
version?: string;

@ApiProperty({
type: String,
example: true,
description: 'Whether the engine is initialized or not.',
})
@IsBoolean()
initalized?: boolean;
}
Loading

0 comments on commit 1890da8

Please sign in to comment.