Skip to content

Commit

Permalink
fix: 963 can not run openai models on windows (#974)
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-jan authored Dec 13, 2023
1 parent f7c7ad5 commit 3266014
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 37 deletions.
5 changes: 4 additions & 1 deletion extensions/inference-openai-extension/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,10 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension {
requestInference(
data?.messages ?? [],
this._engineSettings,
JanInferenceOpenAIExtension._currentModel,
{
...JanInferenceOpenAIExtension._currentModel,
parameters: data.model.parameters,
},
instance.controller
).subscribe({
next: (content) => {
Expand Down
92 changes: 57 additions & 35 deletions extensions/inference-triton-trtllm-extension/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,16 @@ import { EngineSettings } from "./@types/global";
* The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
*/
export default class JanInferenceTritonTrtLLMExtension implements InferenceExtension {
private static readonly _homeDir = 'engines'
private static readonly _engineMetadataFileName = 'triton_trtllm.json'

export default class JanInferenceTritonTrtLLMExtension
implements InferenceExtension
{
private static readonly _homeDir = "engines";
private static readonly _engineMetadataFileName = "triton_trtllm.json";

static _currentModel: Model;

static _engineSettings: EngineSettings = {
"base_url": "",
base_url: "",
};

controller = new AbortController();
Expand All @@ -56,8 +58,8 @@ export default class JanInferenceTritonTrtLLMExtension implements InferenceExten
* Subscribes to events emitted by the @janhq/core package.
*/
onLoad(): void {
fs.mkdir(JanInferenceTritonTrtLLMExtension._homeDir)
JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings()
fs.mkdir(JanInferenceTritonTrtLLMExtension._homeDir);
JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings();

// Events subscription
events.on(EventName.OnMessageSent, (data) =>
Expand Down Expand Up @@ -87,20 +89,31 @@ export default class JanInferenceTritonTrtLLMExtension implements InferenceExten
modelId: string,
settings?: ModelSettingParams
): Promise<void> {
return
return;
}

static async writeDefaultEngineSettings() {
try {
const engine_json = join(JanInferenceTritonTrtLLMExtension._homeDir, JanInferenceTritonTrtLLMExtension._engineMetadataFileName)
const engine_json = join(
JanInferenceTritonTrtLLMExtension._homeDir,
JanInferenceTritonTrtLLMExtension._engineMetadataFileName
);
if (await fs.exists(engine_json)) {
JanInferenceTritonTrtLLMExtension._engineSettings = JSON.parse(await fs.readFile(engine_json))
}
else {
await fs.writeFile(engine_json, JSON.stringify(JanInferenceTritonTrtLLMExtension._engineSettings, null, 2))
JanInferenceTritonTrtLLMExtension._engineSettings = JSON.parse(
await fs.readFile(engine_json)
);
} else {
await fs.writeFile(
engine_json,
JSON.stringify(
JanInferenceTritonTrtLLMExtension._engineSettings,
null,
2
)
);
}
} catch (err) {
console.error(err)
console.error(err);
}
}
/**
Expand Down Expand Up @@ -137,35 +150,39 @@ export default class JanInferenceTritonTrtLLMExtension implements InferenceExten
};

return new Promise(async (resolve, reject) => {
requestInference(data.messages ?? [],
JanInferenceTritonTrtLLMExtension._engineSettings,
JanInferenceTritonTrtLLMExtension._currentModel)
.subscribe({
next: (_content) => {},
complete: async () => {
resolve(message);
},
error: async (err) => {
reject(err);
},
requestInference(
data.messages ?? [],
JanInferenceTritonTrtLLMExtension._engineSettings,
JanInferenceTritonTrtLLMExtension._currentModel
).subscribe({
next: (_content) => {},
complete: async () => {
resolve(message);
},
error: async (err) => {
reject(err);
},
});
});
}

private static async handleModelInit(model: Model) {
if (model.engine !== 'triton_trtllm') { return }
else {
JanInferenceTritonTrtLLMExtension._currentModel = model
JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings()
if (model.engine !== "triton_trtllm") {
return;
} else {
JanInferenceTritonTrtLLMExtension._currentModel = model;
JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings();
// Todo: Check model list with API key
events.emit(EventName.OnModelReady, model)
events.emit(EventName.OnModelReady, model);
// events.emit(EventName.OnModelFail, model)
}
}

private static async handleModelStop(model: Model) {
if (model.engine !== 'triton_trtllm') { return }
events.emit(EventName.OnModelStopped, model)
if (model.engine !== "triton_trtllm") {
return;
}
events.emit(EventName.OnModelStopped, model);
}

/**
Expand All @@ -178,8 +195,10 @@ export default class JanInferenceTritonTrtLLMExtension implements InferenceExten
data: MessageRequest,
instance: JanInferenceTritonTrtLLMExtension
) {
if (data.model.engine !== 'triton_trtllm') { return }

if (data.model.engine !== "triton_trtllm") {
return;
}

const timestamp = Date.now();
const message: ThreadMessage = {
id: ulid(),
Expand All @@ -200,7 +219,10 @@ export default class JanInferenceTritonTrtLLMExtension implements InferenceExten
requestInference(
data?.messages ?? [],
this._engineSettings,
JanInferenceTritonTrtLLMExtension._currentModel,
{
...JanInferenceTritonTrtLLMExtension._currentModel,
parameters: data.model.parameters,
},
instance.controller
).subscribe({
next: (content) => {
Expand Down
1 change: 0 additions & 1 deletion web/hooks/useCreateNewThread.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ export const useCreateNewThread = () => {
created: createdAt,
updated: createdAt,
}

setThreadModelRuntimeParams(thread.id, assistantInfo.model.parameters)

// add the new thread on top of the thread list to the state
Expand Down

0 comments on commit 3266014

Please sign in to comment.