diff --git a/rllm/rllm-base/src/logits.rs b/rllm/rllm-base/src/logits.rs index bb635682..274e798d 100644 --- a/rllm/rllm-base/src/logits.rs +++ b/rllm/rllm-base/src/logits.rs @@ -18,7 +18,8 @@ impl LogitsProcessor { }; Self { - rng: rand::rngs::StdRng::seed_from_u64(42), + rng: rand::rngs::StdRng::from_entropy(), + // seed_from_u64(42), temperature, top_p: sampling_params.top_p, } diff --git a/rllm/rllm-llamacpp/src/llamacpp/tmodel.rs b/rllm/rllm-llamacpp/src/llamacpp/tmodel.rs index 64e99e8d..af212e94 100644 --- a/rllm/rllm-llamacpp/src/llamacpp/tmodel.rs +++ b/rllm/rllm-llamacpp/src/llamacpp/tmodel.rs @@ -150,9 +150,14 @@ impl ModelExec for TModel { None => self.sample_argmax(&logits), Some(temperature) => { let mut prs: Vec = logits.to_vec1(); + let max_logit = prs.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)); let temp = (1.0 / temperature) as f32; for idx in 0..prs.len() { - prs[idx] *= temp; + prs[idx] = ((prs[idx] - max_logit) * temp).exp(); + } + let sum = prs.iter().sum::(); + for idx in 0..prs.len() { + prs[idx] /= sum; } let top_p = state.top_p; if top_p <= 0.0 || top_p >= 1.0 {