Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derive Hash #7

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ all-features = true
default = []
macros = ["dep:clust_macros"]
full = ["macros"]
hash = []

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

Expand All @@ -35,6 +36,7 @@ thiserror = "1.0.*"
pin-project = "1.1.*"
futures-core = "0.3.*"
clust_macros = { version = "0.9.0", optional = true }
tracing = { version = "0.1.*", optional = true, features = ["attributes"] }

[dev-dependencies]
anyhow = "1.0.86"
Expand All @@ -43,3 +45,4 @@ tokio = { version = "1.38.0", features = ["macros", "rt-multi-thread", "fs"] }
futures-util = "0.3.30"
tokio-stream = "0.1.15"
base64 = "0.22.1"
tracing-subscriber = { version = "0.3.18" , features = ["env-filter"] }
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ clust = "0.9.0"

- `macros`: Enable the `clust::attributse::clust_tool` attribute macro for generating `clust::messages::Tool`
or `clust::messages::AsyncTool` from a Rust function.
- `hash`: Derive `Hash` for more types.
- Some types already implement `Hash` even without this feature flag.
- Some types can't implement `Hash` because they contain fields of `f32` type, which doesn't implement `Hash`.

## Usages

Expand Down
12 changes: 12 additions & 0 deletions examples/create_a_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ struct Arguments {

#[tokio::main]
async fn main() -> anyhow::Result<()> {
#[cfg(feature = "tracing")]
init_tracing_subscriber();

// 0. Parse the command-line arguments.
let arguments = Arguments::parse();

Expand Down Expand Up @@ -68,3 +71,12 @@ async fn main() -> anyhow::Result<()> {

Ok(())
}

#[cfg(feature = "tracing")]
fn init_tracing_subscriber() {
use tracing_subscriber::util::SubscriberInitExt;
let subscriber = tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.finish();
subscriber.init();
}
12 changes: 12 additions & 0 deletions examples/streaming_messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ struct Arguments {

#[tokio::main]
async fn main() -> anyhow::Result<()> {
#[cfg(feature = "tracing")]
init_tracing_subscriber();

// 0. Parse the command-line arguments.
let arguments = Arguments::parse();

Expand Down Expand Up @@ -86,3 +89,12 @@ async fn main() -> anyhow::Result<()> {

Ok(())
}

#[cfg(feature = "tracing")]
fn init_tracing_subscriber() {
use tracing_subscriber::util::SubscriberInitExt;
let subscriber = tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.finish();
subscriber.init();
}
2 changes: 1 addition & 1 deletion src/api_key.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::env::VarError;

/// The API key of the Anthropic API.
#[derive(Clone, Eq, PartialEq, Hash)]
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
pub struct ApiKey {
value: String,
}
Expand Down
2 changes: 1 addition & 1 deletion src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::messages::{
use crate::{ApiKey, Beta, Version};

/// The API client.
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct Client {
/// The API key.
api_key: ApiKey,
Expand Down
2 changes: 2 additions & 0 deletions src/messages/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::ClientError;

use futures_core::Stream;

#[cfg_attr(feature = "tracing", tracing::instrument)]
pub(crate) async fn create_a_message(
client: &Client,
request_body: MessagesRequestBody,
Expand Down Expand Up @@ -65,6 +66,7 @@ pub(crate) async fn create_a_message(
}
}

#[cfg_attr(feature = "tracing", tracing::instrument)]
pub(crate) async fn create_a_message_stream(
client: &Client,
request_body: MessagesRequestBody,
Expand Down
7 changes: 7 additions & 0 deletions src/messages/content.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ use crate::messages::{
/// ].into();
/// ```
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "hash", derive(Hash))]
pub enum Content {
/// The single text content.
SingleText(String),
Expand Down Expand Up @@ -206,6 +207,7 @@ impl Content {

/// The content block of the message.
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "hash", derive(Hash))]
pub enum ContentBlock {
/// The text content block.
Text(TextContentBlock),
Expand Down Expand Up @@ -266,6 +268,7 @@ impl_display_for_serialize!(ContentBlock);

/// The text content block.
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "hash", derive(Hash))]
pub struct TextContentBlock {
/// The content type. It is always `text`.
#[serde(rename = "type")]
Expand Down Expand Up @@ -312,6 +315,7 @@ impl TextContentBlock {

/// The image content block.
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "hash", derive(Hash))]
pub struct ImageContentBlock {
/// The content type. It is always `image`.
#[serde(rename = "type")]
Expand Down Expand Up @@ -404,6 +408,7 @@ impl_enum_string_serialization!(

/// The image content source.
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "hash", derive(Hash))]
pub struct ImageContentSource {
/// The source type.
#[serde(rename = "type")]
Expand Down Expand Up @@ -548,6 +553,7 @@ impl ImageMediaType {

/// The tool use content block.
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "hash", derive(Hash))]
pub struct ToolUseContentBlock {
/// The content type. It is always `tool_use`.
#[serde(rename = "type")]
Expand Down Expand Up @@ -586,6 +592,7 @@ impl ToolUseContentBlock {

/// The tool result content block.
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "hash", derive(Hash))]
pub struct ToolResultContentBlock {
/// The content type. It is always `tool_result`.
#[serde(rename = "type")]
Expand Down
9 changes: 5 additions & 4 deletions src/messages/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::messages::{Content, Role};
#[derive(
Debug, Clone, PartialEq, Default, serde::Serialize, serde::Deserialize,
)]
#[cfg_attr(feature = "hash", derive(Hash))]
pub struct Message {
/// The role of the message.
pub role: Role,
Expand Down Expand Up @@ -44,7 +45,7 @@ impl Message {
/// ## Example
/// ```rust
/// use clust::messages::{Content, Message};
///
///
/// let message = Message::assistant(Content::from("assistant message"));
/// ```
pub fn assistant<T>(content: T) -> Self
Expand All @@ -66,7 +67,7 @@ impl Message {
/// ## Example
/// ```rust
/// use clust::messages::{Content, Message, Role};
///
///
/// let message = Message::new(Role::User, Content::from("user message"));
/// let message = Message::new(Role::Assistant, Content::from("assistant message"));
/// ```
Expand Down Expand Up @@ -104,13 +105,13 @@ mod tests {
"assistant-message".into()
);
}

#[test]
fn new() {
let message = Message::new(Role::User, "user-message");
assert_eq!(message.role, Role::User);
assert_eq!(message.content, "user-message".into());

let message = Message::new(Role::Assistant, "assistant-message");
assert_eq!(message.role, Role::Assistant);
assert_eq!(
Expand Down
10 changes: 10 additions & 0 deletions src/messages/message_chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::messages::{

/// The stream chunk of messages.
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "hash", derive(Hash))]
pub enum MessageChunk {
/// Message start chunk.
MessageStart(MessageStartChunk),
Expand Down Expand Up @@ -279,6 +280,7 @@ impl_enum_string_serialization!(

/// The message start chunk.
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "hash", derive(Hash))]
pub struct MessageStartChunk {
/// The type of stream chunk.
#[serde(rename = "type")]
Expand Down Expand Up @@ -310,6 +312,7 @@ impl MessageStartChunk {

/// The content block start chunk.
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "hash", derive(Hash))]
pub struct ContentBlockStartChunk {
/// The type of stream chunk.
#[serde(rename = "type")]
Expand Down Expand Up @@ -348,6 +351,7 @@ impl ContentBlockStartChunk {

/// The ping chunk.
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "hash", derive(Hash))]
pub struct PingChunk {
/// The type of stream chunk.
#[serde(rename = "type")]
Expand Down Expand Up @@ -375,6 +379,7 @@ impl PingChunk {

/// The content block delta chunk.
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "hash", derive(Hash))]
pub struct ContentBlockDeltaChunk {
/// The type of stream chunk.
#[serde(rename = "type")]
Expand Down Expand Up @@ -413,6 +418,7 @@ impl ContentBlockDeltaChunk {

/// The content block stop chunk.
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "hash", derive(Hash))]
pub struct ContentBlockStopChunk {
/// The type of stream chunk.
#[serde(rename = "type")]
Expand Down Expand Up @@ -444,6 +450,7 @@ impl ContentBlockStopChunk {

/// The message delta chunk.
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "hash", derive(Hash))]
pub struct MessageDeltaChunk {
/// The type of stream chunk.
#[serde(rename = "type")]
Expand Down Expand Up @@ -482,6 +489,7 @@ impl MessageDeltaChunk {

/// The message stop chunk.
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "hash", derive(Hash))]
pub struct MessageStopChunk {
/// The type of stream chunk.
#[serde(rename = "type")]
Expand Down Expand Up @@ -509,6 +517,7 @@ impl MessageStopChunk {

/// The text delta content block.
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "hash", derive(Hash))]
pub struct TextDeltaContentBlock {
/// The content type. It is always `text_delta`.
#[serde(rename = "type")]
Expand Down Expand Up @@ -557,6 +566,7 @@ impl TextDeltaContentBlock {
#[derive(
Debug, Clone, PartialEq, Default, serde::Serialize, serde::Deserialize,
)]
#[cfg_attr(feature = "hash", derive(Hash))]
pub struct StreamStop {
/// The stop reason of this stream.
pub stop_reason: Option<StopReason>,
Expand Down
5 changes: 3 additions & 2 deletions src/messages/messages_response_body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::messages::{
///
/// See also [the Messages API](https://docs.anthropic.com/claude/reference/messages_post).
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "hash", derive(Hash))]
pub struct MessagesResponseBody {
/// Unique object identifier.
///
Expand Down Expand Up @@ -85,7 +86,7 @@ impl MessagesResponseBody {
}

/// The object type for message.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MessageObjectType {
/// message
Message,
Expand Down Expand Up @@ -215,7 +216,7 @@ mod tests {
MessageObjectType::Message
);
}

#[test]
fn create_message() {
let response = MessagesResponseBody {
Expand Down
2 changes: 2 additions & 0 deletions src/messages/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::fmt::Display;

/// An object describing metadata about the request.
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "hash", derive(Hash))]
pub struct Metadata {
/// An external identifier for the user who is associated with the request.
pub user_id: UserId,
Expand All @@ -15,6 +16,7 @@ impl_display_for_serialize!(Metadata);
/// This should be an uuid, hash value, or other opaque identifier. Anthropic may use this id to help detect abuse.
/// Do not include any identifying information such as name, email address, or phone number.
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "hash", derive(Hash))]
#[serde(transparent)]
pub struct UserId {
value: String,
Expand Down
4 changes: 2 additions & 2 deletions src/messages/stop_sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::fmt::Display;

/// The stop sequence.
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "hash", derive(Hash))]
#[serde(transparent)]
pub struct StopSequence {
value: String,
Expand Down Expand Up @@ -52,7 +53,7 @@ mod tests {
"\"stop-sequence\""
);
}

#[test]
fn deserialize() {
let stop_sequence = StopSequence::new("stop-sequence");
Expand All @@ -62,4 +63,3 @@ mod tests {
);
}
}

1 change: 1 addition & 0 deletions src/messages/system_prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::fmt::Display;
/// A system prompt is a way of providing context and instructions to Claude, such as specifying a particular goal or role.
/// See our [guide to system prompts](https://docs.anthropic.com/claude/docs/system-prompts).
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "hash", derive(Hash))]
#[serde(transparent)]
pub struct SystemPrompt {
value: String,
Expand Down
3 changes: 3 additions & 0 deletions src/messages/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub trait AsyncTool {
#[derive(
Debug, Clone, PartialEq, Default, serde::Serialize, serde::Deserialize,
)]
#[cfg_attr(feature = "hash", derive(Hash))]
pub struct ToolDefinition {
/// Name of the tool.
pub name: String,
Expand Down Expand Up @@ -60,6 +61,7 @@ impl ToolDefinition {

/// A tool use request.
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "hash", derive(Hash))]
pub struct ToolUse {
/// The ID of the used tool.
pub id: String,
Expand Down Expand Up @@ -104,6 +106,7 @@ impl ToolUse {
#[derive(
Debug, Clone, PartialEq, Default, serde::Serialize, serde::Deserialize,
)]
#[cfg_attr(feature = "hash", derive(Hash))]
pub struct ToolResult {
/// The id of the tool use request this is a result for.
pub tool_use_id: String,
Expand Down