From 81679eac0bc2cd0fb3d87f8b62910081d894db0c Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 8 Mar 2024 21:46:23 +0000 Subject: [PATCH] implement missing softmax in llama.cpp sampling (doh!) --- rllm/rllm-base/src/logits.rs | 3 ++- rllm/rllm-llamacpp/src/llamacpp/tmodel.rs | 7 ++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/rllm/rllm-base/src/logits.rs b/rllm/rllm-base/src/logits.rs index bb635682..274e798d 100644 --- a/rllm/rllm-base/src/logits.rs +++ b/rllm/rllm-base/src/logits.rs @@ -18,7 +18,8 @@ impl LogitsProcessor { }; Self { - rng: rand::rngs::StdRng::seed_from_u64(42), + rng: rand::rngs::StdRng::from_entropy(), + // seed_from_u64(42), temperature, top_p: sampling_params.top_p, } diff --git a/rllm/rllm-llamacpp/src/llamacpp/tmodel.rs b/rllm/rllm-llamacpp/src/llamacpp/tmodel.rs index 64e99e8d..af212e94 100644 --- a/rllm/rllm-llamacpp/src/llamacpp/tmodel.rs +++ b/rllm/rllm-llamacpp/src/llamacpp/tmodel.rs @@ -150,9 +150,14 @@ impl ModelExec for TModel { None => self.sample_argmax(&logits), Some(temperature) => { let mut prs: Vec = logits.to_vec1(); + let max_logit = prs.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)); let temp = (1.0 / temperature) as f32; for idx in 0..prs.len() { - prs[idx] *= temp; + prs[idx] = ((prs[idx] - max_logit) * temp).exp(); + } + let sum = prs.iter().sum::(); + for idx in 0..prs.len() { + prs[idx] /= sum; } let top_p = state.top_p; if top_p <= 0.0 || top_p >= 1.0 {