diff --git a/aicirt/src/main.rs b/aicirt/src/main.rs index 5bb2c87f..4eaaa77b 100644 --- a/aicirt/src/main.rs +++ b/aicirt/src/main.rs @@ -398,7 +398,9 @@ impl ModuleRegistry { fn run_main(&self, req_id: &String) -> Result<()> { let req_instances = self.req_instances.lock().unwrap(); - let inst = req_instances.get(req_id).ok_or_else(|| anyhow!("invalid req_id"))?; + let inst = req_instances + .get(req_id) + .ok_or_else(|| anyhow!("invalid req_id"))?; inst.run_main() } @@ -715,6 +717,7 @@ impl Stepper { } fn aici_post_process(&mut self, req: AiciPostProcessReq) -> Result { + let timing = false; let t0 = Instant::now(); // this if forking due to n= parameter in sampling @@ -727,13 +730,12 @@ impl Stepper { let mut used_ids = Vec::new(); let mut outputs = HashMap::default(); - log::warn!( - "post_process0: {:?} {} {}", - t0.elapsed(), - self.instances.len(), - req.ops.len() - ); + let mut all_dur = Vec::new(); + if timing { + all_dur.push(t0.elapsed()); + } + let mut prev = Instant::now(); for op in req.ops.into_iter() { let instid = op.id; if let Ok(h) = self.get_worker(instid) { @@ -751,15 +753,15 @@ impl Stepper { } else { log::warn!("invalid id {}", instid); } + if timing { + all_dur.push(prev.elapsed()); + prev = Instant::now(); + } } let deadline = Instant::now() + std::time::Duration::from_millis(self.limits.max_pre_step_ms); - let mut all_dur = Vec::new(); - all_dur.push(t0.elapsed()); - - let mut prev = Instant::now(); for id in used_ids { let h = self.get_worker(id).unwrap(); let timeout = deadline.saturating_duration_since(Instant::now()); @@ -769,11 +771,19 @@ impl Stepper { } Err(e) => self.worker_error(id, &mut outputs, e), } - all_dur.push(prev.elapsed()); - prev = Instant::now(); + if timing { + all_dur.push(prev.elapsed()); + prev = Instant::now(); + } } - log::warn!("post_process: {:?}", all_dur); + if timing { + log::warn!( + "post_process time: {:?} {:?}", + t0.elapsed(), + all_dur.iter().map(|x| x.as_micros()).collect::>() + ); + } Ok(AiciPostProcessResp { seqs: outputs }) } @@ -955,7 +965,7 @@ impl CmdRespChannel { } fn bench_hashmap() { - let mut h = HashMap::::default(); + let mut h = HashMap::::default(); for x in 10..50 { h.insert(x, x * x); } diff --git a/harness/bench_server.py b/harness/bench_server.py index e2d44bae..57deea07 100644 --- a/harness/bench_server.py +++ b/harness/bench_server.py @@ -10,7 +10,7 @@ bench_py = """ import pyaici.server as aici async def main(): - await aici.gen_tokens(max_tokens=5) + await aici.gen_tokens(max_tokens=45) aici.start(main()) """