From 7b70dd3cdfb494dba5a3f74e4219b39799c20125 Mon Sep 17 00:00:00 2001 From: LongYinan Date: Tue, 14 May 2024 16:12:24 +0800 Subject: [PATCH] Update deps and fix everything --- .github/workflows/ci.yml | 11 ++- tiktoken-rs/Cargo.toml | 12 ++-- tiktoken-rs/README.md | 23 +++---- tiktoken-rs/src/api.rs | 105 ++++++++++++++++++++++++----- tiktoken-rs/src/vendor_tiktoken.rs | 28 +++++--- vendor/tiktoken | 2 +- 6 files changed, 127 insertions(+), 54 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 350e212..c97b671 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -36,12 +36,10 @@ jobs: submodules: recursive - name: Install rust - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@stable with: - target: ${{ matrix.target }} - toolchain: nightly - profile: minimal - override: true + targets: ${{ matrix.target }} + toolchain: stable - uses: Swatinem/rust-cache@v2 @@ -59,10 +57,9 @@ jobs: submodules: recursive - name: Install rust - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@stable with: toolchain: nightly - override: true components: rustfmt, clippy - uses: Swatinem/rust-cache@v2 diff --git a/tiktoken-rs/Cargo.toml b/tiktoken-rs/Cargo.toml index de6d649..a4f4abd 100644 --- a/tiktoken-rs/Cargo.toml +++ b/tiktoken-rs/Cargo.toml @@ -13,20 +13,16 @@ documentation = "https://docs.rs/crate/tiktoken-rs/" license = "MIT" readme = "../README.md" -[profile.release] -incremental = true -debug = 1 - [dependencies] anyhow = "1.0.76" -async-openai = { version = "0.14.2", optional = true } -base64 = "0.21.5" +async-openai = { version = "0.21.0", optional = true } +base64 = "0.22.1" bstr = "1.6.2" dhat = { version = "0.3.2", optional = true } -fancy-regex = "0.12.0" +fancy-regex = "0.13.0" lazy_static = "1.4.0" parking_lot = "0.12.1" -pyo3 = { version = "0.19.2", optional = true } +pyo3 = { version = "0.21.2", optional = true } rustc-hash = "1.1.0" [features] diff --git a/tiktoken-rs/README.md b/tiktoken-rs/README.md index 345f447..5c2e219 100644 --- a/tiktoken-rs/README.md +++ b/tiktoken-rs/README.md @@ -75,27 +75,24 @@ Need to enable the `async-openai` feature in your `Cargo.toml` file. ```rust use tiktoken_rs::async_openai::get_chat_completion_max_tokens; -use async_openai::types::{ChatCompletionRequestMessage, Role}; +use async_openai::types::{ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent, Role}; let messages = vec![ - ChatCompletionRequestMessage { - content: Some("You are a helpful assistant that only speaks French.".to_string()), + ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { + content: "You are a helpful assistant that only speaks French.".to_string(), role: Role::System, name: None, - function_call: None, - }, - ChatCompletionRequestMessage { - content: Some("Hello, how are you?".to_string()), + }), + ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage { + content: ChatCompletionRequestUserMessageContent::Text("Hello, how are you?".to_string()), role: Role::User, name: None, - function_call: None, - }, - ChatCompletionRequestMessage { - content: Some("Parlez-vous francais?".to_string()), + }), + ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { + content: "Parlez-vous francais?".to_string(), role: Role::System, name: None, - function_call: None, - }, + }), ]; let max_tokens = get_chat_completion_max_tokens("gpt-4", &messages).unwrap(); println!("max_tokens: {}", max_tokens); diff --git a/tiktoken-rs/src/api.rs b/tiktoken-rs/src/api.rs index 27cc7e4..adafe9b 100644 --- a/tiktoken-rs/src/api.rs +++ b/tiktoken-rs/src/api.rs @@ -57,7 +57,7 @@ pub struct FunctionCall { #[derive(Debug, Default, Clone, PartialEq, Eq)] pub struct ChatCompletionRequestMessage { - /// The role of the messages author. One of `system`, `user`, `assistant`, or `function`. + /// The role of the messages author. One of `system`, `user`, `assistant`, `tool`, or `function`. pub role: String, /// The contents of the message. /// `content` is required for all messages except assistant messages with function calls. @@ -379,10 +379,22 @@ pub mod async_openai { { fn from(m: &async_openai::types::ChatCompletionRequestMessage) -> Self { Self { - role: m.role.to_string(), - name: m.name.clone(), - content: m.content.clone(), - function_call: m.function_call.as_ref().map(|f| f.into()), + role: m.role().to_string(), + name: m.name().map(|x| x.to_owned()), + content: m.content(), + function_call: if let async_openai::types::ChatCompletionRequestMessage::Function( + async_openai::types::ChatCompletionRequestFunctionMessage { + name, content, .. + }, + ) = m + { + Some(super::FunctionCall { + name: name.clone(), + arguments: content.clone().unwrap_or_default(), + }) + } else { + None + }, } } } @@ -423,6 +435,69 @@ pub mod async_openai { super::get_chat_completion_max_tokens(model, &messages) } + trait ChatCompletionRequestMessageCommon { + fn role(&self) -> &str; + fn name(&self) -> Option<&str>; + fn content(&self) -> Option; + } + + impl ChatCompletionRequestMessageCommon for async_openai::types::ChatCompletionRequestMessage { + fn role(&self) -> &str { + match self { + async_openai::types::ChatCompletionRequestMessage::System(_) => "system", + async_openai::types::ChatCompletionRequestMessage::User(_) => "user", + async_openai::types::ChatCompletionRequestMessage::Assistant(_) => "assistant", + async_openai::types::ChatCompletionRequestMessage::Tool(_) => "tool", + async_openai::types::ChatCompletionRequestMessage::Function(_) => "function", + } + } + + fn name(&self) -> Option<&str> { + match self { + async_openai::types::ChatCompletionRequestMessage::System( + async_openai::types::ChatCompletionRequestSystemMessage { name, .. }, + ) => name.as_deref(), + async_openai::types::ChatCompletionRequestMessage::User( + async_openai::types::ChatCompletionRequestUserMessage { name, .. }, + ) => name.as_deref(), + async_openai::types::ChatCompletionRequestMessage::Assistant( + async_openai::types::ChatCompletionRequestAssistantMessage { name, .. }, + ) => name.as_deref(), + async_openai::types::ChatCompletionRequestMessage::Function( + async_openai::types::ChatCompletionRequestFunctionMessage { name, .. }, + ) => Some(name.as_str()), + _ => None, + } + } + + fn content(&self) -> Option { + match self { + async_openai::types::ChatCompletionRequestMessage::System( + async_openai::types::ChatCompletionRequestSystemMessage { content, .. }, + ) => Some(content.clone()), + async_openai::types::ChatCompletionRequestMessage::User( + async_openai::types::ChatCompletionRequestUserMessage { content, .. }, + ) => match content { + async_openai::types::ChatCompletionRequestUserMessageContent::Text(s) => { + Some(s.clone()) + } + async_openai::types::ChatCompletionRequestUserMessageContent::Array(m) => { + Some(m.iter().filter_map(|x| if let async_openai::types::ChatCompletionRequestMessageContentPart::Text(async_openai::types::ChatCompletionRequestMessageContentPartText { text, .. }) = x { Some(text.as_str()) } else { None }).collect::>().as_slice().join("")) + }, + }, + async_openai::types::ChatCompletionRequestMessage::Assistant( + async_openai::types::ChatCompletionRequestAssistantMessage { content, .. }, + ) => content.clone(), + async_openai::types::ChatCompletionRequestMessage::Tool(async_openai::types::ChatCompletionRequestToolMessage { content, .. }) => { + Some(content.clone()) + } + async_openai::types::ChatCompletionRequestMessage::Function( + async_openai::types::ChatCompletionRequestFunctionMessage { content, .. }, + ) => content.clone(), + } + } + } + #[cfg(test)] mod tests { use super::*; @@ -430,12 +505,11 @@ pub mod async_openai { #[test] fn test_num_tokens_from_messages() { let model = "gpt-3.5-turbo-0301"; - let messages = &[async_openai::types::ChatCompletionRequestMessage { + let messages = &[async_openai::types::ChatCompletionRequestMessage::System(async_openai::types::ChatCompletionRequestSystemMessage { role: async_openai::types::Role::System, name: None, - content: Some("You are a helpful, pattern-following assistant that translates corporate jargon into plain English.".to_string()), - function_call: None, - }]; + content: "You are a helpful, pattern-following assistant that translates corporate jargon into plain English.".to_string(), + })]; let num_tokens = num_tokens_from_messages(model, messages).unwrap(); assert!(num_tokens > 0); } @@ -443,12 +517,13 @@ pub mod async_openai { #[test] fn test_get_chat_completion_max_tokens() { let model = "gpt-3.5-turbo"; - let messages = &[async_openai::types::ChatCompletionRequestMessage { - content: Some("You are a helpful assistant that only speaks French.".to_string()), - role: async_openai::types::Role::System, - name: None, - function_call: None, - }]; + let messages = &[async_openai::types::ChatCompletionRequestMessage::System( + async_openai::types::ChatCompletionRequestSystemMessage { + content: "You are a helpful assistant that only speaks French.".to_string(), + role: async_openai::types::Role::System, + name: None, + }, + )]; let max_tokens = get_chat_completion_max_tokens(model, messages).unwrap(); assert!(max_tokens > 0); } diff --git a/tiktoken-rs/src/vendor_tiktoken.rs b/tiktoken-rs/src/vendor_tiktoken.rs index aec4a48..daa37f9 100644 --- a/tiktoken-rs/src/vendor_tiktoken.rs +++ b/tiktoken-rs/src/vendor_tiktoken.rs @@ -694,12 +694,20 @@ impl CoreBPE { &self, py: Python, text: &str, - allowed_special: HashSet<&str>, + allowed_special: HashSet, ) -> Py { - let (tokens, completions) = - py.allow_threads(|| self._encode_unstable_native(text, &allowed_special)); - let py_completions = - PyList::new(py, completions.iter().map(|seq| PyList::new(py, &seq[..]))); + let (tokens, completions) = py.allow_threads(|| { + self._encode_unstable_native( + text, + &HashSet::from_iter(allowed_special.iter().map(|x| x.as_str())), + ) + }); + let py_completions = PyList::new_bound( + py, + completions + .iter() + .map(|seq| PyList::new_bound(py, &seq[..])), + ); (tokens, py_completions).into_py(py) } @@ -728,15 +736,15 @@ impl CoreBPE { pub fn decode_bytes(&self, py: Python, tokens: Vec) -> Py { let bytes = py.allow_threads(|| self._decode_native(&tokens)); - PyBytes::new(py, &bytes).into() + PyBytes::new_bound(py, &bytes).into() } pub fn decode_single_token_bytes(&self, py: Python, token: usize) -> PyResult> { if let Some(bytes) = self.decoder.get(&token) { - return Ok(PyBytes::new(py, bytes).into()); + return Ok(PyBytes::new_bound(py, bytes).into()); } if let Some(bytes) = self.special_tokens_decoder.get(&token) { - return Ok(PyBytes::new(py, bytes).into()); + return Ok(PyBytes::new_bound(py, bytes).into()); } Err(PyErr::new::(token.to_string())) } @@ -748,14 +756,14 @@ impl CoreBPE { pub fn token_byte_values(&self, py: Python) -> Vec> { self.sorted_token_bytes .iter() - .map(|x| PyBytes::new(py, x).into()) + .map(|x| PyBytes::new_bound(py, x).into()) .collect() } } #[cfg(feature = "python")] #[pymodule] -fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> { +fn _tiktoken(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; Ok(()) } diff --git a/vendor/tiktoken b/vendor/tiktoken index 5d970c1..c0ba74c 160000 --- a/vendor/tiktoken +++ b/vendor/tiktoken @@ -1 +1 @@ -Subproject commit 5d970c1100d3210b42497203d6b5c1e30cfda6cb +Subproject commit c0ba74c238d18b4824c25f3c27fc8698055b9a76