Skip to content

Commit

Permalink
feat: update embedding
Browse files Browse the repository at this point in the history
- add pooling_type & embd_normalize in init / embedding method
- make n_ubatch same with n_batch if embedding enabled
  • Loading branch information
jhen0409 committed Dec 31, 2024
1 parent 55b3f9f commit 8f8e9c8
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 8 deletions.
2 changes: 2 additions & 0 deletions lib/binding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ export type ChatMessage = {
export type LlamaModelOptions = {
model: string
embedding?: boolean
embd_normalize?: number
pooling_type?: number
n_ctx?: number
n_batch?: number
n_threads?: number
Expand Down
20 changes: 15 additions & 5 deletions src/EmbeddingWorker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
#include "LlamaContext.h"

EmbeddingWorker::EmbeddingWorker(const Napi::CallbackInfo &info,
LlamaSessionPtr &sess, std::string text)
: AsyncWorker(info.Env()), Deferred(info.Env()), _sess(sess), _text(text) {}
LlamaSessionPtr &sess, std::string text, common_params &params)
: AsyncWorker(info.Env()), Deferred(info.Env()), _sess(sess), _text(text), _params(params) {}

void EmbeddingWorker::Execute() {
llama_kv_cache_clear(_sess->context());
Expand All @@ -14,20 +14,30 @@ void EmbeddingWorker::Execute() {
}
const int n_embd = llama_n_embd(_sess->model());
do {
auto ctx = _sess->context();
int ret =
llama_decode(_sess->context(),
llama_decode(ctx,
llama_batch_get_one(tokens.data(), tokens.size()));
if (ret < 0) {
SetError("Failed to inference, code: " + std::to_string(ret));
break;
}
const float *embd = llama_get_embeddings_seq(_sess->context(), 0);

float *embd;
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
embd = llama_get_embeddings(ctx);
} else {
embd = llama_get_embeddings_seq(ctx, 0);
}
if (embd == nullptr) {
SetError("Failed to get embeddings");
break;
}
_result.embedding.resize(n_embd);
memcpy(_result.embedding.data(), embd, n_embd * sizeof(float));
std::vector<float> embedding(embd, embd + n_embd), out(embd, embd + n_embd);
common_embd_normalize(embedding.data(), out.data(), n_embd, _params.embd_normalize);
memcpy(_result.embedding.data(), out.data(), n_embd * sizeof(float));
} while (false);
}

Expand Down
3 changes: 2 additions & 1 deletion src/EmbeddingWorker.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class EmbeddingWorker : public Napi::AsyncWorker,
public Napi::Promise::Deferred {
public:
EmbeddingWorker(const Napi::CallbackInfo &info, LlamaSessionPtr &sess,
std::string text);
std::string text, common_params &params);

protected:
void Execute();
Expand All @@ -19,5 +19,6 @@ class EmbeddingWorker : public Napi::AsyncWorker,
private:
LlamaSessionPtr _sess;
std::string _text;
common_params _params;
EmbeddingResult _result;
};
21 changes: 19 additions & 2 deletions src/LlamaContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,18 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
if (params.model.empty()) {
Napi::TypeError::New(env, "Model is required").ThrowAsJavaScriptException();
}
params.embedding = get_option<bool>(options, "embedding", false);

params.n_ctx = get_option<int32_t>(options, "n_ctx", 512);
params.n_batch = get_option<int32_t>(options, "n_batch", 2048);
params.embedding = get_option<bool>(options, "embedding", false);
if (params.embedding) {
// For non-causal models, batch size must be equal to ubatch size
params.n_ubatch = params.n_batch;
}
params.embd_normalize = get_option<int32_t>(options, "embd_normalize", 2);
int32_t pooling_type = get_option<int32_t>(options, "pooling_type", -1);
params.pooling_type = (enum llama_pooling_type) pooling_type;

params.cpuparams.n_threads =
get_option<int32_t>(options, "n_threads", cpu_get_num_math() / 2);
params.n_gpu_layers = get_option<int32_t>(options, "n_gpu_layers", -1);
Expand Down Expand Up @@ -243,8 +252,16 @@ Napi::Value LlamaContext::Embedding(const Napi::CallbackInfo &info) {
Napi::TypeError::New(env, "Context is disposed")
.ThrowAsJavaScriptException();
}
auto options = Napi::Object::New(env);
if (info.Length() >= 2 && info[1].IsObject()) {
options = info[1].As<Napi::Object>();
}

common_params embdParams;
embdParams.embedding = true;
embdParams.embd_normalize = get_option<int32_t>(options, "embd_normalize", 2);
auto text = info[0].ToString().Utf8Value();
auto *worker = new EmbeddingWorker(info, _sess, text);
auto *worker = new EmbeddingWorker(info, _sess, text, embdParams);
worker->Queue();
return worker->Promise();
}
Expand Down

0 comments on commit 8f8e9c8

Please sign in to comment.