diff --git a/CHANGELOG.md b/CHANGELOG.md index d0ac13b..26ba6e8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add `clust::messages::StopReason::ToolUse`. - Add tools for `clust::messages::MessagesRequestBody`. - Add tool use content block for `clust::messages::Content`. +- Add too result content block for `clust::messages::Content`. ### Removed diff --git a/clust_macros/Cargo.toml b/clust_macros/Cargo.toml index b3523e6..303fd7a 100644 --- a/clust_macros/Cargo.toml +++ b/clust_macros/Cargo.toml @@ -22,6 +22,8 @@ proc-macro = true quote = "1.0.*" syn = { version = "2.0.*", features = ["full"] } proc-macro2 = "1.0.*" +valico = "4.0.*" +serde_json = "1.0.*" [dev-dependencies] tokio = { version = "1.37.0", features = ["macros"] } diff --git a/clust_macros/rust-toolchain.toml b/clust_macros/rust-toolchain.toml index 12163da..6cd16e1 100644 --- a/clust_macros/rust-toolchain.toml +++ b/clust_macros/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "1.76.0" -components = ["rls", "rust-analysis", "rust-src", "rustfmt", "clippy", "cargo-expand"] +channel = "1.77.2" +components = ["rls", "rust-analysis", "rust-src", "rustfmt", "clippy"] diff --git a/clust_macros/src/check_result.rs b/clust_macros/src/check_result.rs deleted file mode 100644 index c53c25e..0000000 --- a/clust_macros/src/check_result.rs +++ /dev/null @@ -1,69 +0,0 @@ -use syn::Type; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub(crate) enum ReturnType { - Value, - Result, -} - -pub(crate) fn get_return_type(ty: &Type) -> ReturnType { - if is_result_type(ty) { - ReturnType::Result - } else { - ReturnType::Value - } -} - -fn is_result_type(ty: &Type) -> bool { - match ty { - | Type::Path(type_path) => { - let path_segments = &type_path.path.segments; - path_segments.last().map_or(false, |last_segment| { - if last_segment.ident == "Result" { - match &last_segment.arguments { - syn::PathArguments::AngleBracketed(args) => args.args.len() == 2, - _ => false, - } - } else if path_segments.len() >= 2 { - path_segments - .iter() - .rev() - .nth(1) - .map_or(false, |second_last_segment| { - second_last_segment.ident == "result" - && match &last_segment.arguments { - syn::PathArguments::AngleBracketed(args) => args.args.len() == 2, - _ => false, - } - }) - } else { - false - } - }) - }, - | _ => false, - } -} - -#[cfg(test)] -mod tests { - use super::*; - - type TestResult = Result; - - #[test] - fn test_get_return_type() { - let ty = syn::parse_str::("Result").unwrap(); - assert_eq!(get_return_type(&ty), ReturnType::Result); - - let ty = - syn::parse_str::("std::result::Result").unwrap(); - assert_eq!(get_return_type(&ty), ReturnType::Result); - - let ty = syn::parse_str::("i32").unwrap(); - assert_eq!(get_return_type(&ty), ReturnType::Value); - - let ty: Type = syn::parse_str::("TestResult").unwrap(); - assert_eq!(get_return_type(&ty), ReturnType::Value); - } -} diff --git a/clust_macros/src/lib.rs b/clust_macros/src/lib.rs index 356a6ac..4a384af 100644 --- a/clust_macros/src/lib.rs +++ b/clust_macros/src/lib.rs @@ -3,8 +3,9 @@ use crate::tool::impl_tool; use proc_macro::TokenStream; -mod check_result; +mod return_type; mod tool; +mod parameter_type; /// A procedural macro that generates a `clust::messages::Tool` or `clust::messages::AsyncTool` /// implementation for the annotated function with documentation. @@ -56,9 +57,7 @@ mod tool; /// /// ```rust /// use clust_macros::clust_tool; -/// use std::collections::BTreeMap; -/// use std::iter::FromIterator; -/// use clust::messages::{FunctionCalls, Invoke, Tool}; +/// use clust::messages::{ToolUse, Tool}; /// /// /// Increments the argument by 1. /// /// @@ -73,17 +72,15 @@ mod tool; /// /// let description = tool.description(); /// -/// let function_calls = FunctionCalls { -/// invoke: Invoke { -/// tool_name: String::from("incrementer"), -/// parameters: BTreeMap::from_iter(vec![( -/// "value".to_string(), -/// "42".to_string(), -/// )]), -/// }, -/// }; -/// -/// let result = tool.call(function_calls).unwrap(); +/// let tool_use = ToolUse::new( +/// "toolu_XXXX", +/// "incrementer", +/// serde_json::json!({ +/// "value": 42 +/// }), +/// ); +/// +/// let result = tool.call(tool_use).unwrap(); /// ``` /// /// Generated XML tool description from above implementation is as follows: diff --git a/clust_macros/src/parameter_type.rs b/clust_macros/src/parameter_type.rs new file mode 100644 index 0000000..d9cf650 --- /dev/null +++ b/clust_macros/src/parameter_type.rs @@ -0,0 +1,354 @@ +use std::fmt::Display; +use syn::{ + GenericArgument, PathArguments, Type, TypeArray, TypeParen, TypeSlice, +}; +use valico::json_schema::PrimitiveType; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub(crate) enum ParameterType { + Null, + Boolean, + Integer, + Number, + String, + Array(Box), + Option(Box), + //Enum(Vec), TODO: + Object, +} + +impl Default for ParameterType { + fn default() -> Self { + Self::Null + } +} + +impl Display for ParameterType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + | ParameterType::Null => write!(f, "null"), + | ParameterType::Boolean => write!(f, "boolean"), + | ParameterType::Integer => write!(f, "integer"), + | ParameterType::Number => write!(f, "number"), + | ParameterType::String => write!(f, "string"), + | ParameterType::Array(inner) => { + write!(f, "array of {}", inner) + }, + | ParameterType::Option(inner) => { + write!(f, "option of {}", inner) + }, + | ParameterType::Object => write!(f, "object"), + } + } +} + +impl ParameterType { + pub(crate) fn from_syn_type(ty: &Type) -> Self { + match ty { + | Type::Path(type_path) => { + let path_segments = &type_path.path.segments; + if let Some(first) = path_segments.first() { + if first.ident == "Option" { + if let PathArguments::AngleBracketed(args) = + first.arguments.clone() + { + if let Some(arg) = args.args.last() { + if let GenericArgument::Type(ty) = arg { + return Self::Option(Box::new( + ParameterType::from_syn_type(ty), + )); + } + } + } + } + + if first.ident == "Vec" { + if let PathArguments::AngleBracketed(args) = + first.arguments.clone() + { + if let Some(arg) = args.args.last() { + if let GenericArgument::Type(ty) = arg { + return Self::Array(Box::new( + ParameterType::from_syn_type(ty), + )); + } + } + } + } + } + + path_segments + .last() + .map_or( + Self::Object, + |last_segment| match last_segment + .ident + .to_string() + .as_str() + { + | "i8" | "i16" | "i32" | "i64" | "i128" + | "isize" | "u8" | "u16" | "u32" | "u64" + | "u128" | "usize" => Self::Integer, + | "f32" | "f64" => Self::Number, + | "bool" => Self::Boolean, + | "String" => Self::String, + | _ => Self::Object, + }, + ) + }, + // Fixed array type like [T; N] + | Type::Array(TypeArray { + elem: element_type, + .. + }) => Self::Array(Box::new(Self::from_syn_type( + element_type.as_ref(), + ))), + // Slice type like [T] + | Type::Slice(slice_type) => { + // [str] + if Self::is_string_slice(slice_type) { + return Self::String; + } + + // Other slices + Self::Array(Box::new(Self::from_syn_type( + slice_type.elem.as_ref(), + ))) + }, + // Parenthesized type like (T) + | Type::Paren(TypeParen { + elem: element_type, + .. + }) => Self::from_syn_type(element_type.as_ref()), + // Pointer type like *const T or *mut T + | Type::Ptr(pointer_type) => { + Self::from_syn_type(pointer_type.elem.as_ref()) + }, + // Reference type like &T or &mut T + | Type::Reference(reference_type) => { + Self::from_syn_type(reference_type.elem.as_ref()) + }, + // Tuple type or other Object types + | _ => Self::Object, + } + } + + fn is_string_slice(ty: &TypeSlice) -> bool { + if let Type::Path(type_path) = ty.elem.as_ref() { + let path_segments = &type_path.path.segments; + if let Some(last) = path_segments.last() { + if last.ident == "str" { + return true; + } + } + } + + false + } + + pub(crate) fn to_primitive_type(&self) -> PrimitiveType { + match self { + | ParameterType::Null => PrimitiveType::Null, + | ParameterType::Boolean => PrimitiveType::Boolean, + | ParameterType::Integer => PrimitiveType::Integer, + | ParameterType::Number => PrimitiveType::Number, + | ParameterType::String => PrimitiveType::String, + | ParameterType::Array(_) => PrimitiveType::Array, + | ParameterType::Option(inner) => inner.to_primitive_type(), + | ParameterType::Object => PrimitiveType::Object, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn integer() { + assert_eq!( + ParameterType::from_syn_type( + &syn::parse_str::("i8").unwrap() + ), + ParameterType::Integer + ); + + assert_eq!( + ParameterType::from_syn_type( + &syn::parse_str::("i16").unwrap() + ), + ParameterType::Integer + ); + + assert_eq!( + ParameterType::from_syn_type( + &syn::parse_str::("i32").unwrap() + ), + ParameterType::Integer + ); + + assert_eq!( + ParameterType::from_syn_type( + &syn::parse_str::("i64").unwrap() + ), + ParameterType::Integer + ); + + assert_eq!( + ParameterType::from_syn_type( + &syn::parse_str::("i128").unwrap() + ), + ParameterType::Integer + ); + + assert_eq!( + ParameterType::from_syn_type( + &syn::parse_str::("isize").unwrap() + ), + ParameterType::Integer + ); + + assert_eq!( + ParameterType::from_syn_type( + &syn::parse_str::("u8").unwrap() + ), + ParameterType::Integer + ); + + assert_eq!( + ParameterType::from_syn_type( + &syn::parse_str::("u16").unwrap() + ), + ParameterType::Integer + ); + + assert_eq!( + ParameterType::from_syn_type( + &syn::parse_str::("u32").unwrap() + ), + ParameterType::Integer + ); + + assert_eq!( + ParameterType::from_syn_type( + &syn::parse_str::("u64").unwrap() + ), + ParameterType::Integer + ); + + assert_eq!( + ParameterType::from_syn_type( + &syn::parse_str::("u128").unwrap() + ), + ParameterType::Integer + ); + + assert_eq!( + ParameterType::from_syn_type( + &syn::parse_str::("usize").unwrap() + ), + ParameterType::Integer + ); + } + + #[test] + fn number() { + assert_eq!( + ParameterType::from_syn_type( + &syn::parse_str::("f32").unwrap() + ), + ParameterType::Number + ); + + assert_eq!( + ParameterType::from_syn_type( + &syn::parse_str::("f64").unwrap() + ), + ParameterType::Number + ); + } + + #[test] + fn boolean() { + assert_eq!( + ParameterType::from_syn_type( + &syn::parse_str::("bool").unwrap() + ), + ParameterType::Boolean + ); + } + + #[test] + fn string() { + assert_eq!( + ParameterType::from_syn_type( + &syn::parse_str::("String").unwrap() + ), + ParameterType::String + ); + assert_eq!( + ParameterType::from_syn_type( + &syn::parse_str::("&[str]").unwrap() + ), + ParameterType::String + ); + } + + #[test] + fn array() { + assert_eq!( + ParameterType::from_syn_type( + &syn::parse_str::("Vec").unwrap() + ), + ParameterType::Array(Box::new(ParameterType::Integer)) + ); + + assert_eq!( + ParameterType::from_syn_type( + &syn::parse_str::("&[i32]").unwrap() + ), + ParameterType::Array(Box::new(ParameterType::Integer)) + ); + } + + #[test] + fn option() { + assert_eq!( + ParameterType::from_syn_type( + &syn::parse_str::("Option").unwrap() + ), + ParameterType::Option(Box::new(ParameterType::Integer)) + ); + + assert_eq!( + ParameterType::from_syn_type( + &syn::parse_str::("Option").unwrap() + ), + ParameterType::Option(Box::new(ParameterType::Boolean)) + ); + + assert_eq!( + ParameterType::from_syn_type( + &syn::parse_str::("Option").unwrap() + ), + ParameterType::Option(Box::new(ParameterType::String)) + ); + } + + #[test] + fn object() { + assert_eq!( + ParameterType::from_syn_type( + &syn::parse_str::("JsonValueType").unwrap() + ), + ParameterType::Object, + ); + + assert_eq!( + ParameterType::from_syn_type( + &syn::parse_str::("(u32, bool)").unwrap() + ), + ParameterType::Object, + ); + } +} diff --git a/clust_macros/src/return_type.rs b/clust_macros/src/return_type.rs new file mode 100644 index 0000000..717cf1c --- /dev/null +++ b/clust_macros/src/return_type.rs @@ -0,0 +1,110 @@ +use syn::Type; + +/// Return type of function. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub(crate) enum ReturnType { + /// Returns a value. + Value, + /// Returns a Result. + Result, + /// Returns nothing. + None, +} + +impl ReturnType { + pub(crate) fn from_syn(ty: &syn::ReturnType) -> ReturnType { + match ty { + | syn::ReturnType::Default => ReturnType::None, + | syn::ReturnType::Type(_, ty) => match ty.as_ref() { + | Type::Path(type_path) => { + return if Self::is_result_at_first_segment(&type_path) + || Self::is_result_by_full_path(&type_path) + { + ReturnType::Result + } else { + ReturnType::Value + } + }, + | _ => ReturnType::Value, + }, + } + } + + // Result + fn is_result_at_first_segment(type_path: &syn::TypePath) -> bool { + if let Some(first) = type_path + .path + .segments + .first() + { + first.ident == "Result" + } else { + false + } + } + + // std::result::Result + fn is_result_by_full_path(type_path: &syn::TypePath) -> bool { + let mut segments = type_path.path.segments.iter(); + if let Some(first) = segments.next() { + if first.ident == "std" { + if let Some(second) = segments.next() { + if second.ident == "result" { + if let Some(third) = segments.next() { + if third.ident == "Result" { + return true; + } + } + } + } + } + } + false + } +} + +#[cfg(test)] +mod tests { + use super::*; + use syn::Token; + + type TestResult = Result; + + #[test] + fn test_get_return_type() { + assert_eq!( + ReturnType::from_syn(&syn::ReturnType::Default), + ReturnType::None + ); + + let ty = syn::parse_str::("Result").unwrap(); + let r_array = syn::parse_str::]>("->").unwrap(); + let return_type = syn::ReturnType::Type(r_array, Box::new(ty)); + assert_eq!( + ReturnType::from_syn(&return_type), + ReturnType::Result + ); + + let ty = + syn::parse_str::("std::result::Result").unwrap(); + let return_type = syn::ReturnType::Type(r_array, Box::new(ty)); + assert_eq!( + ReturnType::from_syn(&return_type), + ReturnType::Result + ); + + let ty = syn::parse_str::("i32").unwrap(); + let return_type = syn::ReturnType::Type(r_array, Box::new(ty)); + assert_eq!( + ReturnType::from_syn(&return_type), + ReturnType::Value + ); + + let ty: Type = syn::parse_str::("TestResult").unwrap(); + let return_type = syn::ReturnType::Type(r_array, Box::new(ty)); + assert_eq!( + ReturnType::from_syn(&return_type), + ReturnType::Value + ); + } +} diff --git a/clust_macros/src/tool.rs b/clust_macros/src/tool.rs index 0152494..40f2072 100644 --- a/clust_macros/src/tool.rs +++ b/clust_macros/src/tool.rs @@ -1,58 +1,88 @@ extern crate proc_macro; use proc_macro::TokenStream; -use proc_macro2::{Ident, Span}; use std::collections::BTreeMap; -use crate::check_result::{get_return_type, ReturnType}; +use proc_macro2::{Ident, Span}; use quote::{quote, ToTokens}; use syn::{AttrStyle, Expr, ItemFn, Meta}; +use crate::parameter_type::ParameterType; +use crate::return_type::ReturnType; + #[derive(Debug, Clone)] struct DocComments { - description: String, + description: Option, parameters: BTreeMap, } #[derive(Debug, Clone)] -struct ParameterType { +struct ParameterWithNoDescription { name: String, - _type: String, + _type: ParameterType, } #[derive(Debug, Clone)] struct Parameter { - pub name: String, - pub _type: String, - pub description: String, -} - -impl ToTokens for Parameter { - fn to_tokens( - &self, - tokens: &mut proc_macro2::TokenStream, - ) { - let name = &self.name; - let _type = &self._type; - let description = &self.description; - - tokens.extend(quote! { - clust::messages::Parameter { - name: format!(r#"{}"#, #name), - _type: format!(r#"{}"#, #_type), - description: format!(r#"{}"#, #description), - }, - }); - } + name: String, + _type: ParameterType, + description: Option, } #[derive(Debug, Clone)] struct ToolInformation { name: String, - description: String, + description: Option, parameters: Vec, } +impl ToolInformation { + fn build_json_schema(&self) -> serde_json::Value { + let mut builder = valico::json_schema::Builder::new(); + builder.object(); + + if let Some(description) = &self.description { + builder.desc(&description.clone()); + } + + let mut required = Vec::new(); + + for parameter in &self.parameters { + builder.properties(|properties| { + properties.insert(¶meter.name, |property| { + if let Some(description) = ¶meter.description { + property.desc(&description.clone()); + } + property.type_( + parameter + ._type + .to_primitive_type(), + ); + + // "items" for array + if let ParameterType::Array(item_type) = + parameter._type.clone() + { + property.items_schema(|items| { + items.type_(item_type.to_primitive_type()); + }); + } + }); + }); + + if let ParameterType::Option(_) = parameter._type { + // Do nothing + } else { + required.push(parameter.name.clone()); + } + } + + builder.required(required); + + builder.into_json() + } +} + fn get_doc_comments(func: &ItemFn) -> Vec { func.attrs .iter() @@ -163,27 +193,33 @@ fn parse_doc_comments(docs: Vec) -> DocComments { } } + let description = if description.is_empty() { + None + } else { + Some(description) + }; + DocComments { - description: format!(r#"{}"#, description), + description, parameters, } } -fn get_parameter_types(func: &ItemFn) -> Vec { +fn get_parameter_types(func: &ItemFn) -> Vec { func.sig.inputs.iter().map(|input| { match input { | syn::FnArg::Typed(pat) => { match pat.pat.as_ref() { | syn::Pat::Ident(ident) => { - ParameterType { + ParameterWithNoDescription { name: ident.ident.to_string(), - _type: pat.ty.to_token_stream().to_string(), + _type: ParameterType::from_syn_type(&pat.ty), } }, | _ => panic!("Tool trait requires named fields"), } }, - | _ => panic!("Tool trait can only be derived for functions with named fields, not for methods."), + | _ => panic!("Tool trait can only be derived for functions with named fields."), } }).collect() } @@ -193,25 +229,28 @@ fn get_tool_information(func: &ItemFn) -> ToolInformation { let doc_comments = parse_doc_comments(doc_comments); let parameter_types = get_parameter_types(&func); - let parameters = doc_comments - .parameters + let parameters = parameter_types .iter() - .map( - |(parameter_name, parameter_description)| { - let parameter_type = parameter_types - .iter() - .find(|parameter_type| { - parameter_type.name == parameter_name.as_str() - }) - .unwrap(); - + .map(|parameter| { + // Add parameter description if it has been found in doc comments + if let Some((_, parameter_description)) = doc_comments + .parameters + .iter() + .find(|(name, _)| *name == ¶meter.name) + { Parameter { - name: parameter_name.clone(), - _type: parameter_type._type.clone(), - description: parameter_description.clone(), + name: parameter.name.clone(), + _type: parameter._type.clone(), + description: Some(parameter_description.clone()), } - }, - ) + } else { + Parameter { + name: parameter.name.clone(), + _type: parameter._type.clone(), + description: None, + } + } + }) .collect(); ToolInformation { @@ -221,44 +260,31 @@ fn get_tool_information(func: &ItemFn) -> ToolInformation { } } -fn quote_parameter_descriptions( - info: &ToolInformation -) -> Vec { - info.parameters - .iter() - .map(|parameter| { - let name = parameter.name.clone(); - let _type = parameter._type.clone(); - let description = parameter.description.clone(); - - quote! { - clust::messages::Parameter { - name: format!(r#"{}"#, #name), - _type: format!(r#"{}"#, #_type), - description: format!(r#"{}"#, #description), - } - } - }) - .collect() -} - -fn quote_description(info: &ToolInformation) -> proc_macro2::TokenStream { +fn quote_definition(info: &ToolInformation) -> proc_macro2::TokenStream { let name = info.name.clone(); let description = info.description.clone(); - let parameters = quote_parameter_descriptions(info); - - quote! { - fn description(&self) -> clust::messages::ToolDescription { - clust::messages::ToolDescription { - tool_name: format!(r#"{}"#, #name), - description: format!(r#"{}"#, #description), - parameters: clust::messages::Parameters { - inner: vec![ - #( - #parameters - ),* - ], - }, + let input_schema = info + .build_json_schema() + .to_string(); + + if let Some(description) = description { + quote! { + fn definition(&self) -> clust::messages::ToolDefinition { + clust::messages::ToolDefinition::new( + #name, + Some(#description), + serde_json::from_str(&#input_schema).expect("Failed to parse JSON schema of tool definition"), + ) + } + } + } else { + quote! { + fn definition(&self) -> clust::messages::ToolDefinition { + clust::messages::ToolDefinition::new( + #name, + None, + serde_json::json!(#input_schema), + ) } } } @@ -271,223 +297,221 @@ fn quote_invoke_parameters( .parameters .iter() .map(|parameter| parameter.name.clone()) - .map(|parameter| { + .map(|name| { quote! { - function_calls.invoke.parameters.get(#parameter) - .ok_or_else(|| clust::messages::ToolCallError::ParameterNotFound(#parameter.to_string()))? - .parse() - .map_err(|_| clust::messages::ToolCallError::ParameterParseFailed(#parameter.to_string()))? + tool_use.input.get(#name) + .ok_or_else(|| clust::messages::ToolCallError::ParameterNotFound(#name.to_string()))? + .to_string() + .parse() + .map_err(|_| clust::messages::ToolCallError::ParameterParseFailed(#name.to_string()))? } }) .collect() } -fn quote_result(name: String) -> proc_macro2::TokenStream { +fn quote_tool_call() -> proc_macro2::TokenStream { quote! { - Ok(clust::messages::FunctionResults::Result( - clust::messages::FunctionResult { - tool_name: #name.to_string(), - stdout: format!("{}", result), - } - )) + fn call(&self, tool_use: clust::messages::ToolUse) + -> std::result::Result } } -fn quote_result_with_match(name: String) -> proc_macro2::TokenStream { +fn quote_async_tool_call() -> proc_macro2::TokenStream { quote! { - match result { - | Ok(value) => { - Ok(clust::messages::FunctionResults::Result( - clust::messages::FunctionResult { - tool_name: #name.to_string(), - stdout: format!("{}", value), - } - )) - }, - | Err(error) => { - Ok(clust::messages::FunctionResults::Error( - format!("{}", error) - )) - }, - } + async fn call(&self, tool_use: clust::messages::ToolUse) + -> std::result::Result } } -fn quote_call( - func: &ItemFn, - info: &ToolInformation, -) -> proc_macro2::TokenStream { +fn quote_check_name(info: &ToolInformation) -> proc_macro2::TokenStream { let name = info.name.clone(); - let ident = func.sig.ident.clone(); - let parameters = quote_invoke_parameters(info); - let quote_result = quote_result(name.clone()); quote! { - fn call(&self, function_calls: clust::messages::FunctionCalls) - -> std::result::Result { - if function_calls.invoke.tool_name != #name { - return Err(clust::messages::ToolCallError::ToolNameMismatch); - } - - let result = #ident( - #( - #parameters - ),* - ); - - #quote_result + if tool_use.name != #name { + return Err(clust::messages::ToolCallError::ToolNameMismatch); } } } -fn quote_call_with_result( +fn quote_call_with_no_value( func: &ItemFn, info: &ToolInformation, ) -> proc_macro2::TokenStream { - let name = info.name.clone(); - let ident = func.sig.ident.clone(); - let parameters = quote_invoke_parameters(info); - let quote_result = quote_result_with_match(name.clone()); + let function = func.sig.ident.clone(); + let impl_invoke_parameters = quote_invoke_parameters(info); quote! { - fn call(&self, function_calls: clust::messages::FunctionCalls) - -> std::result::Result { - if function_calls.invoke.tool_name != #name { - return Err(clust::messages::ToolCallError::ToolNameMismatch); - } - - let result = #ident( - #( - #parameters - ),* - ); - - #quote_result - } + #function( + #( + #impl_invoke_parameters + ),* + ); } } -fn quote_call_async( +fn quote_call_with_value( func: &ItemFn, info: &ToolInformation, ) -> proc_macro2::TokenStream { - let name = info.name.clone(); - let ident = func.sig.ident.clone(); - let parameters = quote_invoke_parameters(info); - let quote_result = quote_result(name.clone()); + let function = func.sig.ident.clone(); + let impl_invoke_parameters = quote_invoke_parameters(info); quote! { - async fn call(&self, function_calls: clust::messages::FunctionCalls) - -> std::result::Result { - if function_calls.invoke.tool_name != #name { - return Err(clust::messages::ToolCallError::ToolNameMismatch); - } + let result = #function( + #( + #impl_invoke_parameters + ),* + ); + } +} - let result = #ident( - #( - #parameters - ),* - ).await; +fn quote_call_with_no_value_async( + func: &ItemFn, + info: &ToolInformation, +) -> proc_macro2::TokenStream { + let function = func.sig.ident.clone(); + let impl_invoke_parameters = quote_invoke_parameters(info); - #quote_result - } + quote! { + #function( + #( + #impl_invoke_parameters + ),* + ).await; } } -fn quote_call_async_with_result( +fn quote_call_with_value_async( func: &ItemFn, info: &ToolInformation, ) -> proc_macro2::TokenStream { - let name = info.name.clone(); - let ident = func.sig.ident.clone(); - let parameters = quote_invoke_parameters(info); - let quote_result = quote_result_with_match(name.clone()); + let function = func.sig.ident.clone(); + let impl_invoke_parameters = quote_invoke_parameters(info); quote! { - async fn call(&self, function_calls: clust::messages::FunctionCalls) - -> std::result::Result { - if function_calls.invoke.tool_name != #name { - return Err(clust::messages::ToolCallError::ToolNameMismatch); - } + let result = #function( + #( + #impl_invoke_parameters + ),* + ).await; + } +} - let result = #ident( - #( - #parameters - ),* - ).await; +fn quote_return_no_value() -> proc_macro2::TokenStream { + quote! { + Ok(clust::messages::ToolResult::success( + tool_use.id, + None, + )) + } +} - #quote_result +fn quote_return_value() -> proc_macro2::TokenStream { + quote! { + Ok(clust::messages::ToolResult::success( + tool_use.id, + Some(format!("{}", result)), + )) + } +} + +fn quote_return_value_with_result() -> proc_macro2::TokenStream { + quote! { + match result { + | Ok(value) => { + Ok(clust::messages::ToolResult::success( + tool_use.id, + Some(format!("{}", value)), + )) + }, + | Err(error) => { + Ok(clust::messages::ToolResult::error( + tool_use.id, + Some(format!("{}", error)), + )) + }, } } } -fn impl_tool_for_function( +fn quote_call( func: &ItemFn, - info: ToolInformation, + info: &ToolInformation, + return_type: ReturnType, + is_async: bool, ) -> proc_macro2::TokenStream { - let impl_description = quote_description(&info); - - let impl_call = match func.sig.output.clone() { - | syn::ReturnType::Default => { - panic!("Function must have a displayable return type") + let impl_tool_call = if !is_async { + quote_tool_call() + } else { + quote_async_tool_call() + }; + let impl_check_name = quote_check_name(info); + let impl_call = match return_type { + | ReturnType::None => { + if !is_async { + quote_call_with_no_value(func, info) + } else { + quote_call_with_no_value_async(func, info) + } }, - | syn::ReturnType::Type(_, _type) => { - let return_type = get_return_type(&_type); - - match return_type { - | ReturnType::Value => quote_call(func, &info), - | ReturnType::Result => quote_call_with_result(func, &info), + | ReturnType::Value | ReturnType::Result => { + if !is_async { + quote_call_with_value(func, info) + } else { + quote_call_with_value_async(func, info) } }, }; + let impl_return_value = match return_type { + | ReturnType::None => quote_return_no_value(), + | ReturnType::Value => quote_return_value(), + | ReturnType::Result => quote_return_value_with_result(), + }; - let struct_name = Ident::new( - &format!("ClustTool_{}", info.name), - Span::call_site(), - ); + quote! { + #impl_tool_call { + #impl_check_name + #impl_call + #impl_return_value + } + } +} + +fn quote_impl_tool(struct_name: &Ident) -> proc_macro2::TokenStream { + let struct_name = struct_name.clone(); quote! { - // Original function - #func + impl clust::messages::Tool for #struct_name + } +} - // Generated tool struct - pub struct #struct_name; +fn quote_impl_async_tool(struct_name: &Ident) -> proc_macro2::TokenStream { + let struct_name = struct_name.clone(); - // Implement Tool trait for generated tool struct - impl clust::messages::Tool for #struct_name { - #impl_description - #impl_call - } + quote! { + impl clust::messages::AsyncTool for #struct_name } } -fn impl_tool_for_async_function( +fn impl_tool_for_function( func: &ItemFn, info: ToolInformation, + return_type: ReturnType, + is_async: bool, ) -> proc_macro2::TokenStream { - let impl_description = quote_description(&info); - - let impl_call = match func.sig.output.clone() { - | syn::ReturnType::Default => { - panic!("Function must have a displayable return type") - }, - | syn::ReturnType::Type(_, _type) => { - let return_type = get_return_type(&_type); - - match return_type { - | ReturnType::Value => quote_call_async(func, &info), - | ReturnType::Result => { - quote_call_async_with_result(func, &info) - }, - } - }, - }; - let struct_name = Ident::new( &format!("ClustTool_{}", info.name), Span::call_site(), ); + let impl_impl_tool = if !is_async { + quote_impl_tool(&struct_name) + } else { + quote_impl_async_tool(&struct_name) + }; + let impl_definition = quote_definition(&info); + let impl_call = quote_call(func, &info, return_type, is_async); + quote! { // Original function #func @@ -495,9 +519,9 @@ fn impl_tool_for_async_function( // Generated tool struct pub struct #struct_name; - // Implement Tool trait for generated tool struct - impl clust::messages::AsyncTool for #struct_name { - #impl_description + // Implement Tool or AsyncTool trait for generated tool struct + #impl_impl_tool { + #impl_definition #impl_call } } @@ -505,12 +529,16 @@ fn impl_tool_for_async_function( pub(crate) fn impl_tool(func: &ItemFn) -> TokenStream { let tool_information = get_tool_information(func); - - if func.sig.asyncness.is_some() { - impl_tool_for_async_function(func, tool_information).into() - } else { - impl_tool_for_function(func, tool_information).into() - } + let is_async = func.sig.asyncness.is_some(); + let return_type = ReturnType::from_syn(&func.sig.output); + + impl_tool_for_function( + func, + tool_information, + return_type, + is_async, + ) + .into() } #[cfg(test)] @@ -596,7 +624,7 @@ mod test { assert_eq!( doc_comments.description, - "A function for testing." + Some("A function for testing.".to_string()) ); assert_eq!(doc_comments.parameters.len(), 1); assert_eq!( @@ -627,7 +655,7 @@ mod test { assert_eq!( doc_comments.description, - "A function for testing." + Some("A function for testing.".to_string()) ); assert_eq!(doc_comments.parameters.len(), 2); assert_eq!( @@ -664,7 +692,7 @@ mod test { assert_eq!(tool_information.name, "test_function"); assert_eq!( tool_information.description, - "A function for testing." + Some("A function for testing.".to_string()) ); assert_eq!( tool_information @@ -686,7 +714,7 @@ mod test { .get(0) .unwrap() ._type, - "i32" + ParameterType::Integer, ); assert_eq!( tool_information @@ -694,7 +722,7 @@ mod test { .get(0) .unwrap() .description, - "First argument." + Some("First argument.".to_string()) ); } @@ -717,7 +745,7 @@ mod test { assert_eq!(tool_information.name, "test_function"); assert_eq!( tool_information.description, - "A function for testing." + Some("A function for testing.".to_string()) ); assert_eq!( tool_information @@ -739,7 +767,7 @@ mod test { .get(0) .unwrap() ._type, - "i32" + ParameterType::Integer, ); assert_eq!( tool_information @@ -747,7 +775,7 @@ mod test { .get(0) .unwrap() .description, - "First argument." + Some("First argument.".to_string()) ); assert_eq!( tool_information @@ -763,7 +791,7 @@ mod test { .get(1) .unwrap() ._type, - "u32" + ParameterType::Integer, ); assert_eq!( tool_information @@ -771,7 +799,31 @@ mod test { .get(1) .unwrap() .description, - "Second argument." + Some("Second argument.".to_string()) + ); + } + + #[test] + fn test_is_optional_type() { + let input = quote! { + fn test_function(arg1: Option, arg2: u32) -> i32 { + arg1.unwrap() + } + }; + + let item_func = syn::parse_str::(&input.to_string()).unwrap(); + let parameter_types = get_parameter_types(&item_func); + + assert_eq!(parameter_types.len(), 2); + assert_eq!(parameter_types[0].name, "arg1"); + assert_eq!( + parameter_types[0]._type, + ParameterType::Option(Box::new(ParameterType::Integer)), + ); + assert_eq!(parameter_types[1].name, "arg2"); + assert_eq!( + parameter_types[1]._type, + ParameterType::Integer ); } } diff --git a/clust_macros/tests/tool.rs b/clust_macros/tests/tool.rs index df1d6e1..2a570d7 100644 --- a/clust_macros/tests/tool.rs +++ b/clust_macros/tests/tool.rs @@ -1,8 +1,4 @@ -use std::collections::BTreeMap; - -use clust::messages::{ - FunctionCalls, FunctionResults, Invoke, Tool, -}; +use clust::messages::{TextContentBlock, Tool, ToolResult, ToolUse}; use clust_macros::clust_tool; @@ -20,19 +16,24 @@ fn test_description() { let tool = ClustTool_test_function {}; assert_eq!( - tool.description().to_string(), - r#" - - test_function - A function for testing. - - - arg1 - i32 - First argument. - - -"# + tool.definition().to_string(), + r#"{ + "name": "test_function", + "description": "A function for testing.", + "input_schema": { + "description": "A function for testing.", + "properties": { + "arg1": { + "description": "First argument.", + "type": "integer" + } + }, + "required": [ + "arg1" + ], + "type": "object" + } +}"# ); } @@ -40,24 +41,15 @@ fn test_description() { fn test_call() { let tool = ClustTool_test_function {}; - let function_calls = FunctionCalls { - invoke: Invoke { - tool_name: String::from("test_function"), - parameters: BTreeMap::from_iter(vec![( - "arg1".to_string(), - "42".to_string(), - )]), - }, - }; - - let result = tool - .call(function_calls) - .unwrap(); - - if let FunctionResults::Result(result) = result { - assert_eq!(result.tool_name, "test_function"); - assert_eq!(result.stdout, "43"); - } else { - panic!("Expected FunctionResults::Result"); - } + let tool_use = ToolUse::new( + "toolu_XXXX", + "test_function", + serde_json::json!({"arg1": 42}), + ); + + let result = tool.call(tool_use).unwrap(); + + assert_eq!(result.tool_use_id, "toolu_XXXX"); + assert_eq!(result.is_error, None); + assert_eq!(result.content.unwrap().text, "43"); } diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 0f4d775..725b30d 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "1.76.0" +channel = "1.77.0" components = ["rls", "rust-analysis", "rust-src", "rustfmt", "clippy"] diff --git a/src/lib.rs b/src/lib.rs index 056a7a7..7ba323d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,7 +6,7 @@ //! - [x] [Streaming Messages](https://docs.anthropic.com/claude/reference/messages-streaming) //! //! ## Feature flags -//! - `macros`: Enable the [`attributes::clust_tool`] attribute macro for generating [`messages::Tool`] +//! - `macros`: Enable the [`attributes::clust_tool`] attribute macro for generating [`messages::ToolDefinition`] //! or [`messages::AsyncTool`] from a Rust function. //! //! ## Usages diff --git a/src/messages.rs b/src/messages.rs index 65d226b..1d43f50 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -66,9 +66,11 @@ pub use stop_sequence::StopSequence; pub use stream_option::StreamOption; pub use system_prompt::SystemPrompt; pub use temperature::Temperature; -pub use tool::Tool; +pub use tool::ToolDefinition; pub use tool::ToolResult; pub use tool::ToolUse; +pub use tool::Tool; +pub use tool::AsyncTool; pub use top_k::TopK; pub use top_p::TopP; pub use usage::Usage; diff --git a/src/messages/content.rs b/src/messages/content.rs index c453fd4..272358a 100644 --- a/src/messages/content.rs +++ b/src/messages/content.rs @@ -30,7 +30,7 @@ use crate::messages::{ /// ))) /// ]); /// let content = Content::MultipleBlocks(vec![ -/// ContentBlock::ToolResult(ToolResultContentBlock::new(ToolResult::new( +/// ContentBlock::ToolResult(ToolResultContentBlock::new(ToolResult::success( /// "tool_use_id", /// Some("content"), /// ))), @@ -45,7 +45,7 @@ use crate::messages::{ /// let content = Content::from(vec![ContentBlock::from("text")]); /// let content = Content::from(ImageContentSource::base64(ImageMediaType::Png, "base64")); /// let content = Content::from(ToolUse::new("id", "name", serde_json::Value::Null)); -/// let content = Content::from(ToolResult::new("tool_use_id", Some("content"))); +/// let content = Content::from(ToolResult::success("tool_use_id", Some("content"))); /// let content = Content::from(vec![ /// ContentBlock::from("text"), /// ContentBlock::from(ImageContentSource::base64(ImageMediaType::Png, "base64")), @@ -56,7 +56,7 @@ use crate::messages::{ /// let content: Content = vec![ContentBlock::from("text")].into(); /// let content: Content = ImageContentSource::base64(ImageMediaType::Png, "base64").into(); /// let content: Content = ToolUse::new("id", "name", serde_json::Value::Null).into(); -/// let content: Content = ToolResult::new("tool_use_id", Some("content")).into(); +/// let content: Content = ToolResult::success("tool_use_id", Some("content")).into(); /// let content: Content = vec![ /// "text".into(), /// ImageContentSource::base64(ImageMediaType::Png, "base64").into(), @@ -593,9 +593,6 @@ pub struct ToolResultContentBlock { /// The tool result. #[serde(flatten)] pub tool_result: ToolResult, - /// Set to true if the tool execution resulted in an error. - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, } impl Default for ToolResultContentBlock { @@ -603,7 +600,6 @@ impl Default for ToolResultContentBlock { Self { _type: ContentType::ToolResult, tool_result: ToolResult::default(), - error: None, } } } @@ -622,16 +618,6 @@ impl ToolResultContentBlock { Self { _type: ContentType::ToolResult, tool_result, - error: None, - } - } - - /// Creates a new tool result content block as an error. - pub fn error(tool_result: ToolResult) -> Self { - Self { - _type: ContentType::ToolResult, - tool_result, - error: Some(true), } } } @@ -1101,26 +1087,27 @@ mod tests { #[test] fn new_tool_result_content_block() { let tool_result_content_block = ToolResultContentBlock::new( - ToolResult::new("tool_use_id", Some("content")), + ToolResult::success("tool_use_id", Some("content")), ); assert_eq!( tool_result_content_block, ToolResultContentBlock { _type: ContentType::ToolResult, - tool_result: ToolResult::new("tool_use_id", Some("content")), - error: None, + tool_result: ToolResult::success( + "tool_use_id", + Some("content"), + ), } ); - let tool_result_content_block = ToolResultContentBlock::error( - ToolResult::new("tool_use_id", Some("content")), + let tool_result_content_block = ToolResultContentBlock::new( + ToolResult::error("tool_use_id", Some("content")), ); assert_eq!( tool_result_content_block, ToolResultContentBlock { _type: ContentType::ToolResult, - tool_result: ToolResult::new("tool_use_id", Some("content")), - error: Some(true), + tool_result: ToolResult::error("tool_use_id", Some("content"),), } ); } @@ -1132,7 +1119,6 @@ mod tests { ToolResultContentBlock { _type: ContentType::ToolResult, tool_result: ToolResult::default(), - error: None, } ); } @@ -1141,8 +1127,7 @@ mod tests { fn display_tool_result_content_block() { let tool_result_content_block = ToolResultContentBlock { _type: ContentType::ToolResult, - tool_result: ToolResult::new("tool_use_id", Some("content")), - error: None, + tool_result: ToolResult::success("tool_use_id", Some("content")), }; assert_eq!( tool_result_content_block.to_string(), @@ -1151,12 +1136,11 @@ mod tests { let tool_result_content_block = ToolResultContentBlock { _type: ContentType::ToolResult, - tool_result: ToolResult::new("tool_use_id", Some("content")), - error: Some(true), + tool_result: ToolResult::error("tool_use_id", Some("content")), }; assert_eq!( tool_result_content_block.to_string(), - "{\n \"type\": \"tool_result\",\n \"tool_use_id\": \"tool_use_id\",\n \"content\": {\n \"type\": \"text\",\n \"text\": \"content\"\n },\n \"error\": true\n}" + "{\n \"type\": \"tool_result\",\n \"tool_use_id\": \"tool_use_id\",\n \"content\": {\n \"type\": \"text\",\n \"text\": \"content\"\n },\n \"is_error\": true\n}" ); } @@ -1164,8 +1148,7 @@ mod tests { fn serialize_tool_result_content_block() { let tool_result_content_block = ToolResultContentBlock { _type: ContentType::ToolResult, - tool_result: ToolResult::new("tool_use_id", Some("content")), - error: None, + tool_result: ToolResult::success("tool_use_id", Some("content")), }; assert_eq!( serde_json::to_string(&tool_result_content_block).unwrap(), @@ -1174,12 +1157,11 @@ mod tests { let tool_result_content_block = ToolResultContentBlock { _type: ContentType::ToolResult, - tool_result: ToolResult::new("tool_use_id", Some("content")), - error: Some(true), + tool_result: ToolResult::error("tool_use_id", Some("content")), }; assert_eq!( serde_json::to_string(&tool_result_content_block).unwrap(), - "{\"type\":\"tool_result\",\"tool_use_id\":\"tool_use_id\",\"content\":{\"type\":\"text\",\"text\":\"content\"},\"error\":true}" + "{\"type\":\"tool_result\",\"tool_use_id\":\"tool_use_id\",\"content\":{\"type\":\"text\",\"text\":\"content\"},\"is_error\":true}" ); } @@ -1187,8 +1169,7 @@ mod tests { fn deserialize_tool_result_content_block() { let tool_result_content_block = ToolResultContentBlock { _type: ContentType::ToolResult, - tool_result: ToolResult::new("tool_use_id", Some("content")), - error: None, + tool_result: ToolResult::success("tool_use_id", Some("content")), }; assert_eq!( serde_json::from_str::("{\"type\":\"tool_result\",\"tool_use_id\":\"tool_use_id\",\"content\":{\"type\":\"text\",\"text\":\"content\"}}").unwrap(), @@ -1197,11 +1178,10 @@ mod tests { let tool_result_content_block = ToolResultContentBlock { _type: ContentType::ToolResult, - tool_result: ToolResult::new("tool_use_id", Some("content")), - error: Some(true), + tool_result: ToolResult::error("tool_use_id", Some("content")), }; assert_eq!( - serde_json::from_str::("{\"type\":\"tool_result\",\"tool_use_id\":\"tool_use_id\",\"content\":{\"type\":\"text\",\"text\":\"content\"},\"error\":true}").unwrap(), + serde_json::from_str::("{\"type\":\"tool_result\",\"tool_use_id\":\"tool_use_id\",\"content\":{\"type\":\"text\",\"text\":\"content\"},\"is_error\":true}").unwrap(), tool_result_content_block ); } @@ -1243,14 +1223,13 @@ mod tests { let content_block = ContentBlock::ToolResult(ToolResultContentBlock::new( - ToolResult::new("tool_use_id", Some("content")), + ToolResult::error("tool_use_id", Some("content")), )); assert_eq!( content_block, ContentBlock::ToolResult(ToolResultContentBlock { _type: ContentType::ToolResult, - tool_result: ToolResult::new("tool_use_id", Some("content")), - error: None, + tool_result: ToolResult::error("tool_use_id", Some("content")), }) ); } @@ -1291,7 +1270,7 @@ mod tests { let content_block = ContentBlock::ToolResult(ToolResultContentBlock::new( - ToolResult::new("tool_use_id", Some("content")), + ToolResult::success("tool_use_id", Some("content")), )); assert_eq!( content_block.to_string(), @@ -1299,12 +1278,12 @@ mod tests { ); let content_block = - ContentBlock::ToolResult(ToolResultContentBlock::error( - ToolResult::new("tool_use_id", Some("content")), + ContentBlock::ToolResult(ToolResultContentBlock::new( + ToolResult::error("tool_use_id", Some("content")), )); assert_eq!( content_block.to_string(), - "{\n \"type\": \"tool_result\",\n \"tool_use_id\": \"tool_use_id\",\n \"content\": {\n \"type\": \"text\",\n \"text\": \"content\"\n },\n \"error\": true\n}" + "{\n \"type\": \"tool_result\",\n \"tool_use_id\": \"tool_use_id\",\n \"content\": {\n \"type\": \"text\",\n \"text\": \"content\"\n },\n \"is_error\": true\n}" ); } @@ -1336,7 +1315,7 @@ mod tests { let content_block = ContentBlock::ToolResult(ToolResultContentBlock::new( - ToolResult::new("tool_use_id", Some("content")), + ToolResult::success("tool_use_id", Some("content")), )); assert_eq!( serde_json::to_string(&content_block).unwrap(), @@ -1344,12 +1323,12 @@ mod tests { ); let content_block = - ContentBlock::ToolResult(ToolResultContentBlock::error( - ToolResult::new("tool_use_id", Some("content")), + ContentBlock::ToolResult(ToolResultContentBlock::new( + ToolResult::error("tool_use_id", Some("content")), )); assert_eq!( serde_json::to_string(&content_block).unwrap(), - "{\"type\":\"tool_result\",\"tool_use_id\":\"tool_use_id\",\"content\":{\"type\":\"text\",\"text\":\"content\"},\"error\":true}" + "{\"type\":\"tool_result\",\"tool_use_id\":\"tool_use_id\",\"content\":{\"type\":\"text\",\"text\":\"content\"},\"is_error\":true}" ); } @@ -1384,7 +1363,7 @@ mod tests { let content_block = ContentBlock::ToolResult(ToolResultContentBlock::new( - ToolResult::new("tool_use_id", Some("content")), + ToolResult::success("tool_use_id", Some("content")), )); assert_eq!( serde_json::from_str::("{\"type\":\"tool_result\",\"tool_use_id\":\"tool_use_id\",\"content\":{\"type\":\"text\",\"text\":\"content\"}}").unwrap(), @@ -1392,11 +1371,11 @@ mod tests { ); let content_block = - ContentBlock::ToolResult(ToolResultContentBlock::error( - ToolResult::new("tool_use_id", Some("content")), + ContentBlock::ToolResult(ToolResultContentBlock::new( + ToolResult::error("tool_use_id", Some("content")), )); assert_eq!( - serde_json::from_str::("{\"type\":\"tool_result\",\"tool_use_id\":\"tool_use_id\",\"content\": {\"type\":\"text\",\"text\":\"content\"},\"error\":true}").unwrap(), + serde_json::from_str::("{\"type\":\"tool_result\",\"tool_use_id\":\"tool_use_id\",\"content\": {\"type\":\"text\",\"text\":\"content\"},\"is_error\":true}").unwrap(), content_block ); } @@ -1770,7 +1749,7 @@ mod tests { assert_eq!( Content::from(vec![ - ContentBlock::from(ToolResult::new( + ContentBlock::from(ToolResult::success( "tool_use_id", Some("content") )), @@ -1778,7 +1757,7 @@ mod tests { ]) .flatten_into_tool_result() .unwrap(), - ToolResult::new("tool_use_id", Some("content")) + ToolResult::success("tool_use_id", Some("content")) ); } } diff --git a/src/messages/messages_request_body.rs b/src/messages/messages_request_body.rs index c2d6fa9..1d307a7 100644 --- a/src/messages/messages_request_body.rs +++ b/src/messages/messages_request_body.rs @@ -1,7 +1,7 @@ use crate::macros::impl_display_for_serialize; use crate::messages::{ ClaudeModel, MaxTokens, Message, Metadata, StopSequence, StreamOption, - SystemPrompt, Temperature, Tool, TopK, TopP, + SystemPrompt, Temperature, ToolDefinition, TopK, TopP, }; use crate::ValidationError; @@ -67,7 +67,7 @@ pub struct MessagesRequestBody { /// - description: Optional, but strongly-recommended description of the tool. /// - input_schema: JSON schema for the tool input shape that the model will produce in tool_use output content blocks. #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option>, + pub tools: Option>, /// Use nucleus sampling. /// /// In nucleus sampling, we compute the cumulative distribution over all the options for each subsequent token in decreasing probability order and cut it off once it reaches a particular probability specified by top_p. You should either alter temperature or top_p, but not both. @@ -220,7 +220,7 @@ impl MessagesRequestBuilder { /// Sets the tools. pub fn tools( mut self, - tools: Vec, + tools: Vec, ) -> Self { self.request_body.tools = Some(tools); self @@ -400,7 +400,7 @@ mod tests { )]) .stream(StreamOption::ReturnOnce) .temperature(Temperature::new(0.5).unwrap()) - .tools(vec![Tool { + .tools(vec![ToolDefinition { name: "tool".into(), description: Some("tool description".into()), input_schema: serde_json::Value::Null, @@ -444,7 +444,7 @@ mod tests { ); assert_eq!( messages_request_body.tools, - Some(vec![Tool { + Some(vec![ToolDefinition { name: "tool".into(), description: Some("tool description".into()), input_schema: serde_json::Value::Null, @@ -478,7 +478,7 @@ mod tests { )]) .stream(StreamOption::ReturnOnce) .temperature(Temperature::new(0.5).unwrap()) - .tools(vec![Tool { + .tools(vec![ToolDefinition { name: "tool".into(), description: Some("tool description".into()), input_schema: serde_json::Value::Null, @@ -522,7 +522,7 @@ mod tests { ); assert_eq!( messages_request_body.tools, - Some(vec![Tool { + Some(vec![ToolDefinition { name: "tool".into(), description: Some("tool description".into()), input_schema: serde_json::Value::Null, diff --git a/src/messages/tool.rs b/src/messages/tool.rs index 2b33653..fc36ddd 100644 --- a/src/messages/tool.rs +++ b/src/messages/tool.rs @@ -1,11 +1,34 @@ use crate::macros::impl_display_for_serialize; -use crate::messages::TextContentBlock; +use crate::messages::{TextContentBlock, ToolCallError}; +use std::future::Future; /// A tool that can be used by assistant. +pub trait Tool { + /// Gets the definition of the tool. + fn definition(&self) -> ToolDefinition; + /// Calls the tool. + fn call( + &self, + tool_use: ToolUse, + ) -> Result; +} + +/// An asynchronous tool that can be used by assistant. +pub trait AsyncTool { + /// Gets the definition of the tool. + fn definition(&self) -> ToolDefinition; + /// Calls the tool asynchronously. + fn call( + &self, + tool_use: ToolUse, + ) -> impl Future> + Send; +} + +/// A tool definition that can be used by assistant. #[derive( Debug, Clone, PartialEq, Default, serde::Serialize, serde::Deserialize, )] -pub struct Tool { +pub struct ToolDefinition { /// Name of the tool. pub name: String, /// Optional, but strongly-recommended description of the tool. @@ -14,7 +37,26 @@ pub struct Tool { pub input_schema: serde_json::Value, } -impl_display_for_serialize!(Tool); +impl_display_for_serialize!(ToolDefinition); + +impl ToolDefinition { + /// Creates a new `ToolDefinition`. + pub fn new( + name: S, + description: Option, + input_schema: serde_json::Value, + ) -> Self + where + S: Into, + T: Into, + { + Self { + name: name.into(), + description: description.map(Into::into), + input_schema, + } + } +} /// A tool use request. #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] @@ -68,13 +110,16 @@ pub struct ToolResult { /// The result of the tool, as a string (e.g. "content": "65 degrees") or list of nested content blocks (e.g. "content": [{"type": "text", "text": "65 degrees"}]\). During beta, only the text type content blocks are supported for tool_result content. #[serde(skip_serializing_if = "Option::is_none")] pub content: Option, + /// Set to true if the tool execution resulted in an error. + #[serde(skip_serializing_if = "Option::is_none")] + pub is_error: Option, } impl_display_for_serialize!(ToolResult); impl ToolResult { - /// Creates a new `ToolResult`. - pub fn new( + /// Creates a new `ToolResult` as a success. + pub fn success( tool_use_id: S, content: Option, ) -> Self @@ -85,6 +130,23 @@ impl ToolResult { Self { tool_use_id: tool_use_id.into(), content: content.map(Into::into), + is_error: None, + } + } + + /// Creates a new `ToolResult` as an error. + pub fn error( + tool_use_id: S, + content: Option, + ) -> Self + where + S: Into, + T: Into, + { + Self { + tool_use_id: tool_use_id.into(), + content: content.map(Into::into), + is_error: Some(true), } } } @@ -94,8 +156,8 @@ mod tests { use super::*; #[test] - fn default_tool() { - let tool = Tool::default(); + fn default_tool_definition() { + let tool = ToolDefinition::default(); assert_eq!(tool.name, String::default()); assert_eq!(tool.description, None); assert_eq!( @@ -105,41 +167,83 @@ mod tests { } #[test] - fn display_tool() { - let tool = Tool { + fn display_tool_definition() { + let tool = ToolDefinition { name: "tool".to_string(), description: Some("tool description".to_string()), - input_schema: serde_json::json!({}), + input_schema: serde_json::json!({ + "properties": { + "arg1": { + "description": "First argument.", + "type": "integer", + }, + }, + "required": ["arg1"], + "type": "object", + }), }; assert_eq!( tool.to_string(), - "{\n \"name\": \"tool\",\n \"description\": \"tool description\",\n \"input_schema\": {}\n}" + r#"{ + "name": "tool", + "description": "tool description", + "input_schema": { + "properties": { + "arg1": { + "description": "First argument.", + "type": "integer" + } + }, + "required": [ + "arg1" + ], + "type": "object" + } +}"# ); } #[test] - fn serialize_tool() { - let tool = Tool { + fn serialize_tool_definition() { + let tool = ToolDefinition { name: "tool".to_string(), description: Some("tool description".to_string()), - input_schema: serde_json::json!({}), + input_schema: serde_json::json!({ + "properties": { + "arg1": { + "description": "First argument.", + "type": "integer", + }, + }, + "required": ["arg1"], + "type": "object", + }), }; assert_eq!( serde_json::to_string(&tool).unwrap(), - r#"{"name":"tool","description":"tool description","input_schema":{}}"# + r#"{"name":"tool","description":"tool description","input_schema":{"properties":{"arg1":{"description":"First argument.","type":"integer"}},"required":["arg1"],"type":"object"}}"#, ); } #[test] - fn deserialize_tool() { - let tool = Tool { + fn deserialize_tool_definition() { + let tool = ToolDefinition { name: "tool".to_string(), description: Some("tool description".to_string()), - input_schema: serde_json::json!({}), + input_schema: serde_json::json!({ + "properties": { + "arg1": { + "description": "First argument.", + "type": "integer", + }, + }, + "required": ["arg1"], + "type": "object", + }), }; assert_eq!( - serde_json::from_str::( - r#"{"name":"tool","description":"tool description","input_schema":{}}"# + serde_json::from_str::( + r#"{"name":"tool","description":"tool description","input_schema":{"properties":{"arg1":{"description":"First argument.","type":"integer"}},"required":["arg1"],"type":"object"}}"# ) .unwrap(), tool @@ -159,11 +263,19 @@ mod tests { let tool_use = ToolUse { id: "id".to_string(), name: "name".to_string(), - input: serde_json::json!({}), + input: serde_json::json!({ + "arg1": 42, + }), }; assert_eq!( tool_use.to_string(), - "{\n \"id\": \"id\",\n \"name\": \"name\",\n \"input\": {}\n}" + r#"{ + "id": "id", + "name": "name", + "input": { + "arg1": 42 + } +}"# ); } @@ -172,11 +284,11 @@ mod tests { let tool_use = ToolUse { id: "id".to_string(), name: "name".to_string(), - input: serde_json::json!({}), + input: serde_json::json!({"arg1": 42}), }; assert_eq!( serde_json::to_string(&tool_use).unwrap(), - r#"{"id":"id","name":"name","input":{}}"# + r#"{"id":"id","name":"name","input":{"arg1":42}}"# ); } @@ -185,11 +297,11 @@ mod tests { let tool_use = ToolUse { id: "id".to_string(), name: "name".to_string(), - input: serde_json::json!({}), + input: serde_json::json!({"arg1": 42}), }; assert_eq!( serde_json::from_str::( - r#"{"id":"id","name":"name","input":{}}"# + r#"{"id":"id","name":"name","input":{"arg1":42}}"# ) .unwrap(), tool_use @@ -198,10 +310,17 @@ mod tests { #[test] fn new_tool_use() { - let tool_use = ToolUse::new("id", "name", serde_json::json!({})); + let tool_use = ToolUse::new( + "id", + "name", + serde_json::json!({"arg1": 42}), + ); assert_eq!(tool_use.id, "id"); assert_eq!(tool_use.name, "name"); - assert_eq!(tool_use.input, serde_json::json!({})); + assert_eq!( + tool_use.input, + serde_json::json!({"arg1": 42}) + ); } #[test] @@ -219,11 +338,22 @@ mod tests { let tool_result = ToolResult { tool_use_id: "id".to_string(), content: Some(TextContentBlock::new("text")), + is_error: None, }; assert_eq!( tool_result.to_string(), "{\n \"tool_use_id\": \"id\",\n \"content\": {\n \"type\": \"text\",\n \"text\": \"text\"\n }\n}" ); + + let tool_result = ToolResult { + tool_use_id: "id".to_string(), + content: Some(TextContentBlock::new("text")), + is_error: Some(true), + }; + assert_eq!( + tool_result.to_string(), + "{\n \"tool_use_id\": \"id\",\n \"content\": {\n \"type\": \"text\",\n \"text\": \"text\"\n },\n \"is_error\": true\n}" + ); } #[test] @@ -231,11 +361,22 @@ mod tests { let tool_result = ToolResult { tool_use_id: "id".to_string(), content: Some(TextContentBlock::new("text")), + is_error: None, }; assert_eq!( serde_json::to_string(&tool_result).unwrap(), r#"{"tool_use_id":"id","content":{"type":"text","text":"text"}}"# ); + + let tool_result = ToolResult { + tool_use_id: "id".to_string(), + content: Some(TextContentBlock::new("text")), + is_error: Some(true), + }; + assert_eq!( + serde_json::to_string(&tool_result).unwrap(), + r#"{"tool_use_id":"id","content":{"type":"text","text":"text"},"is_error":true}"# + ); } #[test] @@ -243,6 +384,7 @@ mod tests { let tool_result = ToolResult { tool_use_id: "id".to_string(), content: Some(TextContentBlock::new("text")), + is_error: None, }; assert_eq!( serde_json::from_str::( @@ -251,11 +393,34 @@ mod tests { .unwrap(), tool_result ); + + let tool_result = ToolResult { + tool_use_id: "id".to_string(), + content: Some(TextContentBlock::new("text")), + is_error: Some(true), + }; + assert_eq!( + serde_json::from_str::( + r#"{"tool_use_id":"id","content":{"type":"text","text":"text"},"is_error":true}"# + ) + .unwrap(), + tool_result + ); } #[test] fn new_tool_result() { - let tool_result = ToolResult::new( + let tool_result = ToolResult::success( + "id", + Some(TextContentBlock::new("text")), + ); + assert_eq!(tool_result.tool_use_id, "id"); + assert_eq!( + tool_result.content, + Some(TextContentBlock::new("text")) + ); + + let tool_result = ToolResult::error( "id", Some(TextContentBlock::new("text")), );