diff --git a/.devcontainer/Dockerfile-vllm b/.devcontainer/Dockerfile-vllm index 36f20884..f57e5225 100644 --- a/.devcontainer/Dockerfile-vllm +++ b/.devcontainer/Dockerfile-vllm @@ -1,7 +1,7 @@ # syntax = edrevo/dockerfile-plus # ^^^ this line enables the INCLUDE+ directive -FROM nvcr.io/nvidia/pytorch:23.09-py3 +FROM nvcr.io/nvidia/pytorch:23.10-py3 INCLUDE+ cuda-settings.dockerfile INCLUDE+ common.dockerfile @@ -15,11 +15,9 @@ RUN pip install -r /tmp/requirements.txt # RUN pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers # takes forever! -RUN pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable +# RUN pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable # RUN pip install typing_extensions==4.5.0 - -RUN pip install -U flash-attn - +# RUN pip install -U flash-attn # RUN pip install torch==2.1.0 nvidia-cuda-runtime # the .so file seems to be missing diff --git a/.devcontainer/vllm-requirements.txt b/.devcontainer/vllm-requirements.txt index 4126a0fe..3a453b68 100644 --- a/.devcontainer/vllm-requirements.txt +++ b/.devcontainer/vllm-requirements.txt @@ -1,19 +1,25 @@ +# vllm: requirements.txt ninja # For faster builds. psutil -ray >= 2.5.1 -pandas # Required for Ray data. +ray >= 2.9 sentencepiece # Required for LLaMA tokenizer. numpy -torch == 2.1.0 -transformers >= 4.33.1 # Required for Code Llama. -xformers >= 0.0.21 +torch == 2.1.2 +transformers >= 4.37.0 # Required for Qwen2 +xformers == 0.0.23.post1 # Required for CUDA 12.1. fastapi -uvicorn -pydantic < 2 # Required for OpenAI server. +uvicorn[standard] +pydantic >= 2.0 # Required for OpenAI server. +aioprometheus[starlette] +pynvml == 11.5.0 +triton >= 2.1.0 +cupy-cuda12x == 12.3.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead. +# vllm: requirements-dev.txt # formatting yapf==0.32.0 -pylint==2.8.2 +toml==0.10.2 +ruff==0.1.5 # type checking mypy==0.991 @@ -24,8 +30,24 @@ types-setuptools # testing pytest pytest-forked +pytest-asyncio +httpx +einops # required for MPT +flash_attn # required for HuggingFace's llama implementation +openai +requests +# ray - XXX +# vllm: requirements-build.txt +# Should be mirrored in pyproject.toml +ninja +packaging +setuptools>=49.4.0 +# torch==2.1.2 - XXX +wheel + +# non-vllm: ujson posix_ipc accelerate -fschat \ No newline at end of file +fschat diff --git a/.vscode/settings.json b/.vscode/settings.json index 90b1a349..7a27eca8 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -67,6 +67,7 @@ "rust" ], "cSpell.words": [ + "actix", "aici", "aicirt", "avgtol", diff --git a/aicirt/src/api.rs b/aicirt/src/api.rs index c0fdcc8c..78e035eb 100644 --- a/aicirt/src/api.rs +++ b/aicirt/src/api.rs @@ -100,6 +100,18 @@ impl SequenceResult { micros: self.micros, } } + pub fn map_result(self, f: F) -> SequenceResult + where + F: FnOnce(T) -> S, + { + SequenceResult { + error: self.error, + result: self.result.map(f), + storage: self.storage, + logs: self.logs, + micros: self.micros, + } + } } #[derive(Serialize, Deserialize)] diff --git a/aicirt/src/lib.rs b/aicirt/src/lib.rs index b1d65b40..36bdf196 100644 --- a/aicirt/src/lib.rs +++ b/aicirt/src/lib.rs @@ -9,8 +9,11 @@ pub mod shm; #[cfg(target_os = "macos")] mod macos; +use std::fmt::Write; + use anyhow::Result; pub use bench::*; +use flexi_logger::style; use flexi_logger::{DeferredNow, Logger, WriteMode}; use log::Record; use thread_priority::{ @@ -27,6 +30,55 @@ pub enum LogMode { Daemon, } +struct LimitedWrite { + limit: usize, + dst: String, +} + +impl Write for LimitedWrite { + fn write_str(&mut self, s: &str) -> std::fmt::Result { + if self.dst.len() > self.limit { + return Err(std::fmt::Error); + } + if self.dst.len() + s.len() < self.limit { + self.dst.push_str(s); + Ok(()) + } else { + let remaining = self.limit - self.dst.len(); + self.dst.push_str(&s[..remaining]); + self.dst.push_str(" (...)"); + Err(std::fmt::Error) + } + } +} + +fn args_to_str(limit: usize, args: &std::fmt::Arguments) -> String { + // let capacity = args.estimated_capacity(); + let mut output = LimitedWrite { + limit, + dst: String::with_capacity(128), + }; + if output.write_fmt(*args).is_err() { + assert!(output.dst.len() > limit); + } + output.dst +} + +fn truncated_format( + w: &mut dyn std::io::Write, + _now: &mut DeferredNow, + record: &Record, +) -> Result<(), std::io::Error> { + let level = record.level(); + write!( + w, + "{} [{}] {}", + style(level).paint(level.to_string()), + record.module_path().unwrap_or(""), + style(level).paint(args_to_str(1000, record.args())) + ) +} + fn daemon_format( w: &mut dyn std::io::Write, now: &mut DeferredNow, @@ -34,16 +86,19 @@ fn daemon_format( ) -> Result<(), std::io::Error> { write!( w, - "[{}] {} {}", + "{} {} [{}] {}", now.format("%Y-%m-%d %H:%M:%S%.3f"), record.level(), - &record.args() + record.module_path().unwrap_or(""), + args_to_str(5000, record.args()) ) } pub fn init_log(mode: LogMode) -> Result<()> { let logger = match mode { - LogMode::Normal => Logger::try_with_env_or_str("info")?.log_to_stdout(), + LogMode::Normal => Logger::try_with_env_or_str("info")? + .format(truncated_format) + .log_to_stdout(), LogMode::Test => { Logger::try_with_env_or_str("debug")?.write_mode(WriteMode::SupportCapture) } diff --git a/aicirt/src/main.rs b/aicirt/src/main.rs index 298adb7c..5c0cb6db 100644 --- a/aicirt/src/main.rs +++ b/aicirt/src/main.rs @@ -476,8 +476,16 @@ impl ModuleRegistry { }) .collect::>(); - ensure_user!(wasm_files.len() > 0, "no wasm files found (selector={:?})", selector); - ensure_user!(wasm_files.len() == 1, "too many wasm files found (selector={:?})", selector); + ensure_user!( + wasm_files.len() > 0, + "no wasm files found (selector={:?})", + selector + ); + ensure_user!( + wasm_files.len() == 1, + "too many wasm files found (selector={:?})", + selector + ); let wasm_file = wasm_files[0]; let upd = wasm_file["updated_at"] @@ -700,15 +708,17 @@ impl Stepper { } outputs.insert( id, - data.json.clone_with(Some(AiciPreProcessResultInner { - suspend: data.suspend, - num_forks: data.num_forks, - ff_tokens: data.ff_tokens, - })), + data.map_result(|pp| { + if pp.suspend { + assert!(pp.num_forks == 1); + } + AiciPreProcessResultInner { + suspend: pp.suspend, + num_forks: pp.num_forks, + ff_tokens: pp.ff_tokens, + } + }), ); - if data.suspend { - assert!(data.num_forks == 1); - } } Err(e) => self.worker_error(id, &mut outputs, e), } diff --git a/aicirt/src/moduleinstance.rs b/aicirt/src/moduleinstance.rs index 95b8d6d4..1154c7e7 100644 --- a/aicirt/src/moduleinstance.rs +++ b/aicirt/src/moduleinstance.rs @@ -4,15 +4,18 @@ use crate::{ setup_linker, AiciLimits, GlobalInfo, ModuleData, LOGIT_BIAS_ALLOW, LOGIT_BIAS_DISALLOW, }, shm::Shm, - worker::{GroupHandle, RtMidProcessArg, RtPreProcessResult}, + worker::{GroupHandle, RtMidProcessArg}, TimerSet, UserError, }; use aici_abi::{ - toktree::TokTrie, InitPromptArg, InitPromptResult, MidProcessResult, PostProcessArg, - PostProcessResult, PreProcessArg, PreProcessResult, TokenId, + toktree::TokTrie, InitPromptArg, MidProcessResult, PostProcessArg, PostProcessResult, + PreProcessArg, PreProcessResult, TokenId, }; use aicirt::{ - api::{AiciMidProcessResultInner, AiciPostProcessResultInner, SequenceResult}, bail_user, bintokens::ByteTokenizer, user_error + api::{AiciMidProcessResultInner, AiciPostProcessResultInner, SequenceResult}, + bail_user, + bintokens::ByteTokenizer, + user_error, }; use anyhow::{anyhow, bail, ensure, Result}; use serde::Deserialize; @@ -98,6 +101,7 @@ pub struct ModuleInstance { store: wasmtime::Store, memory: wasmtime::Memory, instance: wasmtime::Instance, + pending_pre_result: Option, handle: WasmAici, #[allow(dead_code)] limits: AiciLimits, @@ -213,6 +217,7 @@ impl ModuleInstance { memory, instance, limits: ctx.limits, + pending_pre_result: None, }) } @@ -262,8 +267,11 @@ impl ModuleInstance { let mut res = self.do_pre_process_inner(&rtarg)?; if res.ff_tokens.len() > 0 { - ensure!(res.num_forks == 1); - ensure!(res.suspend == false); + ensure!(res.num_forks == 1, "can't fork when returning ff_tokens"); + ensure!( + res.suspend == false, + "can't suspend when returning ff_tokens" + ); ff_tokens.extend_from_slice(&res.ff_tokens); let r_post = self.do_post_process(PostProcessArg { tokens: res.ff_tokens.clone(), @@ -363,7 +371,7 @@ impl ModuleInstance { Ok(AiciPostProcessResultInner { stop: res.stop }) } - fn json_result( + fn seq_result( &mut self, lbl: &str, t0: Instant, @@ -397,16 +405,19 @@ impl ModuleInstance { } } - pub fn pre_process(&mut self, op: PreProcessArg) -> RtPreProcessResult { + pub fn pre_process(&mut self, op: PreProcessArg) -> SequenceResult { let t0 = Instant::now(); + if let Some(pp) = self.pending_pre_result.clone() { + return self.seq_result("pre-cached", t0, Ok(Some(pp))); + } match self.do_pre_process(op) { - Err(e) => RtPreProcessResult::just_json(self.json_result("pre0", t0, Err(e))), - Ok(res) => RtPreProcessResult { - json: self.json_result("pre", t0, Ok(None)), - suspend: res.suspend, - num_forks: res.num_forks, - ff_tokens: res.ff_tokens, - }, + Err(e) => self.seq_result("pre0", t0, Err(e)), + Ok(pp) => { + if pp.ff_tokens.len() > 0 { + self.pending_pre_result = Some(pp.clone()); + } + self.seq_result("pre", t0, Ok(Some(pp))) + } } } @@ -416,9 +427,10 @@ impl ModuleInstance { shm: &Shm, ) -> SequenceResult { let t0 = Instant::now(); + self.pending_pre_result = None; let res = self.do_mid_process(op, shm); // log::info!("mid_process: {:?}", t0.elapsed()); - self.json_result("mid", t0, res) + self.seq_result("mid", t0, res) } pub fn post_process( @@ -426,14 +438,15 @@ impl ModuleInstance { op: PostProcessArg, ) -> SequenceResult { let t0 = Instant::now(); + self.pending_pre_result = None; let res = self.do_post_process(op); match res { Err(e) => { - let mut r = self.json_result("post", t0, Err(e)); + let mut r = self.seq_result("post", t0, Err(e)); r.result = Some(AiciPostProcessResultInner { stop: true }); r } - Ok(res) => self.json_result("post", t0, Ok(Some(res))), + Ok(res) => self.seq_result("post", t0, Ok(Some(res))), } } @@ -441,7 +454,7 @@ impl ModuleInstance { self.store.data_mut().tokenize(s) } - pub fn setup(&mut self, prompt: Vec) -> Result { + fn setup_inner(&mut self, prompt: Vec) -> Result<()> { self.run_init()?; self.handle = self.call_func::<(), WasmAici>("aici_create", ())?; @@ -451,7 +464,27 @@ impl ModuleInstance { .set_process_arg(serde_json::to_vec(&InitPromptArg { prompt })?); self.call_func::("aici_init_prompt", self.handle)?; - let res: InitPromptResult = self.proc_result()?; - Ok(res) + Ok(()) + } + + pub fn setup(&mut self, prompt: Vec) -> SequenceResult { + let t0 = Instant::now(); + match self.setup_inner(prompt) { + Err(err) => self.seq_result("setup", t0, Err(err)), + Ok(()) => match self.do_pre_process(PreProcessArg {}) { + Err(e) => self.seq_result("setup-pre", t0, Err(e)), + Ok(pp) if pp.suspend || pp.num_forks == 0 => self.seq_result( + "setup-pre", + t0, + Err(user_error!("setup-pre asked for suspend")), + ), + Ok(mut pp) => { + // force a single fork; we expect another fork request when pre() runs in its own turn + pp.num_forks = 1; + // don't set self.pending_pre_result -> we assume the results are always handled + self.seq_result("setup-pre", t0, Ok(Some(pp))) + } + }, + } } } diff --git a/aicirt/src/worker.rs b/aicirt/src/worker.rs index 1a1722d2..8acbdd2b 100644 --- a/aicirt/src/worker.rs +++ b/aicirt/src/worker.rs @@ -7,7 +7,7 @@ use crate::{ with_timer, InstantiateReq, TimerRef, UserError, }; use aici_abi::{ - InitPromptResult, MidProcessArg, PostProcessArg, PreProcessArg, StorageCmd, StorageOp, + MidProcessArg, PostProcessArg, PreProcessArg, PreProcessResult, StorageCmd, StorageOp, StorageResp, TokenId, }; use aicirt::{ @@ -176,51 +176,15 @@ impl SeqCmd { } } -#[derive(Serialize, Deserialize, Debug)] -pub struct RtPreProcessResult { - pub json: SequenceResult, - pub suspend: bool, - pub num_forks: usize, - pub ff_tokens: Vec, -} - -impl RtPreProcessResult { - pub fn just_json(json: SequenceResult) -> Self { - RtPreProcessResult { - json, - suspend: false, - num_forks: 1, - ff_tokens: Vec::new(), - } - } -} - #[derive(Serialize, Deserialize, Debug)] enum SeqResp { - Fork { - handle: WireSeqHandle, - }, + Fork { handle: WireSeqHandle }, Ok {}, - InitPrompt { - json: String, - }, - PostPreProcess { - post_json: String, - pre_json: String, - suspend: bool, - num_forks: usize, - ff_tokens: Vec, - }, - MidProcess { - json: String, - }, - Compile { - binary: Vec, - }, - Error { - msg: String, - is_user_error: bool, - }, + InitPrompt { json: String }, + PostPreProcess { post_json: String, pre_json: String }, + MidProcess { json: String }, + Compile { binary: Vec }, + Error { msg: String, is_user_error: bool }, } impl ProcessHandle @@ -389,7 +353,7 @@ impl SeqCtx { inst.tokenize(&p)? }; self.modinst = Some(inst); - let r = self.mutinst().setup(prompt_toks)?; + let r = self.mutinst().setup(prompt_toks); Ok(SeqResp::InitPrompt { json: serde_json::to_string(&r)?, }) @@ -415,10 +379,7 @@ impl SeqCtx { let res = with_timer!(self.pre_timer, self.mutinst().pre_process(data.pre_op)); Ok(SeqResp::PostPreProcess { post_json, - pre_json: serde_json::to_string(&res.json)?, - suspend: res.suspend, - num_forks: res.num_forks, - ff_tokens: res.ff_tokens, + pre_json: serde_json::to_string(&res)?, }) } SeqCmd::MidProcess { data } => { @@ -552,28 +513,20 @@ impl SeqWorkerHandle { timer: &TimerRef, ) -> Result<( Option>, - RtPreProcessResult, + SequenceResult, )> { let r = timer.with(|| self.handle.seq_recv_with_timeout("r-pre_process", timeout)); match r { Ok(SeqResp::PostPreProcess { post_json, pre_json, - suspend, - num_forks, - ff_tokens, }) => Ok(( if post_json.len() > 0 { Some(serde_json::from_str(&post_json)?) } else { None }, - RtPreProcessResult { - json: serde_json::from_str(&pre_json)?, - suspend, - num_forks, - ff_tokens, - }, + serde_json::from_str(&pre_json)?, )), Ok(r) => Err(anyhow!("unexpected response (pre_process) {r:?}")), Err(e) => Err(e.into()), @@ -796,7 +749,7 @@ impl WorkerForker { &self, req: InstantiateReq, module_path: PathBuf, - ) -> Result<(SeqWorkerHandle, InitPromptResult)> { + ) -> Result<(SeqWorkerHandle, SequenceResult)> { let module_arg = match req.module_arg.as_str() { Some(a) => a.to_string(), None => serde_json::to_string(&req.module_arg)?, @@ -842,7 +795,7 @@ impl WorkerForker { Duration::from_millis(self.limits.max_init_ms), )? { SeqResp::InitPrompt { json } => { - let r: InitPromptResult = serde_json::from_str(&json)?; + let r: SequenceResult = serde_json::from_str(&json)?; Ok((res, r)) } r => Err(anyhow!("unexpected response (init prompt) {r:?}")), diff --git a/controllers/aici_abi/src/lib.rs b/controllers/aici_abi/src/lib.rs index 69625545..ae6a418f 100644 --- a/controllers/aici_abi/src/lib.rs +++ b/controllers/aici_abi/src/lib.rs @@ -41,7 +41,7 @@ pub struct SeqId(pub u32); #[derive(Serialize, Deserialize, Debug)] pub struct PreProcessArg {} -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct PreProcessResult { /// If 0 - stop the sequence. /// If 1 - just continue. @@ -55,6 +55,16 @@ pub struct PreProcessResult { pub ff_tokens: Vec, } +impl Default for PreProcessResult { + fn default() -> Self { + PreProcessResult { + num_forks: 1, + suspend: false, + ff_tokens: vec![], + } + } +} + #[derive(Serialize, Deserialize, Debug)] pub struct MidProcessArg { /// fork_group.len() == num_forks. diff --git a/py/pyaici/__init__.py b/py/pyaici/__init__.py index fdcb9d30..b164c315 100644 --- a/py/pyaici/__init__.py +++ b/py/pyaici/__init__.py @@ -1,11 +1,37 @@ import argparse + def runner_from_cli(args): from pyaici.comms import AiciRunner + tokenizer = args.aici_tokenizer + + # when no explicit --aici-tokenizer, we look for: + # --tokenizer + --tokenizer-revision + # --model + --revision + if not tokenizer: + model_tokenizer = getattr(args, "tokenizer", None) + if model_tokenizer: + rev = getattr(args, "tokenizer_revision", None) + if rev: + model_tokenizer += f"@{rev}" + tokenizer = model_tokenizer + else: + model = getattr(args, "model", None) + if model: + rev = getattr(args, "revision", None) + if rev: + model += f"@{rev}" + tokenizer = model + + if not tokenizer: + raise ValueError("No AICIrt tokenizer specified") + if not args.aici_rt: + raise ValueError("No AICIrt path specified") + aici = AiciRunner( rtpath=args.aici_rt, - tokenizer=args.aici_tokenizer, + tokenizer=tokenizer, trace_file=args.aici_trace, rtargs=args.aici_rtarg, ) @@ -22,8 +48,8 @@ def add_cli_args(parser: argparse.ArgumentParser, single=False): parser.add_argument( "--aici-tokenizer", type=str, - default="llama", - help="tokenizer to use; llama, gpt4, ...", + default="", + help="tokenizer to use; llama, phi, ...; can also use HF tokenizer name", ) parser.add_argument( "--aici-trace", @@ -52,3 +78,5 @@ def add_cli_args(parser: argparse.ArgumentParser, single=False): default="", help="arg passed to module (filename)", ) + + return parser diff --git a/py/pyaici/comms.py b/py/pyaici/comms.py index 43732560..6100807b 100644 --- a/py/pyaici/comms.py +++ b/py/pyaici/comms.py @@ -159,7 +159,7 @@ def send(self, data): self.cmd_pending = True self.cmd_ch.send_json(data) - async def exec_async(self, op: str, data={}): + async def exec_async(self, op: str, data={}, auth_info=None): loop = asyncio.get_running_loop() if self.executor is None: @@ -181,6 +181,8 @@ def bg_reader(): rid = os.urandom(8).hex() data["op"] = op data["$rid"] = rid + if auth_info: + data["$auth"] = auth_info req = PendingRequest(cmd=data) self.pending_reqs[rid] = req @@ -238,16 +240,6 @@ def expect(self, ctx): class AiciRunner: instance = None - @staticmethod - def from_cli(args): - aici = AiciRunner( - rtpath=args.aici_rt, - tokenizer=args.aici_tokenizer, - trace_file=args.aici_trace, - rtargs=args.aici_rtarg, - ) - return aici - def __init__( self, rtpath, @@ -288,6 +280,7 @@ def __init__( self.pre_ops = [] self.mid_ops = [] self.freed_seq_ids = [] + self.pending_instantiate_results = {} self.wasm_pre_timer = BenchTimer("wasm_pre") self.wasm_pre_timer_send = BenchTimer("wasm_pre_send") @@ -366,9 +359,54 @@ def replay(self, prev_trace: str): ch.send(obj["cmd"]) ch.expect("replay") - async def upload_module_async(self, wasm: bytes): + async def upload_module_async(self, wasm: bytes, auth_info = None): b64 = base64.b64encode(wasm).decode("utf-8") - return await self.side_cmd.exec_async("mk_module", {"binary": b64}) + return await self.side_cmd.exec_async("mk_module", {"binary": b64}, auth_info=auth_info) + + async def get_tags(self, auth_info = None): + return await self.side_cmd.exec_async("get_tags", {}, auth_info=auth_info) + + async def set_tags(self, module_id: str, tags: Union[str, list[str]], auth_info = None): + if isinstance(tags, str): + tags = [tags] + op = {"module_id": module_id, "tags": tags} + return await self.side_cmd.exec_async("set_tags", op, auth_info=auth_info) + + def usage_json(self, ff_tokens: int, sampled_tokens: int): + return { + "sampled_tokens": sampled_tokens, + "ff_tokens": ff_tokens, + "cost": 2 * sampled_tokens + ff_tokens, + } + + def run_json(self, forks: List[dict], usage: dict): + return { + "object": "run", + "forks": forks, + "usage": usage, + } + + def _save_instantiate_result(self, req_id: str, res: dict): + if res["error"]: + r = self._fork_result(0, [res], finish_reason="fail") + return self.run_json([r], self.usage_json(0, 0)) + else: + self.pending_instantiate_results[req_id] = res + return res["result"]["ff_tokens"] + + def initial_json(self, req_id: str, model: str): + return { + "object": "initial-run", + "id": req_id, + "created": int(time.monotonic()), + "model": model, + } + + def final_data(self): + return "data: [DONE]\n\n" + + def data_line(self, data: dict): + return f"data: {json.dumps(data)}\n\n" async def instantiate_async( self, @@ -386,14 +424,17 @@ async def instantiate_async( module_id (str): The ID of the WASM constraint module (SHA256 hash). module_arg (str or dict): The argument for the module. """ - return await self.side_cmd.exec_async( - "instantiate", - { - "req_id": req_id, - "prompt": prompt, - "module_id": module_id, - "module_arg": module_arg, - }, + return self._save_instantiate_result( + req_id, + await self.side_cmd.exec_async( + "instantiate", + { + "req_id": req_id, + "prompt": prompt, + "module_id": module_id, + "module_arg": module_arg, + }, + ), ) def instantiate( @@ -412,20 +453,27 @@ def instantiate( module_id (str): The ID of the WASM constraint module (SHA256 hash). module_arg (str or dict): The argument for the module. """ - return self.side_cmd.exec( - "instantiate", - { - "req_id": req_id, - "prompt": prompt, - "module_id": module_id, - "module_arg": module_arg, - }, + return self._save_instantiate_result( + req_id, + self.side_cmd.exec( + "instantiate", + { + "req_id": req_id, + "prompt": prompt, + "module_id": module_id, + "module_arg": module_arg, + }, + ), ) def assign_seq_id(self, req_id: str, seq_id: int): """ Assign a sequence ID (number) to a given request ID (passed to .instantiate() before). """ + if req_id in self.pending_instantiate_results: + res = self.pending_instantiate_results[req_id] + del self.pending_instantiate_results[req_id] + self.logs_by_seqid[str(seq_id)] = [res] self.pre_ops.append({"req_id": req_id, "id": seq_id}) def tokens_generated(self, seq_id: int, tokens: List[int], backtrack: int = 0): @@ -498,7 +546,7 @@ def pre_status(self, seq_id: int) -> Tuple[bool, int, List[int]]: res.get("ff_tokens", []), ) else: - return True, 0, [] + return False, 0, [] def mid_status(self, seq_id: int) -> Tuple[List[int], int]: """ @@ -512,7 +560,17 @@ def mid_status(self, seq_id: int) -> Tuple[List[int], int]: res.get("backtrack", 0), ) - def seq_logs(self, seq_id: int) -> dict: + def _fork_result(self, index: int, lst: List[dict], text="", finish_reason=None) -> dict: + return { + "index": index, + "finish_reason": finish_reason, + "text": text, + "error": "\n".join([e["error"] for e in lst if e["error"]]), + "logs": "\n".join([e["logs"] for e in lst if e["logs"]]), + "storage": [q for e in lst for q in e["storage"]], + } + + def seq_logs(self, seq_id: int, index=0, text="", finish_reason=None) -> dict: """ Get the logs for a given sequence ID. """ @@ -522,35 +580,30 @@ def seq_logs(self, seq_id: int) -> dict: del self.logs_by_seqid[ss] else: lst = [] - r = { - "index": seq_id, # this is not quite right, but close - # "finish_reason": None, - # "text": "", - "error": "\n".join([e["error"] for e in lst if e["error"]]), - "logs": "\n".join([e["logs"] for e in lst if e["logs"]]), - "storage": [q for e in lst for q in e["storage"]], - } - return r + return self._fork_result(index, lst, text=text, finish_reason=finish_reason) def pending_logs(self): """ Get the logs for the last step. """ return [int(q) for q in self.logs_by_seqid.keys()] - - def print_logs(self): + + def print_logs_for(self, seq_id:int, r = None): + r = r or self.seq_logs(seq_id) + lines: str = r["logs"] + if lines: + for line in lines.split("\n"): + if line: + print(f"[{seq_id}] {line}") + lines: str = r["error"] + if lines: + for line in lines.split("\n"): + if line: + print(f"[{seq_id}] ERR {line}") + + def print_logs(self): for seq_id in self.pending_logs(): - r = self.seq_logs(seq_id) - lines: str = r["logs"] - if lines: - for line in lines.split("\n"): - if line: - print(f"[{seq_id}] {line}") - lines: str = r["error"] - if lines: - for line in lines.split("\n"): - if line: - print(f"[{seq_id}] ERR {line}") + self.print_logs_for(seq_id) def add_mid(self, id: int, clone_id: Optional[int] = None): assert not self.logit_pending @@ -559,6 +612,9 @@ def add_mid(self, id: int, clone_id: Optional[int] = None): obj["clone_id"] = clone_id self.mid_ops.append(obj) + def needs_exec_mid(self): + return len(self.mid_ops) > 0 + def exec_mid(self): assert not self.logit_pending cmd = { @@ -567,7 +623,6 @@ def exec_mid(self): } self.cmd.send(cmd) self.logit_pending = True - return True def flush_logit_bias(self): """ @@ -589,11 +644,12 @@ def recv_logit_bias(self): self._add_logs(self.last_resp) n: int = data["num_seqs"] assert len(self.mid_ops) == n + seq_id_to_idx = {int(q["id"]): i for i, q in enumerate(self.mid_ops)} self.mid_ops = [] arr = np.frombuffer( self.bin_shm, dtype=np.float32, offset=0, count=n * self.vocab_size ).reshape([n, self.vocab_size]) - return arr + return seq_id_to_idx, arr def stop(self): """ diff --git a/py/vllm b/py/vllm index 6c5c6957..ac0a23a6 160000 --- a/py/vllm +++ b/py/vllm @@ -1 +1 @@ -Subproject commit 6c5c6957494511b99230ab0831b2fea3b1829d87 +Subproject commit ac0a23a63e93088434c25f39d014322417f68c88 diff --git a/rllm/rllm-base/src/engine.rs b/rllm/rllm-base/src/engine.rs index ad9c50b9..ff52c8b2 100644 --- a/rllm/rllm-base/src/engine.rs +++ b/rllm/rllm-base/src/engine.rs @@ -45,6 +45,7 @@ pub struct AddRequest { pub prompt: Vec, pub sampling_params: SamplingParams, pub expected: Option, + pub init_result: Option, } pub enum Repo { @@ -285,6 +286,10 @@ impl RllmEngine { pub fn queue_request(&mut self, req: AddRequest) -> Result<()> { let mut seq = Sequence::new(self.seq_mgr.new_sequence(), &req.prompt); + match req.init_result { + Some(r) => seq.aici_logs.push(r.clone()), + None => {} + } seq.expected = req.expected; seq.pending_fork_ids = (1..req.sampling_params.n) .map(|_| self.seq_mgr.new_sequence()) @@ -327,6 +332,7 @@ impl RllmEngine { ..SamplingParams::default() }, expected: Some(exp_gen), + init_result: None, }) } @@ -342,6 +348,7 @@ impl RllmEngine { prompt: tokens, sampling_params, expected: None, + init_result: None, }) } diff --git a/rllm/rllm-base/src/iface.rs b/rllm/rllm-base/src/iface.rs index 4c7637ce..587a6f00 100644 --- a/rllm/rllm-base/src/iface.rs +++ b/rllm/rllm-base/src/iface.rs @@ -2,12 +2,12 @@ use crate::HashMap; use aici_abi::{ bytes::{limit_bytes, limit_str}, toktree::TokTrie, - InitPromptResult, }; use aicirt::{ api::{ AiciMidProcessReq, AiciMidProcessResp, AiciPostPreProcessReq, AiciPostPreProcessResp, - AuthInfo, GetTagsResp, InstantiateReq, MkModuleReq, MkModuleResp, SetTagsReq, TokensResp, + AiciPreProcessResultInner, AuthInfo, GetTagsResp, InstantiateReq, MkModuleReq, + MkModuleResp, SequenceResult, SetTagsReq, TokensResp, }, futexshm::ClientChannel, msgchannel::MessageChannel, @@ -286,7 +286,7 @@ impl AsyncCmdChannel { &self, req: InstantiateReq, authinfo: AuthInfo, - ) -> Result { + ) -> Result> { self.exec("instantiate", req, authinfo).await } diff --git a/rllm/rllm-base/src/server/completion.rs b/rllm/rllm-base/src/server/completion.rs index 6ffec59e..aa9a9a55 100644 --- a/rllm/rllm-base/src/server/completion.rs +++ b/rllm/rllm-base/src/server/completion.rs @@ -1,3 +1,4 @@ +use crate::seq::{FinishReason, RequestOutput, SeqOutput}; use crate::server::{auth_info, APIError, AiciServerData, InferenceResult}; use crate::{config::SamplingParams, seq::Token, AddRequest}; use actix_web::{post, web, web::Bytes, HttpResponse}; @@ -75,7 +76,7 @@ async fn run_controller( let token_ids = check_length(&request, &data); bail_if_error!(token_ids); - let (max_tokens, token_ids) = token_ids.unwrap(); + let (max_tokens, mut token_ids) = token_ids.unwrap(); let request_id = format!("run-{}", Uuid::new_v4()); @@ -95,7 +96,7 @@ async fn run_controller( bail_if_error!(sampling_params.verify_args()); - if let Some(mod_id) = sampling_params.controller.as_ref() { + let init_result = if let Some(mod_id) = sampling_params.controller.as_ref() { let inst = data .side_cmd_ch .instantiate( @@ -109,17 +110,53 @@ async fn run_controller( ) .await; bail_if_error!(inst); - } - - let rx = data.worker.lock().unwrap().add_request(AddRequest { - request_id: request_id.clone(), - prompt: token_ids, - sampling_params, - expected: None, - }); + let inst = inst.unwrap(); + match &inst.result { + Some(r) => { + assert!(r.num_forks == 1); + assert!(r.suspend == false); + token_ids.extend_from_slice(&r.ff_tokens); + } + None => {} + } + Some(inst.clone_with(Some(()))) + } else { + None + }; - bail_if_error!(rx); - let rx = rx.unwrap(); + let rx = match init_result { + Some(r) if r.error.len() > 0 => { + let outp = RequestOutput { + request_id: request_id.clone(), + usage: Default::default(), + seq_outputs: vec![SeqOutput { + seq_id: 0, + index: 0, + new_output_tokens: vec![], + new_text: String::new(), + output_tokens: vec![], + finish_reason: Some(FinishReason::Failed), + aici_logs: vec![r], + }], + is_final: true, + }; + let (tx, rx) = tokio::sync::mpsc::channel(1); + tx.send(Ok(outp)).await.unwrap(); + rx + } + _ => { + let rx = data.worker.lock().unwrap().add_request(AddRequest { + request_id: request_id.clone(), + prompt: token_ids, + sampling_params, + expected: None, + init_result, + }); + + bail_if_error!(rx); + rx.unwrap() + } + }; return Ok(HttpResponse::Ok() .append_header(("content-type", "text/event-stream")) diff --git a/rllm/rllm-base/src/server/mod.rs b/rllm/rllm-base/src/server/mod.rs index a4d1dce6..08580459 100644 --- a/rllm/rllm-base/src/server/mod.rs +++ b/rllm/rllm-base/src/server/mod.rs @@ -527,7 +527,7 @@ pub async fn server_main( mut args: RllmCliArgs, mut model_args: ME::ModelLoaderArgs, ) -> () { - // we setenv, so that aicirt process also gets it + // we set env, so that aicirt process also gets it match &args.log { Some(v) => std::env::set_var("RUST_LOG", v), None => {} diff --git a/scripts/hf.sh b/scripts/hf.sh index c1f9ccfa..ea8563c4 100755 --- a/scripts/hf.sh +++ b/scripts/hf.sh @@ -5,7 +5,7 @@ set -x RUST_LOG=info,tokenizers=error,rllm=debug,aicirt=info \ PYTHONPATH=py \ -python3 scripts/py/run_hf_low.py \ +python3 scripts/py/run_hf.py \ --aici-rt ./target/release/aicirt \ --controller gh:microsoft/aici/pyctrl \ --controller-arg controllers/pyctrl/samples/test.py \ diff --git a/scripts/py/run_hf.py b/scripts/py/run_hf.py index 7f9d3bfa..005d0b80 100644 --- a/scripts/py/run_hf.py +++ b/scripts/py/run_hf.py @@ -56,20 +56,20 @@ def main(args): arg = f.read() req_id = "r1" # arbitrary string seq_id = 1 # there can be multiple sequences in a single request - runner.instantiate(req_id, empty_tokens, args.controller, arg) + res = runner.instantiate(req_id, empty_tokens, args.controller, arg) + if isinstance(res, list): + ff_tokens = res + else: + runner.print_logs_for(seq_id, res["forks"][0]) + exit(1) + runner.assign_seq_id(req_id, seq_id) runner.print_logs() - # we execute first post_pre here, so we get the initial ff_tokens - runner.exec_post_pre() - runner.print_logs() - suspend, num_forks, ff_tokens = runner.pre_status(seq_id) to_stop = runner.get_seqs_to_stop() if seq_id in to_stop: print("AICI decided to stop") exit(1) - assert not suspend, "forking not implemented" - assert num_forks <= 1, "forking not implemented" prompt = torch.tensor( empty_tokens + ff_tokens, dtype=torch.long, device=model.device diff --git a/scripts/py/run_hf_generate.py b/scripts/py/run_hf_generate.py deleted file mode 100644 index a39a1525..00000000 --- a/scripts/py/run_hf_generate.py +++ /dev/null @@ -1,167 +0,0 @@ -# This example embeds AICI in HuggingFace transformers when using -# model.generate() with Streamer and LogitsProcessor callbacks. -# This does not support backtracking or fast-forward tokens, -# see run_hf.py for an example calling directly into model with -# fast-forward and backtracking. - -import argparse - -from typing import cast, Optional, Union, List -import torch -import pyaici -import pyaici.comms - -from transformers import ( - AutoTokenizer, - PreTrainedModel, - LogitsProcessor, - LogitsProcessorList, - AutoModelForCausalLM, - PreTrainedTokenizer, -) - -from transformers.generation.streamers import BaseStreamer - -device = "cuda" if torch.cuda.is_available() else "cpu" - - -class StopGeneration(Exception): - pass - - -class AsyncLogitProcessor(LogitsProcessor, BaseStreamer): - def __init__(self, seq_id: int, runner: pyaici.comms.AiciRunner) -> None: - super().__init__() - self.runner = runner - self._idx = 0 - self.seq_id = seq_id - runner.add_mid(seq_id) - runner.exec_mid() - - def _check_stop(self): - to_stop = self.runner.get_seqs_to_stop() - if self.seq_id in to_stop: - raise StopGeneration - - def put(self, value: torch.LongTensor): - if self._idx == 0: - self._idx += 1 - return # prompt - runner = self.runner - seq_id = self.seq_id - runner.tokens_generated(seq_id, value.tolist()) - runner.exec_post_pre() - runner.print_logs() - - self._check_stop() - suspend, num_forks, ff_tokens = self.runner.pre_status(self.seq_id) - assert not suspend, "forking not implemented" - assert num_forks <= 1, "forking not implemented" - assert len(ff_tokens) == 0, "ff_tokens not implemented" - - self._idx += 1 - runner.add_mid(seq_id) - runner.exec_mid() - - def end(self): - self._idx = 0 - self.runner.seq_freed(self.seq_id) - self.seq_id += 1 - self.runner.flush_logit_bias() - self.runner.print_logs() - - def __call__( - self, input_ids: torch.LongTensor, scores: torch.Tensor - ) -> torch.Tensor: - runner = self.runner - bias_tensor = runner.recv_logit_bias() - runner.print_logs() - self._check_stop() - ff_tokens, backtrack = runner.mid_status(self.seq_id) - assert backtrack == 0, "backtrack not implemented" - assert len(ff_tokens) == 0, "ff_tokens not implemented" - bias_tensor = torch.from_numpy(bias_tensor).to(scores.device).to(scores.dtype) - # print(bias_tensor.shape, scores.shape, input_ids.shape) - vocab_size = bias_tensor.shape[1] - # scores should be the size of vocabulary but some models (phi-2) make it slightly bigger - assert scores.shape[1] <= vocab_size + 1000 - scores = scores[:, 0:vocab_size] - assert scores.shape == bias_tensor.shape - assert input_ids.shape[0] == 1 # and self._idx == input_ids.shape[1] - return bias_tensor + scores - - -def main(args): - tokenizer = cast( - PreTrainedTokenizer, AutoTokenizer.from_pretrained(args.tokenizer or args.model) - ) - model = AutoModelForCausalLM.from_pretrained( - args.model, - device_map="auto", - torch_dtype=torch.bfloat16, - ) - model = cast(PreTrainedModel, model) - empty_tokens = cast( - List[int], tokenizer.convert_tokens_to_ids(tokenizer.tokenize("")) - ) - - runner = pyaici.runner_from_cli(args) - - arg = "" - if args.controller_arg: - with open(args.controller_arg) as f: - arg = f.read() - req_id = "r1" - seq_id = 1 # there can be multiple sequences in a single request - runner.instantiate(req_id, empty_tokens, args.controller, arg) - runner.assign_seq_id(req_id, seq_id) - runner.print_logs() - # we execute first post_pre here, so we get the initial ff_tokens - runner.exec_post_pre() - runner.print_logs() - suspend, num_forks, ff_tokens = runner.pre_status(seq_id) - to_stop = runner.get_seqs_to_stop() - if seq_id in to_stop: - print("AICI decided to stop") - exit(1) - assert not suspend, "forking not implemented" - assert num_forks <= 1, "forking not implemented" - - prompt = torch.tensor( - empty_tokens + ff_tokens, dtype=torch.long, device=model.device - ).unsqueeze(0) - attn_mask = torch.ones(prompt.shape, dtype=torch.long, device=model.device) - - wproc = AsyncLogitProcessor(seq_id, runner) - proc = LogitsProcessorList() - proc.append(wproc) - try: - model.generate( - input_ids=prompt, - attention_mask=attn_mask, - logits_processor=proc, - streamer=wproc, - max_new_tokens=2000, - temperature=0.01, - do_sample=True, - ) - runner.print_logs() # just in case - except StopGeneration: - runner.print_logs() - print("AICI stop") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Demo on using HF Transformers with aicirt" - ) - parser.add_argument("--model", type=str, required=True, help="model to use") - parser.add_argument( - "--tokenizer", - type=str, - default="", - help="tokenizer to use; defaults to model name", - ) - pyaici.add_cli_args(parser, single=True) - args = parser.parse_args() - main(args) diff --git a/scripts/vllm-direct.sh b/scripts/vllm-direct.sh deleted file mode 100755 index 200a006a..00000000 --- a/scripts/vllm-direct.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/sh - -set -e -set -x -(cd declctrl && ./wasm.sh cache) -mod=`cat tmp/runlog.txt |grep '^[a-f0-9]\{64\}$'` - -RUST_LOG=info \ -PYTHONPATH=py:py/vllm \ -python3 scripts/py/run_vllm.py \ - --aici-rt ./aicirt/target/release/aicirt \ - --aici-module $mod \ - --aici-module-arg declctrl/arg2.json \ - --aici-tokenizer llama \ - --model NousResearch/Llama-2-7b-chat-hf \ - --tokenizer hf-internal-testing/llama-tokenizer diff --git a/scripts/vllm-init.sh b/scripts/vllm-init.sh index 30a83a0b..9ded0555 100755 --- a/scripts/vllm-init.sh +++ b/scripts/vllm-init.sh @@ -3,10 +3,10 @@ set -x set -e mkdir -p tmp -if test -f vllm/setup.py; then +if test -f py/vllm/setup.py; then : else git submodule update --init --recursive fi -cd vllm -pip install --verbose -e . +cd py/vllm +python setup.py develop diff --git a/scripts/vllm-server.sh b/scripts/vllm-server.sh index f3e2bba4..7f82bf70 100755 --- a/scripts/vllm-server.sh +++ b/scripts/vllm-server.sh @@ -3,27 +3,23 @@ set -e set -x -MODEL=NousResearch/Llama-2-7b-chat-hf -TOK=llama - -#MODEL=codellama/CodeLlama-34b-Instruct-hf -MODEL=codellama/CodeLlama-13b-Instruct-hf -TOK=llama16 +MODEL="microsoft/Orca-2-13b" +MODEL_REV="refs/pr/22" +AICI_TOK=orca (cd aicirt && cargo build --release) -RUST_LOG=info \ +RUST_LOG=info,tokenizers=error,aicirt=trace \ PYTHONPATH=py:py/vllm \ -python3 scripts/py/vllm_server.py \ - --aici-rt ./aicirt/target/release/aicirt \ - --aici-tokenizer $TOK \ - --aici-trace tmp/trace.jsonl \ +python3 -m vllm.entrypoints.openai.api_server \ + --aici-rt ./target/release/aicirt \ + --aici-tokenizer $AICI_TOK \ --model $MODEL \ - --aici-rtarg="--wasm-max-pre-step-time=10" \ - --tokenizer hf-internal-testing/llama-tokenizer \ + --revision $MODEL_REV \ --port 4242 --host 127.0.0.1 # --aici-rtarg="--wasm-max-step-time=50" \ # --aici-rtarg="--wasm-max-pre-step-time=2" \ # --aici-rtarg="--wasm-max-init-time=1000" \ # --aici-rtarg="--wasm-max-memory=64" \ +# --aici-rtarg="--wasm-max-pre-step-time=10" \