diff --git a/rllm/rllm-base/src/engine.rs b/rllm/rllm-base/src/engine.rs index 74fbcba7..029f7baa 100644 --- a/rllm/rllm-base/src/engine.rs +++ b/rllm/rllm-base/src/engine.rs @@ -397,12 +397,11 @@ impl RllmEngine { sg.seqs.extend(to_add); } - let vocab_bytes = vocab_size * 4; - let num_seqs = mid_res.mask_num_bytes / vocab_bytes; let shm = &self.aicirt.as_mut().unwrap().bin_shm; - let slice = shm.slice_at_byte_offset::(0, num_seqs * vocab_size); + let slice = shm.slice_at_byte_offset::(mid_res.first_mask_byte_offset, + mid_res.mask_num_elts * mid_res.num_masks); Ok(( - self.tmodel.new_bias(slice, num_seqs, vocab_size), + self.tmodel.new_bias(slice, mid_res.num_masks, mid_res.mask_num_elts), seq_id_mapping, )) } @@ -487,8 +486,6 @@ impl RllmEngine { let (aici_bias, seq_id_mapping) = with_timer!(self.tim_aici_bias, self.aici_bias(sched_out)?); - let vocab_bytes = self.tok_trie.vocab_size() * 4; - for sg in sched_out.next_seq_groups.iter_mut() { for seq in sg.seqs.iter_mut() { if seq.sched_phase != SchedulingPhase::Running { @@ -512,7 +509,7 @@ impl RllmEngine { _ => { match &seq.aici_sampling { Some(b) => { - let seq_idx = b.sample_mask.unwrap() / vocab_bytes; + let seq_idx = b.sample_mask.unwrap(); aici_bias.apply(&mut logits, seq_idx); } None => {} diff --git a/rllm/rllm-cuda/src/llm/tmodel.rs b/rllm/rllm-cuda/src/llm/tmodel.rs index 4e412d16..e8aba4b2 100644 --- a/rllm/rllm-cuda/src/llm/tmodel.rs +++ b/rllm/rllm-cuda/src/llm/tmodel.rs @@ -277,7 +277,7 @@ pub struct TchAiciBias { impl AiciBias for TchAiciBias { fn apply(&self, logits: &mut Tensor, seq_id: usize) { let bias = self.bias.as_ref().unwrap(); - let bias = bias.i((seq_id as i64, ..)); + let bias = bias.i((seq_id as i64, 0..logits.size()[0])); *logits = &*logits + bias; } } diff --git a/rllm/rllm-llamacpp/src/llamacpp/tmodel.rs b/rllm/rllm-llamacpp/src/llamacpp/tmodel.rs index af212e94..89472173 100644 --- a/rllm/rllm-llamacpp/src/llamacpp/tmodel.rs +++ b/rllm/rllm-llamacpp/src/llamacpp/tmodel.rs @@ -271,7 +271,7 @@ impl AiciBias for CppAiciBias { let sp = seq_id * self.vocab_size; let logits = logits.as_mut_slice(); let bias = bias.as_slice(); - for i in 0..self.vocab_size { + for i in 0..logits.len() { logits[i] += bias[sp + i]; } } diff --git a/scripts/test-parallel.sh b/scripts/test-parallel.sh new file mode 100755 index 00000000..498e0fcc --- /dev/null +++ b/scripts/test-parallel.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +N=10 +if [ -n "$1" ]; then + N=$1 +fi + +mkdir -p tmp +rm -f tmp/fail + +for n in $(seq $N) ; do + echo "Start $n" + if ./scripts/test-pyctrl.sh > tmp/logs-$n.txt 2>&1 ; then + echo "Passed test $n" + else + echo "Failed test $n; see tmp/logs-$n.txt" + echo $n >> tmp/fail + fi & + sleep 1 +done + +wait + +if [ -f tmp/fail ]; then + echo "Some tests failed; see tmp/fail" + exit 1 +fi