diff --git a/src/async_traits.rs b/src/async_traits.rs index f70ba24..38c879a 100644 --- a/src/async_traits.rs +++ b/src/async_traits.rs @@ -1,6 +1,6 @@ //! Async versions of traits for issuing Diesel queries. -use crate::connection::Connection as SingleConnection; +use crate::connection::Connection; use async_trait::async_trait; use diesel::{ connection::{ @@ -12,6 +12,7 @@ use diesel::{ methods::{ExecuteDsl, LimitDsl, LoadQuery}, RunQueryDsl, }, + r2d2::R2D2Connection, result::Error as DieselError, }; use futures::future::BoxFuture; @@ -40,6 +41,25 @@ fn retryable_error(err: &DieselError) -> bool { } } +/// An async variant of [`diesel::r2d2::R2D2Connection`]. +#[async_trait] +pub trait AsyncR2D2Connection: AsyncConnection +where + Conn: 'static + DieselConnection + R2D2Connection, + Self: Send + Sized + 'static, +{ + async fn ping_async(&mut self) -> diesel::result::QueryResult<()> { + self.as_async_conn().run(|conn| conn.ping()).await + } + + async fn is_broken_async(&mut self) -> bool { + self.as_async_conn() + .run(|conn| Ok::(conn.is_broken())) + .await + .unwrap() + } +} + /// An async variant of [`diesel::connection::Connection`]. #[async_trait] pub trait AsyncConnection: AsyncSimpleConnection @@ -52,7 +72,7 @@ where #[doc(hidden)] fn as_sync_conn(&self) -> MutexGuard<'_, Conn>; #[doc(hidden)] - fn as_async_conn(&self) -> &SingleConnection; + fn as_async_conn(&self) -> &Connection; /// Runs the function `f` in an context where blocking is safe. async fn run(&self, f: Func) -> Result @@ -169,7 +189,7 @@ where where R: Any + Send + 'static, Fut: FutureExt> + Send, - Func: (Fn(SingleConnection) -> Fut) + Send + Sync, + Func: (Fn(Connection) -> Fut) + Send + Sync, RetryFut: FutureExt + Send, RetryFunc: Fn() -> RetryFut + Send + Sync, { @@ -201,7 +221,7 @@ where #[cfg(feature = "cockroach")] async fn transaction_async_with_retry_inner( &self, - f: &(dyn Fn(SingleConnection) -> BoxFuture<'_, Result, DieselError>> + f: &(dyn Fn(Connection) -> BoxFuture<'_, Result, DieselError>> + Send + Sync), retry: &(dyn Fn() -> BoxFuture<'_, bool> + Send + Sync), @@ -231,7 +251,7 @@ where // Add a SAVEPOINT to which we can later return. Self::add_retry_savepoint(&conn).await?; - let async_conn = SingleConnection(Self::as_async_conn(&conn).0.clone()); + let async_conn = Connection(Self::as_async_conn(&conn).0.clone()); match f(async_conn).await { Ok(value) => { // The user-level operation succeeded: try to commit the @@ -288,7 +308,7 @@ where R: Send + 'static, E: From + Send + 'static, Fut: Future> + Send, - Func: FnOnce(SingleConnection) -> Fut + Send, + Func: FnOnce(Connection) -> Fut + Send, { // This function sure has a bunch of generic parameters, which can cause // a lot of code to be generated, and can slow down compile-time. @@ -314,7 +334,7 @@ where async fn transaction_async_inner<'a, E>( &'a self, f: Box< - dyn FnOnce(SingleConnection) -> BoxFuture<'a, Result, E>> + dyn FnOnce(Connection) -> BoxFuture<'a, Result, E>> + Send + 'a, >, @@ -348,7 +368,7 @@ where // enough to be referenceable by a Future, but short enough that we can // guarantee it doesn't live persist after this function returns, feel // free to make that change. - let async_conn = SingleConnection(Self::as_async_conn(&conn).0.clone()); + let async_conn = Connection(Self::as_async_conn(&conn).0.clone()); match f(async_conn).await { Ok(value) => { conn.run_with_shared_connection(|conn| { diff --git a/src/connection.rs b/src/connection.rs index 1f00b10..8ad4c35 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -61,7 +61,6 @@ where self.inner() } - // TODO: Consider removing me. fn as_async_conn(&self) -> &Connection { self } diff --git a/src/lib.rs b/src/lib.rs index 58cf209..e4fbecd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,7 +11,8 @@ mod connection_manager; mod error; pub use async_traits::{ - AsyncConnection, AsyncRunQueryDsl, AsyncSaveChangesDsl, AsyncSimpleConnection, + AsyncConnection, AsyncR2D2Connection, AsyncRunQueryDsl, AsyncSaveChangesDsl, + AsyncSimpleConnection, }; pub use connection::Connection; pub use connection_manager::ConnectionManager;