Skip to content

Commit

Permalink
Support Option<T>
Browse files Browse the repository at this point in the history
  • Loading branch information
mochi-neko committed Apr 30, 2024
1 parent a789b0a commit 8bff71c
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 52 deletions.
2 changes: 2 additions & 0 deletions clust_macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ mod tool;
/// - `Vec<T>` where `T` is supported type.
/// - `&[T]` where `T` is supported type.
/// - `&[T; N]` where `T` is supported type and `N` is a constant.
/// - Option
/// - `Option<T>` where `T` is supported type.
/// - e.g. `fn function(arg1: i32, arg2: String, arg3: Vec<f64>) -> T`
///
/// ## Supported return values
Expand Down
85 changes: 42 additions & 43 deletions clust_macros/src/parameter_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub(crate) enum ParameterType {
Number,
String,
Array(Box<ParameterType>),
//Option(Box<ParameterType>), // TODO: Implement function argument parsing for Option that does not implement FromStr(= cannot parse from string).
Option(Box<ParameterType>),
//Enum(Vec<String>),
Object,
}
Expand All @@ -37,9 +37,9 @@ impl Display for ParameterType {
| ParameterType::Array(inner) => {
write!(f, "array of {}", inner)
},
// | ParameterType::Option(inner) => {
// write!(f, "option of {}", inner)
// },
| ParameterType::Option(inner) => {
write!(f, "option of {}", inner)
},
| ParameterType::Object => write!(f, "object"),
}
}
Expand All @@ -51,20 +51,19 @@ impl ParameterType {
| Type::Path(type_path) => {
let path_segments = &type_path.path.segments;
if let Some(first) = path_segments.first() {
// TODO: For Option
// if first.ident == "Option" {
// if let PathArguments::AngleBracketed(args) =
// first.arguments.clone()
// {
// if let Some(arg) = args.args.last() {
// if let GenericArgument::Type(ty) = arg {
// return Self::Option(Box::new(
// ParameterType::from_syn_type(ty),
// ));
// }
// }
// }
// }
if first.ident == "Option" {
if let PathArguments::AngleBracketed(args) =
first.arguments.clone()
{
if let Some(arg) = args.args.last() {
if let GenericArgument::Type(ty) = arg {
return Self::Option(Box::new(
ParameterType::from_syn_type(ty),
));
}
}
}
}

if first.ident == "Vec" {
if let PathArguments::AngleBracketed(args) =
Expand Down Expand Up @@ -158,14 +157,14 @@ impl ParameterType {
| ParameterType::Number => PrimitiveType::Number,
| ParameterType::String => PrimitiveType::String,
| ParameterType::Array(_) => PrimitiveType::Array,
// | ParameterType::Option(inner) => inner.to_primitive_type(),
| ParameterType::Option(inner) => inner.to_primitive_type(),
| ParameterType::Object => PrimitiveType::Object,
}
}

pub(crate) fn optional(&self) -> bool {
match self {
// | ParameterType::Option(_) => true,
| ParameterType::Option(_) => true,
| _ => false,
}
}
Expand Down Expand Up @@ -323,29 +322,29 @@ mod tests {
);
}

// #[test]
// fn option() {
// assert_eq!(
// ParameterType::from_syn_type(
// &syn::parse_str::<Type>("Option<i32>").unwrap()
// ),
// ParameterType::Option(Box::new(ParameterType::Integer))
// );
//
// assert_eq!(
// ParameterType::from_syn_type(
// &syn::parse_str::<Type>("Option<bool>").unwrap()
// ),
// ParameterType::Option(Box::new(ParameterType::Boolean))
// );
//
// assert_eq!(
// ParameterType::from_syn_type(
// &syn::parse_str::<Type>("Option<String>").unwrap()
// ),
// ParameterType::Option(Box::new(ParameterType::String))
// );
// }
#[test]
fn option() {
assert_eq!(
ParameterType::from_syn_type(
&syn::parse_str::<Type>("Option<i32>").unwrap()
),
ParameterType::Option(Box::new(ParameterType::Integer))
);

assert_eq!(
ParameterType::from_syn_type(
&syn::parse_str::<Type>("Option<bool>").unwrap()
),
ParameterType::Option(Box::new(ParameterType::Boolean))
);

assert_eq!(
ParameterType::from_syn_type(
&syn::parse_str::<Type>("Option<String>").unwrap()
),
ParameterType::Option(Box::new(ParameterType::String))
);
}

#[test]
fn object() {
Expand Down
31 changes: 22 additions & 9 deletions clust_macros/src/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,17 +297,30 @@ fn quote_invoke_parameters(
info
.parameters
.iter()
.map(|parameter| parameter.name.clone())
.map(|name| {
quote! {
serde_json::from_value(
tool_use
.input
.get(#name)
.ok_or_else(|| clust::messages::ToolCallError::ParameterNotFound(#name.to_string()))?
.clone()
.map(|parameter| {
let name = parameter.name.clone();
if !parameter._type.optional() {
quote! {
serde_json::from_value(
tool_use
.input
.get(#name)
.ok_or_else(|| clust::messages::ToolCallError::ParameterNotFound(#name.to_string()))?
.clone()
)
.map_err(|_| clust::messages::ToolCallError::ParameterParseFailed(#name.to_string()))?
}
} else {
quote! {
serde_json::from_value(
tool_use
.input
.get(#name)
.unwrap_or(&serde_json::Value::Null)
.clone()
)
.map_err(|_| clust::messages::ToolCallError::ParameterParseFailed(#name.to_string()))?
}
}
})
.collect()
Expand Down
71 changes: 71 additions & 0 deletions clust_macros/tests/tool_with_optional_arg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use clust::messages::{Tool, ToolUse};

use clust_macros::clust_tool;

/// A function for testing.
///
/// ## Arguments
/// - `arg1` - First argument.
#[clust_tool]
fn test_function(arg1: Option<i32>) -> i32 {
if let Some(arg1) = arg1 {
arg1
} else {
0
}
}

#[test]
fn test_description() {
let tool = ClustTool_test_function {};

assert_eq!(
tool.definition().to_string(),
r#"{
"name": "test_function",
"description": "A function for testing.",
"input_schema": {
"description": "A function for testing.",
"properties": {
"arg1": {
"description": "First argument.",
"type": "integer"
}
},
"required": [],
"type": "object"
}
}"#
);
}

#[test]
fn test_call() {
let tool = ClustTool_test_function {};

let tool_use = ToolUse::new(
"toolu_XXXX",
"test_function",
serde_json::json!({}),
);

let result = tool.call(tool_use).unwrap();

assert_eq!(result.tool_use_id, "toolu_XXXX");
assert_eq!(result.is_error, None);
assert_eq!(result.content.unwrap().text, "0");

let tool_use = ToolUse::new(
"toolu_XXXX",
"test_function",
serde_json::json!({
"arg1": 42
}),
);

let result = tool.call(tool_use).unwrap();

assert_eq!(result.tool_use_id, "toolu_XXXX");
assert_eq!(result.is_error, None);
assert_eq!(result.content.unwrap().text, "42");
}

0 comments on commit 8bff71c

Please sign in to comment.