Skip to content

Commit

Permalink
implement missing softmax in llama.cpp sampling (doh!)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Mar 8, 2024
1 parent b93225f commit 81679ea
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
3 changes: 2 additions & 1 deletion rllm/rllm-base/src/logits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ impl LogitsProcessor {
};

Self {
rng: rand::rngs::StdRng::seed_from_u64(42),
rng: rand::rngs::StdRng::from_entropy(),
// seed_from_u64(42),
temperature,
top_p: sampling_params.top_p,
}
Expand Down
7 changes: 6 additions & 1 deletion rllm/rllm-llamacpp/src/llamacpp/tmodel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,14 @@ impl ModelExec for TModel {
None => self.sample_argmax(&logits),
Some(temperature) => {
let mut prs: Vec<f32> = logits.to_vec1();
let max_logit = prs.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let temp = (1.0 / temperature) as f32;
for idx in 0..prs.len() {
prs[idx] *= temp;
prs[idx] = ((prs[idx] - max_logit) * temp).exp();
}
let sum = prs.iter().sum::<f32>();
for idx in 0..prs.len() {
prs[idx] /= sum;
}
let top_p = state.top_p;
if top_p <= 0.0 || top_p >= 1.0 {
Expand Down

0 comments on commit 81679ea

Please sign in to comment.