Skip to content

Commit

Permalink
more sensible cmd line parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Jan 23, 2024
1 parent c4bd495 commit 791c74c
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 83 deletions.
13 changes: 7 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 13 additions & 5 deletions aicirt/src/bintokens.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,22 @@ pub fn tokenizers() -> Vec<Tokenizer> {
),
tok!("falcon", "used by Falcon 7b, 40b, etc."),
tok!("mpt", "MPT"),
tok!("phi", "Phi 1.5"),
tok!("phi", "Phi 1.5 and Phi 2"),
tok!("gpt2", "GPT-2"),
]
}

pub fn list_tokenizers() -> String {
format!(
"Available tokenizers for -t or --tokenizer:\n{}",
tokenizers()
.iter()
.map(|t| format!(" -t {:16} {}", t.name, t.description))
.collect::<Vec<_>>()
.join("\n")
)
}

pub fn find_tokenizer(name: &str) -> Result<Tokenizer> {
for mut t in tokenizers() {
if t.name == name {
Expand All @@ -102,10 +113,7 @@ pub fn find_tokenizer(name: &str) -> Result<Tokenizer> {
}

println!("unknown tokenizer: {}", name);
println!("available tokenizers:");
for t in tokenizers() {
println!(" {:20} {}", t.name, t.description);
}
println!("{}", list_tokenizers());
return Err(anyhow!("unknown tokenizer: {}", name));
}

Expand Down
1 change: 1 addition & 0 deletions cpp-rllm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ edition = "2021"

[dependencies]
actix-web = "4.4.0"
clap = { version = "4.4.18", features = ["derive"] }
llama_cpp_low = { path = "../llama-cpp-low" }
rllm = { path = "../rllm", default-features = false, features = ["llamacpp"] }

Expand Down
19 changes: 18 additions & 1 deletion cpp-rllm/src/cpp-rllm.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,21 @@
use clap::Parser;
use rllm::util::parse_with_settings;

/// Serve LLMs with AICI over HTTP with llama.cpp backend.
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
pub struct CppArgs {
#[clap(flatten)]
pub args: rllm::server::RllmCliArgs,

/// Name of .gguf file inside of the model folder/repo.
#[arg(long, help_heading = "Model")]
pub gguf: Option<String>,
}

#[actix_web::main]
async fn main() -> () {
rllm::server::server_main().await;
let mut args = parse_with_settings::<CppArgs>();
args.args.gguf = args.gguf;
rllm::server::server_main(args.args).await;
}
14 changes: 13 additions & 1 deletion rllm/src/driver.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
use clap::Parser;
use rllm::util::parse_with_settings;

/// Serve LLMs with AICI over HTTP with tch (torch) backend.
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
pub struct DriverArgs {
#[clap(flatten)]
pub args: rllm::server::RllmCliArgs,
}

#[actix_web::main]
async fn main() -> () {
rllm::server::server_main().await;
let args = parse_with_settings::<DriverArgs>();
rllm::server::server_main(args.args).await;
}
133 changes: 67 additions & 66 deletions rllm/src/server/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
use crate::{
config::{ModelConfig, SamplingParams},
iface::{kill_self, AiciRtIface, AsyncCmdChannel},
seq::RequestOutput,
util::apply_settings,
AddRequest, DType, HashMap, LoaderArgs, RllmEngine,
};
use actix_web::{middleware::Logger, web, App, HttpServer};
use aici_abi::toktree::TokTrie;
use aicirt::{
Expand All @@ -6,15 +13,8 @@ use aicirt::{
};
use anyhow::Result;
use base64::Engine;
use clap::Parser;
use clap::Args;
use openai::responses::APIError;
use crate::{
config::{ModelConfig, SamplingParams},
iface::{kill_self, AiciRtIface, AsyncCmdChannel},
seq::RequestOutput,
util::apply_settings,
AddRequest, DType, HashMap, LoaderArgs, RllmEngine,
};
use std::{
fmt::Display,
sync::{Arc, Mutex},
Expand Down Expand Up @@ -54,85 +54,85 @@ pub struct OpenAIServerData {
pub stats: Arc<Mutex<ServerStats>>,
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
/// Port to serve on (localhost:port)
#[arg(long, default_value_t = 8080)]
port: u16,

/// Set verbose mode (print all requests)
#[arg(long, default_value_t = false)]
verbose: bool,
#[derive(Args, Debug)]
pub struct RllmCliArgs {
/// Set engine setting (see below or in --help for list)
#[arg(long, short, name = "NAME=VALUE")]
pub setting: Vec<String>,

/// HuggingFace model name; can be also path starting with "./"
#[arg(short, long)]
model: String,
/// HuggingFace model name, URL or path starting with "./"
#[arg(short, long, help_heading = "Model")]
pub model: String,

/// HuggingFace model revision; --model foo/bar@revision is also possible
#[arg(long)]
revision: Option<String>,
#[arg(long, help_heading = "Model")]
pub revision: Option<String>,

/// The folder name that contains safetensor weights and json files
/// (same structure as HuggingFace online)
#[arg(long)]
local_weights: Option<String>,
#[arg(long, help_heading = "Model")]
pub local_weights: Option<String>,

/// Tokenizer to use (see below or in --help for list)
#[arg(short, long, default_value = "llama", help_heading = "Model")]
pub tokenizer: String,

/// Specify which type to use in the model (bf16, f16, f32)
#[arg(long, default_value = "", help_heading = "Model")]
pub dtype: String,

/// Port to serve on (localhost:port)
#[arg(long, default_value_t = 8080, help_heading = "Server")]
pub port: u16,

/// Name of .gguf file inside of the model folder/repo.
#[arg(long)]
gguf: Option<String>,
/// Set verbose mode (print all requests)
#[arg(long, default_value_t = false, help_heading = "Server")]
pub verbose: bool,

/// Tokenizer to use; try --tokenizer list to see options
#[arg(short, long, default_value = "llama")]
tokenizer: String,
/// Enable daemon mode (log timestamps)
#[arg(long, default_value_t = false, help_heading = "Server")]
pub daemon: bool,

/// Path to the aicirt binary.
#[arg(long)]
aicirt: String,
#[arg(long, help_heading = "AICI settings")]
pub aicirt: String,

/// Size of JSON comm buffer in megabytes
#[arg(long, default_value = "128")]
json_size: usize,
#[arg(long, default_value = "128", help_heading = "AICI settings")]
pub json_size: usize,

/// Size of binary comm buffer in megabytes
#[arg(long, default_value = "32")]
bin_size: usize,
#[arg(long, default_value = "32", help_heading = "AICI settings")]
pub bin_size: usize,

/// How many milliseconds to spin-wait for a message over IPC and SHM.
#[arg(long, default_value = "200")]
busy_wait_time: u64,
#[arg(long, default_value = "200", help_heading = "AICI settings")]
pub busy_wait_time: u64,

/// Shm/semaphore name prefix
#[arg(long, default_value = "/aici0-")]
shm_prefix: String,
#[arg(long, default_value = "/aici0-", help_heading = "AICI settings")]
pub shm_prefix: String,

/// Enable nvprof profiling for given engine step
#[arg(long, default_value_t = 0)]
profile_step: usize,

/// Specify which type to use in the model (bf16, f16, f32)
#[arg(long, default_value = "")]
dtype: String,
#[cfg(feature = "cuda")]
#[arg(long, default_value_t = 0, help_heading = "Development")]
pub profile_step: usize,

/// Specify test-cases (expected/*/*.safetensors)
#[arg(long)]
test: Vec<String>,
#[arg(long, help_heading = "Development")]
pub test: Vec<String>,

/// Specify warm-up request (expected/*/*.safetensors or "off")
#[arg(long, short)]
warmup: Option<String>,
#[arg(long, short, help_heading = "Development")]
pub warmup: Option<String>,

/// Exit after processing warmup request
#[arg(long, default_value_t = false)]
warmup_only: bool,
#[arg(long, default_value_t = false, help_heading = "Development")]
pub warmup_only: bool,

/// Set engine setting; try '--setting help' to list them
#[arg(long, short, name = "NAME=VALUE")]
setting: Vec<String>,

/// Enable daemon mode (log timestamps)
#[arg(long, default_value_t = false)]
daemon: bool,
// these are copied from command-specific parsers
#[arg(skip)]
pub gguf: Option<String>,
}

#[actix_web::get("/v1/aici_modules/tags")]
Expand Down Expand Up @@ -344,12 +344,12 @@ fn inference_loop(
}

#[cfg(not(feature = "tch"))]
fn run_tests(_args: &Args, _loader_args: LoaderArgs) {
fn run_tests(_args: &RllmCliArgs, _loader_args: LoaderArgs) {
panic!("tests not supported without tch feature")
}

#[cfg(feature = "tch")]
fn run_tests(args: &Args, loader_args: LoaderArgs) {
fn run_tests(args: &RllmCliArgs, loader_args: LoaderArgs) {
let mut engine = RllmEngine::load(loader_args).expect("failed to load model");
let mut tests = args.test.clone();

Expand Down Expand Up @@ -384,7 +384,7 @@ fn run_tests(args: &Args, loader_args: LoaderArgs) {
}

fn spawn_inference_loop(
args: &Args,
args: &RllmCliArgs,
loader_args: LoaderArgs,
iface: AiciRtIface,
stats: Arc<Mutex<ServerStats>>,
Expand All @@ -394,7 +394,10 @@ fn spawn_inference_loop(
let handle = handle_res.clone();

// prep for move
#[cfg(feature = "cuda")]
let profile_step = args.profile_step;
#[cfg(not(feature = "cuda"))]
let profile_step = 0;
let warmup = args.warmup.clone();
let warmup_only = args.warmup_only.clone();

Expand Down Expand Up @@ -452,9 +455,7 @@ fn url_decode(encoded_str: &str) -> String {
}

// #[actix_web::main]
pub async fn server_main() -> () {
let mut args = Args::parse();

pub async fn server_main(mut args: RllmCliArgs) -> () {
aicirt::init_log(if args.daemon {
aicirt::LogMode::Deamon
} else {
Expand Down
26 changes: 22 additions & 4 deletions rllm/src/util.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::HashMap;
use aicirt::bintokens::list_tokenizers;
use anyhow::{bail, Result};
use clap::{Args, Command, Parser};
use std::time::Instant;

const SETTINGS: [(&'static str, &'static str, f64); 4] = [
Expand All @@ -16,9 +18,12 @@ lazy_static::lazy_static! {
}

pub fn all_settings() -> String {
SETTINGS
.map(|(k, d, v)| format!("{}: {} (default={})", k, d, v))
.join("\n")
format!(
"Settings available via -s or --setting (with their default values):\n{all}\n",
all = SETTINGS
.map(|(k, d, v)| format!(" -s {:20} {}", format!("{}={}", k, v), d))
.join("\n")
)
}

pub fn set_setting(name: &str, val: f64) -> Result<()> {
Expand Down Expand Up @@ -56,7 +61,7 @@ pub fn apply_settings(settings: &Vec<String>) -> Result<()> {
Ok(_) => {}
Err(e) => {
bail!(
"all settings:\n{all}\nfailed to set setting {s}: {e}",
"{all}\nfailed to set setting {s}: {e}",
all = all_settings()
);
}
Expand All @@ -65,6 +70,19 @@ pub fn apply_settings(settings: &Vec<String>) -> Result<()> {
Ok(())
}

pub fn parse_with_settings<T>() -> T
where
T: Parser + Args,
{
let cli =
Command::new("CLI").after_help(format!("\n{}\n{}", all_settings(), list_tokenizers()));
let cli = T::augment_args(cli);
let matches = cli.get_matches();
T::from_arg_matches(&matches)
.map_err(|err| err.exit())
.unwrap()
}

pub fn limit_str(s: &str, max_len: usize) -> String {
limit_bytes(s.as_bytes(), max_len)
}
Expand Down

0 comments on commit 791c74c

Please sign in to comment.