diff --git a/src/error.rs b/src/error.rs index 5d4ed62..f085ba8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,5 @@ -use reqwest::{Error as ReqwestError, StatusCode}; +use reqwest::{Error as ReqwestError, Response, StatusCode}; +use serde_json::Value; use thiserror::Error; #[derive(Debug, Error)] @@ -32,7 +33,13 @@ pub enum HeliusError { } impl HeliusError { - pub fn from_response_status(status: StatusCode, path: String, text: String) -> Self { + pub async fn from_response_status(status: StatusCode, path: String, response: Response) -> Self { + let body: String = response.text().await.unwrap_or_default(); + let v: Value = serde_json::from_str(&body).unwrap_or_default(); + + // Extract only the message part of the JSON + let text: String = v["message"].as_str().unwrap_or("").to_string(); + match status { StatusCode::BAD_REQUEST => HeliusError::BadRequest { path, text }, StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => HeliusError::Unauthorized { path, text }, diff --git a/src/request_handler.rs b/src/request_handler.rs index 1ab8f19..88b08e3 100644 --- a/src/request_handler.rs +++ b/src/request_handler.rs @@ -40,10 +40,7 @@ impl RequestHandler { match status { StatusCode::OK | StatusCode::CREATED => response.json::().await.map_err(HeliusError::SerdeJson), - _ => { - let error_text = response.text().await.unwrap_or_else(|_| "Failed to read response body".to_string()); - Err(HeliusError::from_response_status(status, path, error_text)) - } + _ => Err(HeliusError::from_response_status(status, path, response).await), } } } diff --git a/tests/test_request_handler.rs b/tests/test_request_handler.rs new file mode 100644 index 0000000..4f08c0b --- /dev/null +++ b/tests/test_request_handler.rs @@ -0,0 +1,65 @@ +use helius_sdk::error::{HeliusError, Result}; +use helius_sdk::request_handler::RequestHandler; + +use mockito::{self, Server}; +use reqwest::{Client, Method}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +#[derive(Serialize, Deserialize, Debug, Default)] +struct MockResponse { + message: String, +} + +#[tokio::test] +async fn test_successful_request() { + let mut server: Server = Server::new_with_opts_async(mockito::ServerOpts::default()).await; + let url: String = server.url(); + + server + .mock("GET", "/") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"message": "success"}"#) + .create(); + + let client: Arc = Arc::new(Client::new()); + let handler: RequestHandler = RequestHandler::new(client).unwrap(); + + let response: Result = handler + .send::<(), MockResponse>(Method::GET, url.parse().unwrap(), None) + .await; + + assert!(response.is_ok()); + assert_eq!(response.unwrap().message, "success"); + + server.reset(); +} + +#[tokio::test] +async fn test_bad_request_error() { + let mut server: Server = Server::new_with_opts_async(mockito::ServerOpts::default()).await; + let url: String = server.url(); + + server + .mock("GET", "/") + .with_status(400) + .with_header("content-type", "application/json") + .with_body(r#"{"message": "bad request"}"#) + .create(); + + let client: Arc = Arc::new(Client::new()); + let handler: RequestHandler = RequestHandler::new(client).unwrap(); + + let response: Result = handler + .send::<(), MockResponse>(Method::GET, url.parse().unwrap(), None) + .await; + + assert!(response.is_err()); + match response { + Err(HeliusError::BadRequest { text, .. }) => assert_eq!(text, "bad request"), + _ => panic!("Expected BadRequest error"), + } + + server.reset(); +}