Skip to content

Commit

Permalink
Merge pull request #149 from janhq/hotfix_embedding
Browse files Browse the repository at this point in the history
Hotfix embedding
  • Loading branch information
tikikun authored Nov 17, 2023
2 parents 0f66207 + 6aa879b commit 20d3be8
Showing 1 changed file with 40 additions and 5 deletions.
45 changes: 40 additions & 5 deletions controllers/llamaCPP.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,40 @@ std::shared_ptr<State> createState(int task_id, llamaCPP *instance) {

// --------------------------------------------

std::string create_embedding_payload(const std::vector<float> &embedding,
int prompt_tokens) {
Json::Value root;

root["object"] = "list";

Json::Value dataArray(Json::arrayValue);
Json::Value dataItem;

dataItem["object"] = "embedding";

Json::Value embeddingArray(Json::arrayValue);
for (const auto &value : embedding) {
embeddingArray.append(value);
}
dataItem["embedding"] = embeddingArray;
dataItem["index"] = 0;

dataArray.append(dataItem);
root["data"] = dataArray;

root["model"] = "_";

Json::Value usage;
usage["prompt_tokens"] = prompt_tokens;
usage["total_tokens"] = prompt_tokens; // Assuming total tokens equals prompt
// tokens in this context
root["usage"] = usage;

Json::StreamWriterBuilder writer;
writer["indentation"] = ""; // Compact output
return Json::writeString(writer, root);
}

std::string create_full_return_json(const std::string &id,
const std::string &model,
const std::string &content,
Expand Down Expand Up @@ -245,17 +279,18 @@ void llamaCPP::embedding(
const auto &jsonBody = req->getJsonObject();

json prompt;
if (jsonBody->isMember("content") != 0) {
prompt = (*jsonBody)["content"].asString();
if (jsonBody->isMember("input") != 0) {
prompt = (*jsonBody)["input"].asString();
} else {
prompt = "";
}
const int task_id = llama.request_completion(
{{"prompt", prompt}, {"n_predict", 0}}, false, true);
task_result result = llama.next_result(task_id);
std::string embeddingResp = result.result_json.dump();
std::vector<float> embedding_result = result.result_json["embedding"];
auto resp = nitro_utils::nitroHttpResponse();
resp->setBody(embeddingResp);
std::string embedding_resp = create_embedding_payload(embedding_result, 0);
resp->setBody(embedding_resp);
resp->setContentTypeString("application/json");
callback(resp);
return;
Expand Down Expand Up @@ -363,7 +398,7 @@ void llamaCPP::loadModel(
llama.initialize();

Json::Value jsonResp;
jsonResp["message"] = "Failed to load model";
jsonResp["message"] = "Model loaded successfully";
model_loaded = true;
auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp);

Expand Down

0 comments on commit 20d3be8

Please sign in to comment.