Skip to content

Commit

Permalink
fix(rust): implement database/connection constructors without options (
Browse files Browse the repository at this point in the history
…#2242)

Fixes #2241.
  • Loading branch information
mbrobbel authored Oct 12, 2024
1 parent 1619168 commit cd24bc0
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 45 deletions.
139 changes: 105 additions & 34 deletions rust/core/src/driver_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,19 +240,11 @@ impl ManagedDriver {
check_status(status, error)?;
Ok(driver)
}
}

impl Driver for ManagedDriver {
type DatabaseType = ManagedDatabase;

fn new_database(&mut self) -> Result<Self::DatabaseType> {
self.new_database_with_opts(None)
}

fn new_database_with_opts(
&mut self,
opts: impl IntoIterator<Item = (<Self::DatabaseType as Optionable>::Option, OptionValue)>,
) -> Result<Self::DatabaseType> {
/// Returns a new database using the loaded driver.
///
/// This uses `&mut self` to prevent a deadlock.
fn database_new(&mut self) -> Result<ffi::FFI_AdbcDatabase> {
let driver = &self.inner.driver.lock().unwrap();
let mut database = ffi::FFI_AdbcDatabase::default();

Expand All @@ -262,17 +254,58 @@ impl Driver for ManagedDriver {
let status = unsafe { method(&mut database, &mut error) };
check_status(status, error)?;

// DatabaseSetOption
for (key, value) in opts {
set_option_database(driver, &mut database, self.inner.version, key, value)?;
}
Ok(database)
}

/// Initialize the given database using the loaded driver.
///
/// This uses `&mut self` to prevent a deadlock.
fn database_init(
&mut self,
mut database: ffi::FFI_AdbcDatabase,
) -> Result<ffi::FFI_AdbcDatabase> {
let driver = &self.inner.driver.lock().unwrap();

// DatabaseInit
let mut error = ffi::FFI_AdbcError::with_driver(driver);
let method = driver_method!(driver, DatabaseInit);
let status = unsafe { method(&mut database, &mut error) };
check_status(status, error)?;

Ok(database)
}
}

impl Driver for ManagedDriver {
type DatabaseType = ManagedDatabase;

fn new_database(&mut self) -> Result<Self::DatabaseType> {
// Construct a new database.
let database = self.database_new()?;
// Initialize the database.
let database = self.database_init(database)?;
let inner = Arc::new(ManagedDatabaseInner {
database: Mutex::new(database),
driver: self.inner.clone(),
});
Ok(Self::DatabaseType { inner })
}

fn new_database_with_opts(
&mut self,
opts: impl IntoIterator<Item = (<Self::DatabaseType as Optionable>::Option, OptionValue)>,
) -> Result<Self::DatabaseType> {
// Construct a new database.
let mut database = self.database_new()?;
// Set the options.
{
let driver = &self.inner.driver.lock().unwrap();
for (key, value) in opts {
set_option_database(driver, &mut database, self.inner.version, key, value)?;
}
}
// Initialize the database.
let database = self.database_init(database)?;
let inner = Arc::new(ManagedDatabaseInner {
database: Mutex::new(database),
driver: self.inner.clone(),
Expand Down Expand Up @@ -425,6 +458,41 @@ impl ManagedDatabase {
fn driver_version(&self) -> AdbcVersion {
self.inner.driver.version
}

/// Returns a new connection using the loaded driver.
///
/// This uses `&mut self` to prevent a deadlock.
fn connection_new(&mut self) -> Result<ffi::FFI_AdbcConnection> {
let driver = &self.inner.driver.driver.lock().unwrap();
let mut connection = ffi::FFI_AdbcConnection::default();

// ConnectionNew
let mut error = ffi::FFI_AdbcError::with_driver(driver);
let method = driver_method!(driver, ConnectionNew);
let status = unsafe { method(&mut connection, &mut error) };
check_status(status, error)?;

Ok(connection)
}

/// Initialize the given connection using the loaded driver.
///
/// This uses `&mut self` to prevent a deadlock.
fn connection_init(
&mut self,
mut connection: ffi::FFI_AdbcConnection,
) -> Result<ffi::FFI_AdbcConnection> {
let driver = &self.inner.driver.driver.lock().unwrap();
let mut database = self.inner.database.lock().unwrap();

// ConnectionInit
let mut error = ffi::FFI_AdbcError::with_driver(driver);
let method = driver_method!(driver, ConnectionInit);
let status = unsafe { method(&mut connection, &mut *database, &mut error) };
check_status(status, error)?;

Ok(connection)
}
}

impl Optionable for ManagedDatabase {
Expand Down Expand Up @@ -497,35 +565,38 @@ impl Database for ManagedDatabase {
type ConnectionType = ManagedConnection;

fn new_connection(&mut self) -> Result<Self::ConnectionType> {
self.new_connection_with_opts(None)
// Construct a new connection.
let connection = self.connection_new()?;
// Initialize the connection.
let connection = self.connection_init(connection)?;
let inner = ManagedConnectionInner {
connection: Mutex::new(connection),
database: self.inner.clone(),
};
Ok(Self::ConnectionType {
inner: Arc::new(inner),
})
}

fn new_connection_with_opts(
&mut self,
opts: impl IntoIterator<Item = (<Self::ConnectionType as Optionable>::Option, OptionValue)>,
) -> Result<Self::ConnectionType> {
let driver = &self.inner.driver.driver.lock().unwrap();
let mut database = self.inner.database.lock().unwrap();
let mut connection = ffi::FFI_AdbcConnection::default();
let mut error = ffi::FFI_AdbcError::with_driver(driver);
let method = driver_method!(driver, ConnectionNew);
let status = unsafe { method(&mut connection, &mut error) };
check_status(status, error)?;

for (key, value) in opts {
set_option_connection(driver, &mut connection, self.driver_version(), key, value)?;
// Construct a new connection.
let mut connection = self.connection_new()?;
// Set the options.
{
let driver = &self.inner.driver.driver.lock().unwrap();
for (key, value) in opts {
set_option_connection(driver, &mut connection, self.driver_version(), key, value)?;
}
}

let mut error = ffi::FFI_AdbcError::with_driver(driver);
let method = driver_method!(driver, ConnectionInit);
let status = unsafe { method(&mut connection, database.deref_mut(), &mut error) };
check_status(status, error)?;

// Initialize the connection.
let connection = self.connection_init(connection)?;
let inner = ManagedConnectionInner {
connection: Mutex::new(connection),
database: self.inner.clone(),
};

Ok(Self::ConnectionType {
inner: Arc::new(inner),
})
Expand Down
19 changes: 8 additions & 11 deletions rust/drivers/dummy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,23 +183,22 @@ impl Driver for DummyDriver {
type DatabaseType = DummyDatabase;

fn new_database(&mut self) -> Result<Self::DatabaseType> {
self.new_database_with_opts(None)
Ok(Self::DatabaseType::default())
}

fn new_database_with_opts(
&mut self,
opts: impl IntoIterator<Item = (<Self::DatabaseType as Optionable>::Option, OptionValue)>,
) -> Result<Self::DatabaseType> {
let mut database = Self::DatabaseType {
options: HashMap::new(),
};
let mut database = Self::DatabaseType::default();
for (key, value) in opts {
database.set_option(key, value)?;
}
Ok(database)
}
}

#[derive(Default)]
pub struct DummyDatabase {
options: HashMap<OptionDatabase, OptionValue>,
}
Expand Down Expand Up @@ -232,23 +231,22 @@ impl Database for DummyDatabase {
type ConnectionType = DummyConnection;

fn new_connection(&mut self) -> Result<Self::ConnectionType> {
self.new_connection_with_opts(None)
Ok(Self::ConnectionType::default())
}

fn new_connection_with_opts(
&mut self,
opts: impl IntoIterator<Item = (<Self::ConnectionType as Optionable>::Option, OptionValue)>,
) -> Result<Self::ConnectionType> {
let mut connection = Self::ConnectionType {
options: HashMap::new(),
};
let mut connection = Self::ConnectionType::default();
for (key, value) in opts {
connection.set_option(key, value)?;
}
Ok(connection)
}
}

#[derive(Default)]
pub struct DummyConnection {
options: HashMap<OptionConnection, OptionValue>,
}
Expand Down Expand Up @@ -281,9 +279,7 @@ impl Connection for DummyConnection {
type StatementType = DummyStatement;

fn new_statement(&mut self) -> Result<Self::StatementType> {
Ok(Self::StatementType {
options: HashMap::new(),
})
Ok(Self::StatementType::default())
}

// This method is used to test that errors round-trip correctly.
Expand Down Expand Up @@ -798,6 +794,7 @@ impl Connection for DummyConnection {
}
}

#[derive(Default)]
pub struct DummyStatement {
options: HashMap<OptionStatement, OptionValue>,
}
Expand Down

0 comments on commit cd24bc0

Please sign in to comment.