From 3a3388375be05758d5f5574cce1d0f8eb7bdd604 Mon Sep 17 00:00:00 2001 From: sigoden Date: Wed, 4 Dec 2024 21:03:59 +0800 Subject: [PATCH] refactor: improve retrieve model (#1036) - check the model type while retrieve model - select chat/reranker model even if it is missed in client models - find predefined-models for openai-compatible client with startsWith - remove client::ApiType --- src/client/bedrock.rs | 4 +- src/client/common.rs | 38 +++---------- src/client/ernie.rs | 8 +-- src/client/macros.rs | 45 +++++++++------ src/client/model.rs | 124 +++++++++++++++++++++++++++-------------- src/client/vertexai.rs | 6 +- src/config/agent.rs | 2 +- src/config/mod.rs | 19 ++++--- src/config/session.rs | 2 +- src/main.rs | 6 +- src/rag/mod.rs | 11 ++-- src/serve.rs | 8 ++- 12 files changed, 155 insertions(+), 118 deletions(-) diff --git a/src/client/bedrock.rs b/src/client/bedrock.rs index 081a65b6..7cd289c5 100644 --- a/src/client/bedrock.rs +++ b/src/client/bedrock.rs @@ -67,7 +67,7 @@ impl BedrockClient { let body = build_chat_completions_body(data, &self.model)?; let mut request_data = RequestData::new("", body); - self.patch_request_data(&mut request_data, ApiType::ChatCompletions); + self.patch_request_data(&mut request_data); let RequestData { url: _, headers, @@ -118,7 +118,7 @@ impl BedrockClient { }); let mut request_data = RequestData::new("", body); - self.patch_request_data(&mut request_data, ApiType::Embeddings); + self.patch_request_data(&mut request_data); let RequestData { url: _, headers, diff --git a/src/client/common.rs b/src/client/common.rs index 9c516c83..e0ec8619 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -19,7 +19,7 @@ use tokio::sync::mpsc::unbounded_channel; const MODELS_YAML: &str = include_str!("../../models.yaml"); lazy_static::lazy_static! { - pub static ref ALL_MODELS: Vec = serde_yaml::from_str(MODELS_YAML).unwrap(); + pub static ref ALL_PREDEFINED_MODELS: Vec = serde_yaml::from_str(MODELS_YAML).unwrap(); static ref ESCAPE_SLASH_RE: Regex = Regex::new(r"(? RequestBuilder { - self.patch_request_data(&mut request_data, api_type); + self.patch_request_data(&mut request_data); request_data.into_builder(client) } - fn patch_request_data(&self, request_data: &mut RequestData, api_type: ApiType) { + fn patch_request_data(&self, request_data: &mut RequestData) { + let model_type = self.model().model_type(); let map = std::env::var(get_env_name(&format!( "patch_{}_{}", self.model().client_name(), - api_type.name(), + model_type.api_name(), ))) .ok() .and_then(|v| serde_json::from_str(&v).ok()) .or_else(|| { self.patch_config() - .and_then(|v| api_type.extract_patch(v)) + .and_then(|v| model_type.extract_patch(v)) .cloned() }); let map = match map { @@ -200,30 +200,6 @@ pub struct RequestPatch { pub type ApiPatch = IndexMap; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ApiType { - ChatCompletions, - Embeddings, - Rerank, -} - -impl ApiType { - pub fn name(&self) -> &str { - match self { - ApiType::ChatCompletions => "chat_completions", - ApiType::Embeddings => "embeddings", - ApiType::Rerank => "rerank", - } - } - pub fn extract_patch<'a>(&self, patch: &'a RequestPatch) -> Option<&'a ApiPatch> { - match self { - ApiType::ChatCompletions => patch.chat_completions.as_ref(), - ApiType::Embeddings => patch.embeddings.as_ref(), - ApiType::Rerank => patch.rerank.as_ref(), - } - } -} - pub struct RequestData { pub url: String, pub headers: IndexMap, @@ -383,7 +359,7 @@ pub fn create_openai_compatible_client_config(client: &str) -> Result Result { prepare_access_token(self, client).await?; let request_data = prepare_chat_completions(self, data)?; - let builder = self.request_builder(client, request_data, ApiType::ChatCompletions); + let builder = self.request_builder(client, request_data); chat_completions(builder, &self.model).await } @@ -53,7 +53,7 @@ impl Client for ErnieClient { ) -> Result<()> { prepare_access_token(self, client).await?; let request_data = prepare_chat_completions(self, data)?; - let builder = self.request_builder(client, request_data, ApiType::ChatCompletions); + let builder = self.request_builder(client, request_data); chat_completions_streaming(builder, handler, &self.model).await } @@ -64,7 +64,7 @@ impl Client for ErnieClient { ) -> Result { prepare_access_token(self, client).await?; let request_data = prepare_embeddings(self, data)?; - let builder = self.request_builder(client, request_data, ApiType::Embeddings); + let builder = self.request_builder(client, request_data); embeddings(builder, &self.model).await } @@ -75,7 +75,7 @@ impl Client for ErnieClient { ) -> Result { prepare_access_token(self, client).await?; let request_data = prepare_rerank(self, data)?; - let builder = self.request_builder(client, request_data, ApiType::Rerank); + let builder = self.request_builder(client, request_data); rerank(builder, &self.model).await } } diff --git a/src/client/macros.rs b/src/client/macros.rs index 0543cd66..4f52044f 100644 --- a/src/client/macros.rs +++ b/src/client/macros.rs @@ -52,9 +52,10 @@ macro_rules! register_client { pub fn list_models(local_config: &$config) -> Vec { let client_name = Self::name(local_config); if local_config.models.is_empty() { - if let Some(models) = $crate::client::ALL_MODELS.iter().find(|v| { + if let Some(models) = $crate::client::ALL_PREDEFINED_MODELS.iter().find(|v| { v.platform == $name || - ($name == OpenAICompatibleClient::NAME && local_config.name.as_deref() == Some(&v.platform)) + ($name == OpenAICompatibleClient::NAME + && local_config.name.as_ref().map(|name| name.starts_with(&v.platform)).unwrap_or_default()) }) { return Model::from_config(client_name, &models.models); } @@ -98,32 +99,40 @@ macro_rules! register_client { anyhow::bail!("Unknown client '{}'", client) } - static ALL_CLIENT_MODELS: std::sync::OnceLock> = std::sync::OnceLock::new(); + static ALL_CLIENT_NAMES: std::sync::OnceLock> = std::sync::OnceLock::new(); - pub fn list_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> { - let models = ALL_CLIENT_MODELS.get_or_init(|| { + pub fn list_client_names(config: &$crate::config::Config) -> Vec<&'static String> { + let names = ALL_CLIENT_NAMES.get_or_init(|| { config .clients .iter() .flat_map(|v| match v { - $(ClientConfig::$config(c) => $client::list_models(c),)+ + $(ClientConfig::$config(c) => vec![$client::name(c).to_string()],)+ ClientConfig::Unknown => vec![], }) .collect() }); - models.iter().collect() + names.iter().collect() } - pub fn list_chat_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> { - list_models(config).into_iter().filter(|v| v.model_type() == "chat").collect() - } + static ALL_MODELS: std::sync::OnceLock> = std::sync::OnceLock::new(); - pub fn list_embedding_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> { - list_models(config).into_iter().filter(|v| v.model_type() == "embedding").collect() + pub fn list_all_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> { + let models = ALL_MODELS.get_or_init(|| { + config + .clients + .iter() + .flat_map(|v| match v { + $(ClientConfig::$config(c) => $client::list_models(c),)+ + ClientConfig::Unknown => vec![], + }) + .collect() + }); + models.iter().collect() } - pub fn list_reranker_models(config: &$crate::config::Config) -> Vec<&'static $crate::client::Model> { - list_models(config).into_iter().filter(|v| v.model_type() == "reranker").collect() + pub fn list_models(config: &$crate::config::Config, model_type: $crate::client::ModelType) -> Vec<&'static $crate::client::Model> { + list_all_models(config).into_iter().filter(|v| v.model_type() == model_type).collect() } }; } @@ -175,7 +184,7 @@ macro_rules! impl_client_trait { data: $crate::client::ChatCompletionsData, ) -> anyhow::Result<$crate::client::ChatCompletionsOutput> { let request_data = $prepare_chat_completions(self, data)?; - let builder = self.request_builder(client, request_data, ApiType::ChatCompletions); + let builder = self.request_builder(client, request_data); $chat_completions(builder, self.model()).await } @@ -186,7 +195,7 @@ macro_rules! impl_client_trait { data: $crate::client::ChatCompletionsData, ) -> Result<()> { let request_data = $prepare_chat_completions(self, data)?; - let builder = self.request_builder(client, request_data, ApiType::ChatCompletions); + let builder = self.request_builder(client, request_data); $chat_completions_streaming(builder, handler, self.model()).await } @@ -196,7 +205,7 @@ macro_rules! impl_client_trait { data: &$crate::client::EmbeddingsData, ) -> Result<$crate::client::EmbeddingsOutput> { let request_data = $prepare_embeddings(self, data)?; - let builder = self.request_builder(client, request_data, ApiType::Embeddings); + let builder = self.request_builder(client, request_data); $embeddings(builder, self.model()).await } @@ -206,7 +215,7 @@ macro_rules! impl_client_trait { data: &$crate::client::RerankData, ) -> Result<$crate::client::RerankOutput> { let request_data = $prepare_rerank(self, data)?; - let builder = self.request_builder(client, request_data, ApiType::Rerank); + let builder = self.request_builder(client, request_data); $rerank(builder, self.model()).await } } diff --git a/src/client/model.rs b/src/client/model.rs index 5d496d23..47f65988 100644 --- a/src/client/model.rs +++ b/src/client/model.rs @@ -1,7 +1,7 @@ use super::{ - list_chat_models, list_embedding_models, list_reranker_models, + list_all_models, list_client_names, message::{Message, MessageContent, MessageContentPart}, - MessageContentToolCalls, + ApiPatch, MessageContentToolCalls, RequestPatch, }; use crate::config::Config; @@ -9,6 +9,7 @@ use crate::utils::{estimate_token_length, format_option_value}; use anyhow::{bail, Result}; use serde::{Deserialize, Serialize}; +use std::fmt::Display; const PER_MESSAGES_TOKENS: usize = 5; const BASIS_TOKENS: usize = 2; @@ -43,29 +44,8 @@ impl Model { .collect() } - pub fn retrieve_chat(config: &Config, model_id: &str) -> Result { - match Self::find(&list_chat_models(config), model_id) { - Some(v) => Ok(v), - None => bail!("Unknown chat model '{model_id}'"), - } - } - - pub fn retrieve_embedding(config: &Config, model_id: &str) -> Result { - match Self::find(&list_embedding_models(config), model_id) { - Some(v) => Ok(v), - None => bail!("Unknown embedding model '{model_id}'"), - } - } - - pub fn retrieve_reranker(config: &Config, model_id: &str) -> Result { - match Self::find(&list_reranker_models(config), model_id) { - Some(v) => Ok(v), - None => bail!("Unknown reranker model '{model_id}'"), - } - } - - pub fn find(models: &[&Self], model_id: &str) -> Option { - let mut model = None; + pub fn retrieve_model(config: &Config, model_id: &str, model_type: ModelType) -> Result { + let models = list_all_models(config); let (client_name, model_name) = match model_id.split_once(':') { Some((client_name, model_name)) => { if model_name.is_empty() { @@ -78,21 +58,33 @@ impl Model { }; match model_name { Some(model_name) => { - if let Some(found) = models.iter().find(|v| v.id() == model_id) { - model = Some((*found).clone()); - } else if let Some(found) = models.iter().find(|v| v.client_name == client_name) { - let mut found = (*found).clone(); - found.data.name = model_name.to_string(); - model = Some(found) + if let Some(model) = models.iter().find(|v| v.id() == model_id) { + if model.model_type() == model_type { + return Ok((*model).clone()); + } else { + bail!("Model '{model_id}' is not a {model_type} model") + } + } + if list_client_names(config) + .into_iter() + .any(|v| *v == client_name) + && model_type.can_create_from_name() + { + let mut new_model = Self::new(client_name, model_name); + new_model.data.model_type = model_type.to_string(); + return Ok(new_model); } } None => { - if let Some(found) = models.iter().find(|v| v.client_name == client_name) { - model = Some((*found).clone()); + if let Some(found) = models + .iter() + .find(|v| v.client_name == client_name && v.model_type() == model_type) + { + return Ok((*found).clone()); } } - } - model + }; + bail!("Unknown {model_type} model '{model_id}'") } pub fn id(&self) -> String { @@ -111,8 +103,14 @@ impl Model { &self.data.name } - pub fn model_type(&self) -> &str { - &self.data.model_type + pub fn model_type(&self) -> ModelType { + if self.data.model_type.starts_with("embed") { + ModelType::Embedding + } else if self.data.model_type.starts_with("rerank") { + ModelType::Reranker + } else { + ModelType::Chat + } } pub fn data(&self) -> &ModelData { @@ -125,7 +123,7 @@ impl Model { pub fn description(&self) -> String { match self.model_type() { - "chat" => { + ModelType::Chat => { let ModelData { max_input_tokens, max_output_tokens, @@ -156,7 +154,7 @@ impl Model { max_input_tokens, max_output_tokens, input_price, output_price, capabilities ) } - "embedding" => { + ModelType::Embedding => { let ModelData { input_price, max_tokens_per_chunk, @@ -168,7 +166,7 @@ impl Model { let price = format_option_value(input_price); format!("max-tokens:{max_tokens};max-batch:{max_batch};price:{price}") } - _ => String::new(), + ModelType::Reranker => String::new(), } } @@ -310,13 +308,14 @@ impl ModelData { pub fn new(name: &str) -> Self { Self { name: name.to_string(), + model_type: default_model_type(), ..Default::default() } } } #[derive(Debug, Clone, Deserialize)] -pub struct BuiltinModels { +pub struct PredefinedModels { pub platform: String, pub models: Vec, } @@ -324,3 +323,46 @@ pub struct BuiltinModels { fn default_model_type() -> String { "chat".into() } + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum ModelType { + Chat, + Embedding, + Reranker, +} + +impl Display for ModelType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ModelType::Chat => write!(f, "chat"), + ModelType::Embedding => write!(f, "embedding"), + ModelType::Reranker => write!(f, "reranker"), + } + } +} + +impl ModelType { + pub fn can_create_from_name(self) -> bool { + match self { + ModelType::Chat => true, + ModelType::Embedding => false, + ModelType::Reranker => true, + } + } + + pub fn api_name(self) -> &'static str { + match self { + ModelType::Chat => "chat_completions", + ModelType::Embedding => "embeddings", + ModelType::Reranker => "rerank", + } + } + + pub fn extract_patch(self, patch: &RequestPatch) -> Option<&ApiPatch> { + match self { + ModelType::Chat => patch.chat_completions.as_ref(), + ModelType::Embedding => patch.embeddings.as_ref(), + ModelType::Reranker => patch.rerank.as_ref(), + } + } +} diff --git a/src/client/vertexai.rs b/src/client/vertexai.rs index 07ff2a43..7b731648 100644 --- a/src/client/vertexai.rs +++ b/src/client/vertexai.rs @@ -45,7 +45,7 @@ impl Client for VertexAIClient { let model = self.model(); let model_category = ModelCategory::from_str(model.name())?; let request_data = prepare_chat_completions(self, data, &model_category)?; - let builder = self.request_builder(client, request_data, ApiType::ChatCompletions); + let builder = self.request_builder(client, request_data); match model_category { ModelCategory::Gemini => gemini_chat_completions(builder, model).await, ModelCategory::Claude => claude_chat_completions(builder, model).await, @@ -63,7 +63,7 @@ impl Client for VertexAIClient { let model = self.model(); let model_category = ModelCategory::from_str(model.name())?; let request_data = prepare_chat_completions(self, data, &model_category)?; - let builder = self.request_builder(client, request_data, ApiType::ChatCompletions); + let builder = self.request_builder(client, request_data); match model_category { ModelCategory::Gemini => { gemini_chat_completions_streaming(builder, handler, model).await @@ -84,7 +84,7 @@ impl Client for VertexAIClient { ) -> Result>> { prepare_gcloud_access_token(client, self.name(), &self.config.adc_file).await?; let request_data = prepare_embeddings(self, data)?; - let builder = self.request_builder(client, request_data, ApiType::Embeddings); + let builder = self.request_builder(client, request_data); embeddings(builder, self.model()).await } } diff --git a/src/config/agent.rs b/src/config/agent.rs index 5e54b269..c81ed119 100644 --- a/src/config/agent.rs +++ b/src/config/agent.rs @@ -61,7 +61,7 @@ impl Agent { let model = { let config = config.read(); match agent_config.model_id.as_ref() { - Some(model_id) => Model::retrieve_chat(&config, model_id)?, + Some(model_id) => Model::retrieve_model(&config, model_id, ModelType::Chat)?, None => config.current_model().clone(), } }; diff --git a/src/config/mod.rs b/src/config/mod.rs index db210ab6..4c6a321a 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -11,8 +11,8 @@ pub use self::role::{ use self::session::Session; use crate::client::{ - create_client_config, list_chat_models, list_client_types, list_reranker_models, ClientConfig, - MessageContentToolCalls, Model, OPENAI_COMPATIBLE_PLATFORMS, + create_client_config, list_client_types, list_models, ClientConfig, MessageContentToolCalls, + Model, ModelType, OPENAI_COMPATIBLE_PLATFORMS, }; use crate::function::{FunctionDeclaration, Functions, ToolResult}; use crate::rag::Rag; @@ -775,7 +775,7 @@ impl Config { pub fn set_rag_reranker_model(config: &GlobalConfig, value: Option) -> Result<()> { if let Some(id) = &value { - Model::retrieve_reranker(&config.read(), id)?; + Model::retrieve_model(&config.read(), id, ModelType::Reranker)?; } let has_rag = config.read().rag.is_some(); match has_rag { @@ -822,7 +822,7 @@ impl Config { } pub fn set_model(&mut self, model_id: &str) -> Result<()> { - let model = Model::retrieve_chat(self, model_id)?; + let model = Model::retrieve_model(self, model_id, ModelType::Chat)?; match self.role_like_mut() { Some(role_like) => role_like.set_model(&model), None => { @@ -893,7 +893,7 @@ impl Config { match role.model_id() { Some(model_id) => { if self.model.id() != model_id { - let model = Model::retrieve_chat(self, model_id)?; + let model = Model::retrieve_model(self, model_id, ModelType::Chat)?; role.set_model(&model); } else { role.set_model(&self.model); @@ -1666,7 +1666,7 @@ impl Config { if args.len() == 1 { values = match cmd { ".role" => map_completion_values(Self::list_roles(true)), - ".model" => list_chat_models(self) + ".model" => list_models(self, ModelType::Chat) .into_iter() .map(|v| (v.id(), Some(v.description()))) .collect(), @@ -1761,7 +1761,10 @@ impl Config { }; complete_option_bool(save_session) } - "rag_reranker_model" => list_reranker_models(self).iter().map(|v| v.id()).collect(), + "rag_reranker_model" => list_models(self, ModelType::Reranker) + .iter() + .map(|v| v.id()) + .collect(), "highlight" => complete_bool(self.highlight), _ => vec![], }; @@ -2268,7 +2271,7 @@ impl Config { fn setup_model(&mut self) -> Result<()> { let mut model_id = self.model_id.clone(); if model_id.is_empty() { - let models = list_chat_models(self); + let models = list_models(self, ModelType::Chat); if models.is_empty() { bail!("No available model"); } diff --git a/src/config/session.rs b/src/config/session.rs index 820fa123..227a23c2 100644 --- a/src/config/session.rs +++ b/src/config/session.rs @@ -82,7 +82,7 @@ impl Session { let mut session: Self = serde_yaml::from_str(&content).with_context(|| format!("Invalid session {}", name))?; - session.model = Model::retrieve_chat(config, &session.model_id)?; + session.model = Model::retrieve_model(config, &session.model_id, ModelType::Chat)?; if let Some(autoname) = name.strip_prefix("_/") { session.name = TEMP_SESSION_NAME.to_string(); diff --git a/src/main.rs b/src/main.rs index 77b72c22..97fbf72d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,7 +13,9 @@ mod utils; extern crate log; use crate::cli::Cli; -use crate::client::{call_chat_completions, call_chat_completions_streaming, list_chat_models}; +use crate::client::{ + call_chat_completions, call_chat_completions_streaming, list_models, ModelType, +}; use crate::config::{ ensure_parent_exists, list_agents, load_env_file, Config, GlobalConfig, Input, WorkingMode, CODE_ROLE, EXPLAIN_SHELL_ROLE, SHELL_ROLE, TEMP_SESSION_NAME, @@ -69,7 +71,7 @@ async fn run(config: GlobalConfig, cli: Cli, text: Option) -> Result<()> } if cli.list_models { - for model in list_chat_models(&config.read()) { + for model in list_models(&config.read(), ModelType::Chat) { println!("{}", model.id()); } return Ok(()); diff --git a/src/rag/mod.rs b/src/rag/mod.rs index 03b8f7e6..ceb83f69 100644 --- a/src/rag/mod.rs +++ b/src/rag/mod.rs @@ -110,7 +110,8 @@ impl Rag { pub fn create(config: &GlobalConfig, name: &str, path: &Path, data: RagData) -> Result { let hnsw = data.build_hnsw(); let bm25 = data.build_bm25(); - let embedding_model = Model::retrieve_embedding(&config.read(), &data.embedding_model)?; + let embedding_model = + Model::retrieve_model(&config.read(), &data.embedding_model, ModelType::Embedding)?; let rag = Rag { config: config.clone(), name: name.to_string(), @@ -164,14 +165,15 @@ impl Rag { value } None => { - let models = list_embedding_models(&config.read()); + let models = list_models(&config.read(), ModelType::Embedding); if models.is_empty() { bail!("No available embedding model"); } select_embedding_model(&models)? } }; - let embedding_model = Model::retrieve_embedding(&config.read(), &embedding_model_id)?; + let embedding_model = + Model::retrieve_model(&config.read(), &embedding_model_id, ModelType::Embedding)?; let chunk_size = match chunk_size { Some(value) => { @@ -516,7 +518,8 @@ impl Rag { let ids = match rerank_model { Some(model_id) => { - let model = Model::retrieve_reranker(&self.config.read(), model_id)?; + let model = + Model::retrieve_model(&self.config.read(), model_id, ModelType::Reranker)?; let client = init_client(&self.config, Some(model))?; let ids: IndexSet = [vector_search_ids, keyword_search_ids] .concat() diff --git a/src/serve.rs b/src/serve.rs index c19be1f5..b3fe5e0e 100644 --- a/src/serve.rs +++ b/src/serve.rs @@ -75,7 +75,7 @@ impl Server { fn new(config: &GlobalConfig) -> Self { let mut config = config.read().clone(); config.functions = Functions::default(); - let mut models = list_models(&config); + let mut models = list_all_models(&config); let mut default_model = config.model.clone(); default_model.data_mut().name = DEFAULT_MODEL_NAME.into(); models.insert(0, &default_model); @@ -483,7 +483,8 @@ impl Server { let config = Arc::new(RwLock::new(self.config.clone())); - let embedding_model = Model::retrieve_embedding(&config.read(), &embedding_model_id)?; + let embedding_model = + Model::retrieve_model(&config.read(), &embedding_model_id, ModelType::Embedding)?; let texts = match input { EmbeddingsReqBodyInput::Single(v) => vec![v], @@ -542,7 +543,8 @@ impl Server { let config = Arc::new(RwLock::new(self.config.clone())); - let reranker_model = Model::retrieve_embedding(&config.read(), &reranker_model_id)?; + let reranker_model = + Model::retrieve_model(&config.read(), &reranker_model_id, ModelType::Reranker)?; let client = init_client(&config, Some(reranker_model))?; let data = client