Skip to content

Commit

Permalink
use new bias type in rllm
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Apr 19, 2024
1 parent 221fbe2 commit a449bdb
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 9 deletions.
11 changes: 4 additions & 7 deletions rllm/rllm-base/src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,12 +397,11 @@ impl<ME: ModelExec> RllmEngine<ME> {
sg.seqs.extend(to_add);
}

let vocab_bytes = vocab_size * 4;
let num_seqs = mid_res.mask_num_bytes / vocab_bytes;
let shm = &self.aicirt.as_mut().unwrap().bin_shm;
let slice = shm.slice_at_byte_offset::<f32>(0, num_seqs * vocab_size);
let slice = shm.slice_at_byte_offset::<f32>(mid_res.first_mask_byte_offset,
mid_res.mask_num_elts * mid_res.num_masks);
Ok((
self.tmodel.new_bias(slice, num_seqs, vocab_size),
self.tmodel.new_bias(slice, mid_res.num_masks, mid_res.mask_num_elts),
seq_id_mapping,
))
}
Expand Down Expand Up @@ -487,8 +486,6 @@ impl<ME: ModelExec> RllmEngine<ME> {
let (aici_bias, seq_id_mapping) =
with_timer!(self.tim_aici_bias, self.aici_bias(sched_out)?);

let vocab_bytes = self.tok_trie.vocab_size() * 4;

for sg in sched_out.next_seq_groups.iter_mut() {
for seq in sg.seqs.iter_mut() {
if seq.sched_phase != SchedulingPhase::Running {
Expand All @@ -512,7 +509,7 @@ impl<ME: ModelExec> RllmEngine<ME> {
_ => {
match &seq.aici_sampling {
Some(b) => {
let seq_idx = b.sample_mask.unwrap() / vocab_bytes;
let seq_idx = b.sample_mask.unwrap();
aici_bias.apply(&mut logits, seq_idx);
}
None => {}
Expand Down
2 changes: 1 addition & 1 deletion rllm/rllm-cuda/src/llm/tmodel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ pub struct TchAiciBias {
impl AiciBias<Tensor> for TchAiciBias {
fn apply(&self, logits: &mut Tensor, seq_id: usize) {
let bias = self.bias.as_ref().unwrap();
let bias = bias.i((seq_id as i64, ..));
let bias = bias.i((seq_id as i64, 0..logits.size()[0]));
*logits = &*logits + bias;
}
}
2 changes: 1 addition & 1 deletion rllm/rllm-llamacpp/src/llamacpp/tmodel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ impl AiciBias<Tensor> for CppAiciBias {
let sp = seq_id * self.vocab_size;
let logits = logits.as_mut_slice();
let bias = bias.as_slice();
for i in 0..self.vocab_size {
for i in 0..logits.len() {
logits[i] += bias[sp + i];
}
}
Expand Down
27 changes: 27 additions & 0 deletions scripts/test-parallel.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/bin/bash

N=10
if [ -n "$1" ]; then
N=$1
fi

mkdir -p tmp
rm -f tmp/fail

for n in $(seq $N) ; do
echo "Start $n"
if ./scripts/test-pyctrl.sh > tmp/logs-$n.txt 2>&1 ; then
echo "Passed test $n"
else
echo "Failed test $n; see tmp/logs-$n.txt"
echo $n >> tmp/fail
fi &
sleep 1
done

wait

if [ -f tmp/fail ]; then
echo "Some tests failed; see tmp/fail"
exit 1
fi

0 comments on commit a449bdb

Please sign in to comment.