Skip to content

Commit

Permalink
llama : add support for GritLM (ggerganov#5959)
Browse files Browse the repository at this point in the history
* add gritlm example

* gritlm results match

* tabs to spaces

* comment out debug printing

* rebase to new embed

* gritlm embeddings are back babeee

* add to gitignore

* allow to toggle embedding mode

* Clean-up GritLM sample code.

* Fix types.

* Flush stdout and output ending newline if streaming.

* mostly style fixes; correct KQ_mask comment

* add causal_attn flag to llama_cparams

* gritml : minor

* llama : minor

---------

Co-authored-by: Douglas Hanley <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
  • Loading branch information
3 people authored Mar 10, 2024
1 parent 2960eae commit bcebd7d
Show file tree
Hide file tree
Showing 7 changed files with 267 additions and 4 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ models-mnt
/embedding
/gguf
/gguf-llama-simple
/gritlm
/imatrix
/infill
/libllama.so
Expand Down
6 changes: 5 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
BUILD_TARGETS = \
main quantize quantize-stats perplexity imatrix embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey tests/test-c.o
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey gritlm tests/test-c.o

# Binaries only useful for tests
TEST_TARGETS = \
Expand Down Expand Up @@ -724,6 +724,10 @@ embedding: examples/embedding/embedding.cpp ggml.o llama.o $(C
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

gritlm: examples/gritlm/gritlm.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
Expand Down
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ else()
add_subdirectory(convert-llama2c-to-ggml)
add_subdirectory(embedding)
add_subdirectory(finetune)
add_subdirectory(gritlm)
add_subdirectory(infill)
add_subdirectory(llama-bench)
add_subdirectory(llava)
Expand Down
5 changes: 5 additions & 0 deletions examples/gritlm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
set(TARGET gritlm)
add_executable(${TARGET} gritlm.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
229 changes: 229 additions & 0 deletions examples/gritlm/gritlm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
#include "common.h"
#include "llama.h"

#include <string>
#include <vector>

// #define GRIT_DEBUG

static float dot_product(const std::vector<float> & v1, const std::vector<float> & v2) {
float dot = 0.0f;
for (uint64_t i = 0; i < v1.size(); ++i) {
dot += v1[i] * v2[i];
}
return dot;
}

static float norm(const std::vector<float> & v) {
return std::sqrt(dot_product(v, v));
}

static float cosine_similarity(const std::vector<float> & v1, const std::vector<float> & v2) {
return dot_product(v1, v2) / (norm(v1) * norm(v2));
}

static std::vector<std::vector<float>> encode(llama_context * ctx, const std::vector<std::string> & sentences, const std::string & instruction) {
std::vector<std::vector<float>> result;

const llama_model * mdl = llama_get_model(ctx);

llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);

for (uint64_t i = 0; i < sentences.size(); i++) {
llama_batch_clear(batch);

const std::string input_string = instruction + sentences[i];

std::vector<llama_token> inputs = llama_tokenize(mdl, input_string, true, false);

const int32_t n_toks = inputs.size();

// GritLM seems to have EOS = ""
// https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18
// inputs.push_back(llama_token_eos(mdl));

// we want to ignore instruction tokens for mean pooling
const int32_t n_inst = llama_tokenize(mdl, instruction, true, false).size();

#ifdef GRIT_DEBUG
// debug tokens - should be matching as referenced in the GritLM sample
std::for_each(inputs.begin(), inputs.end(), [&ctx](llama_token t) {
std::printf("[%u:%s]", t, llama_token_to_piece(ctx, t).c_str());
});
std::printf("\n");
#endif

// add input to batch (this increments n_tokens)
for (int32_t j = 0; j < n_toks; j++) {
llama_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst);
}

// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);
llama_set_causal_attn(ctx, false);

// run model
llama_decode(ctx, batch);

// get embedding dimensions
uint64_t n_embd = llama_n_embd(mdl);

// allocate embedding output
std::vector<float> emb_unorm(n_embd, 0.0f);

// sum up all token embeddings
for (int32_t k = n_inst; k < n_toks; k++) {
float * emb = llama_get_embeddings_ith(ctx, k);
for (uint64_t j = 0; j < n_embd; j++) {
emb_unorm[j] += emb[j];
}
}

// divide by number of tokens (mean pooling)
{
const uint64_t n_sent = n_toks - n_inst;

for (uint64_t j = 0; j < n_embd; j++) {
emb_unorm[j] /= n_sent;
}
}

std::vector<float> emb_norm(emb_unorm.size());
llama_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
result.push_back(emb_norm);

#ifdef GRIT_DEBUG
// print out emb_norm
std::printf("embedding %ld: ", i);
for (uint64_t j = 0; j < n_embd; j++) {
std::printf("%.5f ", emb_norm[j]);
}
std::printf("\n\n");
#endif
}

llama_batch_free(batch);

return result;
}

static std::string generate(llama_context * ctx, const std::string & prompt, bool stream) {
std::string result;

const llama_model * mdl = llama_get_model(ctx);
llama_token eos_token = llama_token_eos(mdl);

llama_kv_cache_clear(ctx);
llama_set_causal_attn(ctx, true);
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);

std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
int32_t i_current_token = 0;

while (true) {
llama_batch_clear(bat);
auto n_inputs = (int32_t)inputs.size();
for (int32_t i = 0; i < n_inputs; i++) {
llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
}
inputs.clear();

llama_decode(ctx, bat);
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);

auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
auto n_candidates = (int32_t)candidates.size();
for (int32_t token = 0; token < n_candidates; token++) {
candidates[token] = llama_token_data{ token, logits[token], 0.0f };
}
auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false };

llama_token token = llama_sample_token_greedy(ctx, &candidates_p);
if (token == eos_token) {
break;
}

std::string piece = llama_token_to_piece(ctx, token);
if (stream) {
std::printf("%s", piece.c_str());
std::fflush(stdout);
}

inputs.push_back(token);

result += piece;
}

if (stream) {
std::printf("\n");
}

llama_batch_free(bat);

return result;
}

static std::string gritlm_instruction(const std::string & instruction) {
return !instruction.empty() ? "<|user|>\n" + instruction + "\n<|embed|>\n" : "<|embed|>\n";
}

int main(int argc, char * argv[]) {
gpt_params params;
if (!gpt_params_parse(argc, argv, params)) {
return 1;
}

llama_model_params mparams = llama_model_params_from_gpt_params(params);
llama_context_params cparams = llama_context_params_from_gpt_params(params);

llama_backend_init();

llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);

// create new context - set to embedding mode
cparams.embeddings = true;
llama_context * ctx = llama_new_context_with_model(mdl, cparams);

// ### Embedding/Representation ###
// samples taken from: https://github.com/ContextualAI/gritlm#basic
{
const std::string instruction = "Given a scientific paper title, retrieve the paper's abstract";

const std::vector<std::string> queries = {
"Bitcoin: A Peer-to-Peer Electronic Cash System",
"Generative Representational Instruction Tuning",
};

const std::vector<std::string> documents = {
"A purely peer-to-peer version of electronic cash would allow online payments to be sent directly from one party to another without going through a financial institution. Digital signatures provide part of the solution, but the main benefits are lost if a trusted third party is still required to prevent double-spending. We propose a solution to the double-spending problem using a peer-to-peer network. The network timestamps transactions by hashing them into an ongoing chain of hash-based proof-of-work, forming a record that cannot be changed without redoing the proof-of-work. The longest chain not only serves as proof of the sequence of events witnessed, but proof that it came from the largest pool of CPU power. As long as a majority of CPU power is controlled by nodes that are not cooperating to attack the network, they'll generate the longest chain and outpace attackers. The network itself requires minimal structure. Messages are broadcast on a best effort basis, and nodes can leave and rejoin the network at will, accepting the longest proof-of-work chain as proof of what happened while they were gone.",
"All text-based language problems can be reduced to either generation or embedding. Current models only perform well at one or the other. We introduce generative representational instruction tuning (GRIT) whereby a large language model is trained to handle both generative and embedding tasks by distinguishing between them through instructions. Compared to other open models, our resulting GritLM 7B sets a new state of the art on the Massive Text Embedding Benchmark (MTEB) and outperforms all models up to its size on a range of generative tasks. By scaling up further, GritLM 8X7B outperforms all open generative language models that we tried while still being among the best embedding models. Notably, we find that GRIT matches training on only generative or embedding data, thus we can unify both at no performance loss. Among other benefits, the unification via GRIT speeds up Retrieval-Augmented Generation (RAG) by > 60% for long documents, by no longer requiring separate retrieval and generation models. Models, code, etc. are freely available at https://github.com/ContextualAI/gritlm.",
};

// No need to add instruction for retrieval documents
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));

const float cosine_sim_q0_d0 = cosine_similarity(q_rep[0], d_rep[0]);
const float cosine_sim_q0_d1 = cosine_similarity(q_rep[0], d_rep[1]);
const float cosine_sim_q1_d0 = cosine_similarity(q_rep[1], d_rep[0]);
const float cosine_sim_q1_d1 = cosine_similarity(q_rep[1], d_rep[1]);

std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[0].c_str(), cosine_sim_q0_d0);
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[1].c_str(), cosine_sim_q0_d1);
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[0].c_str(), cosine_sim_q1_d0);
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1);
}

// ### Generation ###
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
{
const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
std::string response = generate(ctx, prompt, true);
}

llama_free(ctx);
llama_free_model(mdl);
llama_backend_free();

return 0;
}
25 changes: 22 additions & 3 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1744,6 +1744,7 @@ struct llama_cparams {
float defrag_thold;

bool embeddings;
bool causal_attn;
bool offload_kqv;

enum llama_pooling_type pooling_type;
Expand Down Expand Up @@ -3939,6 +3940,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
LLAMA_LOG_INFO("%s: causal attm = %d\n", __func__, hparams.causal_attn);
LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type);
LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type);
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type);
Expand Down Expand Up @@ -8532,7 +8534,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
}

if (hparams.causal_attn) {
GGML_ASSERT(
(hparams.causal_attn || !cparams.causal_attn) &&
"non-causal attention with generative models is not supported"
);

// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
if (cparams.causal_attn) {
const int64_t n_kv = kv_self.n;
const int64_t n_tokens = batch.n_tokens;

Expand Down Expand Up @@ -8560,8 +8568,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
} else {
// non-causal attention attends only the tokens within the batch (i.e. the KV cache is not used)
// when using kv cache, the mask needs to match the kv cache size
const int64_t n_tokens = batch.n_tokens;
const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;

assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));

Expand All @@ -8580,7 +8589,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}

data[h*(n_tokens*n_tokens) + j*n_tokens + i] = f;
data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
}

for (int i = n_tokens; i < n_stride; ++i) {
data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
}
}
}
Expand Down Expand Up @@ -12733,6 +12746,8 @@ struct llama_context * llama_new_context_with_model(
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
}

cparams.causal_attn = hparams.causal_attn;

if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
Expand Down Expand Up @@ -13767,6 +13782,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)
ctx->abort_callback_data = abort_callback_data;
}

void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
ctx->cparams.causal_attn = causal_attn;
}

struct llama_batch llama_batch_get_one(
llama_token * tokens,
int32_t n_tokens,
Expand Down
4 changes: 4 additions & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,10 @@ extern "C" {
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);

// Set whether to use causal attention or not
// If set to true, the model will only attend to the past tokens
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);

// Set abort callback
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);

Expand Down

0 comments on commit bcebd7d

Please sign in to comment.