Skip to content

Commit

Permalink
Update deps and fix everything
Browse files Browse the repository at this point in the history
  • Loading branch information
Brooooooklyn committed May 14, 2024
1 parent ccbcb50 commit 7b70dd3
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 54 deletions.
11 changes: 4 additions & 7 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,10 @@ jobs:
submodules: recursive

- name: Install rust
uses: actions-rs/toolchain@v1
uses: dtolnay/rust-toolchain@stable
with:
target: ${{ matrix.target }}
toolchain: nightly
profile: minimal
override: true
targets: ${{ matrix.target }}
toolchain: stable

- uses: Swatinem/rust-cache@v2

Expand All @@ -59,10 +57,9 @@ jobs:
submodules: recursive

- name: Install rust
uses: actions-rs/toolchain@v1
uses: dtolnay/rust-toolchain@stable
with:
toolchain: nightly
override: true
components: rustfmt, clippy

- uses: Swatinem/rust-cache@v2
Expand Down
12 changes: 4 additions & 8 deletions tiktoken-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,16 @@ documentation = "https://docs.rs/crate/tiktoken-rs/"
license = "MIT"
readme = "../README.md"

[profile.release]
incremental = true
debug = 1

[dependencies]
anyhow = "1.0.76"
async-openai = { version = "0.14.2", optional = true }
base64 = "0.21.5"
async-openai = { version = "0.21.0", optional = true }
base64 = "0.22.1"
bstr = "1.6.2"
dhat = { version = "0.3.2", optional = true }
fancy-regex = "0.12.0"
fancy-regex = "0.13.0"
lazy_static = "1.4.0"
parking_lot = "0.12.1"
pyo3 = { version = "0.19.2", optional = true }
pyo3 = { version = "0.21.2", optional = true }
rustc-hash = "1.1.0"

[features]
Expand Down
23 changes: 10 additions & 13 deletions tiktoken-rs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,27 +75,24 @@ Need to enable the `async-openai` feature in your `Cargo.toml` file.

```rust
use tiktoken_rs::async_openai::get_chat_completion_max_tokens;
use async_openai::types::{ChatCompletionRequestMessage, Role};
use async_openai::types::{ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent, Role};

let messages = vec![
ChatCompletionRequestMessage {
content: Some("You are a helpful assistant that only speaks French.".to_string()),
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
content: "You are a helpful assistant that only speaks French.".to_string(),
role: Role::System,
name: None,
function_call: None,
},
ChatCompletionRequestMessage {
content: Some("Hello, how are you?".to_string()),
}),
ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("Hello, how are you?".to_string()),
role: Role::User,
name: None,
function_call: None,
},
ChatCompletionRequestMessage {
content: Some("Parlez-vous francais?".to_string()),
}),
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
content: "Parlez-vous francais?".to_string(),
role: Role::System,
name: None,
function_call: None,
},
}),
];
let max_tokens = get_chat_completion_max_tokens("gpt-4", &messages).unwrap();
println!("max_tokens: {}", max_tokens);
Expand Down
105 changes: 90 additions & 15 deletions tiktoken-rs/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ pub struct FunctionCall {

#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct ChatCompletionRequestMessage {
/// The role of the messages author. One of `system`, `user`, `assistant`, or `function`.
/// The role of the messages author. One of `system`, `user`, `assistant`, `tool`, or `function`.
pub role: String,
/// The contents of the message.
/// `content` is required for all messages except assistant messages with function calls.
Expand Down Expand Up @@ -379,10 +379,22 @@ pub mod async_openai {
{
fn from(m: &async_openai::types::ChatCompletionRequestMessage) -> Self {
Self {
role: m.role.to_string(),
name: m.name.clone(),
content: m.content.clone(),
function_call: m.function_call.as_ref().map(|f| f.into()),
role: m.role().to_string(),
name: m.name().map(|x| x.to_owned()),
content: m.content(),
function_call: if let async_openai::types::ChatCompletionRequestMessage::Function(
async_openai::types::ChatCompletionRequestFunctionMessage {
name, content, ..
},
) = m
{
Some(super::FunctionCall {
name: name.clone(),
arguments: content.clone().unwrap_or_default(),
})
} else {
None
},
}
}
}
Expand Down Expand Up @@ -423,32 +435,95 @@ pub mod async_openai {
super::get_chat_completion_max_tokens(model, &messages)
}

trait ChatCompletionRequestMessageCommon {
fn role(&self) -> &str;
fn name(&self) -> Option<&str>;
fn content(&self) -> Option<String>;
}

impl ChatCompletionRequestMessageCommon for async_openai::types::ChatCompletionRequestMessage {
fn role(&self) -> &str {
match self {
async_openai::types::ChatCompletionRequestMessage::System(_) => "system",
async_openai::types::ChatCompletionRequestMessage::User(_) => "user",
async_openai::types::ChatCompletionRequestMessage::Assistant(_) => "assistant",
async_openai::types::ChatCompletionRequestMessage::Tool(_) => "tool",
async_openai::types::ChatCompletionRequestMessage::Function(_) => "function",
}
}

fn name(&self) -> Option<&str> {
match self {
async_openai::types::ChatCompletionRequestMessage::System(
async_openai::types::ChatCompletionRequestSystemMessage { name, .. },
) => name.as_deref(),
async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage { name, .. },
) => name.as_deref(),
async_openai::types::ChatCompletionRequestMessage::Assistant(
async_openai::types::ChatCompletionRequestAssistantMessage { name, .. },
) => name.as_deref(),
async_openai::types::ChatCompletionRequestMessage::Function(
async_openai::types::ChatCompletionRequestFunctionMessage { name, .. },
) => Some(name.as_str()),
_ => None,
}
}

fn content(&self) -> Option<String> {
match self {
async_openai::types::ChatCompletionRequestMessage::System(
async_openai::types::ChatCompletionRequestSystemMessage { content, .. },
) => Some(content.clone()),
async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage { content, .. },
) => match content {
async_openai::types::ChatCompletionRequestUserMessageContent::Text(s) => {
Some(s.clone())
}
async_openai::types::ChatCompletionRequestUserMessageContent::Array(m) => {
Some(m.iter().filter_map(|x| if let async_openai::types::ChatCompletionRequestMessageContentPart::Text(async_openai::types::ChatCompletionRequestMessageContentPartText { text, .. }) = x { Some(text.as_str()) } else { None }).collect::<Vec<_>>().as_slice().join(""))
},
},
async_openai::types::ChatCompletionRequestMessage::Assistant(
async_openai::types::ChatCompletionRequestAssistantMessage { content, .. },
) => content.clone(),
async_openai::types::ChatCompletionRequestMessage::Tool(async_openai::types::ChatCompletionRequestToolMessage { content, .. }) => {
Some(content.clone())
}
async_openai::types::ChatCompletionRequestMessage::Function(
async_openai::types::ChatCompletionRequestFunctionMessage { content, .. },
) => content.clone(),
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_num_tokens_from_messages() {
let model = "gpt-3.5-turbo-0301";
let messages = &[async_openai::types::ChatCompletionRequestMessage {
let messages = &[async_openai::types::ChatCompletionRequestMessage::System(async_openai::types::ChatCompletionRequestSystemMessage {
role: async_openai::types::Role::System,
name: None,
content: Some("You are a helpful, pattern-following assistant that translates corporate jargon into plain English.".to_string()),
function_call: None,
}];
content: "You are a helpful, pattern-following assistant that translates corporate jargon into plain English.".to_string(),
})];
let num_tokens = num_tokens_from_messages(model, messages).unwrap();
assert!(num_tokens > 0);
}

#[test]
fn test_get_chat_completion_max_tokens() {
let model = "gpt-3.5-turbo";
let messages = &[async_openai::types::ChatCompletionRequestMessage {
content: Some("You are a helpful assistant that only speaks French.".to_string()),
role: async_openai::types::Role::System,
name: None,
function_call: None,
}];
let messages = &[async_openai::types::ChatCompletionRequestMessage::System(
async_openai::types::ChatCompletionRequestSystemMessage {
content: "You are a helpful assistant that only speaks French.".to_string(),
role: async_openai::types::Role::System,
name: None,
},
)];
let max_tokens = get_chat_completion_max_tokens(model, messages).unwrap();
assert!(max_tokens > 0);
}
Expand Down
28 changes: 18 additions & 10 deletions tiktoken-rs/src/vendor_tiktoken.rs
Original file line number Diff line number Diff line change
Expand Up @@ -694,12 +694,20 @@ impl CoreBPE {
&self,
py: Python,
text: &str,
allowed_special: HashSet<&str>,
allowed_special: HashSet<String>,
) -> Py<PyTuple> {
let (tokens, completions) =
py.allow_threads(|| self._encode_unstable_native(text, &allowed_special));
let py_completions =
PyList::new(py, completions.iter().map(|seq| PyList::new(py, &seq[..])));
let (tokens, completions) = py.allow_threads(|| {
self._encode_unstable_native(
text,
&HashSet::from_iter(allowed_special.iter().map(|x| x.as_str())),
)
});
let py_completions = PyList::new_bound(
py,
completions
.iter()
.map(|seq| PyList::new_bound(py, &seq[..])),
);
(tokens, py_completions).into_py(py)
}

Expand Down Expand Up @@ -728,15 +736,15 @@ impl CoreBPE {

pub fn decode_bytes(&self, py: Python, tokens: Vec<usize>) -> Py<PyBytes> {
let bytes = py.allow_threads(|| self._decode_native(&tokens));
PyBytes::new(py, &bytes).into()
PyBytes::new_bound(py, &bytes).into()
}

pub fn decode_single_token_bytes(&self, py: Python, token: usize) -> PyResult<Py<PyBytes>> {
if let Some(bytes) = self.decoder.get(&token) {
return Ok(PyBytes::new(py, bytes).into());
return Ok(PyBytes::new_bound(py, bytes).into());
}
if let Some(bytes) = self.special_tokens_decoder.get(&token) {
return Ok(PyBytes::new(py, bytes).into());
return Ok(PyBytes::new_bound(py, bytes).into());
}
Err(PyErr::new::<exceptions::PyKeyError, _>(token.to_string()))
}
Expand All @@ -748,14 +756,14 @@ impl CoreBPE {
pub fn token_byte_values(&self, py: Python) -> Vec<Py<PyBytes>> {
self.sorted_token_bytes
.iter()
.map(|x| PyBytes::new(py, x).into())
.map(|x| PyBytes::new_bound(py, x).into())
.collect()
}
}

#[cfg(feature = "python")]
#[pymodule]
fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> {
fn _tiktoken(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<CoreBPE>()?;
Ok(())
}
Expand Down

0 comments on commit 7b70dd3

Please sign in to comment.