Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mochi-neko committed Mar 13, 2024
1 parent 19dccc1 commit d203bff
Show file tree
Hide file tree
Showing 9 changed files with 236 additions and 166 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async fn main() -> anyhow::Result<()> {
};

// 3. Call the streaming API.
let stream = client
let mut stream = client
.create_a_message_stream(request_body)
.await?;

Expand Down
2 changes: 1 addition & 1 deletion src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ impl Client {
/// ..Default::default()
/// };
///
/// let stream = client
/// let mut stream = client
/// .create_a_message_stream(request_body)
/// .await?;
///
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
//! };
//!
//! // 3. Call the streaming API.
//! let stream = client
//! let mut stream = client
//! .create_a_message_stream(request_body)
//! .await?;
//!
Expand Down
17 changes: 9 additions & 8 deletions src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,25 @@ pub use error::StreamError;
pub use max_tokens::MaxTokens;
pub use message::Message;
pub use messages_request_body::MessagesRequestBody;
pub use messages_response_body::MessageObjectType;
pub use messages_response_body::MessagesResponseBody;
pub use metadata::Metadata;
pub use metadata::UserId;
pub use result::MessagesResult;
pub use role::Role;
pub use stop_reason::StopReason;
pub use stop_sequence::StopSequence;
pub use stream_chunk::ContentBlockDelta;
pub use stream_chunk::ContentBlockStart;
pub use stream_chunk::ContentBlockStop;
pub use stream_chunk::ContentBlockDeltaChunk;
pub use stream_chunk::ContentBlockStartChunk;
pub use stream_chunk::ContentBlockStopChunk;
pub use stream_chunk::DeltaUsage;
pub use stream_chunk::MessageDelta;
pub use stream_chunk::MessageStart;
pub use stream_chunk::MessageStop;
pub use stream_chunk::Ping;
pub use stream_chunk::MessageDeltaChunk;
pub use stream_chunk::MessageStartChunk;
pub use stream_chunk::MessageStopChunk;
pub use stream_chunk::PingChunk;
pub use stream_chunk::StreamChunk;
pub use stream_chunk::StreamChunkType;
pub use stream_chunk::StreamResult;
pub use stream_chunk::StreamStop;
pub use stream_option::StreamOption;
pub use system_prompt::SystemPrompt;
pub use temperature::Temperature;
Expand Down
18 changes: 9 additions & 9 deletions src/messages/chunk_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,10 @@ data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_seque
| StreamChunk::MessageStart(message_start) => {
assert_eq!(
message_start,
MessageStart::new(MessagesResponseBody {
MessageStartChunk::new(MessagesResponseBody {
id: "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY"
.to_string(),
_type: "message".to_string(),
_type: MessageObjectType::Message,
role: Role::Assistant,
content: vec![].into(),
model: ClaudeModel::Claude3Opus20240229,
Expand All @@ -185,7 +185,7 @@ data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_seque
| StreamChunk::ContentBlockStart(content_block_start) => {
assert_eq!(
content_block_start,
ContentBlockStart::new(0, "".into()),
ContentBlockStartChunk::new(0, "".into()),
);
},
| _ => panic!("unexpected chunk type"),
Expand All @@ -198,7 +198,7 @@ data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_seque
.unwrap();
match chunk {
| StreamChunk::Ping(ping) => {
assert_eq!(ping, Ping::new());
assert_eq!(ping, PingChunk::new());
},
| _ => panic!("unexpected chunk type"),
}
Expand All @@ -212,7 +212,7 @@ data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_seque
| StreamChunk::ContentBlockDelta(content_block_delta) => {
assert_eq!(
content_block_delta,
ContentBlockDelta::new(0, "Hello".into()),
ContentBlockDeltaChunk::new(0, "Hello".into()),
);
},
| _ => panic!("unexpected chunk type"),
Expand All @@ -227,7 +227,7 @@ data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_seque
| StreamChunk::ContentBlockDelta(content_block_delta) => {
assert_eq!(
content_block_delta,
ContentBlockDelta::new(0, "!".into()),
ContentBlockDeltaChunk::new(0, "!".into()),
);
},
| _ => panic!("unexpected chunk type"),
Expand All @@ -242,7 +242,7 @@ data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_seque
| StreamChunk::ContentBlockStop(content_block_stop) => {
assert_eq!(
content_block_stop,
ContentBlockStop::new(0),
ContentBlockStopChunk::new(0),
);
},
| _ => panic!("unexpected chunk type"),
Expand All @@ -257,8 +257,8 @@ data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_seque
| StreamChunk::MessageDelta(message_delta) => {
assert_eq!(
message_delta,
MessageDelta::new(
StreamResult {
MessageDeltaChunk::new(
StreamStop {
stop_reason: Some(StopReason::EndTurn),
stop_sequence: None,
},
Expand Down
2 changes: 1 addition & 1 deletion src/messages/message.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::macros::impl_display_for_serialize;
use crate::messages::{Content, Role};

/// The input message.
/// The message.
#[derive(
Debug, Clone, PartialEq, Default, serde::Serialize, serde::Deserialize,
)]
Expand Down
91 changes: 71 additions & 20 deletions src/messages/messages_response_body.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
use crate::macros::impl_display_for_serialize;
use crate::macros::{
impl_display_for_serialize, impl_enum_string_serialization,
};
use crate::messages::{
ClaudeModel, Content, Role, StopReason, StopSequence, Usage,
};
use std::fmt::{Display, Formatter};

/// The response body for the Messages API.
///
/// See also [the Messages API](https://docs.anthropic.com/claude/reference/messages_post).
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[derive(
Debug, Clone, PartialEq, Default, serde::Serialize, serde::Deserialize,
)]
pub struct MessagesResponseBody {
/// Unique object identifier.
///
Expand All @@ -16,7 +21,7 @@ pub struct MessagesResponseBody {
///
/// For Messages, this is always "message".
#[serde(rename = "type")]
pub _type: String,
pub _type: MessageObjectType,
/// Conversational role of the generated message.
///
/// This will always be "assistant".
Expand Down Expand Up @@ -52,22 +57,36 @@ pub struct MessagesResponseBody {
pub usage: Usage,
}

impl Default for MessagesResponseBody {
impl_display_for_serialize!(MessagesResponseBody);

/// The object type for message.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MessageObjectType {
/// message
Message,
}

impl Default for MessageObjectType {
fn default() -> Self {
Self {
id: Default::default(),
_type: "message".to_string(),
role: Default::default(),
content: Default::default(),
model: Default::default(),
stop_reason: Default::default(),
stop_sequence: Default::default(),
usage: Default::default(),
Self::Message
}
}

impl Display for MessageObjectType {
fn fmt(
&self,
f: &mut Formatter<'_>,
) -> std::fmt::Result {
match self {
| MessageObjectType::Message => write!(f, "{}", "message"),
}
}
}

impl_display_for_serialize!(MessagesResponseBody);
impl_enum_string_serialization!(
MessageObjectType,
Message => "message"
);

#[cfg(test)]
mod tests {
Expand All @@ -77,7 +96,7 @@ mod tests {
fn serialize() {
let response = MessagesResponseBody {
id: "id".to_string(),
_type: "type".to_string(),
_type: MessageObjectType::Message,
role: Role::Assistant,
content: "content".into(),
model: ClaudeModel::Claude3Sonnet20240229,
Expand All @@ -90,15 +109,15 @@ mod tests {
};
assert_eq!(
serde_json::to_string(&response).unwrap(),
"{\"id\":\"id\",\"type\":\"type\",\"role\":\"assistant\",\"content\":\"content\",\"model\":\"claude-3-sonnet-20240229\",\"stop_reason\":\"end_turn\",\"stop_sequence\":\"stop_sequence\",\"usage\":{\"input_tokens\":1,\"output_tokens\":2}}"
"{\"id\":\"id\",\"type\":\"message\",\"role\":\"assistant\",\"content\":\"content\",\"model\":\"claude-3-sonnet-20240229\",\"stop_reason\":\"end_turn\",\"stop_sequence\":\"stop_sequence\",\"usage\":{\"input_tokens\":1,\"output_tokens\":2}}"
);
}

#[test]
fn deserialize() {
let response = MessagesResponseBody {
id: "id".to_string(),
_type: "type".to_string(),
_type: MessageObjectType::Message,
role: Role::Assistant,
content: "content".into(),
model: ClaudeModel::Claude3Sonnet20240229,
Expand All @@ -111,7 +130,7 @@ mod tests {
};
assert_eq!(
serde_json::from_str::<MessagesResponseBody>(
"{\"id\":\"id\",\"type\":\"type\",\"role\":\"assistant\",\"content\":\"content\",\"model\":\"claude-3-sonnet-20240229\",\"stop_reason\":\"end_turn\",\"stop_sequence\":\"stop_sequence\",\"usage\":{\"input_tokens\":1,\"output_tokens\":2}}"
"{\"id\":\"id\",\"type\":\"message\",\"role\":\"assistant\",\"content\":\"content\",\"model\":\"claude-3-sonnet-20240229\",\"stop_reason\":\"end_turn\",\"stop_sequence\":\"stop_sequence\",\"usage\":{\"input_tokens\":1,\"output_tokens\":2}}"
).unwrap(),
response
);
Expand All @@ -121,7 +140,7 @@ mod tests {
fn display() {
let response = MessagesResponseBody {
id: "id".to_string(),
_type: "type".to_string(),
_type: MessageObjectType::Message,
role: Role::Assistant,
content: "content".into(),
model: ClaudeModel::Claude3Sonnet20240229,
Expand All @@ -134,7 +153,39 @@ mod tests {
};
assert_eq!(
response.to_string(),
"{\n \"id\": \"id\",\n \"type\": \"type\",\n \"role\": \"assistant\",\n \"content\": \"content\",\n \"model\": \"claude-3-sonnet-20240229\",\n \"stop_reason\": \"end_turn\",\n \"stop_sequence\": \"stop_sequence\",\n \"usage\": {\n \"input_tokens\": 1,\n \"output_tokens\": 2\n }\n}"
"{\n \"id\": \"id\",\n \"type\": \"message\",\n \"role\": \"assistant\",\n \"content\": \"content\",\n \"model\": \"claude-3-sonnet-20240229\",\n \"stop_reason\": \"end_turn\",\n \"stop_sequence\": \"stop_sequence\",\n \"usage\": {\n \"input_tokens\": 1,\n \"output_tokens\": 2\n }\n}"
);
}

#[test]
fn default_message_object_type() {
assert_eq!(
MessageObjectType::default(),
MessageObjectType::Message
);
}

#[test]
fn message_object_type_display() {
assert_eq!(
MessageObjectType::Message.to_string(),
"message"
);
}

#[test]
fn message_object_type_serialize() {
assert_eq!(
serde_json::to_string(&MessageObjectType::Message).unwrap(),
"\"message\""
);
}

#[test]
fn message_object_type_deserialize() {
assert_eq!(
serde_json::from_str::<MessageObjectType>("\"message\"").unwrap(),
MessageObjectType::Message
);
}
}
Loading

0 comments on commit d203bff

Please sign in to comment.