From a48aeb1e40183ecae47abe980118bcea3b087c0f Mon Sep 17 00:00:00 2001 From: Gil Emmanuel Bancud Date: Wed, 10 Jul 2024 23:38:03 +0800 Subject: [PATCH 1/3] Add ToolChoice for message requests --- src/messages.rs | 1 + src/messages/messages_request_body.rs | 83 +++++++++++++++++++- src/messages/tool.rs | 106 ++++++++++++++++++++++++++ 3 files changed, 187 insertions(+), 3 deletions(-) diff --git a/src/messages.rs b/src/messages.rs index 7ed4aa1..4661d0c 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -68,6 +68,7 @@ pub use system_prompt::SystemPrompt; pub use temperature::Temperature; pub use tool::AsyncTool; pub use tool::Tool; +pub use tool::ToolChoice; pub use tool::ToolDefinition; pub use tool::ToolList; pub use tool::ToolResult; diff --git a/src/messages/messages_request_body.rs b/src/messages/messages_request_body.rs index 1d307a7..2d4fc71 100644 --- a/src/messages/messages_request_body.rs +++ b/src/messages/messages_request_body.rs @@ -1,7 +1,7 @@ use crate::macros::impl_display_for_serialize; use crate::messages::{ ClaudeModel, MaxTokens, Message, Metadata, StopSequence, StreamOption, - SystemPrompt, Temperature, ToolDefinition, TopK, TopP, + SystemPrompt, Temperature, ToolChoice, ToolDefinition, TopK, TopP, }; use crate::ValidationError; @@ -82,6 +82,17 @@ pub struct MessagesRequestBody { /// Recommended for advanced use cases only. You usually only need to use temperature. #[serde(skip_serializing_if = "Option::is_none")] pub top_k: Option, + /// How the model should use the provided tools. The model can use a specific tool, any available tool, or decide by itself. + /// + /// This field is an object with the following possible structures: + /// - `{"type": "auto"}`: Allows the model to decide whether to use tools. + /// - `{"type": "any"}`: Tells the model it must use one of the provided tools, but doesn't specify which. + /// - `{"type": "tool", "name": ""}`: Forces the model to use the specified tool. + /// + /// The `type` field is required and must be one of: "auto", "any", or "tool". + /// When `type` is "tool", the `name` field is also required. + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, } impl_display_for_serialize!(MessagesRequestBody); @@ -244,6 +255,15 @@ impl MessagesRequestBuilder { self } + /// Sets the tool choice. + pub fn tool_choice( + mut self, + tool_choice: ToolChoice, + ) -> Self { + self.request_body.tool_choice = Some(tool_choice); + self + } + /// Builds the MessagesRequestBody. pub fn build(self) -> MessagesRequestBody { self.request_body @@ -282,6 +302,7 @@ mod tests { assert_eq!(messages_request_body.temperature, None); assert_eq!(messages_request_body.top_p, None); assert_eq!(messages_request_body.top_k, None); + assert_eq!(messages_request_body.tool_choice, None); } #[test] @@ -307,6 +328,7 @@ mod tests { assert_eq!(messages_request_body.tools, None); assert_eq!(messages_request_body.top_p, None); assert_eq!(messages_request_body.top_k, None); + assert_eq!(messages_request_body.tool_choice, None); } #[test] @@ -343,10 +365,11 @@ mod tests { tools: None, top_p: Some(TopP::new(0.5).unwrap()), top_k: Some(TopK::new(50)), + tool_choice: Some(ToolChoice::any), }; assert_eq!( serde_json::to_string(&messages_request_body).unwrap(), - "{\"model\":\"claude-3-sonnet-20240229\",\"messages\":[],\"system\":\"system-prompt\",\"max_tokens\":16,\"metadata\":{\"user_id\":\"metadata\"},\"stop_sequences\":[\"stop-sequence\"],\"stream\":false,\"temperature\":0.5,\"top_p\":0.5,\"top_k\":50}" + "{\"model\":\"claude-3-sonnet-20240229\",\"messages\":[],\"system\":\"system-prompt\",\"max_tokens\":16,\"metadata\":{\"user_id\":\"metadata\"},\"stop_sequences\":[\"stop-sequence\"],\"stream\":false,\"temperature\":0.5,\"top_p\":0.5,\"top_k\":50,\"tool_choice\":{\"type\":\"any\"}}" ); } @@ -373,11 +396,12 @@ mod tests { stream: Some(StreamOption::ReturnOnce), temperature: Some(Temperature::new(0.5).unwrap()), tools: None, + tool_choice: Some(ToolChoice::auto), top_p: Some(TopP::new(0.5).unwrap()), top_k: Some(TopK::new(50)), }; assert_eq!( - serde_json::from_str::("{\"model\":\"claude-3-sonnet-20240229\",\"messages\":[],\"system\":\"system-prompt\",\"max_tokens\":16,\"metadata\":{\"user_id\":\"metadata\"},\"stop_sequences\":[\"stop-sequence\"],\"stream\":false,\"temperature\":0.5,\"top_p\":0.5,\"top_k\":50}").unwrap(), + serde_json::from_str::("{\"model\":\"claude-3-sonnet-20240229\",\"messages\":[],\"system\":\"system-prompt\",\"max_tokens\":16,\"metadata\":{\"user_id\":\"metadata\"},\"stop_sequences\":[\"stop-sequence\"],\"stream\":false,\"temperature\":0.5,\"top_p\":0.5,\"top_k\":50,\"tool_choice\":{\"type\":\"auto\"}}").unwrap(), messages_request_body ); } @@ -407,6 +431,9 @@ mod tests { }]) .top_p(TopP::new(0.5).unwrap()) .top_k(TopK::new(50)) + .tool_choice(ToolChoice::tool { + name: "tool".into(), + }) .build(); assert_eq!( @@ -458,6 +485,12 @@ mod tests { messages_request_body.top_k, Some(TopK::new(50)) ); + assert_eq!( + messages_request_body.tool_choice, + Some(ToolChoice::tool { + name: "tool".into(), + }) + ); } #[test] @@ -536,5 +569,49 @@ mod tests { messages_request_body.top_k, Some(TopK::new(50)) ); + assert_eq!(messages_request_body.tool_choice, None); + } + + #[test] + fn test_messages_request_body_with_tool_choice() { + let request_body = + MessagesRequestBuilder::new(ClaudeModel::Claude3Sonnet20240229) + .messages(vec![Message::user( + "Hello, Claude!", + )]) + .max_tokens( + MaxTokens::new(100, ClaudeModel::Claude3Sonnet20240229) + .unwrap(), + ) + .tool_choice(ToolChoice::auto) + .build(); + + assert_eq!( + request_body.tool_choice, + Some(ToolChoice::auto) + ); + + let json = serde_json::to_string(&request_body).unwrap(); + assert!(json.contains(r#""tool_choice":{"type":"auto"}"#)); + } + + #[test] + fn test_messages_request_body_serialization_with_tool_choice() { + let request_body = MessagesRequestBody { + model: ClaudeModel::Claude3Sonnet20240229, + messages: vec![Message::user( + "Hello", + )], + max_tokens: MaxTokens::new(100, ClaudeModel::Claude3Sonnet20240229) + .unwrap(), + tool_choice: Some(ToolChoice::tool { + name: "get_weather".to_string(), + }), + ..Default::default() + }; + + let json = serde_json::to_string(&request_body).unwrap(); + assert!(json + .contains(r#""tool_choice":{"type":"tool","name":"get_weather"}"#)); } } diff --git a/src/messages/tool.rs b/src/messages/tool.rs index 0546f61..3080f87 100644 --- a/src/messages/tool.rs +++ b/src/messages/tool.rs @@ -214,6 +214,22 @@ impl ToolList { } } +#[allow(non_camel_case_types)] +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +#[serde(tag = "type")] +pub enum ToolChoice { + /// Allows Claude to decide whether to call any provided tools or not. + auto, + /// Tells Claude that it must use one of the provided tools, but doesn't force a particular tool. + any, + /// Forces Claude to always use a particular tool. + tool { + name: String, + }, +} + +impl_display_for_serialize!(ToolChoice); + #[cfg(test)] mod tests { use super::*; @@ -574,4 +590,94 @@ mod tests { let tool_result = tool_list.call(tool_use); assert!(tool_result.is_err()) } + #[test] + fn test_tool_choice_serialization() { + assert_eq!( + serde_json::to_value(&ToolChoice::auto).unwrap(), + serde_json::json!({"type": "auto"}) + ); + assert_eq!( + serde_json::to_value(&ToolChoice::any).unwrap(), + serde_json::json!({"type": "any"}) + ); + assert_eq!( + serde_json::to_value(&ToolChoice::tool { + name: "get_weather".to_string() + }) + .unwrap(), + serde_json::json!({"type": "tool", "name": "get_weather"}) + ); + } + + #[test] + fn test_tool_choice_deserialization() { + assert_eq!( + serde_json::from_value::( + serde_json::json!({"type": "auto"}) + ) + .unwrap(), + ToolChoice::auto + ); + assert_eq!( + serde_json::from_value::( + serde_json::json!({"type": "any"}) + ) + .unwrap(), + ToolChoice::any + ); + assert_eq!( + serde_json::from_value::( + serde_json::json!({"type": "tool", "name": "get_weather"}) + ) + .unwrap(), + ToolChoice::tool { + name: "get_weather".to_string() + } + ); + } + + #[test] + fn test_tool_choice_display() { + assert_eq!( + ToolChoice::auto.to_string(), + "{\n \"type\": \"auto\"\n}" + ); + assert_eq!( + ToolChoice::any.to_string(), + "{\n \"type\": \"any\"\n}" + ); + assert_eq!( + ToolChoice::tool { + name: "get_weather".to_string() + } + .to_string(), + "{\n \"type\": \"tool\",\n \"name\": \"get_weather\"\n}" + ); + } + + #[test] + fn test_tool_choice_roundtrip() { + let auto_choice = ToolChoice::auto; + let any_choice = ToolChoice::any; + let tool = ToolChoice::tool { + name: "get_weather".to_string(), + }; + + let auto_json = serde_json::to_string(&auto_choice).unwrap(); + let any_json = serde_json::to_string(&any_choice).unwrap(); + let tool_json = serde_json::to_string(&tool).unwrap(); + + assert_eq!( + serde_json::from_str::(&auto_json).unwrap(), + auto_choice + ); + assert_eq!( + serde_json::from_str::(&any_json).unwrap(), + any_choice + ); + assert_eq!( + serde_json::from_str::(&tool_json).unwrap(), + tool + ); + } } From 8e468d2916551eb8bbd282c138376ef83bc24cab Mon Sep 17 00:00:00 2001 From: Gil Emmanuel Bancud Date: Thu, 11 Jul 2024 02:01:01 +0800 Subject: [PATCH 2/3] Add ToolUse raw info extraction methods --- src/messages/messages_response_body.rs | 177 ++++++++++++++++++++++++- 1 file changed, 172 insertions(+), 5 deletions(-) diff --git a/src/messages/messages_response_body.rs b/src/messages/messages_response_body.rs index dd433f7..7196828 100644 --- a/src/messages/messages_response_body.rs +++ b/src/messages/messages_response_body.rs @@ -4,8 +4,8 @@ use crate::macros::{ impl_display_for_serialize, impl_enum_string_serialization, }; use crate::messages::{ - ClaudeModel, Content, Message, Role, StopReason, - StopSequence, Usage, + ClaudeModel, Content, ContentBlock, Message, Role, StopReason, + StopSequence, ToolUse, Usage, }; /// The response body for the Messages API. @@ -76,12 +76,35 @@ impl_display_for_serialize!(MessagesResponseBody); impl MessagesResponseBody { /// Creates `Message` from the response body. - pub fn crate_message(self) -> Message { + pub fn create_message(self) -> Message { Message { role: self.role, content: self.content, } } + + /// Extracts tool use information from the response body. + pub fn extract_tool_use(&self) -> Option<&ToolUse> { + if let Content::MultipleBlocks(blocks) = &self.content { + for block in blocks { + if let ContentBlock::ToolUse(tool_use_block) = block { + return Some(&tool_use_block.tool_use); + } + } + } + None + } + + /// Extracts tool name and input from the response body. + pub fn extract_tool_fields(&self) -> Option<(String, serde_json::Value)> { + self.extract_tool_use() + .map(|tool_use| { + ( + tool_use.name.clone(), + tool_use.input.clone(), + ) + }) + } } /// The object type for message. @@ -215,7 +238,7 @@ mod tests { MessageObjectType::Message ); } - + #[test] fn create_message() { let response = MessagesResponseBody { @@ -224,8 +247,152 @@ mod tests { }; assert_eq!( - response.crate_message(), + response.create_message(), Message::assistant("content") ); } + + #[test] + fn test_extract_tool_use() { + // Test case 1: Valid tool use + let response = MessagesResponseBody { + content: Content::MultipleBlocks(vec![ + ContentBlock::ToolUse(ToolUseContentBlock { + _type: ContentType::ToolUse, + tool_use: ToolUse { + id: "tool_id".to_string(), + name: "test_tool".to_string(), + input: serde_json::json!({"key": "value"}), + }, + }), + ]), + ..Default::default() + }; + let tool_use = response.extract_tool_use(); + assert!(tool_use.is_some()); + assert_eq!(tool_use.unwrap().name, "test_tool"); + + // Test case 2: No tool use + let response = MessagesResponseBody { + content: Content::MultipleBlocks(vec![ContentBlock::Text( + TextContentBlock::new("text"), + )]), + ..Default::default() + }; + let tool_use = response.extract_tool_use(); + assert!(tool_use.is_none()); + + // Test case 3: Empty content + let response = MessagesResponseBody { + content: Content::MultipleBlocks(vec![]), + ..Default::default() + }; + let tool_use = response.extract_tool_use(); + assert!(tool_use.is_none()); + + // Test case 4: Multiple tool uses (should return the first one) + let response = MessagesResponseBody { + content: Content::MultipleBlocks(vec![ + ContentBlock::ToolUse(ToolUseContentBlock { + _type: ContentType::ToolUse, + tool_use: ToolUse { + id: "tool_id_1".to_string(), + name: "test_tool_1".to_string(), + input: serde_json::json!({"key": "value1"}), + }, + }), + ContentBlock::ToolUse(ToolUseContentBlock { + _type: ContentType::ToolUse, + tool_use: ToolUse { + id: "tool_id_2".to_string(), + name: "test_tool_2".to_string(), + input: serde_json::json!({"key": "value2"}), + }, + }), + ]), + ..Default::default() + }; + let tool_use = response.extract_tool_use(); + assert!(tool_use.is_some()); + assert_eq!(tool_use.unwrap().name, "test_tool_1"); + } + + #[test] + fn test_extract_tool_fields() { + // Test case 1: Valid tool use + let response = MessagesResponseBody { + content: Content::MultipleBlocks(vec![ + ContentBlock::ToolUse(ToolUseContentBlock { + _type: ContentType::ToolUse, + tool_use: ToolUse { + id: "tool_id".to_string(), + name: "test_tool".to_string(), + input: serde_json::json!({"key": "value"}), + }, + }), + ]), + ..Default::default() + }; + let fields = response.extract_tool_fields(); + assert!(fields.is_some()); + let (name, input) = fields.unwrap(); + assert_eq!(name, "test_tool"); + assert_eq!( + input, + serde_json::json!({"key": "value"}) + ); + + // Test case 2: No tool use + let response = MessagesResponseBody { + content: Content::MultipleBlocks(vec![ContentBlock::Text( + TextContentBlock::new("text"), + )]), + ..Default::default() + }; + let fields = response.extract_tool_fields(); + assert!(fields.is_none()); + + // Test case 3: Empty content + let response = MessagesResponseBody { + content: Content::MultipleBlocks(vec![]), + ..Default::default() + }; + let fields = response.extract_tool_fields(); + assert!(fields.is_none()); + + // Test case 4: Complex input + let response = MessagesResponseBody { + content: Content::MultipleBlocks(vec![ + ContentBlock::ToolUse(ToolUseContentBlock { + _type: ContentType::ToolUse, + tool_use: ToolUse { + id: "tool_id".to_string(), + name: "complex_tool".to_string(), + input: serde_json::json!({ + "nested": { + "array": [1, 2, 3], + "object": {"a": "b"} + }, + "boolean": true + }), + }, + }), + ]), + ..Default::default() + }; + let fields = response.extract_tool_fields(); + assert!(fields.is_some()); + let (name, input) = fields.unwrap(); + assert_eq!(name, "complex_tool"); + assert_eq!( + input, + serde_json::json!({ + "nested": { + "array": [1, 2, 3], + "object": {"a": "b"} + }, + "boolean": true + }) + ); + } } From 47e76d5f3516958219051590d6fa82774f8693f7 Mon Sep 17 00:00:00 2001 From: Gil Emmanuel Bancud Date: Thu, 11 Jul 2024 02:24:17 +0800 Subject: [PATCH 3/3] Fix function spelling on examples --- examples/conversation.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/conversation.rs b/examples/conversation.rs index 0541819..52091ff 100644 --- a/examples/conversation.rs +++ b/examples/conversation.rs @@ -71,7 +71,7 @@ async fn main() -> anyhow::Result<()> { // 5. Store the first assistant message. request_body .messages - .push(response.crate_message()); + .push(response.create_message()); // 6. Add the second user message. request_body @@ -96,7 +96,7 @@ async fn main() -> anyhow::Result<()> { // 9. Store the second assistant message. request_body .messages - .push(response.crate_message()); + .push(response.create_message()); // Continue the conversation...