diff --git a/concatsql/src/connection.rs b/concatsql/src/connection.rs index 53555ac..52e98ce 100644 --- a/concatsql/src/connection.rs +++ b/concatsql/src/connection.rs @@ -1,34 +1,51 @@ -use std::fmt; -use std::cell::Cell; use std::borrow::Cow; +use std::cell::Cell; +use std::fmt; -use crate::Result; -use crate::ErrorLevel; use crate::row::Row; -use crate::wrapstring::{WrapString, IntoWrapString}; use crate::value::Value; +use crate::wrapstring::{IntoWrapString, WrapString}; +use crate::ErrorLevel; +use crate::Result; #[allow(clippy::type_complexity)] pub(crate) trait ConcatsqlConn { - fn execute_inner<'a>(&self, query: Cow<'a, str>, params: &[Value<'a>], error_level: &crate::ErrorLevel) -> Result<()>; - fn iterate_inner<'a>(&self, query: Cow<'a, str>, params: &[Value<'a>], error_level: &crate::ErrorLevel, - callback: &mut dyn FnMut(&[(&str, Option<&str>)]) -> bool) -> Result<()>; - fn rows_inner<'a, 'r>(&self, query: Cow<'a, str>, params: &[Value<'a>], error_level: &crate::ErrorLevel) - -> Result>>; + fn execute_inner<'a>( + &self, + query: Cow<'a, str>, + params: &[Value<'a>], + error_level: &crate::ErrorLevel, + ) -> Result<()>; + fn iterate_inner<'a>( + &self, + query: Cow<'a, str>, + params: &[Value<'a>], + error_level: &crate::ErrorLevel, + callback: &mut dyn FnMut(&[(&str, Option<&str>)]) -> bool, + ) -> Result<()>; + fn rows_inner<'a, 'r>( + &self, + query: Cow<'a, str>, + params: &[Value<'a>], + error_level: &crate::ErrorLevel, + ) -> Result>>; fn close(&self); fn kind(&self) -> ConnKind; } #[doc(hidden)] pub enum ConnKind { - #[cfg(feature = "sqlite")] SQLite, - #[cfg(feature = "mysql")] MySQL, - #[cfg(feature = "postgres")] PostgreSQL, + #[cfg(feature = "sqlite")] + SQLite, + #[cfg(feature = "mysql")] + MySQL, + #[cfg(feature = "postgres")] + PostgreSQL, } /// A database connection. pub struct Connection { - pub(crate) conn: Box, + pub(crate) conn: Box, pub(crate) error_level: Cell, } @@ -67,7 +84,11 @@ impl<'a> Connection { /// ``` #[inline] pub fn execute>(&self, query: T) -> Result<()> { - self.conn.execute_inner(query.compile(self.conn.kind()), query.params(), &self.error_level.get()) + self.conn.execute_inner( + query.compile(self.conn.kind()), + query.params(), + &self.error_level.get(), + ) } /// Execute a statement and process the resulting rows as plain text. @@ -94,10 +115,15 @@ impl<'a> Connection { /// ``` #[inline] pub fn iterate, F>(&self, query: T, mut callback: F) -> Result<()> - where - F: FnMut(&[(&str, Option<&str>)]) -> bool, + where + F: FnMut(&[(&str, Option<&str>)]) -> bool, { - self.conn.iterate_inner(query.compile(self.conn.kind()), query.params(), &self.error_level.get(), &mut callback) + self.conn.iterate_inner( + query.compile(self.conn.kind()), + query.params(), + &self.error_level.get(), + &mut callback, + ) } /// Execute a statement and returns the rows. @@ -119,7 +145,11 @@ impl<'a> Connection { /// ``` #[inline] pub fn rows<'r, T: IntoWrapString<'a>>(&self, query: T) -> Result>> { - self.conn.rows_inner(query.compile(self.conn.kind()), query.params(), &self.error_level.get()) + self.conn.rows_inner( + query.compile(self.conn.kind()), + query.params(), + &self.error_level.get(), + ) } /// Sets the error level. @@ -170,4 +200,3 @@ impl Drop for Connection { pub unsafe fn without_escape(query: &T) -> WrapString { WrapString::new(query) } - diff --git a/concatsql/src/error.rs b/concatsql/src/error.rs index 8250d89..8ecf79e 100644 --- a/concatsql/src/error.rs +++ b/concatsql/src/error.rs @@ -41,28 +41,36 @@ impl Default for ErrorLevel { impl Error { #[allow(unused_variables)] - pub(crate) fn new(error_level: &ErrorLevel, err_msg: E1, detail_msg: E2) -> Result<(), Error> - where - E1: ToString, - E2: ToString, + pub(crate) fn new( + error_level: &ErrorLevel, + err_msg: E1, + detail_msg: E2, + ) -> Result<(), Error> + where + E1: ToString, + E2: ToString, { match error_level { ErrorLevel::AlwaysOk => Ok(()), - ErrorLevel::Release => Err(Error::AnyError), - ErrorLevel::Develop => Err(Error::Message(err_msg.to_string())), + ErrorLevel::Release => Err(Error::AnyError), + ErrorLevel::Develop => Err(Error::Message(err_msg.to_string())), #[cfg(debug_assertions)] - ErrorLevel::Debug => Err(Error::Message(err_msg.to_string() + ": " + &detail_msg.to_string())), + ErrorLevel::Debug => Err(Error::Message( + err_msg.to_string() + ": " + &detail_msg.to_string(), + )), } } } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", + write!( + f, + "{}", match self { - Error::Message(s) => s.to_string(), - Error::AnyError => String::from("AnyError"), - Error::ParseError => String::from("ParseError"), + Error::Message(s) => s.to_string(), + Error::AnyError => String::from("AnyError"), + Error::ParseError => String::from("ParseError"), Error::ColumnNotFound => String::from("ColumnNotFound"), } ) @@ -84,18 +92,19 @@ mod tests { fn errors() { assert_eq!(ErrorLevel::default(), ErrorLevel::Develop); assert_eq!(Error::Message("test".to_string()).to_string(), "test"); + assert_eq!(Error::new(&ErrorLevel::AlwaysOk, "test", "test"), Ok(())); assert_eq!( - Error::new(&ErrorLevel::AlwaysOk, "test", "test"), - Ok(())); + Error::new(&ErrorLevel::Release, "test", "test"), + Err(Error::AnyError) + ); assert_eq!( - Error::new(&ErrorLevel::Release, "test", "test"), - Err(Error::AnyError)); + Error::new(&ErrorLevel::Develop, "test", "test"), + Err(Error::Message("test".into())) + ); assert_eq!( - Error::new(&ErrorLevel::Develop, "test", "test"), - Err(Error::Message("test".into()))); - assert_eq!( - Error::new(&ErrorLevel::Debug, "test", "test"), - Err(Error::Message("test: test".into()))); + Error::new(&ErrorLevel::Debug, "test", "test"), + Err(Error::Message("test: test".into())) + ); } #[test] @@ -110,11 +119,11 @@ mod tests { conn.execute({ conn.error_level(ErrorLevel::Develop); "SELECT 1" - }).unwrap(); + }) + .unwrap(); conn.error_level({ conn.execute("SELECT 1").unwrap(); ErrorLevel::Develop }); } } - diff --git a/concatsql/src/lib.rs b/concatsql/src/lib.rs index c2f1a6d..c0b96c1 100644 --- a/concatsql/src/lib.rs +++ b/concatsql/src/lib.rs @@ -38,46 +38,46 @@ mod connection; mod error; mod parser; mod row; -mod wrapstring; mod value; +mod wrapstring; -#[cfg(feature = "sqlite")] -#[cfg_attr(docsrs, doc(cfg(feature = "sqlite")))] -pub mod sqlite; #[cfg(feature = "mysql")] #[cfg_attr(docsrs, doc(cfg(feature = "mysql")))] pub mod mysql; #[cfg(feature = "postgres")] #[cfg_attr(docsrs, doc(cfg(feature = "postgres")))] pub mod postgres; +#[cfg(feature = "sqlite")] +#[cfg_attr(docsrs, doc(cfg(feature = "sqlite")))] +pub mod sqlite; -pub use crate::connection::{Connection, without_escape}; +pub use crate::connection::{without_escape, Connection}; pub use crate::error::{Error, ErrorLevel}; -pub use crate::row::{Row, Get, FromSql}; -pub use crate::parser::{html_special_chars, _sanitize_like, invalid_literal}; -pub use crate::wrapstring::{WrapString, IntoWrapString}; -pub use crate::value::{Value, ToValue}; +pub use crate::parser::{_sanitize_like, html_special_chars, invalid_literal}; +pub use crate::row::{FromSql, Get, Row}; +pub use crate::value::{ToValue, Value}; +pub use crate::wrapstring::{IntoWrapString, WrapString}; pub use concatsql_macro::query; pub mod prelude { //! Re-exports important traits and types. - #[cfg(feature = "sqlite")] - #[cfg_attr(docsrs, doc(cfg(feature = "sqlite")))] - pub use crate::sqlite; #[cfg(feature = "mysql")] #[cfg_attr(docsrs, doc(cfg(feature = "mysql")))] pub use crate::mysql; #[cfg(feature = "postgres")] #[cfg_attr(docsrs, doc(cfg(feature = "postgres")))] pub use crate::postgres; + #[cfg(feature = "sqlite")] + #[cfg_attr(docsrs, doc(cfg(feature = "sqlite")))] + pub use crate::sqlite; - pub use crate::connection::{Connection, without_escape}; - pub use crate::row::{Row, Get, FromSql}; - pub use crate::{sanitize_like, params}; + pub use crate::connection::{without_escape, Connection}; + pub use crate::row::{FromSql, Get, Row}; + pub use crate::value::{ToValue, Value}; pub use crate::wrapstring::WrapString; - pub use crate::value::{Value, ToValue}; + pub use crate::{params, sanitize_like}; pub use concatsql_macro::query; } @@ -120,12 +120,16 @@ pub type Result = std::result::Result; /// prep!("INSERT INTO msg VALUES (\"I'm cat.\")"); /// prep!("INSERT INTO msg VALUES (") + "I'm cat." + prep!(")"); /// ``` -#[deprecated(note="please use `query!` instead")] +#[deprecated(note = "please use `query!` instead")] #[allow(deprecated)] #[macro_export] macro_rules! prep { - () => { $crate::WrapString::null() }; - ($query:expr) => { $crate::WrapString::init($query) }; + () => { + $crate::WrapString::null() + }; + ($query:expr) => { + $crate::WrapString::init($query) + }; } /// Prepare a SQL statement for execution. @@ -165,7 +169,7 @@ macro_rules! prep { /// prep("INSERT INTO msg VALUES (") + "I'm cat." + prep(")"); /// ``` #[inline] -#[deprecated(note="please use `query!` instead")] +#[deprecated(note = "please use `query!` instead")] #[allow(deprecated)] pub fn prep(query: &'static str) -> WrapString { WrapString::init(query) @@ -188,4 +192,3 @@ macro_rules! params { &[ $(&$param as &dyn $crate::ToValue),+ ] as &[&dyn $crate::ToValue] }; } - diff --git a/concatsql/src/mysql/connection.rs b/concatsql/src/mysql/connection.rs index f3a70b7..6bab49e 100644 --- a/concatsql/src/mysql/connection.rs +++ b/concatsql/src/mysql/connection.rs @@ -1,16 +1,16 @@ extern crate mysql_sys as mysql; -use mysql::{Opts, Conn}; use mysql::prelude::*; +use mysql::{Conn, Opts}; -use std::cell::{Cell, RefCell}; use std::borrow::Cow; +use std::cell::{Cell, RefCell}; -use crate::Result; +use crate::connection::{ConcatsqlConn, ConnKind, Connection}; +use crate::error::{Error, ErrorLevel}; use crate::parser::to_hex; use crate::row::Row; -use crate::connection::{Connection, ConcatsqlConn, ConnKind}; -use crate::error::{Error, ErrorLevel}; -use crate::value::{Value, SystemTimeToString}; +use crate::value::{SystemTimeToString, Value}; +use crate::Result; /// Open a read-write connection to a new or existing database. pub fn open(url: &str) -> Result { @@ -25,29 +25,34 @@ pub fn open(url: &str) -> Result { }; Ok(Connection { - conn: Box::new(RefCell::new(conn)), + conn: Box::new(RefCell::new(conn)), error_level: Cell::new(ErrorLevel::default()), }) } macro_rules! to_mysql_value { - ($value:expr) => ( + ($value:expr) => { match $value { - Value::Null => mysql::Value::from(None as Option), - Value::I32(value) => mysql::Value::from(value), - Value::I64(value) => mysql::Value::from(value), - Value::F32(value) => mysql::Value::from(value), - Value::F64(value) => mysql::Value::from(value), - Value::Text(value) => mysql::Value::from(value.as_ref()), - Value::Bytes(value) => mysql::Value::from(value), + Value::Null => mysql::Value::from(None as Option), + Value::I32(value) => mysql::Value::from(value), + Value::I64(value) => mysql::Value::from(value), + Value::F32(value) => mysql::Value::from(value), + Value::F64(value) => mysql::Value::from(value), + Value::Text(value) => mysql::Value::from(value.as_ref()), + Value::Bytes(value) => mysql::Value::from(value), Value::IpAddr(value) => mysql::Value::from(value.to_string()), - Value::Time(value) => mysql::Value::from(value.to_string()), + Value::Time(value) => mysql::Value::from(value.to_string()), } - ); + }; } impl ConcatsqlConn for RefCell { - fn execute_inner<'a>(&self, query: Cow<'a, str>, params: &[Value<'a>], error_level: &ErrorLevel) -> Result<()> { + fn execute_inner<'a>( + &self, + query: Cow<'a, str>, + params: &[Value<'a>], + error_level: &ErrorLevel, + ) -> Result<()> { let mut conn = self.borrow_mut(); if params.is_empty() { match conn.query_drop(&query) { @@ -55,7 +60,10 @@ impl ConcatsqlConn for RefCell { Err(e) => Error::new(error_level, "exec error", &e), } } else { - let params = params.iter().map(|value| to_mysql_value!(value)).collect::>(); + let params = params + .iter() + .map(|value| to_mysql_value!(value)) + .collect::>(); match conn.exec_drop(&query, params) { Ok(_) => Ok(()), Err(e) => Error::new(error_level, "exec error", &e), @@ -63,9 +71,13 @@ impl ConcatsqlConn for RefCell { } } - fn iterate_inner<'a>(&self, query: Cow<'a, str>, params: &[Value<'a>], error_level: &ErrorLevel, - callback: &mut dyn FnMut(&[(&str, Option<&str>)]) -> bool) -> Result<()> - { + fn iterate_inner<'a>( + &self, + query: Cow<'a, str>, + params: &[Value<'a>], + error_level: &ErrorLevel, + callback: &mut dyn FnMut(&[(&str, Option<&str>)]) -> bool, + ) -> Result<()> { macro_rules! run { ($result:expr) => { while let Some(result_set) = $result.next_set() { @@ -73,7 +85,8 @@ impl ConcatsqlConn for RefCell { Ok(result_set) => result_set, Err(e) => return Error::new(error_level, "exec error", &e), }; - let mut pairs: Vec<(String, Option)> = Vec::with_capacity(result_set.affected_rows() as usize); + let mut pairs: Vec<(String, Option)> = + Vec::with_capacity(result_set.affected_rows() as usize); for row in result_set { let row = match row { @@ -87,7 +100,8 @@ impl ConcatsqlConn for RefCell { } } - let pairs: Vec<(&str, Option<&str>)> = pairs.iter().map(|p| (&*p.0, p.1.as_deref())).collect(); + let pairs: Vec<(&str, Option<&str>)> = + pairs.iter().map(|p| (&*p.0, p.1.as_deref())).collect(); if !pairs.is_empty() && !callback(&pairs) { return Error::new(error_level, "exec error", "query aborted"); } @@ -104,7 +118,10 @@ impl ConcatsqlConn for RefCell { }; run!(result); } else { - let params = params.iter().map(|value| to_mysql_value!(value)).collect::>(); + let params = params + .iter() + .map(|value| to_mysql_value!(value)) + .collect::>(); let mut result = match conn.exec_iter(&query, params) { Ok(result) => result, Err(e) => return Error::new(error_level, "exec error", &e), @@ -115,9 +132,12 @@ impl ConcatsqlConn for RefCell { Ok(()) } - fn rows_inner<'a, 'r>(&self, query: Cow<'a, str>, params: &[Value<'a>], error_level: &ErrorLevel) - -> Result>> - { + fn rows_inner<'a, 'r>( + &self, + query: Cow<'a, str>, + params: &[Value<'a>], + error_level: &ErrorLevel, + ) -> Result>> { let mut conn = self.borrow_mut(); macro_rules! run { @@ -125,7 +145,9 @@ impl ConcatsqlConn for RefCell { if let Some(result_set) = $result.next_set() { let result_set = match result_set { Ok(result_set) => result_set, - Err(e) => return Error::new(error_level, "exec error", &e).map(|_| Vec::new()), + Err(e) => { + return Error::new(error_level, "exec error", &e).map(|_| Vec::new()) + } }; let mut first_row = true; @@ -133,18 +155,28 @@ impl ConcatsqlConn for RefCell { for result_row in result_set { let result_row = match result_row { Ok(row) => row, - Err(e) => return Error::new(error_level, "exec error", &e).map(|_| Vec::new()), + Err(e) => { + return Error::new(error_level, "exec error", &e) + .map(|_| Vec::new()) + } }; let column_len = result_row.columns_ref().len(); if first_row { first_row = false; - let columns = result_row.columns_ref().iter().map(|col|col.name_str().to_string()).collect(); + let columns = result_row + .columns_ref() + .iter() + .map(|col| col.name_str().to_string()) + .collect(); let mut row = Row::new(columns); for index in 0..column_len { unsafe { - row.insert(&*(row.column(index) as *const str), result_row.get_to_string(index)); + row.insert( + &*(row.column(index) as *const str), + result_row.get_to_string(index), + ); } } $rows.push(row); @@ -152,7 +184,10 @@ impl ConcatsqlConn for RefCell { let mut row = Row::new($rows[0].columns()); for index in 0..column_len { unsafe { - row.insert(&*($rows[0].column(index) as *const str), result_row.get_to_string(index)); + row.insert( + &*($rows[0].column(index) as *const str), + result_row.get_to_string(index), + ); } } $rows.push(row); @@ -171,7 +206,10 @@ impl ConcatsqlConn for RefCell { }; run!(result, rows); } else { - let params = params.iter().map(|value| to_mysql_value!(value)).collect::>(); + let params = params + .iter() + .map(|value| to_mysql_value!(value)) + .collect::>(); let mut result = match conn.exec_iter(&query, params) { Ok(result) => result, Err(e) => return Error::new(error_level, "exec error", &e).map(|_| Vec::new()), @@ -198,37 +236,36 @@ trait GetToString { impl GetToString for mysql::Row { fn get_to_string(&self, index: usize) -> Option { match self[index] { - mysql::Value::NULL => None, - mysql::Value::Int(v) => Some(v.to_string()), - mysql::Value::UInt(v) => Some(v.to_string()), // unreachable ? - mysql::Value::Float(v) => Some(v.to_string()), // unreachable ? - mysql::Value::Double(v) => Some(v.to_string()), // unreachable ? + mysql::Value::NULL => None, + mysql::Value::Int(v) => Some(v.to_string()), + mysql::Value::UInt(v) => Some(v.to_string()), // unreachable ? + mysql::Value::Float(v) => Some(v.to_string()), // unreachable ? + mysql::Value::Double(v) => Some(v.to_string()), // unreachable ? mysql::Value::Bytes(ref bytes) => match String::from_utf8(bytes.to_vec()) { Ok(string) => Some(string), Err(_) => Some(to_hex(bytes)), - } + }, mysql::Value::Date(year, month, day, hour, minute, second, micros) => Some(format!( - "{:04}-{:02}-{:02} {:02}:{:02}:{:02}.{:06}", year, month, day, hour, minute, second, micros - )), // unreachable ? - mysql::Value::Time(neg, days, hours, minutes, seconds, micros) => { - Some(if neg { - format!( - "-{:03}:{:02}:{:02}.{:06}", - days * 24 + u32::from(hours), - minutes, - seconds, - micros - ) - } else { - format!( - "{:03}:{:02}:{:02}.{:06}", - days * 24 + u32::from(hours), - minutes, - seconds, - micros - ) - }) - } // unreachable ? + "{:04}-{:02}-{:02} {:02}:{:02}:{:02}.{:06}", + year, month, day, hour, minute, second, micros + )), // unreachable ? + mysql::Value::Time(neg, days, hours, minutes, seconds, micros) => Some(if neg { + format!( + "-{:03}:{:02}:{:02}.{:06}", + days * 24 + u32::from(hours), + minutes, + seconds, + micros + ) + } else { + format!( + "{:03}:{:02}:{:02}.{:06}", + days * 24 + u32::from(hours), + minutes, + seconds, + micros + ) + }), // unreachable ? } } } @@ -237,19 +274,22 @@ impl GetToString for mysql::Row { mod tests { use crate as concatsql; use concatsql::error::*; - #[cfg(debug_assertions)] - use concatsql::prep; + use concatsql::prelude::*; #[test] fn open() { assert!(crate::mysql::open("mysql://localhost:3306/test").is_ok()); assert_eq!( crate::mysql::open(""), - Err(Error::Message("failed to open: URL ParseError { relative URL without a base }".into())) + Err(Error::Message( + "failed to open: URL ParseError { relative URL without a base }".into() + )) ); assert_eq!( crate::mysql::open("foo\0bar"), - Err(Error::Message("failed to open: URL ParseError { relative URL without a base }".into())) + Err(Error::Message( + "failed to open: URL ParseError { relative URL without a base }".into() + )) ); } @@ -264,11 +304,11 @@ mod tests { fn execute() { let conn = crate::mysql::open("mysql://localhost:3306/test").unwrap(); assert_eq!( - conn.execute(prep!("\0")), + conn.execute(query!("\0")), Err(Error::Message("exec error".into())), ); assert_eq!( - conn.execute(prep!("invalid query")), + conn.execute(query!("invalid query")), Err(Error::Message("exec error".into())), ); assert!(conn.execute("SELECT 1").is_ok()); @@ -279,29 +319,82 @@ mod tests { fn iterate() { let conn = crate::mysql::open("mysql://localhost:3306/test").unwrap(); assert_eq!( - conn.iterate(prep!("\0"), |_| { unreachable!(); }), + conn.iterate(query!("\0"), |_| { + unreachable!(); + }), Err(Error::Message("exec error".into())), ); assert_eq!( - conn.iterate(prep!("invalid query"), |_| { unreachable!(); }), + conn.iterate(query!("invalid query"), |_| { + unreachable!(); + }), Err(Error::Message("exec error".into())), ); - assert!(conn.iterate("SELECT 1", |_|{true}).is_ok()); + assert!(conn.iterate("SELECT 1", |_| { true }).is_ok()); } #[test] fn get_to_string() { let conn = crate::mysql::open("mysql://localhost:3306/test").unwrap(); - #[cfg(debug_assertions)] conn.error_level(ErrorLevel::Debug); + #[cfg(debug_assertions)] + conn.error_level(ErrorLevel::Debug); conn.execute(" CREATE TEMPORARY TABLE test (bytes BLOB, i32 INT, f32 FLOAT, f64 DOUBLE, date DATE, time TIME, none INT); INSERT INTO test(bytes, i32, f32, f64, date, time) VALUES(X'ABCD', 1, 2, 3, '1900-01-01', '123:00:00'); ").unwrap(); - assert_eq!(conn.rows("SELECT bytes FROM test").unwrap().first().unwrap().get(0).unwrap(), "ABCD"); - assert_eq!(conn.rows("SELECT i32 FROM test").unwrap().first().unwrap().get(0).unwrap(), "1"); - assert_eq!(conn.rows("SELECT f32 FROM test").unwrap().first().unwrap().get(0).unwrap(), "2"); - assert_eq!(conn.rows("SELECT f64 FROM test").unwrap().first().unwrap().get(0).unwrap(), "3"); - assert_eq!(conn.rows("SELECT date FROM test").unwrap().first().unwrap().get(0).unwrap(), "1900-01-01"); - assert_eq!(conn.rows("SELECT time FROM test").unwrap().first().unwrap().get(0).unwrap(), "123:00:00"); + assert_eq!( + conn.rows("SELECT bytes FROM test") + .unwrap() + .first() + .unwrap() + .get(0) + .unwrap(), + "ABCD" + ); + assert_eq!( + conn.rows("SELECT i32 FROM test") + .unwrap() + .first() + .unwrap() + .get(0) + .unwrap(), + "1" + ); + assert_eq!( + conn.rows("SELECT f32 FROM test") + .unwrap() + .first() + .unwrap() + .get(0) + .unwrap(), + "2" + ); + assert_eq!( + conn.rows("SELECT f64 FROM test") + .unwrap() + .first() + .unwrap() + .get(0) + .unwrap(), + "3" + ); + assert_eq!( + conn.rows("SELECT date FROM test") + .unwrap() + .first() + .unwrap() + .get(0) + .unwrap(), + "1900-01-01" + ); + assert_eq!( + conn.rows("SELECT time FROM test") + .unwrap() + .first() + .unwrap() + .get(0) + .unwrap(), + "123:00:00" + ); } } diff --git a/concatsql/src/mysql/mod.rs b/concatsql/src/mysql/mod.rs index 35d4c13..c659488 100644 --- a/concatsql/src/mysql/mod.rs +++ b/concatsql/src/mysql/mod.rs @@ -2,8 +2,8 @@ pub(crate) mod connection; -use crate::Result; use crate::connection::Connection; +use crate::Result; /// Open a read-write connection to a new or existing database. /// @@ -20,4 +20,3 @@ use crate::connection::Connection; pub fn open(url: &str) -> Result { connection::open(url) } - diff --git a/concatsql/src/parser.rs b/concatsql/src/parser.rs index 0edf1ce..7dd6bed 100644 --- a/concatsql/src/parser.rs +++ b/concatsql/src/parser.rs @@ -21,11 +21,11 @@ pub fn html_special_chars(input: &str) -> String { for c in input.chars() { match c { '\'' => escaped.push_str("'"), - '"' => escaped.push_str("""), - '&' => escaped.push_str("&"), - '<' => escaped.push_str("<"), - '>' => escaped.push_str(">"), - c => escaped.push(c), + '"' => escaped.push_str("""), + '&' => escaped.push_str("&"), + '<' => escaped.push_str("<"), + '>' => escaped.push_str(">"), + c => escaped.push(c), } } escaped @@ -55,8 +55,12 @@ pub fn html_special_chars(input: &str) -> String { /// ``` #[macro_export] macro_rules! sanitize_like { - ($pattern:tt) => { $crate::_sanitize_like($pattern, '\\') }; - ($pattern:tt, $escape:tt) => { $crate::_sanitize_like($pattern, $escape) }; + ($pattern:tt) => { + $crate::_sanitize_like($pattern, '\\') + }; + ($pattern:tt, $escape:tt) => { + $crate::_sanitize_like($pattern, $escape) + }; } #[doc(hidden)] pub fn _sanitize_like(pattern: T, escape_character: char) -> String { @@ -74,9 +78,13 @@ pub(crate) fn escape_string(s: &str) -> String { let mut escaped = String::new(); escaped.push('\''); for c in s.chars() { - if c == '\'' { escaped.push('\''); } + if c == '\'' { + escaped.push('\''); + } #[cfg(any(feature = "mysql", feature = "postgres"))] - if c == '\\' { escaped.push('\\'); } + if c == '\\' { + escaped.push('\\'); + } escaped.push(c); } escaped.push('\''); @@ -89,7 +97,10 @@ pub(crate) fn to_hex(bytes: &[u8]) -> String { static ref LUT: Vec = (0u8..=255).map(|n| format!("{:02X}", n)).collect(); } - bytes.iter().map(|&n| LUT.get(n as usize).unwrap().to_owned()).collect::() + bytes + .iter() + .map(|&n| LUT.get(n as usize).unwrap().to_owned()) + .collect::() } pub(crate) fn to_binary_literal(bytes: &[u8]) -> String { @@ -125,7 +136,7 @@ mod tests { #[cfg(feature = "sqlite")] #[cfg(not(all(feature = "sqlite", feature = "mysql", feature = "postgres")))] fn escape_string() { - assert_eq!(super::escape_string("O'Reilly"), "'O''Reilly'"); + assert_eq!(super::escape_string("O'Reilly"), "'O''Reilly'"); assert_eq!(super::escape_string("O\\'Reilly"), "'O\\''Reilly'"); } @@ -133,7 +144,7 @@ mod tests { #[cfg(feature = "mysql")] #[cfg(not(all(feature = "sqlite", feature = "mysql", feature = "postgres")))] fn escape_string() { - assert_eq!(super::escape_string("O'Reilly"), "'O''Reilly'"); + assert_eq!(super::escape_string("O'Reilly"), "'O''Reilly'"); assert_eq!(super::escape_string("O\\'Reilly"), "'O\\\\''Reilly'"); } @@ -141,7 +152,7 @@ mod tests { #[cfg(feature = "postgres")] #[cfg(not(all(feature = "sqlite", feature = "mysql", feature = "postgres")))] fn escape_string() { - assert_eq!(super::escape_string("O'Reilly"), "'O''Reilly'"); + assert_eq!(super::escape_string("O'Reilly"), "'O''Reilly'"); assert_eq!(super::escape_string("O\\'Reilly"), "'O\\\\''Reilly'"); } } diff --git a/concatsql/src/postgres/connection.rs b/concatsql/src/postgres/connection.rs index 32751ab..f4369bb 100644 --- a/concatsql/src/postgres/connection.rs +++ b/concatsql/src/postgres/connection.rs @@ -3,15 +3,15 @@ extern crate postgres_sys as postgres; use postgres::{Client, NoTls}; use uuid::Uuid; +use std::borrow::Cow; use std::cell::{Cell, RefCell}; use std::time::SystemTime; -use std::borrow::Cow; -use crate::Result; -use crate::row::Row; -use crate::connection::{Connection, ConcatsqlConn, ConnKind}; +use crate::connection::{ConcatsqlConn, ConnKind, Connection}; use crate::error::{Error, ErrorLevel}; -use crate::value::{Value, SystemTimeToString}; +use crate::row::Row; +use crate::value::{SystemTimeToString, Value}; +use crate::Result; /// Open a read-write connection to a new or existing database. pub fn open(params: &str) -> Result { @@ -21,36 +21,44 @@ pub fn open(params: &str) -> Result { }; Ok(Connection { - conn: Box::new(RefCell::new(conn)), + conn: Box::new(RefCell::new(conn)), error_level: Cell::new(ErrorLevel::default()), }) } macro_rules! to_sql { - ($value:expr) => ( + ($value:expr) => { match $value { - Value::Null => &"NULL" as &(dyn postgres::types::ToSql + Sync), - Value::I32(value) => value, - Value::I64(value) => value, - Value::F32(value) => value, - Value::F64(value) => value, - Value::Text(value) => value, - Value::Bytes(value) => value, + Value::Null => &"NULL" as &(dyn postgres::types::ToSql + Sync), + Value::I32(value) => value, + Value::I64(value) => value, + Value::F32(value) => value, + Value::F64(value) => value, + Value::Text(value) => value, + Value::Bytes(value) => value, Value::IpAddr(value) => value, - Value::Time(value) => value, + Value::Time(value) => value, } - ); + }; } impl ConcatsqlConn for RefCell { - fn execute_inner<'a>(&self, query: Cow<'a, str>, params: &[Value<'a>], error_level: &ErrorLevel) -> Result<()> { + fn execute_inner<'a>( + &self, + query: Cow<'a, str>, + params: &[Value<'a>], + error_level: &ErrorLevel, + ) -> Result<()> { if params.is_empty() { match self.borrow_mut().batch_execute(&query) { Ok(_) => Ok(()), Err(e) => Error::new(error_level, "exec error", &e), } } else { - let params = params.iter().map(|value| to_sql!(value)).collect::>(); + let params = params + .iter() + .map(|value| to_sql!(value)) + .collect::>(); match self.borrow_mut().execute(&query as &str, ¶ms[..]) { Ok(_) => Ok(()), Err(e) => Error::new(error_level, "exec error", &e), @@ -58,10 +66,17 @@ impl ConcatsqlConn for RefCell { } } - fn iterate_inner<'a>(&self, query: Cow<'a, str>, params: &[Value<'a>], error_level: &ErrorLevel, - callback: &mut dyn FnMut(&[(&str, Option<&str>)]) -> bool) -> Result<()> - { - let params = params.iter().map(|value| to_sql!(value)).collect::>(); + fn iterate_inner<'a>( + &self, + query: Cow<'a, str>, + params: &[Value<'a>], + error_level: &ErrorLevel, + callback: &mut dyn FnMut(&[(&str, Option<&str>)]) -> bool, + ) -> Result<()> { + let params = params + .iter() + .map(|value| to_sql!(value)) + .collect::>(); let rows = match self.borrow_mut().query(&query as &str, ¶ms[..]) { Ok(result) => result, Err(e) => return Error::new(error_level, "exec error", &e), @@ -75,7 +90,8 @@ impl ConcatsqlConn for RefCell { } } - let pairs: Vec<(&str, Option<&str>)> = pairs.iter().map(|p| (&*p.0, p.1.as_deref())).collect(); + let pairs: Vec<(&str, Option<&str>)> = + pairs.iter().map(|p| (&*p.0, p.1.as_deref())).collect(); if !pairs.is_empty() && !callback(&pairs) { return Error::new(error_level, "exec error", "query aborted"); } @@ -83,10 +99,16 @@ impl ConcatsqlConn for RefCell { Ok(()) } - fn rows_inner<'a, 'r>(&self, query: Cow<'a, str>, params: &[Value<'a>], error_level: &ErrorLevel) - -> Result>> - { - let params = params.iter().map(|value| to_sql!(value)).collect::>(); + fn rows_inner<'a, 'r>( + &self, + query: Cow<'a, str>, + params: &[Value<'a>], + error_level: &ErrorLevel, + ) -> Result>> { + let params = params + .iter() + .map(|value| to_sql!(value)) + .collect::>(); let result = match self.borrow_mut().query(&query as &str, ¶ms[..]) { Ok(result) => result, Err(e) => return Error::new(error_level, "exec error", &e).map(|_| Vec::new()), @@ -96,10 +118,19 @@ impl ConcatsqlConn for RefCell { // First row if let Some(first_row) = result.first() { let column_len = first_row.columns().len(); - let columns = first_row.columns().iter().map(|col|col.name().to_string()).collect(); + let columns = first_row + .columns() + .iter() + .map(|col| col.name().to_string()) + .collect(); let mut row = Row::new(columns); for index in 0..column_len { - unsafe { row.insert(&*(row.column(index) as *const str), first_row.get_to_string(index)); } + unsafe { + row.insert( + &*(row.column(index) as *const str), + first_row.get_to_string(index), + ); + } } rows.push(row); } @@ -109,7 +140,12 @@ impl ConcatsqlConn for RefCell { let column_len = result_row.columns().len(); let mut row = Row::new(rows[0].columns()); for index in 0..column_len { - unsafe { row.insert(&*(rows[0].column(index) as *const str), result_row.get_to_string(index)); } + unsafe { + row.insert( + &*(rows[0].column(index) as *const str), + result_row.get_to_string(index), + ); + } } rows.push(row); } @@ -164,24 +200,26 @@ impl GetToString for postgres::row::Row { } } - #[cfg(test)] mod tests { use crate as concatsql; use concatsql::error::*; - #[cfg(debug_assertions)] - use concatsql::prep; + use concatsql::prelude::*; #[test] fn open() { assert!(crate::postgres::open("postgresql://postgres:postgres@localhost").is_ok()); assert_eq!( crate::postgres::open(""), - Err(Error::Message("failed to open: invalid configuration: host missing".into())) + Err(Error::Message( + "failed to open: invalid configuration: host missing".into() + )) ); assert_eq!( crate::postgres::open("foo\0bar"), - Err(Error::Message("failed to open: invalid connection string: unexpected EOF".into())) + Err(Error::Message( + "failed to open: invalid connection string: unexpected EOF".into() + )) ); } @@ -196,11 +234,11 @@ mod tests { fn execute() { let conn = crate::postgres::open("postgresql://postgres:postgres@localhost").unwrap(); assert_eq!( - conn.execute(prep!("\0")), + conn.execute(query!("\0")), Err(Error::Message("exec error".into())), ); assert_eq!( - conn.execute(prep!("invalid query")), + conn.execute(query!("invalid query")), Err(Error::Message("exec error".into())), ); assert!(conn.execute("SELECT 1").is_ok()); @@ -211,13 +249,17 @@ mod tests { fn iterate() { let conn = crate::postgres::open("postgresql://postgres:postgres@localhost").unwrap(); assert_eq!( - conn.iterate(prep!("\0"), |_| { unreachable!(); }), + conn.iterate(query!("\0"), |_| { + unreachable!(); + }), Err(Error::Message("exec error".into())), ); assert_eq!( - conn.iterate(prep!("invalid query"), |_| { unreachable!(); }), + conn.iterate(query!("invalid query"), |_| { + unreachable!(); + }), Err(Error::Message("exec error".into())), ); - assert!(conn.iterate("SELECT 1", |_|{true}).is_ok()); + assert!(conn.iterate("SELECT 1", |_| { true }).is_ok()); } } diff --git a/concatsql/src/postgres/mod.rs b/concatsql/src/postgres/mod.rs index 5dca6d7..7534bd8 100644 --- a/concatsql/src/postgres/mod.rs +++ b/concatsql/src/postgres/mod.rs @@ -2,8 +2,8 @@ pub(crate) mod connection; -use crate::Result; use crate::connection::Connection; +use crate::Result; /// Open a read-write connection to a new or existing database. /// @@ -19,4 +19,3 @@ use crate::connection::Connection; pub fn open(params: &str) -> Result { connection::open(params) } - diff --git a/concatsql/src/row.rs b/concatsql/src/row.rs index 221a77b..f57817f 100644 --- a/concatsql/src/row.rs +++ b/concatsql/src/row.rs @@ -1,8 +1,8 @@ use std::str::FromStr; use std::sync::Arc; -use indexmap::map::IndexMap; use crate::error::Error; +use indexmap::map::IndexMap; type IndexMapPairs<'a> = IndexMap<&'a str, Option>; @@ -10,7 +10,7 @@ type IndexMapPairs<'a> = IndexMap<&'a str, Option>; #[derive(Debug, PartialEq)] pub struct Row<'a> { columns: Arc<[String]>, - pairs: IndexMapPairs<'a>, + pairs: IndexMapPairs<'a>, } impl<'a> Row<'a> { @@ -90,7 +90,7 @@ impl<'a> Row<'a> { /// Get the column name. #[inline] - pub fn column_name(&self, key: T) -> Option<&str> { + pub fn column_name(&self, key: T) -> Option<&str> { key.get_key(&self.pairs) } @@ -103,10 +103,7 @@ impl<'a> Row<'a> { #[inline] pub fn iter(&self) -> RowIter { - RowIter { - row: self, - now: 0, - } + RowIter { row: self, now: 0 } } } @@ -128,7 +125,7 @@ impl<'a> Iterator for RowIter<'a> { fn next(&mut self) -> Option<&'a str> { if self.now < self.row.column_count() { self.now += 1; - self.row.get(self.now-1) + self.row.get(self.now - 1) } else { None } @@ -156,7 +153,13 @@ impl Get for str { } fn get_into(&self, pairs: &IndexMapPairs) -> Result { - U::from_sql(pairs.get(self).ok_or(Error::ColumnNotFound)?.as_deref().unwrap_or("")) + U::from_sql( + pairs + .get(self) + .ok_or(Error::ColumnNotFound)? + .as_deref() + .unwrap_or(""), + ) } fn get_key<'a>(&self, pairs: &'a IndexMapPairs) -> Option<&'a str> { @@ -170,7 +173,13 @@ impl Get for String { } fn get_into(&self, pairs: &IndexMapPairs) -> Result { - U::from_sql(pairs.get(&**self).ok_or(Error::ColumnNotFound)?.as_deref().unwrap_or("")) + U::from_sql( + pairs + .get(&**self) + .ok_or(Error::ColumnNotFound)? + .as_deref() + .unwrap_or(""), + ) } fn get_key<'a>(&self, pairs: &'a IndexMapPairs) -> Option<&'a str> { @@ -184,7 +193,14 @@ impl Get for usize { } fn get_into(&self, pairs: &IndexMapPairs) -> Result { - U::from_sql(pairs.get_index(*self).ok_or(Error::ColumnNotFound)?.1.as_deref().unwrap_or("")) + U::from_sql( + pairs + .get_index(*self) + .ok_or(Error::ColumnNotFound)? + .1 + .as_deref() + .unwrap_or(""), + ) } fn get_key<'a>(&self, pairs: &'a IndexMapPairs) -> Option<&'a str> { @@ -192,7 +208,10 @@ impl Get for usize { } } -impl<'b, T> Get for &'b T where T: Get + ?Sized { +impl<'b, T> Get for &'b T +where + T: Get + ?Sized, +{ fn get<'a>(&self, pairs: &'a IndexMapPairs) -> Option<&'a str> { T::get(self, pairs) } @@ -256,8 +275,9 @@ impl FromSql for Vec { fn from_sql(s: &str) -> Result { (0..s.len()) .step_by(2) - .map(|i| u8::from_str_radix(&s[i..i+2], 16).map_err(|_|())) - .collect::, ()>>().map_err(|_|Error::ParseError) + .map(|i| u8::from_str_radix(&s[i..i + 2], 16).map_err(|_| ())) + .collect::, ()>>() + .map_err(|_| Error::ParseError) } } @@ -270,11 +290,14 @@ mod tests { #[cfg(feature = "sqlite")] fn column_names() { let conn = crate::sqlite::open(":memory:").unwrap(); - conn.execute(r#" + conn.execute( + r#" CREATE TABLE users (name TEXT, age INTEGER); INSERT INTO users (name, age) VALUES ('Alice', 42); INSERT INTO users (name, age) VALUES ('Bob', 69); - "#).unwrap(); + "#, + ) + .unwrap(); for row in conn.rows("SELECT * FROM users").unwrap() { assert_eq!(row.column_names(), ["name", "age"]); @@ -284,7 +307,12 @@ mod tests { #[test] #[allow(clippy::needless_borrow, clippy::needless_borrows_for_generic_args)] fn row() { - let mut row = Row::new(["key1","key2","key3","ABC"].iter().map(ToString::to_string).collect()); + let mut row = Row::new( + ["key1", "key2", "key3", "ABC"] + .iter() + .map(ToString::to_string) + .collect(), + ); row.insert("key1", Some("value".to_string())); row.insert("key2", None); row.insert("key3", Some("42".to_string())); @@ -300,7 +328,10 @@ mod tests { assert_eq!(row.get(2), Some("42")); assert_eq!(row.get(3), None); - assert_eq!(row.get_into::<&str, String>("key1"), Ok(String::from("value"))); + assert_eq!( + row.get_into::<&str, String>("key1"), + Ok(String::from("value")) + ); assert_eq!(row.get_into::<&str, i32>("key3"), Ok(42)); assert_eq!(row.get_into::<&str, usize>("key3"), Ok(42)); assert_eq!(row.get_into("key3"), Ok(42)); @@ -362,26 +393,29 @@ mod tests { assert_eq!(row.get(&&String::from("key1")), Some("value")); row.insert("ABC", Some("414243".to_string())); - assert_eq!(row.get_into::<_, Vec>("ABC"), Ok(vec![b'A',b'B',b'C'])); + assert_eq!( + row.get_into::<_, Vec>("ABC"), + Ok(vec![b'A', b'B', b'C']) + ); assert!(row.get_into::<_, i8>("ABC").is_err()); assert!(row.get_into::<_, u8>("ABC").is_err()); assert!(row.get_into::<_, i16>("ABC").is_err()); assert!(row.get_into::<_, u16>("ABC").is_err()); - assert_eq!(row.get_into::<_, i32>("ABC"), Ok(414243)); - assert_eq!(row.get_into::<_, u32>("ABC"), Ok(414243)); - assert_eq!(row.get_into::<_, i64>("ABC"), Ok(414243)); - assert_eq!(row.get_into::<_, u64>("ABC"), Ok(414243)); - assert_eq!(row.get_into::<_, i128>("ABC"), Ok(414243)); - assert_eq!(row.get_into::<_, u128>("ABC"), Ok(414243)); + assert_eq!(row.get_into::<_, i32>("ABC"), Ok(414243)); + assert_eq!(row.get_into::<_, u32>("ABC"), Ok(414243)); + assert_eq!(row.get_into::<_, i64>("ABC"), Ok(414243)); + assert_eq!(row.get_into::<_, u64>("ABC"), Ok(414243)); + assert_eq!(row.get_into::<_, i128>("ABC"), Ok(414243)); + assert_eq!(row.get_into::<_, u128>("ABC"), Ok(414243)); assert_eq!(row.get_into::<_, isize>("ABC"), Ok(414243)); assert_eq!(row.get_into::<_, usize>("ABC"), Ok(414243)); assert_eq!(row.get_into::<_, u8>("ABC"), Err(Error::ParseError)); assert_eq!(row.get_into::<_, u8>("def"), Err(Error::ColumnNotFound)); - assert_eq!(row.column_name(0), Some("key1")); - assert_eq!(row.column_name(99), None); - assert_eq!(row.column_name("key1"), Some("key1")); + assert_eq!(row.column_name(0), Some("key1")); + assert_eq!(row.column_name(99), None); + assert_eq!(row.column_name("key1"), Some("key1")); assert_eq!(row.column_name("key99"), None); } @@ -389,14 +423,20 @@ mod tests { #[cfg(feature = "sqlite")] fn iter() { let conn = crate::sqlite::open(":memory:").unwrap(); - conn.execute(r#" + conn.execute( + r#" CREATE TABLE users (name TEXT, age INTEGER); INSERT INTO users (name, age) VALUES ('Alice', 42); INSERT INTO users (name, age) VALUES ('Bob', 69); - "#).unwrap(); + "#, + ) + .unwrap(); let mut cnt = 0; - for row in conn.rows("SELECT * FROM users WHERE name = 'Alice'").unwrap() { + for row in conn + .rows("SELECT * FROM users WHERE name = 'Alice'") + .unwrap() + { for (index, value) in row.iter().enumerate() { cnt += 1; assert_eq!(value, ["Alice", "42"][index]); @@ -409,14 +449,20 @@ mod tests { #[cfg(feature = "sqlite")] fn into_iter() { let conn = crate::sqlite::open(":memory:").unwrap(); - conn.execute(r#" + conn.execute( + r#" CREATE TABLE users (name TEXT, age INTEGER); INSERT INTO users (name, age) VALUES ('Alice', 42); INSERT INTO users (name, age) VALUES ('Bob', 69); - "#).unwrap(); + "#, + ) + .unwrap(); let mut cnt = 0; - for row in conn.rows("SELECT * FROM users WHERE name = 'Alice'").unwrap() { + for row in conn + .rows("SELECT * FROM users WHERE name = 'Alice'") + .unwrap() + { for value in &row { assert_eq!(value, ["Alice", "42"][cnt]); cnt += 1; @@ -429,25 +475,30 @@ mod tests { #[cfg(feature = "sqlite")] fn index() { let conn = crate::sqlite::open(":memory:").unwrap(); - conn.execute(r#" + conn.execute( + r#" CREATE TABLE users (name TEXT, age INTEGER); INSERT INTO users (name, age) VALUES ('Alice', 42); INSERT INTO users (name, age) VALUES ('Bob', 69); - "#).unwrap(); + "#, + ) + .unwrap(); let mut cnt = 0; - for row in conn.rows("SELECT * FROM users WHERE name = 'Alice'").unwrap() { + for row in conn + .rows("SELECT * FROM users WHERE name = 'Alice'") + .unwrap() + { cnt += 1; assert_eq!(&row[0], "Alice"); assert_eq!(&row[1], "42"); assert_eq!(&row["name"], "Alice"); - assert_eq!(&row["age"], "42"); + assert_eq!(&row["age"], "42"); assert_eq!(row[0], *"Alice"); assert_eq!(row[1], *"42"); assert_eq!(row["name"], *"Alice"); - assert_eq!(row["age"], *"42"); + assert_eq!(row["age"], *"42"); } assert_eq!(cnt, 1); } } - diff --git a/concatsql/src/sqlite/connection.rs b/concatsql/src/sqlite/connection.rs index 0a90461..1cfef0c 100644 --- a/concatsql/src/sqlite/connection.rs +++ b/concatsql/src/sqlite/connection.rs @@ -1,52 +1,57 @@ extern crate sqlite3_sys as ffi; -use std::ffi::{CStr, CString, c_void}; -use std::ptr::{self, NonNull}; -use std::path::Path; -use std::cell::Cell; use std::borrow::Cow; +use std::cell::Cell; +use std::ffi::{c_void, CStr, CString}; +use std::path::Path; +use std::ptr::{self, NonNull}; -use crate::Result; -use crate::row::Row; -use crate::connection::{Connection, ConcatsqlConn, ConnKind}; +use crate::connection::{ConcatsqlConn, ConnKind, Connection}; use crate::error::{Error, ErrorLevel}; -use crate::value::{Value, SystemTimeToString}; +use crate::row::Row; +use crate::value::{SystemTimeToString, Value}; +use crate::Result; /// Open a read-write connection to a new or existing database. pub fn open>(path: T, openflags: i32) -> Result { let path = match path.as_ref().to_str() { - Some(path) => { - match CString::new(path) { - Ok(string) => string, - _ => return Err(Error::Message(format!("invalid path: {}", path))), - } + Some(path) => match CString::new(path) { + Ok(string) => string, + _ => return Err(Error::Message(format!("invalid path: {}", path))), }, - _ => return Err(Error::Message(format!("failed to open path: {:?}", path.as_ref()))), + _ => { + return Err(Error::Message(format!( + "failed to open path: {:?}", + path.as_ref() + ))) + } }; let mut conn_ptr = ptr::null_mut(); - let open_result = unsafe { ffi::sqlite3_open_v2( - path.as_ptr(), - &mut conn_ptr, - openflags, - ptr::null()) - }; + let open_result = + unsafe { ffi::sqlite3_open_v2(path.as_ptr(), &mut conn_ptr, openflags, ptr::null()) }; match open_result { - ffi::SQLITE_OK => - Ok(Connection { - conn: Box::new(unsafe { NonNull::new_unchecked(conn_ptr) }), - error_level: Cell::new(ErrorLevel::default()), - }), + ffi::SQLITE_OK => Ok(Connection { + conn: Box::new(unsafe { NonNull::new_unchecked(conn_ptr) }), + error_level: Cell::new(ErrorLevel::default()), + }), _ => { - unsafe { ffi::sqlite3_close(conn_ptr); } + unsafe { + ffi::sqlite3_close(conn_ptr); + } Err(Error::Message("failed to connect".into())) } } } impl ConcatsqlConn for NonNull { - fn execute_inner<'a>(&self, query: Cow<'a, str>, params: &[Value<'a>], error_level: &ErrorLevel) -> Result<()> { + fn execute_inner<'a>( + &self, + query: Cow<'a, str>, + params: &[Value<'a>], + error_level: &ErrorLevel, + ) -> Result<()> { let query = match CString::new(query.as_bytes()) { Ok(string) => string, _ => return Error::new(error_level, "invalid query", query), @@ -60,8 +65,8 @@ impl ConcatsqlConn for NonNull { ffi::sqlite3_exec( self.as_ptr(), query.as_ptr(), - None, // callback fn - ptr::null_mut(), // callback arg + None, // callback fn + ptr::null_mut(), // callback arg &mut errmsg, ); } @@ -72,8 +77,11 @@ impl ConcatsqlConn for NonNull { unsafe { ffi::sqlite3_finalize(stmt); ffi::sqlite3_free(errmsg as *mut _); - return Error::new(error_level, "exec error", - CStr::from_ptr(ffi::sqlite3_errmsg(self.as_ptr())).to_string_lossy()); + return Error::new( + error_level, + "exec error", + CStr::from_ptr(ffi::sqlite3_errmsg(self.as_ptr())).to_string_lossy(), + ); } } } @@ -89,8 +97,11 @@ impl ConcatsqlConn for NonNull { if result != ffi::SQLITE_OK { ffi::sqlite3_finalize(stmt); - return Error::new(error_level, "exec error", - CStr::from_ptr(ffi::sqlite3_errmsg(self.as_ptr())).to_string_lossy()); + return Error::new( + error_level, + "exec error", + CStr::from_ptr(ffi::sqlite3_errmsg(self.as_ptr())).to_string_lossy(), + ); } bind_all(stmt, params, error_level)?; @@ -98,11 +109,14 @@ impl ConcatsqlConn for NonNull { loop { match ffi::sqlite3_step(stmt) { ffi::SQLITE_DONE => break, - ffi::SQLITE_ROW => (), // Do nothing + ffi::SQLITE_ROW => (), // Do nothing _ => { ffi::sqlite3_finalize(stmt); - return Error::new(error_level, "exec error", - CStr::from_ptr(ffi::sqlite3_errmsg(self.as_ptr())).to_string_lossy()); + return Error::new( + error_level, + "exec error", + CStr::from_ptr(ffi::sqlite3_errmsg(self.as_ptr())).to_string_lossy(), + ); } } } @@ -112,9 +126,13 @@ impl ConcatsqlConn for NonNull { } } - fn iterate_inner<'a>(&self, query: Cow<'a, str>, params: &[Value<'a>], error_level: &ErrorLevel, - callback: &mut dyn FnMut(&[(&str, Option<&str>)]) -> bool) -> Result<()> - { + fn iterate_inner<'a>( + &self, + query: Cow<'a, str>, + params: &[Value<'a>], + error_level: &ErrorLevel, + callback: &mut dyn FnMut(&[(&str, Option<&str>)]) -> bool, + ) -> Result<()> { let query = match CString::new(query.as_bytes()) { Ok(string) => string, _ => return Error::new(error_level, "invalid query", query), @@ -132,8 +150,11 @@ impl ConcatsqlConn for NonNull { if result != ffi::SQLITE_OK { ffi::sqlite3_finalize(stmt); - return Error::new(error_level, "exec error", - CStr::from_ptr(ffi::sqlite3_errmsg(self.as_ptr())).to_string_lossy()); + return Error::new( + error_level, + "exec error", + CStr::from_ptr(ffi::sqlite3_errmsg(self.as_ptr())).to_string_lossy(), + ); } bind_all(stmt, params, error_level)?; @@ -146,15 +167,19 @@ impl ConcatsqlConn for NonNull { ffi::SQLITE_ROW => { let mut pairs = Vec::with_capacity(column_count as usize); pairs.storing(stmt, column_count); - let pairs: Vec<(&str, Option<&str>)> = pairs.iter().map(|p| (p.0, p.1.as_deref())).collect(); + let pairs: Vec<(&str, Option<&str>)> = + pairs.iter().map(|p| (p.0, p.1.as_deref())).collect(); if !callback(&pairs) { break; } } _ => { ffi::sqlite3_finalize(stmt); - return Error::new(error_level, "exec error", - CStr::from_ptr(ffi::sqlite3_errmsg(self.as_ptr())).to_string_lossy()); + return Error::new( + error_level, + "exec error", + CStr::from_ptr(ffi::sqlite3_errmsg(self.as_ptr())).to_string_lossy(), + ); } } } @@ -164,9 +189,12 @@ impl ConcatsqlConn for NonNull { } } - fn rows_inner<'a, 'r>(&self, query: Cow<'a, str>, params: &[Value<'a>], error_level: &ErrorLevel) - -> Result>> - { + fn rows_inner<'a, 'r>( + &self, + query: Cow<'a, str>, + params: &[Value<'a>], + error_level: &ErrorLevel, + ) -> Result>> { let mut rows: Vec = Vec::new(); let query = match CString::new(query.as_bytes()) { Ok(string) => string, @@ -185,9 +213,12 @@ impl ConcatsqlConn for NonNull { if result != ffi::SQLITE_OK { ffi::sqlite3_finalize(stmt); - return Error::new(error_level, "exec error", - CStr::from_ptr(ffi::sqlite3_errmsg(self.as_ptr())).to_string_lossy()) - .map(|_| Vec::new()); + return Error::new( + error_level, + "exec error", + CStr::from_ptr(ffi::sqlite3_errmsg(self.as_ptr())).to_string_lossy(), + ) + .map(|_| Vec::new()); } bind_all(stmt, params, error_level)?; @@ -203,7 +234,7 @@ impl ConcatsqlConn for NonNull { ffi::SQLITE_ROW => { let mut pairs = Vec::with_capacity(column_count as usize); pairs.storing(stmt, column_count); - let columns = pairs.iter().map(|(column, _)|column.to_string()).collect(); + let columns = pairs.iter().map(|(column, _)| column.to_string()).collect(); let mut row = Row::new(columns); for (index, (_, value)) in pairs.into_iter().enumerate() { row.insert(&*(row.column(index) as *const str), value); @@ -212,9 +243,12 @@ impl ConcatsqlConn for NonNull { } _ => { ffi::sqlite3_finalize(stmt); - return Error::new(error_level, "exec error", - CStr::from_ptr(ffi::sqlite3_errmsg(self.as_ptr())).to_string_lossy()) - .map(|_| Vec::new()); + return Error::new( + error_level, + "exec error", + CStr::from_ptr(ffi::sqlite3_errmsg(self.as_ptr())).to_string_lossy(), + ) + .map(|_| Vec::new()); } } @@ -233,9 +267,12 @@ impl ConcatsqlConn for NonNull { } _ => { ffi::sqlite3_finalize(stmt); - return Error::new(error_level, "exec error", - CStr::from_ptr(ffi::sqlite3_errmsg(self.as_ptr())).to_string_lossy()) - .map(|_| Vec::new()); + return Error::new( + error_level, + "exec error", + CStr::from_ptr(ffi::sqlite3_errmsg(self.as_ptr())).to_string_lossy(), + ) + .map(|_| Vec::new()); } } } @@ -247,7 +284,11 @@ impl ConcatsqlConn for NonNull { fn close(&self) { unsafe { - ffi::sqlite3_busy_handler(self.as_ptr() as *const _ as *mut ffi::sqlite3, None, std::ptr::null_mut()); + ffi::sqlite3_busy_handler( + self.as_ptr() as *const _ as *mut ffi::sqlite3, + None, + std::ptr::null_mut(), + ); let close_result = ffi::sqlite3_close(self.as_ptr() as *const _ as *mut ffi::sqlite3); std::ptr::drop_in_place(self.as_ptr() as *const _ as *mut ffi::sqlite3); if close_result != ffi::SQLITE_OK { @@ -294,42 +335,38 @@ impl Storing for Vec<(&str, Option)> { } } -unsafe fn bind_all(stmt: *mut ffi::sqlite3_stmt, params: &[Value<'_>], error_level: &ErrorLevel) -> Result<()> { +unsafe fn bind_all( + stmt: *mut ffi::sqlite3_stmt, + params: &[Value<'_>], + error_level: &ErrorLevel, +) -> Result<()> { for (index, param) in (1i32..).zip(params.iter()) { let result = match param { - Value::Null => { - ffi::sqlite3_bind_null(stmt, index) - } - Value::I32(value) => { - ffi::sqlite3_bind_int(stmt, index, *value) - } - Value::I64(value) => { - ffi::sqlite3_bind_int64(stmt, index, *value) - } - Value::F32(value) => { - ffi::sqlite3_bind_double(stmt, index, *value as f64) - } - Value::F64(value) => { - ffi::sqlite3_bind_double(stmt, index, *value) - } - Value::Text(value) => { - ffi::sqlite3_bind_text( - stmt, - index, - value.as_ptr() as *const _, - value.len() as i32, - Some(std::mem::transmute::<*const c_void, extern "C" fn(*mut c_void)>(ffi::SQLITE_TRANSIENT as *const c_void)), - ) - } - Value::Bytes(value) => { - ffi::sqlite3_bind_blob( - stmt, - index, - value.as_ptr() as *const _, - value.len() as i32, - Some(std::mem::transmute::<*const c_void, extern "C" fn(*mut c_void)>(ffi::SQLITE_TRANSIENT as *const c_void)), - ) - } + Value::Null => ffi::sqlite3_bind_null(stmt, index), + Value::I32(value) => ffi::sqlite3_bind_int(stmt, index, *value), + Value::I64(value) => ffi::sqlite3_bind_int64(stmt, index, *value), + Value::F32(value) => ffi::sqlite3_bind_double(stmt, index, *value as f64), + Value::F64(value) => ffi::sqlite3_bind_double(stmt, index, *value), + Value::Text(value) => ffi::sqlite3_bind_text( + stmt, + index, + value.as_ptr() as *const _, + value.len() as i32, + Some(std::mem::transmute::< + *const c_void, + extern "C" fn(*mut c_void), + >(ffi::SQLITE_TRANSIENT as *const c_void)), + ), + Value::Bytes(value) => ffi::sqlite3_bind_blob( + stmt, + index, + value.as_ptr() as *const _, + value.len() as i32, + Some(std::mem::transmute::< + *const c_void, + extern "C" fn(*mut c_void), + >(ffi::SQLITE_TRANSIENT as *const c_void)), + ), Value::IpAddr(value) => { let value = value.to_string(); ffi::sqlite3_bind_text( @@ -337,7 +374,10 @@ unsafe fn bind_all(stmt: *mut ffi::sqlite3_stmt, params: &[Value<'_>], error_lev index, value.as_ptr() as *const _, value.len() as i32, - Some(std::mem::transmute::<*const c_void, extern "C" fn(*mut c_void)>(ffi::SQLITE_TRANSIENT as *const c_void)), + Some(std::mem::transmute::< + *const c_void, + extern "C" fn(*mut c_void), + >(ffi::SQLITE_TRANSIENT as *const c_void)), ) } Value::Time(value) => { @@ -347,13 +387,20 @@ unsafe fn bind_all(stmt: *mut ffi::sqlite3_stmt, params: &[Value<'_>], error_lev index, value.as_ptr() as *const _, value.len() as i32, - Some(std::mem::transmute::<*const c_void, extern "C" fn(*mut c_void)>(ffi::SQLITE_TRANSIENT as *const c_void)), + Some(std::mem::transmute::< + *const c_void, + extern "C" fn(*mut c_void), + >(ffi::SQLITE_TRANSIENT as *const c_void)), ) } }; if result != ffi::SQLITE_OK { ffi::sqlite3_finalize(stmt); - return Error::new(error_level, "bind error", CStr::from_ptr(ffi::sqlite3_errstr(result)).to_string_lossy()); + return Error::new( + error_level, + "bind error", + CStr::from_ptr(ffi::sqlite3_errstr(result)).to_string_lossy(), + ); } } @@ -364,16 +411,18 @@ unsafe fn bind_all(stmt: *mut ffi::sqlite3_stmt, params: &[Value<'_>], error_lev mod tests { use crate as concatsql; use concatsql::error::*; + use concatsql::prelude::*; use temporary::Folder; - #[cfg(debug_assertions)] - use concatsql::prep; #[test] fn open() { let dir = Folder::new("sqlite").unwrap(); let path = dir.path().join("test.db"); assert_ne!(crate::sqlite::open(""), crate::sqlite::open("")); - assert_ne!(crate::sqlite::open(":memory:"), crate::sqlite::open(":memory:")); + assert_ne!( + crate::sqlite::open(":memory:"), + crate::sqlite::open(":memory:") + ); assert_ne!(crate::sqlite::open(&path), crate::sqlite::open(&path)); assert_eq!( crate::sqlite::open("foo\0bar"), @@ -386,11 +435,11 @@ mod tests { fn execute() { let conn = crate::sqlite::open(":memory:").unwrap(); assert_eq!( - conn.execute(prep!("\0")), + conn.execute(query!("\0")), Err(Error::Message("invalid query".into())), ); assert_eq!( - conn.execute(prep!("invalid query")), + conn.execute(query!("invalid query")), Err(Error::Message("exec error".into())), ); assert!(conn.execute("SELECT 1").is_ok()); @@ -401,23 +450,27 @@ mod tests { fn iterate() { let conn = crate::sqlite::open(":memory:").unwrap(); assert_eq!( - conn.iterate(prep!("\0"), |_| { unreachable!(); }), + conn.iterate(query!("\0"), |_| { + unreachable!(); + }), Err(Error::Message("invalid query".into())), ); assert_eq!( - conn.iterate(prep!("invalid query"), |_| { unreachable!(); }), + conn.iterate(query!("invalid query"), |_| { + unreachable!(); + }), Err(Error::Message("exec error".into())), ); - assert!(conn.iterate("SELECT 1", |_|{true}).is_ok()); + assert!(conn.iterate("SELECT 1", |_| { true }).is_ok()); } #[test] #[cfg(debug_assertions)] fn simulate() { - assert_eq!(prep!("SELECT").simulate(), "SELECT"); - assert_eq!(prep!("O''Reilly").simulate(), "O''Reilly"); - assert_eq!(prep!("\"O'Reilly\"").simulate(), "\"O'Reilly\""); - assert_eq!(prep!("O'Reilly").simulate(), "O'Reilly"); + assert_eq!(query!("SELECT").simulate(), "SELECT"); + assert_eq!(query!("O''Reilly").simulate(), "O''Reilly"); + assert_eq!(query!("\"O'Reilly\"").simulate(), "\"O'Reilly\""); + assert_eq!(query!("O'Reilly").simulate(), "O'Reilly"); } #[test] @@ -426,4 +479,3 @@ mod tests { assert_eq!(format!("{:?}", &conn), format!("{:?}", &conn)); } } - diff --git a/concatsql/src/sqlite/mod.rs b/concatsql/src/sqlite/mod.rs index 5a1a459..07233b8 100644 --- a/concatsql/src/sqlite/mod.rs +++ b/concatsql/src/sqlite/mod.rs @@ -1,8 +1,8 @@ //! Interface to [SQLite](https://www.sqlite.org) of ConcatSQL. -use std::path::Path; -use crate::Result; use crate::connection::Connection; +use crate::Result; +use std::path::Path; pub(crate) mod connection; @@ -29,7 +29,10 @@ pub(crate) mod connection; /// ``` #[inline] pub fn open>(path: T) -> Result { - connection::open(path, sqlite3_sys::SQLITE_OPEN_CREATE | sqlite3_sys::SQLITE_OPEN_READWRITE) + connection::open( + path, + sqlite3_sys::SQLITE_OPEN_CREATE | sqlite3_sys::SQLITE_OPEN_READWRITE, + ) } /// Open a readonly connection to a new or existing database. @@ -46,11 +49,10 @@ pub fn version() -> usize { unsafe { sqlite3_sys::sqlite3_libversion_number() as usize } } - #[cfg(test)] mod tests { use crate as concatsql; - use concatsql::prep; + use concatsql::prelude::*; use temporary::Folder; #[test] @@ -69,7 +71,8 @@ mod tests { let path = dir.path().join("test.db"); { let conn = crate::sqlite::open(&path).unwrap(); - conn.execute(prep!("CREATE TABLE users(id INTEGER, name TEXT);")).unwrap(); + conn.execute(query!("CREATE TABLE users(id INTEGER, name TEXT);")) + .unwrap(); } crate::sqlite::open_readonly(path).unwrap(); } @@ -82,13 +85,16 @@ mod tests { let path = dir.path().join("test.db"); { let conn = crate::sqlite::open(&path).unwrap(); - conn.execute(prep!("CREATE TABLE users(id INTEGER, name TEXT);")).unwrap(); + conn.execute(query!("CREATE TABLE users(id INTEGER, name TEXT);")) + .unwrap(); } let conn = crate::sqlite::open_readonly(path).unwrap(); conn.error_level(ErrorLevel::Debug); assert_eq!( - conn.execute(prep!("INSERT INTO users VALUES(42, 'Alice');")), - Err(Error::Message("exec error: attempt to write a readonly database".to_string())) + conn.execute(query!("INSERT INTO users VALUES(42, 'Alice');")), + Err(Error::Message( + "exec error: attempt to write a readonly database".to_string() + )) ); } diff --git a/concatsql/src/value.rs b/concatsql/src/value.rs index 7edd0ac..a636570 100644 --- a/concatsql/src/value.rs +++ b/concatsql/src/value.rs @@ -115,7 +115,7 @@ pub trait SystemTimeToString { fn to_string(&self) -> String; } -impl SystemTimeToString for SystemTime { +impl SystemTimeToString for SystemTime { fn to_string(&self) -> String { let datetime: DateTime = (*self).into(); datetime.format("%Y-%m-%d %H:%M:%S.%f").to_string() @@ -128,4 +128,3 @@ impl SystemTimeToString for &SystemTime { datetime.format("%Y-%m-%d %H:%M:%S.%f").to_string() } } - diff --git a/concatsql/src/wrapstring.rs b/concatsql/src/wrapstring.rs index ba0631c..68ca92f 100644 --- a/concatsql/src/wrapstring.rs +++ b/concatsql/src/wrapstring.rs @@ -1,17 +1,17 @@ -use std::ops::Add; use std::borrow::Cow; use std::net::IpAddr; +use std::ops::Add; use std::time::SystemTime; use uuid::Uuid; -use crate::parser::{escape_string, to_binary_literal}; -use crate::value::{Value, ToValue, SystemTimeToString}; use crate::connection::ConnKind; +use crate::parser::{escape_string, to_binary_literal}; +use crate::value::{SystemTimeToString, ToValue, Value}; /// Wraps a [String](https://doc.rust-lang.org/std/string/struct.String.html) type. #[derive(Clone, Debug, PartialEq)] pub struct WrapString<'a> { - pub(crate) query: Vec>>, + pub(crate) query: Vec>>, pub(crate) params: Vec>, } @@ -20,9 +20,7 @@ impl<'a> WrapString<'a> { #[inline] pub fn _init(query: Vec>, params: Vec>) -> Self { Self { - query: query.iter() - .map(|q| q.map(Cow::from)) - .collect(), + query: query.iter().map(|q| q.map(Cow::from)).collect(), params, } } @@ -31,7 +29,7 @@ impl<'a> WrapString<'a> { #[inline] pub fn init(s: &'static str) -> Self { Self { - query: vec![ Some(Cow::Borrowed(s)) ], + query: vec![Some(Cow::Borrowed(s))], params: Vec::new(), } } @@ -40,7 +38,7 @@ impl<'a> WrapString<'a> { #[inline] pub const fn null() -> Self { Self { - query: Vec::new(), + query: Vec::new(), params: Vec::new(), } } @@ -48,7 +46,7 @@ impl<'a> WrapString<'a> { #[inline] pub(crate) fn new(s: &T) -> Self { Self { - query: vec![ Some(Cow::Owned(s.to_string())) ], + query: vec![Some(Cow::Owned(s.to_string()))], params: Vec::new(), } } @@ -78,15 +76,15 @@ impl<'a> WrapString<'a> { Some(s) => query.push_str(s), None => { match &self.params[index] { - Value::Null => query.push_str("NULL"), - Value::I32(value) => query.push_str(&value.to_string()), - Value::I64(value) => query.push_str(&value.to_string()), - Value::F32(value) => query.push_str(&value.to_string()), - Value::F64(value) => query.push_str(&value.to_string()), - Value::Text(value) => query.push_str(&escape_string(value)), - Value::Bytes(value) => query.push_str(&to_binary_literal(value)), + Value::Null => query.push_str("NULL"), + Value::I32(value) => query.push_str(&value.to_string()), + Value::I64(value) => query.push_str(&value.to_string()), + Value::F32(value) => query.push_str(&value.to_string()), + Value::F64(value) => query.push_str(&value.to_string()), + Value::Text(value) => query.push_str(&escape_string(value)), + Value::Bytes(value) => query.push_str(&to_binary_literal(value)), Value::IpAddr(value) => query.push_str(&format!("'{}'", value)), - Value::Time(value) => query.push_str(&format!("'{}'", value.to_string())), + Value::Time(value) => query.push_str(&format!("'{}'", value.to_string())), } index += 1; } @@ -97,11 +95,7 @@ impl<'a> WrapString<'a> { /// Returns the length of a string other than a placeholders. pub fn len(&self) -> usize { - self.query - .iter() - .flatten() - .map(|part|part.len()) - .sum() + self.query.iter().flatten().map(|part| part.len()).sum() } /// Returns the query's vector length. @@ -144,7 +138,7 @@ impl<'a> WrapString<'a> { /// ``` pub fn squash(&mut self) { let mut new_query = Vec::new(); - let mut new_part = String::new(); + let mut new_part = String::new(); for part in &self.query { if let Some(part) = part { new_part.push_str(part); @@ -164,7 +158,7 @@ impl<'a> Add for WrapString<'a> { type Output = WrapString<'a>; #[inline] fn add(mut self, other: WrapString<'a>) -> WrapString<'a> { - self.query .extend_from_slice(&other.query); + self.query.extend_from_slice(&other.query); self.params.extend_from_slice(&other.params); self } @@ -174,7 +168,7 @@ impl<'a, 'b> Add<&'b WrapString<'a>> for WrapString<'a> { type Output = WrapString<'a>; #[inline] fn add(mut self, other: &'b WrapString<'a>) -> WrapString<'a> { - self.query .extend_from_slice(&other.query); + self.query.extend_from_slice(&other.query); self.params.extend_from_slice(&other.params); self } @@ -184,7 +178,7 @@ impl<'a> Add for WrapString<'a> { type Output = WrapString<'a>; #[inline] fn add(mut self, other: String) -> WrapString<'a> { - self.query .push(None); + self.query.push(None); self.params.push(Value::Text(Cow::Owned(other))); self } @@ -194,7 +188,7 @@ impl<'a> Add<&'a String> for WrapString<'a> { type Output = WrapString<'a>; #[inline] fn add(mut self, other: &'a String) -> WrapString<'a> { - self.query .push(None); + self.query.push(None); self.params.push(Value::Text(Cow::Borrowed(other))); self } @@ -204,7 +198,7 @@ impl<'a> Add<&'a str> for WrapString<'a> { type Output = WrapString<'a>; #[inline] fn add(mut self, other: &'a str) -> WrapString<'a> { - self.query .push(None); + self.query.push(None); self.params.push(Value::Text(Cow::Borrowed(other))); self } @@ -214,7 +208,7 @@ impl<'a> Add<&'a &str> for WrapString<'a> { type Output = WrapString<'a>; #[inline] fn add(mut self, other: &'a &str) -> WrapString<'a> { - self.query .push(None); + self.query.push(None); self.params.push(Value::Text(Cow::Borrowed(other))); self } @@ -224,7 +218,7 @@ impl<'a> Add> for WrapString<'a> { type Output = WrapString<'a>; #[inline] fn add(mut self, other: std::borrow::Cow<'a, str>) -> WrapString<'a> { - self.query .push(None); + self.query.push(None); self.params.push(Value::Text(other)); self } @@ -234,7 +228,7 @@ impl<'a> Add<&'a std::borrow::Cow<'a, str>> for WrapString<'a> { type Output = WrapString<'a>; #[inline] fn add(mut self, other: &'a std::borrow::Cow<'a, str>) -> WrapString<'a> { - self.query .push(None); + self.query.push(None); self.params.push(Value::Text(Cow::Borrowed(other))); self } @@ -244,7 +238,7 @@ impl<'a> Add> for WrapString<'a> { type Output = WrapString<'a>; #[inline] fn add(mut self, other: Vec) -> WrapString<'a> { - self.query .push(None); + self.query.push(None); self.params.push(Value::Bytes(other)); self } @@ -254,7 +248,7 @@ impl<'a> Add<&Vec> for WrapString<'a> { type Output = WrapString<'a>; #[inline] fn add(mut self, other: &Vec) -> WrapString<'a> { - self.query .push(None); + self.query.push(None); self.params.push(Value::Bytes(other.clone())); self } @@ -264,7 +258,7 @@ impl<'a> Add<&[u8]> for WrapString<'a> { type Output = WrapString<'a>; #[inline] fn add(mut self, other: &[u8]) -> WrapString<'a> { - self.query .push(None); + self.query.push(None); self.params.push(Value::Bytes(other.to_vec())); self } @@ -320,8 +314,9 @@ impl<'a> Add for WrapString<'a> { type Output = WrapString<'a>; #[inline] fn add(mut self, other: Uuid) -> WrapString<'a> { - self.query .push(None); - self.params.push(Value::Text(Cow::Owned(format!("{:X}", other.simple())))); + self.query.push(None); + self.params + .push(Value::Text(Cow::Owned(format!("{:X}", other.simple())))); self } } @@ -331,8 +326,9 @@ impl<'a> Add<&Uuid> for WrapString<'a> { type Output = WrapString<'a>; #[inline] fn add(mut self, other: &Uuid) -> WrapString<'a> { - self.query .push(None); - self.params.push(Value::Text(Cow::Owned(format!("{:X}", other.simple())))); + self.query.push(None); + self.params + .push(Value::Text(Cow::Owned(format!("{:X}", other.simple())))); self } } @@ -341,7 +337,7 @@ impl<'a> Add for WrapString<'a> { type Output = WrapString<'a>; #[inline] fn add(mut self, other: IpAddr) -> WrapString<'a> { - self.query .push(None); + self.query.push(None); self.params.push(Value::IpAddr(other)); self } @@ -351,7 +347,7 @@ impl<'a> Add<&IpAddr> for WrapString<'a> { type Output = WrapString<'a>; #[inline] fn add(mut self, other: &IpAddr) -> WrapString<'a> { - self.query .push(None); + self.query.push(None); self.params.push(Value::IpAddr(*other)); self } @@ -361,7 +357,7 @@ impl<'a> Add for WrapString<'a> { type Output = WrapString<'a>; #[inline] fn add(mut self, other: SystemTime) -> WrapString<'a> { - self.query .push(None); + self.query.push(None); self.params.push(Value::Time(other)); self } @@ -371,7 +367,7 @@ impl<'a> Add<&SystemTime> for WrapString<'a> { type Output = WrapString<'a>; #[inline] fn add(mut self, other: &SystemTime) -> WrapString<'a> { - self.query .push(None); + self.query.push(None); self.params.push(Value::Time(*other)); self } @@ -391,7 +387,7 @@ impl<'a> Add for WrapString<'a> { type Output = WrapString<'a>; #[inline] fn add(mut self, other: f32) -> WrapString<'a> { - self.query .push(None); + self.query.push(None); self.params.push(Value::F32(other)); self } @@ -401,7 +397,7 @@ impl<'a> Add for WrapString<'a> { type Output = WrapString<'a>; #[inline] fn add(mut self, other: f64) -> WrapString<'a> { - self.query .push(None); + self.query.push(None); self.params.push(Value::F64(other)); self } @@ -444,7 +440,7 @@ impl<'a> Add<()> for WrapString<'a> { type Output = WrapString<'a>; #[inline] fn add(mut self, _other: ()) -> WrapString<'a> { - self.query .push(None); + self.query.push(None); self.params.push(Value::Null); self } @@ -467,7 +463,7 @@ impl<'a> Add> for WrapString<'a> { #[inline] fn add(mut self, other: Vec) -> WrapString<'a> { if other.is_empty() { - self.query .push(None); + self.query.push(None); self.params.push(Value::Null); return self; } @@ -486,7 +482,7 @@ impl<'a> Add> for WrapString<'a> { macro_rules! impl_add_arrays_borrowed_for_WrapString { ( $($t:ty),* ) => {$( - /// In operator with string arrays. + /// In operator with string arrays. /// If the array is empty, it will be ignored. /// /// # Examples @@ -523,7 +519,7 @@ macro_rules! impl_add_arrays_borrowed_for_WrapString { ( $($t:ty,)* ) => { impl_add_arrays_borrowed_for_WrapString!{ $( $t ),* } } } -impl_add_arrays_borrowed_for_WrapString!{ +impl_add_arrays_borrowed_for_WrapString! { Vec<&'a str>, &'a Vec, &'a Vec<&'a str>, @@ -531,7 +527,6 @@ impl_add_arrays_borrowed_for_WrapString!{ &'a [String], } - /// A trait for converting that can be converted to [`WrapString`]. pub trait IntoWrapString<'a> { #[doc(hidden)] @@ -545,29 +540,47 @@ macro_rules! compile { match $kind { #[cfg(feature = "sqlite")] ConnKind::SQLite => { - let mut query = String::with_capacity($self.query.iter().map(|q|q.as_ref().map_or(1, |q|q.len())).sum()); + let mut query = String::with_capacity( + $self + .query + .iter() + .map(|q| q.as_ref().map_or(1, |q| q.len())) + .sum(), + ); for part in &$self.query { match part { Some(s) => query.push_str(s), - None => query.push('?'), + None => query.push('?'), } } Cow::Owned(query) } #[cfg(feature = "mysql")] ConnKind::MySQL => { - let mut query = String::with_capacity($self.query.iter().map(|q|q.as_ref().map_or(1, |q|q.len())).sum()); + let mut query = String::with_capacity( + $self + .query + .iter() + .map(|q| q.as_ref().map_or(1, |q| q.len())) + .sum(), + ); for part in &$self.query { match part { Some(s) => query.push_str(s), - None => query.push('?'), + None => query.push('?'), } } Cow::Owned(query) } #[cfg(feature = "postgres")] ConnKind::PostgreSQL => { - let mut query = String::with_capacity($self.query.iter().map(|q|q.as_ref().map_or(3, |q|q.len())).sum()); + let mut query = String::with_capacity( + $self + .query + .iter() + .map(|q| q.as_ref().map_or(3, |q| q.len())) + .sum(), + ); let mut index = 1; for part in &$self.query { match part { @@ -581,7 +594,7 @@ macro_rules! compile { Cow::Owned(query) } } - } + }; } impl<'a> IntoWrapString<'a> for WrapString<'a> { @@ -630,7 +643,6 @@ impl<'a> IntoWrapString<'a> for &'static str { mod tests { use crate as concatsql; use concatsql::prelude::*; - use concatsql::prep; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::time::UNIX_EPOCH; @@ -640,7 +652,7 @@ mod tests { clippy::deref_addrof, clippy::identity_op, clippy::approx_constant, - clippy::many_single_char_names, + clippy::many_single_char_names )] fn concat_anything_type() { use std::borrow::Cow; @@ -649,146 +661,175 @@ mod tests { let c = &**&&String::from("C"); let d = &***&&&String::from("D"); let e = String::from("E"); - let sql: WrapString = prep!("A") + prep!("B") + "C" + String::from("D") + &e + &prep!("F") + 42 + 3.14; + let sql: WrapString = + query!("A") + query!("B") + "C" + String::from("D") + &e + &query!("F") + 42 + 3.14; assert_eq!(sql.simulate(), "AB'C''D''E'F423.14"); - let sql = prep!() + a + b + c + d; + let sql = query!("") + a + b + c + d; assert_eq!(sql.simulate(), "'A''B''C''D'"); - let sql = prep!() + "A" + &"B" + *&&"C" + **&&&"D"; + let sql = query!("") + "A" + &"B" + *&&"C" + **&&&"D"; assert_eq!(sql.simulate(), "'A''B''C''D'"); - let sql = prep!() + 0usize + 1u8 + 2u16 + 3u32 + 4u64 + 5isize + 6i8 + 7i16 + 8i32 + 9i64 + 0f32 + 1f64; + let sql = query!("") + + 0usize + + 1u8 + + 2u16 + + 3u32 + + 4u64 + + 5isize + + 6i8 + + 7i16 + + 8i32 + + 9i64 + + 0f32 + + 1f64; assert_eq!(sql.simulate(), "012345678901"); - let sql = prep!() + f32::MAX + f32::INFINITY + f32::NAN; - assert_eq!(sql.simulate(), "340282350000000000000000000000000000000infNaN"); - let sql = prep!() + vec![b'A',b'B',b'C'] + &vec![0,1,2]; + let sql = query!("") + f32::MAX + f32::INFINITY + f32::NAN; + assert_eq!( + sql.simulate(), + "340282350000000000000000000000000000000infNaN" + ); + let sql = query!("") + vec![b'A', b'B', b'C'] + &vec![0, 1, 2]; if cfg!(feature = "sqlite") || cfg!(feature = "mysql") { assert_eq!(sql.simulate(), "X'414243'X'000102'"); } else { assert_eq!(sql.simulate(), "'\\x414243''\\x000102'"); } - let sql = prep!() + Cow::Borrowed("A") + &Cow::Borrowed("B") + Cow::Owned("C".to_string()); + let sql = + query!("") + Cow::Borrowed("A") + &Cow::Borrowed("B") + Cow::Owned("C".to_string()); assert_eq!(sql.simulate(), "'A''B''C'"); - let sql = prep!("A") + Some("B") + Some(String::from("C")) + Some(0i32) + Some(3.14f32) + Some(42i32) + None as Option + (); + let sql = query!("A") + + Some("B") + + Some(String::from("C")) + + Some(0i32) + + Some(3.14f32) + + Some(42i32) + + None as Option + + (); assert_eq!(sql.simulate(), "A'B''C'03.1442NULLNULL"); let vec: Vec = Vec::new(); - let sql = prep!("(") + vec + prep!(")"); + let sql = query!("(") + vec + query!(")"); assert_eq!(sql.simulate(), "(NULL)"); - let sql = prep!("(") + vec!["A"] + prep!(")"); + let sql = query!("(") + vec!["A"] + query!(")"); assert_eq!(sql.simulate(), "('A')"); - let sql = prep!("(") + vec!["A","B"] + prep!(")"); + let sql = query!("(") + vec!["A", "B"] + query!(")"); assert_eq!(sql.simulate(), "('A','B')"); - let sql = prep!("(") + vec![String::from("A"),String::from("B")] + prep!(")"); + let sql = query!("(") + vec![String::from("A"), String::from("B")] + query!(")"); assert_eq!(sql.simulate(), "('A','B')"); - let vec = vec!["A","B"]; - let sql = prep!("(") + &vec + prep!(")"); + let vec = vec!["A", "B"]; + let sql = query!("(") + &vec + query!(")"); assert_eq!(sql.simulate(), "('A','B')"); - let vec = vec![String::from("A"),String::from("B")]; - let sql = prep!("(") + &vec + prep!(")"); + let vec = vec![String::from("A"), String::from("B")]; + let sql = query!("(") + &vec + query!(")"); assert_eq!(sql.simulate(), "('A','B')"); - let sql = prep!("(") + &["A","B"][..] + prep!(")"); + let sql = query!("(") + &["A", "B"][..] + query!(")"); assert_eq!(sql.simulate(), "('A','B')"); - let sli = &[String::from("A"),String::from("B")][..]; - let sql = prep!("(") + sli + prep!(")"); + let sli = &[String::from("A"), String::from("B")][..]; + let sql = query!("(") + sli + query!(")"); assert_eq!(sql.simulate(), "('A','B')"); - let sql = prep!() + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); + let sql = query!("") + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); assert_eq!(sql.simulate(), "'127.0.0.1'"); - let sql = prep!() + IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)); + let sql = query!("") + IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)); assert_eq!(sql.simulate(), "'::1'"); - let sql = prep!() + UNIX_EPOCH; + let sql = query!("") + UNIX_EPOCH; assert_eq!(sql.simulate(), "'1970-01-01 00:00:00.000000000'"); } #[test] fn params() { - let sql = prep!() + params![ - (), - 42i8, - 42i16, - 42i32, - 0.1f32, - 2.3f64, - String::from("A"), - "B", - IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), - UNIX_EPOCH, - ]; - assert_eq!(sql.simulate(), "NULL,42,42,42,0.1,2.3,'A','B','::1','1970-01-01 00:00:00.000000000'"); + let sql = query!("") + + params![ + (), + 42i8, + 42i16, + 42i32, + 0.1f32, + 2.3f64, + String::from("A"), + "B", + IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), + UNIX_EPOCH, + ]; + assert_eq!( + sql.simulate(), + "NULL,42,42,42,0.1,2.3,'A','B','::1','1970-01-01 00:00:00.000000000'" + ); } #[test] #[allow(clippy::op_ref)] fn uuid() { use uuid::Uuid; - let uuid = prep!() + Uuid::nil(); + let uuid = query!("") + Uuid::nil(); assert_eq!(uuid.simulate(), "'00000000000000000000000000000000'"); - let uuid = prep!() + &Uuid::nil(); + let uuid = query!("") + &Uuid::nil(); assert_eq!(uuid.simulate(), "'00000000000000000000000000000000'"); - let uuid = prep!() + Uuid::parse_str("936DA01F-9ABD-4D9D-80C7-02AF85C822A8").unwrap(); + let uuid = query!("") + Uuid::parse_str("936DA01F-9ABD-4D9D-80C7-02AF85C822A8").unwrap(); assert_eq!(uuid.simulate(), "'936DA01F9ABD4D9D80C702AF85C822A8'"); - let uuid = prep!() + Uuid::new_v4(); - assert_eq!(uuid.simulate().len(), 32+2); + let uuid = query!("") + Uuid::new_v4(); + assert_eq!(uuid.simulate().len(), 32 + 2); } #[test] fn len() { - assert_eq!((prep!("ABC") + prep!("123")).len(), 6); - let sql: WrapString = prep!("ABC") + 42 + prep!("123"); + assert_eq!((query!("ABC") + query!("123")).len(), 6); + let sql: WrapString = query!("ABC") + 42 + query!("123"); assert_eq!(sql.len(), 6); - assert_eq!(prep!().len(), 0); + assert_eq!(query!("").len(), 0); } #[test] fn query_len() { - assert_eq!((prep!("ABC") + prep!("123")).query_len(), 2); - let sql: WrapString = prep!("ABC") + 42 + prep!("123"); + assert_eq!((query!("ABC") + query!("123")).query_len(), 2); + let sql: WrapString = query!("ABC") + 42 + query!("123"); assert_eq!(sql.query_len(), 3); - assert_eq!(prep!().query_len(), 0); + assert_eq!(query!("").query_len(), 0); } #[test] fn params_len() { - assert_eq!((prep!("ABC") + prep!("123")).params_len(), 0); - let sql: WrapString = prep!("ABC") + 42 + prep!("123"); + assert_eq!((query!("ABC") + query!("123")).params_len(), 0); + let sql: WrapString = query!("ABC") + 42 + query!("123"); assert_eq!(sql.params_len(), 1); - assert_eq!(prep!().params_len(), 0); + assert_eq!(query!("").params_len(), 0); } #[test] fn clear() { - let mut sql: WrapString = prep!("ABC") + 42 + prep!("123"); - assert_eq!(sql.query_len(), 3); + let mut sql: WrapString = query!("ABC") + 42 + query!("123"); + assert_eq!(sql.query_len(), 3); assert_eq!(sql.params_len(), 1); sql.clear(); - assert_eq!(sql.query_len(), 0); + assert_eq!(sql.query_len(), 0); assert_eq!(sql.params_len(), 0); } #[test] fn is_empty() { - assert!(prep!().is_empty()); + assert!(query!("").is_empty()); } #[test] fn squash() { - let mut sql: WrapString = prep!("A") + prep!("B") + 42 + prep!("1") + prep!("2") + prep!("3"); - assert_eq!(sql.query_len(), 6); + let mut sql: WrapString = + query!("A") + query!("B") + 42 + query!("1") + query!("2") + query!("3"); + assert_eq!(sql.query_len(), 6); assert_eq!(sql.params_len(), 1); sql.squash(); - assert_eq!(sql.query_len(), 3); + assert_eq!(sql.query_len(), 3); assert_eq!(sql.params_len(), 1); } mod simulate { use crate as concatsql; - use concatsql::prep; + use concatsql::prelude::*; #[test] fn double_quotaion_inside_double_quote() { assert_eq!( - (prep!() + r#"".ow(""inside str"") -> String""#).simulate(), + (query!("") + r#"".ow(""inside str"") -> String""#).simulate(), r#"'".ow(""inside str"") -> String"'"# ); assert_eq!( - (prep!() + r#"".ow("inside str") -> String""#).simulate(), + (query!("") + r#"".ow("inside str") -> String""#).simulate(), r#"'".ow("inside str") -> String"'"# ); } @@ -796,11 +837,11 @@ mod tests { #[test] fn double_quotaion_inside_sigle_quote() { assert_eq!( - (prep!() + r#""I'm Alice""#).simulate(), + (query!("") + r#""I'm Alice""#).simulate(), r#"'"I''m Alice"'"# ); assert_eq!( - (prep!() + r#""I''m Alice""#).simulate(), + (query!("") + r#""I''m Alice""#).simulate(), r#"'"I''''m Alice"'"# ); } @@ -808,7 +849,7 @@ mod tests { #[test] fn single_quotaion_inside_double_quote() { assert_eq!( - (prep!() + r#"'.ow("inside str") -> String'"#).simulate(), + (query!("") + r#"'.ow("inside str") -> String'"#).simulate(), r#"'''.ow("inside str") -> String'''"# ); } @@ -816,7 +857,7 @@ mod tests { #[test] fn single_quotaion_inside_sigle_quote() { assert_eq!( - (prep!() + "'I''m Alice'").simulate(), + (query!("") + "'I''m Alice'").simulate(), r#"'''I''''m Alice'''"# ); } @@ -824,7 +865,7 @@ mod tests { #[test] fn non_quotaion_inside_sigle_quote() { assert_eq!( - (prep!() + "foo'bar'foo").simulate(), + (query!("") + "foo'bar'foo").simulate(), r#"'foo''bar''foo'"# ); } @@ -832,16 +873,16 @@ mod tests { #[test] fn non_quotaion_inside_double_quote() { assert_eq!( - (prep!() + r#"foo"bar"foo"#).simulate(), + (query!("") + r#"foo"bar"foo"#).simulate(), r#"'foo"bar"foo'"# ); } #[test] fn empty_string() { - assert_eq!(prep!().simulate(), ""); - assert_eq!(prep!("").simulate(), ""); - assert_eq!((prep!("") + "").simulate(), "''"); + assert_eq!(query!("").simulate(), ""); + assert_eq!(query!("").simulate(), ""); + assert_eq!((query!("") + "").simulate(), "''"); } } } diff --git a/concatsql/tests/mysql.rs b/concatsql/tests/mysql.rs index 8e37d7a..83ccf6a 100644 --- a/concatsql/tests/mysql.rs +++ b/concatsql/tests/mysql.rs @@ -4,29 +4,30 @@ #[cfg(debug_assertions)] mod mysql { use concatsql::prelude::*; - use concatsql::prep; use concatsql::{Error, ErrorLevel}; macro_rules! err { - () => { Err(Error::AnyError) }; - ($msg:expr) => { Err(Error::Message($msg.to_string())) }; + () => { + Err(Error::AnyError) + }; + ($msg:expr) => { + Err(Error::Message($msg.to_string())) + }; } + const STMT: &str = r#"CREATE TEMPORARY TABLE users (name TEXT, age INTEGER); + INSERT INTO users (name, age) VALUES ('Alice', 42); + INSERT INTO users (name, age) VALUES ('Bob', 69); + INSERT INTO users (name, age) VALUES ('Carol', 50);"#; + pub fn prepare() -> concatsql::Connection { let conn = concatsql::mysql::open("mysql://localhost:3306/test").unwrap(); conn.error_level(ErrorLevel::Debug); - let stmt = prep!(stmt()); - conn.execute(stmt).unwrap(); + let query = query!("{STMT}"); + conn.execute(query).unwrap(); conn } - fn stmt() -> &'static str { - r#"CREATE TEMPORARY TABLE users (name TEXT, age INTEGER); - INSERT INTO users (name, age) VALUES ('Alice', 42); - INSERT INTO users (name, age) VALUES ('Bob', 69); - INSERT INTO users (name, age) VALUES ('Carol', 50);"# - } - #[test] fn open() { let _conn = concatsql::mysql::open("mysql://localhost:3306/test").unwrap(); @@ -35,36 +36,38 @@ mod mysql { #[test] fn execute() { let conn = concatsql::mysql::open("mysql://localhost:3306/test").unwrap(); - let stmt = prep!(stmt()); - conn.execute(stmt).unwrap(); + let query = query!("{STMT}"); + conn.execute(query).unwrap(); } #[test] fn iterate() { let conn = prepare(); let expects = ["Alice", "Bob", "Carol"]; - let sql = prep!("SELECT name FROM users;"); + let sql = query!("SELECT name FROM users;"); conn.iterate(sql, |pairs| { for (i, (_, value)) in pairs.iter().enumerate() { assert_eq!(*value.as_ref().unwrap(), expects[i]); } true - }).unwrap(); + }) + .unwrap(); } #[test] fn iterate_2sets() { let conn = prepare(); let expects = ["Alice", "Bob", "Carol", "Alice", "Bob", "Carol"]; - let sql = prep!("SELECT name FROM users; SELECT name FROM users;"); + let sql = query!("SELECT name FROM users; SELECT name FROM users;"); conn.iterate(sql, |pairs| { for (i, (_, value)) in pairs.iter().enumerate() { assert_eq!(*value.as_ref().unwrap(), expects[i]); } true - }).unwrap(); + }) + .unwrap(); } #[test] @@ -72,29 +75,34 @@ mod mysql { let conn = prepare(); let expects = ["Alice", "Bob"]; let age = "50"; - let sql = prep!("SELECT name FROM users WHERE ") + - &prep!("age < ") + age + &prep!(" OR ") + age + &prep!(" < age"); + let sql = query!("SELECT name FROM users WHERE ") + + &query!("age < ") + + age + + &query!(" OR ") + + age + + &query!(" < age"); conn.iterate(sql, |pairs| { for (i, (_, value)) in pairs.iter().enumerate() { assert_eq!(*value.as_ref().unwrap(), expects[i]); } true - }).unwrap(); + }) + .unwrap(); } #[test] fn rows() { let conn = prepare(); - let expects = [ ("Alice", 42), ("Bob", 69), ("Carol", 50) ]; - let sql = prep!("SELECT * FROM users;"); + let expects = [("Alice", 42), ("Bob", 69), ("Carol", 50)]; + let sql = query!("SELECT * FROM users;"); let mut cnt = 0; let rows = conn.rows(&sql).unwrap(); for (i, row) in rows.iter().enumerate() { cnt += 1; assert_eq!(row.get("name").unwrap(), expects[i].0); - assert_eq!(row.get("age").unwrap(), expects[i].1.to_string()); + assert_eq!(row.get("age").unwrap(), expects[i].1.to_string()); } assert!(cnt == expects.len()); } @@ -102,14 +110,18 @@ mod mysql { #[test] fn rows_foreach() { let conn = prepare(); - let expects = [ ("Alice", 42), ("Bob", 69), ("Carol", 50) ]; + let expects = [("Alice", 42), ("Bob", 69), ("Carol", 50)]; let mut cnt = 0; - conn.rows(&prep!("SELECT * FROM users;")).unwrap().iter().enumerate().for_each(|(i, row)| { - cnt += 1; - assert_eq!(row.get("name").unwrap(), expects[i].0); - assert_eq!(row.get("age").unwrap(), expects[i].1.to_string()); - }); + conn.rows(query!("SELECT * FROM users;")) + .unwrap() + .iter() + .enumerate() + .for_each(|(i, row)| { + cnt += 1; + assert_eq!(row.get("name").unwrap(), expects[i].0); + assert_eq!(row.get("age").unwrap(), expects[i].1.to_string()); + }); assert!(cnt == expects.len()); } @@ -117,47 +129,56 @@ mod mysql { fn start_with_quotation_and_end_with_anything_else() { let conn = prepare(); let name = "'Alice'; DROP TABLE users; --"; - let sql = prep!("select age from users where name = ") + name; + let sql = query!("select age from users where name = ") + name; assert_eq!( sql.simulate(), "select age from users where name = '''Alice''; DROP TABLE users; --'" ); - conn.iterate(&sql, |_| { unreachable!(); }).unwrap(); + conn.iterate(&sql, |_| { + unreachable!(); + }) + .unwrap(); } #[test] fn whitespace() { let conn = prepare(); - let sql = prep!("select\n*\rfrom\nusers;"); + let sql = query!("select\n*\rfrom\nusers;"); - conn.iterate(sql, |_| { true }).unwrap(); + conn.iterate(sql, |_| true).unwrap(); } #[test] fn sqli_eq_nonquote() { let conn = prepare(); let name = "Alice' or '1'='1"; - let sql = prep!("select age from users where name =") + name + &prep!(";"); + let sql = query!("select age from users where name =") + name + &query!(";"); // "select age from users where name = 'Alice'' or ''1''=''1';" - conn.iterate(sql, |_| { unreachable!(); }).unwrap(); + conn.iterate(sql, |_| { + unreachable!(); + }) + .unwrap(); } #[test] fn sanitizing() { let conn = prepare(); let name = r#""#; - let sql = prep!("INSERT INTO users VALUES(") + name + &prep!(", 12345);"); + let sql = query!("INSERT INTO users VALUES(") + name + &query!(", 12345);"); conn.execute(sql).unwrap(); - conn.rows(prep!("SELECT name FROM users WHERE age = 12345;")).unwrap().iter() .all(|row| { - assert_eq!( - concatsql::html_special_chars(row.get("name").unwrap()), - "<script>alert("&1");</script>" - ); - true - }); + conn.rows(query!("SELECT name FROM users WHERE age = 12345;")) + .unwrap() + .iter() + .all(|row| { + assert_eq!( + concatsql::html_special_chars(row.get("name").unwrap()), + "<script>alert("&1");</script>" + ); + true + }); } #[test] @@ -176,9 +197,9 @@ mod mysql { conn.error_level(ErrorLevel::AlwaysOk); let invalid_sql = "INVALID_SQL"; - assert_eq!(conn.execute(invalid_sql), Ok(())); - assert_eq!(conn.iterate(invalid_sql, |_| unreachable!()), Ok(())); - assert_eq!(conn.rows(invalid_sql), Ok(Vec::new())); + assert_eq!(conn.execute(invalid_sql), Ok(())); + assert_eq!(conn.iterate(invalid_sql, |_| unreachable!()), Ok(())); + assert_eq!(conn.rows(invalid_sql), Ok(Vec::new())); } #[test] @@ -187,9 +208,9 @@ mod mysql { conn.error_level(ErrorLevel::Release); let invalid_sql = "INVALID_SQL"; - assert_eq!(conn.execute(invalid_sql), err!()); - assert_eq!(conn.iterate(invalid_sql, |_| unreachable!()), err!()); - assert_eq!(conn.rows(invalid_sql), err!()); + assert_eq!(conn.execute(invalid_sql), err!()); + assert_eq!(conn.iterate(invalid_sql, |_| unreachable!()), err!()); + assert_eq!(conn.rows(invalid_sql), err!()); } #[test] @@ -198,9 +219,12 @@ mod mysql { conn.error_level(ErrorLevel::Develop); let invalid_sql = "INVALID_SQL"; - assert_eq!(conn.execute(invalid_sql), err!("exec error")); - assert_eq!(conn.iterate(invalid_sql, |_| unreachable!()), err!("exec error")); - assert_eq!(conn.rows(invalid_sql), err!("exec error")); + assert_eq!(conn.execute(invalid_sql), err!("exec error")); + assert_eq!( + conn.iterate(invalid_sql, |_| unreachable!()), + err!("exec error") + ); + assert_eq!(conn.rows(invalid_sql), err!("exec error")); } #[test] @@ -221,7 +245,7 @@ mod mysql { fn integer() { let conn = prepare(); let age = 50; - let sql = prep!("select name from users where age <") + age; + let sql = query!("select name from users where age <") + age; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get("name").unwrap(), "Alice"); @@ -231,20 +255,20 @@ mod mysql { #[test] fn prep_into_execute() { let conn = concatsql::mysql::open("mysql://localhost:3306/test").unwrap(); - conn.execute(prep!("SELECT ") + 1).unwrap(); + conn.execute(query!("SELECT ") + 1).unwrap(); } #[test] fn prep_into_iterate() { let conn = concatsql::mysql::open("mysql://localhost:3306/test").unwrap(); - conn.iterate(prep!("SELECT ") + 1, |_| true ).unwrap(); + conn.iterate(query!("SELECT ") + 1, |_| true).unwrap(); } #[test] fn prep_into_rows() { let conn = concatsql::mysql::open("mysql://localhost:3306/test").unwrap(); - for row in conn.rows(prep!("SELECT ") + 1).unwrap() { - assert_eq!(row.get(0).unwrap(), "1"); + for row in conn.rows(query!("SELECT ") + 1).unwrap() { + assert_eq!(row.get(0).unwrap(), "1"); assert_eq!(row.get("?").unwrap(), "1"); } } @@ -253,20 +277,25 @@ mod mysql { fn executable_comment_syntax() { let conn = prepare(); let sqls = vec![ - //(prep!("SELECT 1 ") + "/*! +1 */", "SELECT 1 '/*! +1 */'", "1"), <- syntax error - (prep!("SELECT 1 /*! +1 */"), "SELECT 1 /*! +1 */", "2"), - (prep!("SELECT /*! 42 */"), "SELECT /*! 42 */", "42"), - (prep!("SELECT ") + "/*! 42 */", "SELECT '/*! 42 */'", "/*! 42 */"), + //(query!("SELECT 1 ") + "/*! +1 */", "SELECT 1 '/*! +1 */'", "1"), <- syntax error + (query!("SELECT 1 /*! +1 */"), "SELECT 1 /*! +1 */", "2"), + (query!("SELECT /*! 42 */"), "SELECT /*! 42 */", "42"), + ( + query!("SELECT ") + "/*! 42 */", + "SELECT '/*! 42 */'", + "/*! 42 */", + ), ]; for (sql, simulate, result) in sqls { assert_eq!(sql.simulate(), simulate); conn.iterate(&sql, |pairs| { - for (_, (_, value)) in pairs.iter().enumerate() { + for (_, value) in pairs.iter() { assert_eq!(*value.as_ref().unwrap(), result); } true - }).unwrap(); + }) + .unwrap(); } } @@ -275,7 +304,7 @@ mod mysql { let conn = prepare(); let name = "A%"; - let sql = prep!("SELECT * FROM users WHERE name LIKE") + name + prep!(";"); + let sql = query!("SELECT * FROM users WHERE name LIKE") + name + query!(";"); let mut executed = false; conn.rows(&sql).unwrap().iter().all(|row| { @@ -286,18 +315,26 @@ mod mysql { assert!(executed); let name = "A"; - let sql = prep!("SELECT * FROM users WHERE name LIKE ") + ("%".to_owned() + name + "%"); + let sql = query!("SELECT * FROM users WHERE name LIKE ") + ("%".to_owned() + name + "%"); assert_eq!(sql.simulate(), "SELECT * FROM users WHERE name LIKE '%A%'"); conn.execute(&sql).unwrap(); let name = "%A%"; - let sql = prep!("SELECT * FROM users WHERE name LIKE ") + ("%".to_owned() + &sanitize_like!(name) + "%"); - assert_eq!(sql.simulate(), "SELECT * FROM users WHERE name LIKE '%\\\\%A\\\\%%'"); + let sql = query!("SELECT * FROM users WHERE name LIKE ") + + ("%".to_owned() + &sanitize_like!(name) + "%"); + assert_eq!( + sql.simulate(), + "SELECT * FROM users WHERE name LIKE '%\\\\%A\\\\%%'" + ); conn.execute(&sql).unwrap(); let name = String::from("%A%"); - let sql = prep!("SELECT * FROM users WHERE name LIKE ") + ("%".to_owned() + &sanitize_like!(name, '$') + "%"); - assert_eq!(sql.simulate(), "SELECT * FROM users WHERE name LIKE '%$%A$%%'"); + let sql = query!("SELECT * FROM users WHERE name LIKE ") + + ("%".to_owned() + &sanitize_like!(name, '$') + "%"); + assert_eq!( + sql.simulate(), + "SELECT * FROM users WHERE name LIKE '%$%A$%%'" + ); conn.execute(&sql).unwrap(); } @@ -305,17 +342,27 @@ mod mysql { fn multiple_stmt() { let conn = prepare(); let mut cnt = 0; - for (i, row) in conn.rows("SELECT 1; SELECT 2, 3;").unwrap().iter().enumerate() { - /*^^^^^^^^*/// <- only first statement + for (i, row) in conn + .rows("SELECT 1; SELECT 2, 3;") + .unwrap() + .iter() + .enumerate() + { + /*^^^^^^^^*/// <- only first statement cnt += 1; - assert_eq!(row.get_into::<_, i32>(0).unwrap(), [ 1, 2, 3 ][i]); - }; + assert_eq!(row.get_into::<_, i32>(0).unwrap(), [1, 2, 3][i]); + } let conn = prepare(); - for (i, row) in conn.rows("SELECT age FROM users;").unwrap().iter().enumerate() { + for (i, row) in conn + .rows("SELECT age FROM users;") + .unwrap() + .iter() + .enumerate() + { cnt += 1; - assert_eq!(row.get_into::<_, i32>(0).unwrap(), [ 42, 69, 50 ][i]); - }; + assert_eq!(row.get_into::<_, i32>(0).unwrap(), [42, 69, 50][i]); + } assert_eq!(cnt, 4); } @@ -336,9 +383,10 @@ mod mysql { #[test] fn blob() { let conn = concatsql::mysql::open("mysql://localhost:3306/test").unwrap(); - conn.execute("CREATE TEMPORARY TABLE b (data blob)").unwrap(); + conn.execute("CREATE TEMPORARY TABLE b (data blob)") + .unwrap(); let data = vec![0x1, 0xA, 0xFF, 0x00, 0x7F]; - let sql = prep!("INSERT INTO b VALUES (") + &data + prep!(")"); + let sql = query!("INSERT INTO b VALUES (") + &data + query!(")"); conn.execute(sql).unwrap(); for row in conn.rows("SELECT data FROM b").unwrap() { assert_eq!(row.get_into::<_, Vec>(0).unwrap(), data); @@ -348,19 +396,24 @@ mod mysql { #[test] fn question() { let conn = prepare(); - let sql = prep!("SELECT name FROM users WHERE name=") + "?"; - for _ in conn.rows(&sql).unwrap() { unreachable!(); } + let sql = query!("SELECT name FROM users WHERE name=") + "?"; + for _ in conn.rows(&sql).unwrap() { + unreachable!(); + } } #[test] fn map_collect() { let conn = prepare(); let rows = conn.rows("SELECT * FROM users").unwrap(); - let names = rows.iter().map(|row| row.get("name")).collect::>>(); + let names = rows + .iter() + .map(|row| row.get("name")) + .collect::>>(); let mut cnt = 0; for (i, name) in names.iter().enumerate() { cnt += 1; - assert_eq!(name.unwrap(), ["Alice","Bob","Carol"][i]) + assert_eq!(name.unwrap(), ["Alice", "Bob", "Carol"][i]) } assert_eq!(cnt, 3); } @@ -368,11 +421,11 @@ mod mysql { #[test] fn in_array() { let conn = prepare(); - let sql = prep!("SELECT * FROM users WHERE name IN (") + vec![] as Vec<&str> + prep!(")"); + let sql = query!("SELECT * FROM users WHERE name IN (") + vec![] as Vec<&str> + query!(")"); conn.rows(&sql).unwrap(); - let sql = prep!("SELECT * FROM users WHERE name IN (") + vec!["Adam"] + prep!(")"); + let sql = query!("SELECT * FROM users WHERE name IN (") + vec!["Adam"] + query!(")"); conn.rows(&sql).unwrap(); - let sql = prep!("SELECT * FROM users WHERE name IN (") + vec!["Adam","Eve"] + prep!(")"); + let sql = query!("SELECT * FROM users WHERE name IN (") + vec!["Adam", "Eve"] + query!(")"); conn.rows(&sql).unwrap(); } @@ -380,11 +433,12 @@ mod mysql { fn uuid() { use uuid::Uuid; let conn = prepare(); - let sql = prep!("SELECT ") + Uuid::nil(); + let sql = query!("SELECT ") + Uuid::nil(); for row in conn.rows(&sql).unwrap() { assert_eq!(&row[0], "00000000000000000000000000000000"); } - let sql = prep!("SELECT ") + Uuid::parse_str("936DA01F-9ABD-4D9D-80C7-02AF85C822A8").unwrap(); + let sql = + query!("SELECT ") + Uuid::parse_str("936DA01F-9ABD-4D9D-80C7-02AF85C822A8").unwrap(); for row in conn.rows(&sql).unwrap() { assert_eq!(&row[0], "936DA01F9ABD4D9D80C702AF85C822A8"); } @@ -395,89 +449,98 @@ mod mysql { let conn = prepare(); let name = "' OR 1=2; SELECT 1; --"; - let sql = prep!("SELECT age FROM users WHERE name = '") + name + &prep!("';"); // '?' is not placeholder + let sql = query!("SELECT age FROM users WHERE name = '") + name + &query!("';"); // '?' is not placeholder assert_eq!( conn.rows(&sql), - Err(Error::Message("exec error: DriverError { Statement takes 0 parameters but 1 was supplied }".to_string())) + Err(Error::Message( + "exec error: DriverError { Statement takes 0 parameters but 1 was supplied }" + .to_string() + )) ); let name = "' OR 1=1; --"; - let sql = prep!("SELECT age FROM users WHERE name = '") + name + &prep!("';"); // '?' is not placeholder + let sql = query!("SELECT age FROM users WHERE name = '") + name + &query!("';"); // '?' is not placeholder assert_eq!( conn.rows(&sql), - Err(Error::Message("exec error: DriverError { Statement takes 0 parameters but 1 was supplied }".to_string())) + Err(Error::Message( + "exec error: DriverError { Statement takes 0 parameters but 1 was supplied }" + .to_string() + )) ); let name = "Alice"; - let sql = prep!("SELECT age FROM users WHERE name = '") + name + &prep!("';"); // '?' is not placeholder + let sql = query!("SELECT age FROM users WHERE name = '") + name + &query!("';"); // '?' is not placeholder assert_eq!( conn.rows(&sql), - Err(Error::Message("exec error: DriverError { Statement takes 0 parameters but 1 was supplied }".to_string())) + Err(Error::Message( + "exec error: DriverError { Statement takes 0 parameters but 1 was supplied }" + .to_string() + )) ); let name = "'' OR 1=1; --"; - let sql = prep!("SELECT age FROM users WHERE name = ") + name; + let sql = query!("SELECT age FROM users WHERE name = ") + name; for _ in conn.rows(&sql).unwrap() { unreachable!(); } let name = "''; DROP TABLE users; --"; - let sql = prep!("SELECT age FROM users WHERE name = ") + name; + let sql = query!("SELECT age FROM users WHERE name = ") + name; for _ in conn.rows(&sql).unwrap() { unreachable!(); } - let sql = prep!("SELECT ") + "0x50 + 0x45"; + let sql = query!("SELECT ") + "0x50 + 0x45"; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "0x50 + 0x45"); } - let sql = prep!("SELECT ") + "0x414243"; + let sql = query!("SELECT ") + "0x414243"; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "0x414243"); } - let sql = prep!("SELECT ") + "CHAR(0x66)"; + let sql = query!("SELECT ") + "CHAR(0x66)"; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "CHAR(0x66)"); } - let sql = prep!("SELECT ") + "IF(1=1, 'true', 'false')"; + let sql = query!("SELECT ") + "IF(1=1, 'true', 'false')"; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "IF(1=1, 'true', 'false')"); } - let sql = prep!("SELECT ") + "na + '-' + me FROM users"; + let sql = query!("SELECT ") + "na + '-' + me FROM users"; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "na + '-' + me FROM users"); } - let sql = prep!("SELECT ") + "CONCAT(login, password)"; + let sql = query!("SELECT ") + "CONCAT(login, password)"; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "CONCAT(login, password)"); } - let sql = prep!("SELECT ") + "CONCAT('0x',HEX('c:\\\\boot.ini'))"; + let sql = query!("SELECT ") + "CONCAT('0x',HEX('c:\\\\boot.ini'))"; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "CONCAT('0x',HEX('c:\\\\boot.ini'))"); } - let sql = prep!("SELECT ") + "CONCAT(CHAR(75),CHAR(76),CHAR(77))"; + let sql = query!("SELECT ") + "CONCAT(CHAR(75),CHAR(76),CHAR(77))"; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "CONCAT(CHAR(75),CHAR(76),CHAR(77))"); } - let sql = prep!("SELECT ") + "LOAD_FILE(0x633A5C626F6F742E696E69)"; + let sql = query!("SELECT ") + "LOAD_FILE(0x633A5C626F6F742E696E69)"; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "LOAD_FILE(0x633A5C626F6F742E696E69)"); } - let sql = prep!("SELECT ") + "ASCII('a')"; + let sql = query!("SELECT ") + "ASCII('a')"; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "ASCII('a')"); } - let sql = prep!("SELECT ") + "CHAR(64)"; + let sql = query!("SELECT ") + "CHAR(64)"; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "CHAR(64)"); } @@ -486,14 +549,14 @@ mod mysql { #[cfg(feature = "mysql")] mod anti_patterns { - use concatsql::prep; + use concatsql::prelude::*; // Although it becomes possible, I do not believe it is less useful // because its real advantage is that it still makes it harder to do the wrong thing. #[test] fn string_to_static_str() { let conn = concatsql::mysql::open("mysql://localhost:3306/test").unwrap(); - let sql: &'static str = Box::leak(String::from("SELECT 1").into_boxed_str()); // Leak! + let sql: &'static str = Box::leak(String::from("SELECT 1").into_boxed_str()); // Leak! conn.execute(sql).unwrap(); } @@ -502,62 +565,62 @@ mod anti_patterns { let conn = super::mysql::prepare(); let mut cnt = 0; - let sql = prep!("SELECT age FROM users WHERE name = ") + i32::MAX; + let sql = query!("SELECT age FROM users WHERE name = ") + i32::MAX; for _ in conn.rows(&sql).unwrap() { unreachable!(); } - let sql = prep!("SELECT age FROM users WHERE name < ") + i32::MAX; + let sql = query!("SELECT age FROM users WHERE name < ") + i32::MAX; for _ in conn.rows(&sql).unwrap() { cnt += 1; } - let sql = prep!("SELECT age FROM users WHERE name > ") + i32::MAX; + let sql = query!("SELECT age FROM users WHERE name > ") + i32::MAX; for _ in conn.rows(&sql).unwrap() { unreachable!(); } - let sql = prep!("SELECT age FROM users WHERE name = ") + i32::MIN; + let sql = query!("SELECT age FROM users WHERE name = ") + i32::MIN; for _ in conn.rows(&sql).unwrap() { unreachable!(); } - let sql = prep!("SELECT age FROM users WHERE name < ") + i32::MIN; + let sql = query!("SELECT age FROM users WHERE name < ") + i32::MIN; for _ in conn.rows(&sql).unwrap() { unreachable!(); } - let sql = prep!("SELECT age FROM users WHERE name > ") + i32::MIN; + let sql = query!("SELECT age FROM users WHERE name > ") + i32::MIN; for _ in conn.rows(&sql).unwrap() { cnt += 1; } - let sql = prep!("SELECT age FROM users WHERE name = ") + u32::MAX; + let sql = query!("SELECT age FROM users WHERE name = ") + u32::MAX; for _ in conn.rows(&sql).unwrap() { unreachable!(); } - let sql = prep!("SELECT age FROM users WHERE name < ") + u32::MAX; + let sql = query!("SELECT age FROM users WHERE name < ") + u32::MAX; for _ in conn.rows(&sql).unwrap() { unreachable!(); } - let sql = prep!("SELECT age FROM users WHERE name > ") + u32::MAX; + let sql = query!("SELECT age FROM users WHERE name > ") + u32::MAX; for _ in conn.rows(&sql).unwrap() { cnt += 1; } - let sql = prep!("SELECT age FROM users WHERE name = ") + u32::MIN; + let sql = query!("SELECT age FROM users WHERE name = ") + u32::MIN; for _ in conn.rows(&sql).unwrap() { cnt += 1; } - let sql = prep!("SELECT age FROM users WHERE name < ") + u32::MIN; + let sql = query!("SELECT age FROM users WHERE name < ") + u32::MIN; for _ in conn.rows(&sql).unwrap() { unreachable!(); } - let sql = prep!("SELECT age FROM users WHERE name > ") + u32::MIN; + let sql = query!("SELECT age FROM users WHERE name > ") + u32::MIN; for _ in conn.rows(&sql).unwrap() { unreachable!(); } @@ -565,4 +628,3 @@ mod anti_patterns { assert_eq!(cnt, 12); } } - diff --git a/concatsql/tests/postgres.rs b/concatsql/tests/postgres.rs index e95236a..06c5f4f 100644 --- a/concatsql/tests/postgres.rs +++ b/concatsql/tests/postgres.rs @@ -4,47 +4,49 @@ #[cfg(debug_assertions)] mod postgres { use concatsql::prelude::*; - use concatsql::prep; use concatsql::{Error, ErrorLevel}; macro_rules! err { - () => { Err(Error::AnyError) }; - ($msg:expr) => { Err(Error::Message($msg.to_string())) }; + () => { + Err(Error::AnyError) + }; + ($msg:expr) => { + Err(Error::Message($msg.to_string())) + }; } + const STMT: &str = r#"CREATE TEMPORARY TABLE users (name TEXT, age INTEGER); + INSERT INTO users (name, age) VALUES ('Alice', 42); + INSERT INTO users (name, age) VALUES ('Bob', 69); + INSERT INTO users (name, age) VALUES ('Carol', 50);"#; + pub fn prepare() -> concatsql::Connection { let conn = concatsql::postgres::open("postgresql://postgres:postgres@localhost").unwrap(); conn.error_level(ErrorLevel::Debug); - let stmt = prep!(stmt()); - conn.execute(stmt).unwrap(); + let query = query!("{STMT}"); + conn.execute(query).unwrap(); conn } - fn stmt() -> &'static str { - r#"CREATE TEMPORARY TABLE users (name TEXT, age INTEGER); - INSERT INTO users (name, age) VALUES ('Alice', 42); - INSERT INTO users (name, age) VALUES ('Bob', 69); - INSERT INTO users (name, age) VALUES ('Carol', 50);"# - } - #[test] fn open() { - let _conn = concatsql::postgres::open("host=localhost user=postgres password=postgres").unwrap(); + let _conn = + concatsql::postgres::open("host=localhost user=postgres password=postgres").unwrap(); let _conn = concatsql::postgres::open("postgresql://postgres:postgres@localhost").unwrap(); } #[test] fn execute() { let conn = concatsql::postgres::open("postgresql://postgres:postgres@localhost").unwrap(); - let stmt = prep!(stmt()); - conn.execute(stmt).unwrap(); + let query = query!("{STMT}"); + conn.execute(query).unwrap(); } #[test] fn iterate() { let conn = prepare(); let expects = ["Alice", "Bob", "Carol"]; - let sql = prep!("SELECT name FROM users;"); + let sql = query!("SELECT name FROM users;"); let mut i = 0; conn.iterate(sql, |pairs| { @@ -53,25 +55,33 @@ mod postgres { i += 1; } true - }).unwrap(); + }) + .unwrap(); } #[test] #[should_panic = "exec error"] fn multiple_stm_should_errort() { let conn = prepare(); - let sql = prep!("SELECT name FROM users; SELECT name FROM users;"); + let sql = query!("SELECT name FROM users; SELECT name FROM users;"); - conn.iterate(sql, |_| { unreachable!(); }).unwrap(); + conn.iterate(sql, |_| { + unreachable!(); + }) + .unwrap(); } #[test] fn iterate_or() { let conn = prepare(); let expects = ["Alice", "Bob"]; - let age = 50; // "50" error - let sql = prep!("SELECT name FROM users WHERE ") + - &prep!("age < ") + age + &prep!(" OR ") + age + &prep!(" < age"); + let age = 50; // "50" error + let sql = query!("SELECT name FROM users WHERE ") + + &query!("age < ") + + age + + &query!(" OR ") + + age + + &query!(" < age"); let mut i = 0; conn.iterate(sql, |pairs| { @@ -80,21 +90,22 @@ mod postgres { i += 1; } true - }).unwrap(); + }) + .unwrap(); } #[test] fn rows() { let conn = prepare(); let expects = [("Alice", 42), ("Bob", 69), ("Carol", 50)]; - let sql = prep!("SELECT * FROM users;"); + let sql = query!("SELECT * FROM users;"); let mut cnt = 0; let rows = conn.rows(&sql).unwrap(); for (i, row) in rows.iter().enumerate() { cnt += 1; assert_eq!(row.get("name").unwrap(), expects[i].0); - assert_eq!(row.get("age").unwrap(), expects[i].1.to_string()); + assert_eq!(row.get("age").unwrap(), expects[i].1.to_string()); } assert!(cnt == expects.len()); } @@ -105,11 +116,15 @@ mod postgres { let expects = [("Alice", 42), ("Bob", 69), ("Carol", 50)]; let mut cnt = 0; - conn.rows(&prep!("SELECT * FROM users;")).unwrap().iter().enumerate().for_each(|(i, row)| { - cnt += 1; - assert_eq!(row.get("name").unwrap(), expects[i].0); - assert_eq!(row.get("age").unwrap(), expects[i].1.to_string()); - }); + conn.rows(query!("SELECT * FROM users;")) + .unwrap() + .iter() + .enumerate() + .for_each(|(i, row)| { + cnt += 1; + assert_eq!(row.get("name").unwrap(), expects[i].0); + assert_eq!(row.get("age").unwrap(), expects[i].1.to_string()); + }); assert!(cnt == expects.len()); } @@ -117,47 +132,56 @@ mod postgres { fn start_with_quotation_and_end_with_anything_else() { let conn = prepare(); let name = "'Alice'; DROP TABLE users; --"; - let sql = prep!("select age from users where name = ") + name + &prep!(""); + let sql = query!("select age from users where name = ") + name + &query!(""); assert_eq!( sql.simulate(), "select age from users where name = '''Alice''; DROP TABLE users; --'" ); - conn.iterate(&sql, |_| { unreachable!(); }).unwrap(); + conn.iterate(&sql, |_| { + unreachable!(); + }) + .unwrap(); } #[test] fn whitespace() { let conn = prepare(); - let sql = prep!("select\n*\rfrom\nusers;"); + let sql = query!("select\n*\rfrom\nusers;"); - conn.iterate(sql, |_| { true }).unwrap(); + conn.iterate(sql, |_| true).unwrap(); } #[test] fn sqli_eq_nonquote() { let conn = prepare(); let name = "Alice' or '1'='1"; - let sql = prep!("select age from users where name =") + name + &prep!(";"); + let sql = query!("select age from users where name =") + name + &query!(";"); // "select age from users where name = 'Alice'' or ''1''=''1';" - conn.iterate(sql, |_| { unreachable!(); }).unwrap(); + conn.iterate(sql, |_| { + unreachable!(); + }) + .unwrap(); } #[test] fn sanitizing() { let conn = prepare(); let name = r#""#; - let sql = prep!("INSERT INTO users VALUES(") + name + &prep!(", 12345);"); + let sql = query!("INSERT INTO users VALUES(") + name + &query!(", 12345);"); conn.execute(sql).unwrap(); - conn.rows(prep!("SELECT name FROM users WHERE age = 12345;")).unwrap().iter() .all(|row| { - assert_eq!( - concatsql::html_special_chars(row.get("name").unwrap()), - "<script>alert("&1");</script>" - ); - true - }); + conn.rows(query!("SELECT name FROM users WHERE age = 12345;")) + .unwrap() + .iter() + .all(|row| { + assert_eq!( + concatsql::html_special_chars(row.get("name").unwrap()), + "<script>alert("&1");</script>" + ); + true + }); } #[test] @@ -176,9 +200,9 @@ mod postgres { conn.error_level(ErrorLevel::AlwaysOk); let invalid_sql = "INVALID_SQL"; - assert_eq!(conn.execute(invalid_sql), Ok(())); - assert_eq!(conn.iterate(invalid_sql, |_| unreachable!()), Ok(())); - assert_eq!(conn.rows(invalid_sql), Ok(Vec::new())); + assert_eq!(conn.execute(invalid_sql), Ok(())); + assert_eq!(conn.iterate(invalid_sql, |_| unreachable!()), Ok(())); + assert_eq!(conn.rows(invalid_sql), Ok(Vec::new())); } #[test] @@ -187,9 +211,9 @@ mod postgres { conn.error_level(ErrorLevel::Release); let invalid_sql = "INVALID_SQL"; - assert_eq!(conn.execute(invalid_sql), err!()); - assert_eq!(conn.iterate(invalid_sql, |_| unreachable!()), err!()); - assert_eq!(conn.rows(invalid_sql), err!()); + assert_eq!(conn.execute(invalid_sql), err!()); + assert_eq!(conn.iterate(invalid_sql, |_| unreachable!()), err!()); + assert_eq!(conn.rows(invalid_sql), err!()); } #[test] @@ -198,9 +222,12 @@ mod postgres { conn.error_level(ErrorLevel::Develop); let invalid_sql = "INVALID_SQL"; - assert_eq!(conn.execute(invalid_sql), err!("exec error")); - assert_eq!(conn.iterate(invalid_sql, |_| unreachable!()), err!("exec error")); - assert_eq!(conn.rows(invalid_sql), err!("exec error")); + assert_eq!(conn.execute(invalid_sql), err!("exec error")); + assert_eq!( + conn.iterate(invalid_sql, |_| unreachable!()), + err!("exec error") + ); + assert_eq!(conn.rows(invalid_sql), err!("exec error")); } #[test] @@ -209,19 +236,25 @@ mod postgres { conn.error_level(ErrorLevel::Debug); let invalid_sql = "INVALID_SQL"; - assert_eq!(conn.execute(invalid_sql), - err!("exec error: db error: ERROR: \"INVALID_SQL\"またはその近辺で構文エラー")); - assert_eq!(conn.iterate(invalid_sql, |_| unreachable!()), - err!("exec error: db error: ERROR: \"INVALID_SQL\"またはその近辺で構文エラー")); - assert_eq!(conn.rows(invalid_sql), - err!("exec error: db error: ERROR: \"INVALID_SQL\"またはその近辺で構文エラー")); + assert_eq!( + conn.execute(invalid_sql), + err!("exec error: db error: ERROR: \"INVALID_SQL\"またはその近辺で構文エラー") + ); + assert_eq!( + conn.iterate(invalid_sql, |_| unreachable!()), + err!("exec error: db error: ERROR: \"INVALID_SQL\"またはその近辺で構文エラー") + ); + assert_eq!( + conn.rows(invalid_sql), + err!("exec error: db error: ERROR: \"INVALID_SQL\"またはその近辺で構文エラー") + ); } #[test] fn integer() { let conn = prepare(); let age = 50; - let sql = prep!("select name from users where age < ") + age; + let sql = query!("select name from users where age < ") + age; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get("name").unwrap(), "Alice"); @@ -231,29 +264,37 @@ mod postgres { #[test] fn prep_into_execute() { let conn = concatsql::postgres::open("postgresql://postgres:postgres@localhost").unwrap(); - conn.execute(prep!("SELECT ") + 1 + prep!("::INTEGER")).unwrap(); + conn.execute(query!("SELECT ") + 1 + query!("::INTEGER")) + .unwrap(); } #[test] fn prep_into_iterate() { let conn = concatsql::postgres::open("postgresql://postgres:postgres@localhost").unwrap(); - conn.iterate(prep!("SELECT ") + 1 + prep!("::INTEGER"), |_| true ).unwrap(); + conn.iterate(query!("SELECT ") + 1 + query!("::INTEGER"), |_| true) + .unwrap(); } #[test] fn prep_into_rows() { let conn = concatsql::postgres::open("postgresql://postgres:postgres@localhost").unwrap(); - for row in conn.rows(prep!("SELECT ") + 1 + prep!("::INTEGER")).unwrap() { + for row in conn + .rows(query!("SELECT ") + 1 + query!("::INTEGER")) + .unwrap() + { assert_eq!(row.column_name(0).unwrap(), "int4"); assert_eq!(row.get("int4").unwrap(), "1"); assert_eq!(row.get(0).unwrap(), "1"); } - for row in conn.rows(prep!("SELECT ") + "1" + prep!("::TEXT")).unwrap() { + for row in conn + .rows(query!("SELECT ") + "1" + query!("::TEXT")) + .unwrap() + { assert_eq!(row.column_name(0).unwrap(), "text"); assert_eq!(row.get("text").unwrap(), "1"); assert_eq!(row.get(0).unwrap(), "1"); } - for row in conn.rows(prep!("SELECT 1")).unwrap() { + for row in conn.rows(query!("SELECT 1")).unwrap() { assert_eq!(row.column_name(0).unwrap(), "?column?"); assert_eq!(row.get("?column?").unwrap(), "1"); assert_eq!(row.get(0).unwrap(), "1"); @@ -265,7 +306,7 @@ mod postgres { let conn = prepare(); let name = "A%"; - let sql = prep!("SELECT * FROM users WHERE name LIKE ") + name; + let sql = query!("SELECT * FROM users WHERE name LIKE ") + name; let mut executed = false; conn.rows(&sql).unwrap().iter().all(|row| { @@ -276,18 +317,26 @@ mod postgres { assert!(executed); let name = "A"; - let sql = prep!("SELECT * FROM users WHERE name LIKE ") + ("%".to_owned() + name + "%"); + let sql = query!("SELECT * FROM users WHERE name LIKE ") + ("%".to_owned() + name + "%"); assert_eq!(sql.simulate(), "SELECT * FROM users WHERE name LIKE '%A%'"); conn.execute(&sql).unwrap(); let name = "%A%"; - let sql = prep!("SELECT * FROM users WHERE name LIKE ") + ("%".to_owned() + &sanitize_like!(name) + "%"); - assert_eq!(sql.simulate(), "SELECT * FROM users WHERE name LIKE '%\\\\%A\\\\%%'"); + let sql = query!("SELECT * FROM users WHERE name LIKE ") + + ("%".to_owned() + &sanitize_like!(name) + "%"); + assert_eq!( + sql.simulate(), + "SELECT * FROM users WHERE name LIKE '%\\\\%A\\\\%%'" + ); conn.execute(&sql).unwrap(); let name = String::from("%A%"); - let sql = prep!("SELECT * FROM users WHERE name LIKE ") + ("%".to_owned() + &sanitize_like!(name, '$') + "%"); - assert_eq!(sql.simulate(), "SELECT * FROM users WHERE name LIKE '%$%A$%%'"); + let sql = query!("SELECT * FROM users WHERE name LIKE ") + + ("%".to_owned() + &sanitize_like!(name, '$') + "%"); + assert_eq!( + sql.simulate(), + "SELECT * FROM users WHERE name LIKE '%$%A$%%'" + ); conn.execute(&sql).unwrap(); } @@ -295,15 +344,25 @@ mod postgres { fn multiple_stmt() { let conn = prepare(); let mut cnt = 0; - for (i, row) in conn.rows("SELECT 1 UNION SELECT 2;").unwrap().iter().enumerate() { + for (i, row) in conn + .rows("SELECT 1 UNION SELECT 2;") + .unwrap() + .iter() + .enumerate() + { cnt += 1; - assert_eq!(row.get_into::<_, i32>(0).unwrap(), [ 1, 2 ][i]); - }; + assert_eq!(row.get_into::<_, i32>(0).unwrap(), [1, 2][i]); + } - for (i, row) in conn.rows("SELECT age FROM users;").unwrap().iter().enumerate() { + for (i, row) in conn + .rows("SELECT age FROM users;") + .unwrap() + .iter() + .enumerate() + { cnt += 1; - assert_eq!(row.get_into::<_, i32>(0).unwrap(), [ 42, 69, 50 ][i]); - }; + assert_eq!(row.get_into::<_, i32>(0).unwrap(), [42, 69, 50][i]); + } assert_eq!(cnt, 5); } @@ -314,7 +373,9 @@ mod postgres { let capacity = 64; let mut conns = Vec::with_capacity(capacity); for _ in 0..capacity { - conns.push(concatsql::postgres::open("postgresql://postgres:postgres@localhost").unwrap()); + conns.push( + concatsql::postgres::open("postgresql://postgres:postgres@localhost").unwrap(), + ); } for i in 1..capacity { assert_ne!(conns[0], conns[i]); @@ -324,9 +385,10 @@ mod postgres { #[test] fn blob() { let conn = concatsql::postgres::open("postgresql://postgres:postgres@localhost").unwrap(); - conn.execute("CREATE TEMPORARY TABLE b (data bytea)").unwrap(); + conn.execute("CREATE TEMPORARY TABLE b (data bytea)") + .unwrap(); let data = vec![0x1, 0xA, 0xFF, 0x00, 0x7F]; - let sql = prep!("INSERT INTO b VALUES (") + &data + prep!(")"); + let sql = query!("INSERT INTO b VALUES (") + &data + query!(")"); conn.execute(sql).unwrap(); for row in conn.rows("SELECT data FROM b").unwrap() { assert_eq!(row.get_into::<_, Vec>(0).unwrap(), data); @@ -336,19 +398,24 @@ mod postgres { #[test] fn question() { let conn = prepare(); - let sql = prep!("SELECT name FROM users WHERE name=") + "?"; - for _ in conn.rows(&sql).unwrap() { unreachable!(); } + let sql = query!("SELECT name FROM users WHERE name=") + "?"; + for _ in conn.rows(&sql).unwrap() { + unreachable!(); + } } #[test] fn map_collect() { let conn = prepare(); let rows = conn.rows("SELECT * FROM users").unwrap(); - let names = rows.iter().map(|row| row.get("name")).collect::>>(); + let names = rows + .iter() + .map(|row| row.get("name")) + .collect::>>(); let mut cnt = 0; for (i, name) in names.iter().enumerate() { cnt += 1; - assert_eq!(name.unwrap(), ["Alice","Bob","Carol"][i]) + assert_eq!(name.unwrap(), ["Alice", "Bob", "Carol"][i]) } assert_eq!(cnt, 3); } @@ -356,11 +423,11 @@ mod postgres { #[test] fn in_array() { let conn = prepare(); - let sql = prep!("SELECT * FROM users WHERE name IN (") + vec![] as Vec<&str> + prep!(")"); + let sql = query!("SELECT * FROM users WHERE name IN (") + vec![] as Vec<&str> + query!(")"); conn.rows(&sql).unwrap(); - let sql = prep!("SELECT * FROM users WHERE name IN (") + vec!["Adam"] + prep!(")"); + let sql = query!("SELECT * FROM users WHERE name IN (") + vec!["Adam"] + query!(")"); conn.rows(&sql).unwrap(); - let sql = prep!("SELECT * FROM users WHERE name IN (") + vec!["Adam","Eve"] + prep!(")"); + let sql = query!("SELECT * FROM users WHERE name IN (") + vec!["Adam", "Eve"] + query!(")"); conn.rows(&sql).unwrap(); } @@ -368,11 +435,12 @@ mod postgres { fn uuid() { use uuid::Uuid; let conn = prepare(); - let sql = prep!("SELECT ") + Uuid::nil(); + let sql = query!("SELECT ") + Uuid::nil(); for row in conn.rows(&sql).unwrap() { assert_eq!(&row[0], "00000000000000000000000000000000"); } - let sql = prep!("SELECT ") + Uuid::parse_str("936DA01F-9ABD-4D9D-80C7-02AF85C822A8").unwrap(); + let sql = + query!("SELECT ") + Uuid::parse_str("936DA01F-9ABD-4D9D-80C7-02AF85C822A8").unwrap(); for row in conn.rows(&sql).unwrap() { assert_eq!(&row[0], "936DA01F9ABD4D9D80C702AF85C822A8"); } @@ -383,48 +451,48 @@ mod postgres { let conn = prepare(); let name = "'' OR 1=1; --"; - let sql = prep!("SELECT age FROM users WHERE name = ") + name; + let sql = query!("SELECT age FROM users WHERE name = ") + name; for _ in conn.rows(&sql).unwrap() { unreachable!(); } let name = "''; DROP TABLE users; --"; - let sql = prep!("SELECT age FROM users WHERE name = ") + name; + let sql = query!("SELECT age FROM users WHERE name = ") + name; for _ in conn.rows(&sql).unwrap() { unreachable!(); } - let sql = prep!("SELECT ") + "0x50 + 0x45"; + let sql = query!("SELECT ") + "0x50 + 0x45"; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "0x50 + 0x45"); } - let sql = prep!("SELECT ") + "0x414243"; + let sql = query!("SELECT ") + "0x414243"; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "0x414243"); } - let sql = prep!("SELECT ") + "CHAR(0x66)"; + let sql = query!("SELECT ") + "CHAR(0x66)"; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "CHAR(0x66)"); } - let sql = prep!("SELECT ") + "IF(1=1, 'true', 'false')"; + let sql = query!("SELECT ") + "IF(1=1, 'true', 'false')"; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "IF(1=1, 'true', 'false')"); } - let sql = prep!("SELECT ") + "na + '-' + me FROM users"; + let sql = query!("SELECT ") + "na + '-' + me FROM users"; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "na + '-' + me FROM users"); } - let sql = prep!("SELECT ") + "ASCII('a')"; + let sql = query!("SELECT ") + "ASCII('a')"; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "ASCII('a')"); } - let sql = prep!("SELECT ") + "CHAR(64)"; + let sql = query!("SELECT ") + "CHAR(64)"; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "CHAR(64)"); } @@ -433,14 +501,14 @@ mod postgres { #[cfg(feature = "postgres")] mod anti_patterns { - use concatsql::prep; + use concatsql::prelude::*; // Although it becomes possible, I do not believe it is less useful // because its real advantage is that it still makes it harder to do the wrong thing. #[test] fn string_to_static_str() { let conn = concatsql::postgres::open("postgresql://postgres:postgres@localhost").unwrap(); - let sql: &'static str = Box::leak(String::from("SELECT 1").into_boxed_str()); // Leak! + let sql: &'static str = Box::leak(String::from("SELECT 1").into_boxed_str()); // Leak! conn.execute(sql).unwrap(); } @@ -448,40 +516,40 @@ mod anti_patterns { fn text_op_integer() { let conn = super::postgres::prepare(); - let sql = prep!("SELECT age FROM users WHERE name = ") + i32::MAX; + let sql = query!("SELECT age FROM users WHERE name = ") + i32::MAX; assert!(conn.rows(&sql).is_err()); - let sql = prep!("SELECT age FROM users WHERE name < ") + i32::MAX; + let sql = query!("SELECT age FROM users WHERE name < ") + i32::MAX; assert!(conn.rows(&sql).is_err()); - let sql = prep!("SELECT age FROM users WHERE name > ") + i32::MAX; + let sql = query!("SELECT age FROM users WHERE name > ") + i32::MAX; assert!(conn.rows(&sql).is_err()); - let sql = prep!("SELECT age FROM users WHERE name = ") + i32::MIN; + let sql = query!("SELECT age FROM users WHERE name = ") + i32::MIN; assert!(conn.rows(&sql).is_err()); - let sql = prep!("SELECT age FROM users WHERE name < ") + i32::MIN; + let sql = query!("SELECT age FROM users WHERE name < ") + i32::MIN; assert!(conn.rows(&sql).is_err()); - let sql = prep!("SELECT age FROM users WHERE name > ") + i32::MIN; + let sql = query!("SELECT age FROM users WHERE name > ") + i32::MIN; assert!(conn.rows(&sql).is_err()); - let sql = prep!("SELECT age FROM users WHERE name = ") + u32::MAX; + let sql = query!("SELECT age FROM users WHERE name = ") + u32::MAX; assert!(conn.rows(&sql).is_err()); - let sql = prep!("SELECT age FROM users WHERE name < ") + u32::MAX; + let sql = query!("SELECT age FROM users WHERE name < ") + u32::MAX; assert!(conn.rows(&sql).is_err()); - let sql = prep!("SELECT age FROM users WHERE name > ") + u32::MAX; + let sql = query!("SELECT age FROM users WHERE name > ") + u32::MAX; assert!(conn.rows(&sql).is_err()); - let sql = prep!("SELECT age FROM users WHERE name = ") + u32::MIN; + let sql = query!("SELECT age FROM users WHERE name = ") + u32::MIN; assert!(conn.rows(&sql).is_err()); - let sql = prep!("SELECT age FROM users WHERE name < ") + u32::MIN; + let sql = query!("SELECT age FROM users WHERE name < ") + u32::MIN; assert!(conn.rows(&sql).is_err()); - let sql = prep!("SELECT age FROM users WHERE name > ") + u32::MIN; + let sql = query!("SELECT age FROM users WHERE name > ") + u32::MIN; assert!(conn.rows(&sql).is_err()); } @@ -491,16 +559,15 @@ mod anti_patterns { let conn = super::postgres::prepare(); let name = "' OR 1=2; SELECT 1; --"; - let sql = prep!("SELECT age FROM users WHERE name = '") + name + &prep!("';"); // '?' is not placeholder + let sql = query!("SELECT age FROM users WHERE name = '") + name + &query!("';"); // '?' is not placeholder conn.execute(sql).ok(); let name = "' OR 1=1; --"; - let sql = prep!("SELECT age FROM users WHERE name = '") + name + &prep!("';"); // '?' is not placeholder + let sql = query!("SELECT age FROM users WHERE name = '") + name + &query!("';"); // '?' is not placeholder conn.execute(sql).ok(); let name = "Alice"; - let sql = prep!("SELECT age FROM users WHERE name = '") + name + &prep!("';"); // '?' is not placeholder + let sql = query!("SELECT age FROM users WHERE name = '") + name + &query!("';"); // '?' is not placeholder conn.execute(sql).ok(); } } - diff --git a/concatsql/tests/sqlite.rs b/concatsql/tests/sqlite.rs index 12560f1..cc62171 100644 --- a/concatsql/tests/sqlite.rs +++ b/concatsql/tests/sqlite.rs @@ -4,29 +4,30 @@ #[cfg(debug_assertions)] mod sqlite { use concatsql::prelude::*; - use concatsql::prep; use concatsql::{Error, ErrorLevel}; macro_rules! err { - () => { Err(Error::AnyError) }; - ($msg:expr) => { Err(Error::Message($msg.to_string())) }; + () => { + Err(Error::AnyError) + }; + ($msg:expr) => { + Err(Error::Message($msg.to_string())) + }; } + const STMT: &str = r#"CREATE TABLE users (name TEXT, age INTEGER); + INSERT INTO users (name, age) VALUES ('Alice', 42); + INSERT INTO users (name, age) VALUES ('Bob', 69); + INSERT INTO users (name, age) VALUES ('Carol', 50);"#; + pub fn prepare() -> concatsql::Connection { let conn = concatsql::sqlite::open(":memory:").unwrap(); conn.error_level(ErrorLevel::Debug); - let stmt = prep!(stmt()); - conn.execute(stmt).unwrap(); + let query = query!("{STMT}"); + conn.execute(query).unwrap(); conn } - fn stmt() -> &'static str { - r#"CREATE TABLE users (name TEXT, age INTEGER); - INSERT INTO users (name, age) VALUES ('Alice', 42); - INSERT INTO users (name, age) VALUES ('Bob', 69); - INSERT INTO users (name, age) VALUES ('Carol', 50);"# - } - #[test] fn open() { let _conn = concatsql::sqlite::open(":memory:").unwrap(); @@ -49,8 +50,8 @@ mod sqlite { )} let conn = concatsql::sqlite::open(":memory:").unwrap(); - let stmt = prep!(stmt()); - conn.execute(stmt).unwrap(); + let query = query!("{STMT}"); + conn.execute(query).unwrap(); static_strings! { select = "SELECT "; cols = "name "; @@ -58,21 +59,21 @@ mod sqlite { table = "users"; sql = select!(), cols!(), from!(), table!(); } - assert_eq!(prep!(sql).simulate(), "SELECT name FROM users"); + assert_eq!(query!("{sql}").simulate(), "SELECT name FROM users"); } #[test] fn execute() { let conn = concatsql::sqlite::open(":memory:").unwrap(); - let stmt = prep!(stmt()); - conn.execute(stmt).unwrap(); + let query = query!("{STMT}"); + conn.execute(query).unwrap(); } #[test] fn iterate() { let conn = prepare(); let expects = ["Alice", "Bob", "Carol"]; - let sql = prep!("SELECT name FROM users;"); + let sql = query!("SELECT name FROM users;"); let mut i = 0; conn.iterate(sql, |pairs| { @@ -81,14 +82,15 @@ mod sqlite { } i += 1; true - }).unwrap(); + }) + .unwrap(); } #[test] fn iterate_2sets() { let conn = prepare(); let expects = ["Alice", "Bob", "Carol", "Alice", "Bob", "Carol"]; - let sql = prep!("SELECT name FROM users; SELECT name FROM users;"); + let sql = query!("SELECT name FROM users; SELECT name FROM users;"); let mut i = 0; conn.iterate(sql, |pairs| { @@ -97,7 +99,8 @@ mod sqlite { } i += 1; true - }).unwrap(); + }) + .unwrap(); } #[test] @@ -105,8 +108,12 @@ mod sqlite { let conn = prepare(); let expects = ["Alice", "Bob"]; let age = "50"; - let sql = prep!("SELECT name FROM users WHERE ") + - &prep!("age < ") + age + &prep!(" OR ") + age + &prep!(" < age"); + let sql = query!("SELECT name FROM users WHERE ") + + &query!("age < ") + + age + + &query!(" OR ") + + age + + &query!(" < age"); let mut i = 0; conn.iterate(sql, |pairs| { @@ -115,21 +122,22 @@ mod sqlite { } i += 1; true - }).unwrap(); + }) + .unwrap(); } #[test] fn rows() { let conn = prepare(); let expects = [("Alice", 42), ("Bob", 69), ("Carol", 50)]; - let sql = prep!("SELECT * FROM users;"); + let sql = query!("SELECT * FROM users;"); let mut cnt = 0; let rows = conn.rows(&sql).unwrap(); for (i, row) in rows.iter().enumerate() { cnt += 1; assert_eq!(row.get("name").unwrap(), expects[i].0); - assert_eq!(row.get("age").unwrap(), expects[i].1.to_string()); + assert_eq!(row.get("age").unwrap(), expects[i].1.to_string()); } assert!(cnt == expects.len()); } @@ -140,11 +148,15 @@ mod sqlite { let expects = [("Alice", 42), ("Bob", 69), ("Carol", 50)]; let mut cnt = 0; - conn.rows(&prep!("SELECT * FROM users;")).unwrap().iter().enumerate().for_each(|(i, row)| { - cnt += 1; - assert_eq!(row.get("name").unwrap(), expects[i].0); - assert_eq!(row.get("age").unwrap(), expects[i].1.to_string()); - }); + conn.rows(query!("SELECT * FROM users;")) + .unwrap() + .iter() + .enumerate() + .for_each(|(i, row)| { + cnt += 1; + assert_eq!(row.get("name").unwrap(), expects[i].0); + assert_eq!(row.get("age").unwrap(), expects[i].1.to_string()); + }); assert!(cnt == expects.len()); } @@ -152,47 +164,56 @@ mod sqlite { fn start_with_quotation_and_end_with_anything_else() { let conn = prepare(); let name = "'Alice'; DROP TABLE users; --"; - let sql = prep!("select age from users where name = ") + name + &prep!(""); + let sql = query!("select age from users where name = ") + name + &query!(""); assert_eq!( sql.simulate(), "select age from users where name = '''Alice''; DROP TABLE users; --'" ); - conn.iterate(&sql, |_| { unreachable!(); }).unwrap(); + conn.iterate(&sql, |_| { + unreachable!(); + }) + .unwrap(); } #[test] fn whitespace() { let conn = prepare(); - let sql = prep!("select\n*\rfrom\nusers;"); + let sql = query!("select\n*\rfrom\nusers;"); - conn.iterate(sql, |_| { true }).unwrap(); + conn.iterate(sql, |_| true).unwrap(); } #[test] fn sqli_eq_nonquote() { let conn = prepare(); let name = "Alice' or '1'='1"; - let sql = prep!("select age from users where name =") + name + &prep!(";"); + let sql = query!("select age from users where name =") + name + &query!(";"); // "select age from users where name = 'Alice'' or ''1''=''1';" - conn.iterate(sql, |_| { unreachable!(); }).unwrap(); + conn.iterate(sql, |_| { + unreachable!(); + }) + .unwrap(); } #[test] fn sanitizing() { let conn = prepare(); let name = r#""#; - let sql = prep!("INSERT INTO users VALUES(") + name + &prep!(", 12345);"); + let sql = query!("INSERT INTO users VALUES(") + name + &query!(", 12345);"); conn.execute(sql).unwrap(); - conn.rows(prep!("SELECT name FROM users WHERE age = 12345;")).unwrap().iter() .all(|row| { - assert_eq!( - concatsql::html_special_chars(row.get("name").unwrap()), - "<script>alert("&1");</script>" - ); - true - }); + conn.rows(query!("SELECT name FROM users WHERE age = 12345;")) + .unwrap() + .iter() + .all(|row| { + assert_eq!( + concatsql::html_special_chars(row.get("name").unwrap()), + "<script>alert("&1");</script>" + ); + true + }); } #[test] @@ -211,9 +232,9 @@ mod sqlite { conn.error_level(ErrorLevel::AlwaysOk); let invalid_sql = "INVALID_SQL"; - assert_eq!(conn.execute(invalid_sql), Ok(())); - assert_eq!(conn.iterate(invalid_sql, |_| unreachable!()), Ok(())); - assert_eq!(conn.rows(invalid_sql), Ok(Vec::new())); + assert_eq!(conn.execute(invalid_sql), Ok(())); + assert_eq!(conn.iterate(invalid_sql, |_| unreachable!()), Ok(())); + assert_eq!(conn.rows(invalid_sql), Ok(Vec::new())); } #[test] @@ -222,9 +243,9 @@ mod sqlite { conn.error_level(ErrorLevel::Release); let invalid_sql = "INVALID_SQL"; - assert_eq!(conn.execute(invalid_sql), err!()); - assert_eq!(conn.iterate(invalid_sql, |_| unreachable!()), err!()); - assert_eq!(conn.rows(invalid_sql), err!()); + assert_eq!(conn.execute(invalid_sql), err!()); + assert_eq!(conn.iterate(invalid_sql, |_| unreachable!()), err!()); + assert_eq!(conn.rows(invalid_sql), err!()); } #[test] @@ -233,9 +254,12 @@ mod sqlite { conn.error_level(ErrorLevel::Develop); let invalid_sql = "INVALID_SQL"; - assert_eq!(conn.execute(invalid_sql), err!("exec error")); - assert_eq!(conn.iterate(invalid_sql, |_| unreachable!()), err!("exec error")); - assert_eq!(conn.rows(invalid_sql), err!("exec error")); + assert_eq!(conn.execute(invalid_sql), err!("exec error")); + assert_eq!( + conn.iterate(invalid_sql, |_| unreachable!()), + err!("exec error") + ); + assert_eq!(conn.rows(invalid_sql), err!("exec error")); } #[test] @@ -244,31 +268,37 @@ mod sqlite { conn.error_level(ErrorLevel::Debug); let invalid_sql = "INVALID_SQL"; - assert_eq!(conn.execute(invalid_sql), - err!("exec error: near \"INVALID_SQL\": syntax error")); - assert_eq!(conn.iterate(invalid_sql, |_| unreachable!()), - err!("exec error: near \"INVALID_SQL\": syntax error")); - assert_eq!(conn.rows(invalid_sql), - err!("exec error: near \"INVALID_SQL\": syntax error")); + assert_eq!( + conn.execute(invalid_sql), + err!("exec error: near \"INVALID_SQL\": syntax error") + ); + assert_eq!( + conn.iterate(invalid_sql, |_| unreachable!()), + err!("exec error: near \"INVALID_SQL\": syntax error") + ); + assert_eq!( + conn.rows(invalid_sql), + err!("exec error: near \"INVALID_SQL\": syntax error") + ); } #[test] fn prep_into_execute() { let conn = concatsql::sqlite::open(":memory:").unwrap(); - conn.execute(prep!("SELECT ") + 1).unwrap(); + conn.execute(query!("SELECT ") + 1).unwrap(); } #[test] fn prep_into_iterate() { let conn = concatsql::sqlite::open(":memory:").unwrap(); - conn.iterate(prep!("SELECT ") + 1, |_| true ).unwrap(); + conn.iterate(query!("SELECT ") + 1, |_| true).unwrap(); } #[test] fn prep_into_rows() { let conn = concatsql::sqlite::open(":memory:").unwrap(); let mut executed = false; - for row in &conn.rows(prep!("SELECT ") + 1).unwrap() { + for row in &conn.rows(query!("SELECT ") + 1).unwrap() { executed = true; assert_eq!(row.get(0).unwrap(), "1"); } @@ -277,12 +307,12 @@ mod sqlite { #[test] fn multi_thread() { - use std::thread; use std::sync::{Arc, Mutex}; + use std::thread; let conn = Arc::new(Mutex::new(concatsql::sqlite::open(":memory:").unwrap())); - let stmt = prep!(stmt()); - conn.lock().unwrap().execute(stmt).unwrap(); + let query = query!("{STMT}"); + conn.lock().unwrap().execute(query).unwrap(); let mut handles = vec![]; @@ -290,20 +320,33 @@ mod sqlite { let conn_clone = conn.clone(); let handle = thread::spawn(move || { let conn = &*conn_clone.lock().unwrap(); - let sql = prep!("INSERT INTO users VALUES ('Thread', ") + i + prep!(");"); + let sql = query!("INSERT INTO users VALUES ('Thread', ") + i + query!(");"); conn.execute(sql).unwrap(); }); handles.push(handle); } - for handle in handles { handle.join().unwrap(); } + for handle in handles { + handle.join().unwrap(); + } let conn = &*conn.lock().unwrap(); - assert_eq!(90, (0..10).map(|mut i| { - conn.iterate(prep!("SELECT age FROM users WHERE age = ") + i, |pairs| { - pairs.iter().for_each(|(_, v)| { assert_eq!(i.to_string(), v.unwrap()); i*=2; }); true - }).unwrap(); i - }).sum::()); + assert_eq!( + 90, + (0..10) + .map(|mut i| { + conn.iterate(query!("SELECT age FROM users WHERE age = ") + i, |pairs| { + pairs.iter().for_each(|(_, v)| { + assert_eq!(i.to_string(), v.unwrap()); + i *= 2; + }); + true + }) + .unwrap(); + i + }) + .sum::() + ); } #[test] @@ -311,7 +354,7 @@ mod sqlite { let conn = prepare(); let name = "A%"; - let sql = prep!("SELECT * FROM users WHERE name LIKE ") + name + prep!(";"); + let sql = query!("SELECT * FROM users WHERE name LIKE ") + name + query!(";"); let mut executed = false; conn.rows(&sql).unwrap().iter().all(|row| { @@ -322,22 +365,33 @@ mod sqlite { assert!(executed); let name = "A"; - let sql = prep!("SELECT * FROM users WHERE name LIKE ") + ("%".to_owned() + name + "%"); + let sql = query!("SELECT * FROM users WHERE name LIKE ") + ("%".to_owned() + name + "%"); assert_eq!(sql.simulate(), "SELECT * FROM users WHERE name LIKE '%A%'"); conn.execute(&sql).unwrap(); let name = "%A%"; - let sql = prep!("SELECT * FROM users WHERE name LIKE ") + ("%".to_owned() + &sanitize_like!(name) + "%"); + let sql = query!("SELECT * FROM users WHERE name LIKE ") + + ("%".to_owned() + &sanitize_like!(name) + "%"); if cfg!(feature = "mysql") || cfg!(feature = "postgres") { - assert_eq!(sql.simulate(), "SELECT * FROM users WHERE name LIKE '%\\\\%A\\\\%%'"); + assert_eq!( + sql.simulate(), + "SELECT * FROM users WHERE name LIKE '%\\\\%A\\\\%%'" + ); } else { - assert_eq!(sql.simulate(), "SELECT * FROM users WHERE name LIKE '%\\%A\\%%'"); + assert_eq!( + sql.simulate(), + "SELECT * FROM users WHERE name LIKE '%\\%A\\%%'" + ); } conn.execute(&sql).unwrap(); let name = String::from("%A%"); - let sql = prep!("SELECT * FROM users WHERE name LIKE ") + ("%".to_owned() + &sanitize_like!(name, '$') + "%"); - assert_eq!(sql.simulate(), "SELECT * FROM users WHERE name LIKE '%$%A$%%'"); + let sql = query!("SELECT * FROM users WHERE name LIKE ") + + ("%".to_owned() + &sanitize_like!(name, '$') + "%"); + assert_eq!( + sql.simulate(), + "SELECT * FROM users WHERE name LIKE '%$%A$%%'" + ); conn.execute(&sql).unwrap(); } @@ -346,7 +400,7 @@ mod sqlite { let conn = prepare(); let name = "A?['i]*"; - let sql = prep!("SELECT * FROM users WHERE name GLOB ") + name; + let sql = query!("SELECT * FROM users WHERE name GLOB ") + name; let mut executed = false; conn.rows(&sql).unwrap().iter().all(|row| { @@ -362,14 +416,19 @@ mod sqlite { let conn = prepare(); let mut cnt = 0; for (i, row) in conn.rows("SELECT 1; SELECT 2;").unwrap().iter().enumerate() { - /*^^^^^^^^*/// <- only first statement + /*^^^^^^^^*/// <- only first statement cnt += 1; - assert_eq!(row.get_into::<_, i32>(0).unwrap(), [ 1, 2 ][i]); - }; - for (i, row) in conn.rows("SELECT age FROM users;").unwrap().iter().enumerate() { + assert_eq!(row.get_into::<_, i32>(0).unwrap(), [1, 2][i]); + } + for (i, row) in conn + .rows("SELECT age FROM users;") + .unwrap() + .iter() + .enumerate() + { cnt += 1; - assert_eq!(row.get_into::<_, i32>(0).unwrap(), [ 42, 69, 50 ][i]); - }; + assert_eq!(row.get_into::<_, i32>(0).unwrap(), [42, 69, 50][i]); + } assert_eq!(cnt, 4); } @@ -391,7 +450,7 @@ mod sqlite { let conn = concatsql::sqlite::open(":memory:").unwrap(); conn.execute("CREATE TABLE b (data blob)").unwrap(); let data = vec![0x1, 0xA, 0xFF, 0x00, 0x7F]; - let sql = prep!("INSERT INTO b VALUES (") + &data + prep!(")"); + let sql = query!("INSERT INTO b VALUES (") + &data + query!(")"); conn.execute(sql).unwrap(); for row in conn.rows("SELECT data FROM b").unwrap() { assert_eq!(row.get_into::<_, Vec>(0).unwrap(), data); @@ -401,28 +460,39 @@ mod sqlite { #[test] fn question() { let conn = prepare(); - let sql = prep!("SELECT name FROM users WHERE name=") + "?"; - for _ in conn.rows(&sql).unwrap() { unreachable!(); } + let sql = query!("SELECT name FROM users WHERE name=") + "?"; + for _ in conn.rows(&sql).unwrap() { + unreachable!(); + } } #[test] fn iterator() { let conn = prepare(); - let sql = prep!("SELECT name FROM users WHERE name=") + "?"; - for _ in conn.rows(&sql).unwrap() { unreachable!(); } - for _ in conn.rows(&sql).unwrap().iter() { unreachable!(); } - for _ in &conn.rows(&sql).unwrap() { unreachable!(); } + let sql = query!("SELECT name FROM users WHERE name=") + "?"; + for _ in conn.rows(&sql).unwrap() { + unreachable!(); + } + for _ in conn.rows(&sql).unwrap().iter() { + unreachable!(); + } + for _ in &conn.rows(&sql).unwrap() { + unreachable!(); + } } #[test] fn map_collect() { let conn = prepare(); let rows = conn.rows("SELECT * FROM users").unwrap(); - let names = rows.iter().map(|row| row.get("name")).collect::>>(); + let names = rows + .iter() + .map(|row| row.get("name")) + .collect::>>(); let mut cnt = 0; for (i, name) in names.iter().enumerate() { cnt += 1; - assert_eq!(name.unwrap(), ["Alice","Bob","Carol"][i]) + assert_eq!(name.unwrap(), ["Alice", "Bob", "Carol"][i]) } assert_eq!(cnt, 3); } @@ -430,23 +500,32 @@ mod sqlite { #[test] fn without_escape() { unsafe { - assert_eq!((prep!() + concatsql::without_escape(&String::from("42")) ).simulate(), "42"); - assert_eq!((prep!() + concatsql::without_escape(&String::from("foo"))).simulate(), "foo"); - assert_eq!((prep!() + concatsql::without_escape(&String::from("")) ).simulate(), ""); - assert_eq!((prep!() + String::from("42") ).simulate(), "'42'"); - assert_eq!((prep!() + String::from("foo") ).simulate(), "'foo'"); - assert_eq!((prep!() + String::from("") ).simulate(), "''"); + assert_eq!( + (query!("") + concatsql::without_escape(&String::from("42"))).simulate(), + "42" + ); + assert_eq!( + (query!("") + concatsql::without_escape(&String::from("foo"))).simulate(), + "foo" + ); + assert_eq!( + (query!("") + concatsql::without_escape(&String::from(""))).simulate(), + "" + ); + assert_eq!((query!("") + String::from("42")).simulate(), "'42'"); + assert_eq!((query!("") + String::from("foo")).simulate(), "'foo'"); + assert_eq!((query!("") + String::from("")).simulate(), "''"); } } #[test] fn in_array() { let conn = prepare(); - let sql = prep!("SELECT * FROM users WHERE name IN (") + vec![] as Vec<&str> + prep!(")"); + let sql = query!("SELECT * FROM users WHERE name IN (") + vec![] as Vec<&str> + query!(")"); conn.rows(&sql).unwrap(); - let sql = prep!("SELECT * FROM users WHERE name IN (") + vec!["Adam"] + prep!(")"); + let sql = query!("SELECT * FROM users WHERE name IN (") + vec!["Adam"] + query!(")"); conn.rows(&sql).unwrap(); - let sql = prep!("SELECT * FROM users WHERE name IN (") + vec!["Adam","Eve"] + prep!(")"); + let sql = query!("SELECT * FROM users WHERE name IN (") + vec!["Adam", "Eve"] + query!(")"); conn.rows(&sql).unwrap(); } @@ -454,11 +533,12 @@ mod sqlite { fn uuid() { use uuid::Uuid; let conn = prepare(); - let sql = prep!("SELECT ") + Uuid::nil(); + let sql = query!("SELECT ") + Uuid::nil(); for row in conn.rows(&sql).unwrap() { assert_eq!(&row[0], "00000000000000000000000000000000"); } - let sql = prep!("SELECT ") + Uuid::parse_str("936DA01F-9ABD-4D9D-80C7-02AF85C822A8").unwrap(); + let sql = + query!("SELECT ") + Uuid::parse_str("936DA01F-9ABD-4D9D-80C7-02AF85C822A8").unwrap(); for row in conn.rows(&sql).unwrap() { assert_eq!(&row[0], "936DA01F9ABD4D9D80C702AF85C822A8"); } @@ -469,69 +549,75 @@ mod sqlite { let conn = prepare(); let name = "' OR 1=2; SELECT 1; --"; - let sql = prep!("SELECT age FROM users WHERE name = '") + name + &prep!("';"); // '?' is not placeholder + let sql = query!("SELECT age FROM users WHERE name = '") + name + &query!("';"); // '?' is not placeholder assert_eq!( conn.rows(&sql), - Err(Error::Message("bind error: column index out of range".to_string())) + Err(Error::Message( + "bind error: column index out of range".to_string() + )) ); let name = "' OR 1=1; --"; - let sql = prep!("SELECT age FROM users WHERE name = '") + name + &prep!("';"); // '?' is not placeholder + let sql = query!("SELECT age FROM users WHERE name = '") + name + &query!("';"); // '?' is not placeholder assert_eq!( conn.rows(&sql), - Err(Error::Message("bind error: column index out of range".to_string())) + Err(Error::Message( + "bind error: column index out of range".to_string() + )) ); let name = "Alice"; - let sql = prep!("SELECT age FROM users WHERE name = '") + name + &prep!("';"); // '?' is not placeholder + let sql = query!("SELECT age FROM users WHERE name = '") + name + &query!("';"); // '?' is not placeholder assert_eq!( conn.rows(&sql), - Err(Error::Message("bind error: column index out of range".to_string())) + Err(Error::Message( + "bind error: column index out of range".to_string() + )) ); let name = "'' OR 1=1; --"; - let sql = prep!("SELECT age FROM users WHERE name = ") + name; + let sql = query!("SELECT age FROM users WHERE name = ") + name; for _ in conn.rows(&sql).unwrap() { unreachable!(); } let name = "''; DROP TABLE users; --"; - let sql = prep!("SELECT age FROM users WHERE name = ") + name; + let sql = query!("SELECT age FROM users WHERE name = ") + name; for _ in conn.rows(&sql).unwrap() { unreachable!(); } - let sql = prep!("SELECT ") + "0x50 + 0x45"; + let sql = query!("SELECT ") + "0x50 + 0x45"; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "0x50 + 0x45"); } - let sql = prep!("SELECT ") + "0x414243"; + let sql = query!("SELECT ") + "0x414243"; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "0x414243"); } - let sql = prep!("SELECT ") + "CHAR(0x66)"; + let sql = query!("SELECT ") + "CHAR(0x66)"; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "CHAR(0x66)"); } - let sql = prep!("SELECT ") + "IF(1=1, 'true', 'false')"; + let sql = query!("SELECT ") + "IF(1=1, 'true', 'false')"; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "IF(1=1, 'true', 'false')"); } - let sql = prep!("SELECT ") + "na + '-' + me FROM users"; + let sql = query!("SELECT ") + "na + '-' + me FROM users"; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "na + '-' + me FROM users"); } - let sql = prep!("SELECT ") + "ASCII('a')"; + let sql = query!("SELECT ") + "ASCII('a')"; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "ASCII('a')"); } - let sql = prep!("SELECT ") + "CHAR(64)"; + let sql = query!("SELECT ") + "CHAR(64)"; for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "CHAR(64)"); } @@ -542,15 +628,15 @@ mod sqlite { #[cfg(not(debug_assertions))] mod sqlite_release_build { use concatsql::prelude::*; - use concatsql::prep; #[test] fn sqli_enable() { let conn = concatsql::sqlite::open(":memory:").unwrap(); - conn.execute("CREATE TABLE users (name TEXT, age INTEGER);").unwrap(); + conn.execute("CREATE TABLE users (name TEXT, age INTEGER);") + .unwrap(); let name = "OR 1=2; SELECT 1; --"; - let sql = prep!("SELECT age FROM users WHERE name = '") + name + &prep!("';"); + let sql = query!("SELECT age FROM users WHERE name = '") + name + &query!("';"); for row in conn.rows(&sql).unwrap() { assert_eq!(row.get(0).unwrap(), "1"); @@ -561,7 +647,6 @@ mod sqlite_release_build { #[cfg(feature = "sqlite")] mod anti_patterns { use concatsql::prelude::*; - use concatsql::prep; // Although it becomes possible, I do not believe it is less useful // because its real advantage is that it still makes it harder to do the wrong thing. @@ -570,7 +655,9 @@ mod anti_patterns { let conn = sqlite::open(":memory:").unwrap(); let sql: &'static str = Box::leak(String::from("SELECT 1").into_boxed_str()); conn.execute(sql).unwrap(); - unsafe { drop(Box::from_raw(sql.as_ptr() as *mut u8)); } + unsafe { + drop(Box::from_raw(sql.as_ptr() as *mut u8)); + } } #[test] @@ -578,63 +665,63 @@ mod anti_patterns { let conn = super::sqlite::prepare(); let mut cnt = 0; - let sql = prep!("SELECT age FROM users WHERE name = ") + i32::MAX; + let sql = query!("SELECT age FROM users WHERE name = ") + i32::MAX; for _ in conn.rows(&sql).unwrap() { unreachable!(); } - let sql = prep!("SELECT age FROM users WHERE name < ") + i32::MAX; + let sql = query!("SELECT age FROM users WHERE name < ") + i32::MAX; for _ in conn.rows(&sql).unwrap() { unreachable!(); } - let sql = prep!("SELECT age FROM users WHERE name > ") + i32::MAX; + let sql = query!("SELECT age FROM users WHERE name > ") + i32::MAX; for _ in conn.rows(&sql).unwrap() { cnt += 1; } - let sql = prep!("SELECT age FROM users WHERE name = ") + i32::MIN; + let sql = query!("SELECT age FROM users WHERE name = ") + i32::MIN; for _ in conn.rows(&sql).unwrap() { unreachable!(); } - let sql = prep!("SELECT age FROM users WHERE name < ") + i32::MIN; + let sql = query!("SELECT age FROM users WHERE name < ") + i32::MIN; for _ in conn.rows(&sql).unwrap() { unreachable!(); } - let sql = prep!("SELECT age FROM users WHERE name > ") + i32::MIN; + let sql = query!("SELECT age FROM users WHERE name > ") + i32::MIN; for _ in conn.rows(&sql).unwrap() { cnt += 1; } - let sql = prep!("SELECT age FROM users WHERE name = ") + u32::MAX; + let sql = query!("SELECT age FROM users WHERE name = ") + u32::MAX; for _ in conn.rows(&sql).unwrap() { unreachable!(); } - let sql = prep!("SELECT age FROM users WHERE name < ") + u32::MAX; + let sql = query!("SELECT age FROM users WHERE name < ") + u32::MAX; for _ in conn.rows(&sql).unwrap() { unreachable!(); } - let sql = prep!("SELECT age FROM users WHERE name > ") + u32::MAX; + let sql = query!("SELECT age FROM users WHERE name > ") + u32::MAX; for _ in conn.rows(&sql).unwrap() { cnt += 1; } - let sql = prep!("SELECT age FROM users WHERE name = ") + u32::MIN; + let sql = query!("SELECT age FROM users WHERE name = ") + u32::MIN; for _ in conn.rows(&sql).unwrap() { unreachable!(); } - let sql = prep!("SELECT age FROM users WHERE name < ") + u32::MIN; + let sql = query!("SELECT age FROM users WHERE name < ") + u32::MIN; #[allow(clippy::never_loop)] for _ in conn.rows(&sql).unwrap() { unreachable!(); } - let sql = prep!("SELECT age FROM users WHERE name > ") + u32::MIN; + let sql = query!("SELECT age FROM users WHERE name > ") + u32::MIN; for _ in conn.rows(&sql).unwrap() { cnt += 1; } @@ -642,4 +729,3 @@ mod anti_patterns { assert_eq!(cnt, 12); } } - diff --git a/concatsql_macro/src/lib.rs b/concatsql_macro/src/lib.rs index 288a574..b871b7c 100644 --- a/concatsql_macro/src/lib.rs +++ b/concatsql_macro/src/lib.rs @@ -1,20 +1,17 @@ extern crate proc_macro; -use proc_macro::TokenStream; -use proc_macro2::{Ident, Span}; -use proc_macro_error::{ - proc_macro_error, - abort_call_site, -}; -use quote::quote; -use syn::LitStr; use nom::{ - IResult, branch::alt, bytes::complete::tag, character::complete::{char, none_of}, multi::{many0, many1}, + IResult, }; +use proc_macro::TokenStream; +use proc_macro2::{Ident, Span}; +use proc_macro_error::{abort_call_site, proc_macro_error}; +use quote::quote; +use syn::LitStr; #[derive(Debug)] enum Query { @@ -48,10 +45,10 @@ impl FormatParser { for q in query.into_iter() { match q { Query::Lit(s) => { - lits.push(quote!{ Some( #s ) }); + lits.push(quote! { Some( #s ) }); } Query::Param(p) => { - lits.push(quote!{ None }); + lits.push(quote! { None }); params.push(Ident::new(&p, Span::call_site())); } } @@ -61,11 +58,17 @@ impl FormatParser { vec![ #(#lits),* ], vec![ #(#params.to_value()),* ], ) - }.into()) + } + .into()) } fn format(input: &str) -> IResult<&str, Vec> { - many0(alt((FormatParser::brace_open, FormatParser::brace_close, FormatParser::param, FormatParser::lit)))(input) + many0(alt(( + FormatParser::brace_open, + FormatParser::brace_close, + FormatParser::param, + FormatParser::lit, + )))(input) } fn lit(input: &str) -> IResult<&str, Query> { diff --git a/concatsql_macro/tests/macro.rs b/concatsql_macro/tests/macro.rs index 64e4c6b..48ef8a0 100644 --- a/concatsql_macro/tests/macro.rs +++ b/concatsql_macro/tests/macro.rs @@ -1,6 +1,6 @@ mod macros { - use concatsql_macro::query; use concatsql::prelude::*; + use concatsql_macro::query; #[test] fn query_test() { @@ -9,7 +9,10 @@ mod macros { assert_eq!(sql.simulate(), "SELECT 'foo'"); let age = "42 OR 1=1; --"; let sql = query!(r#"SELECT name FROM users WHERE age = {age}"#); - assert_eq!(sql.simulate(), "SELECT name FROM users WHERE age = '42 OR 1=1; --'"); + assert_eq!( + sql.simulate(), + "SELECT name FROM users WHERE age = '42 OR 1=1; --'" + ); let sql = query!(r#"{name}{name};"#); assert_eq!(sql.simulate(), "'foo''foo';"); let sql = query!(r#"{{name}}"#);