From bddf967ce046413322b824f3f55f1e7f0e346784 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dawid=20Ci=C4=99=C5=BCarkiewicz?= Date: Fri, 20 Sep 2024 15:40:01 -0700 Subject: [PATCH] chore(db): improve `global_dbtx` when module uses nested `with_prefix` --- fedimint-core/src/db/mod.rs | 202 +++++++++++++++++++++++-------- gateway/ln-gateway/src/client.rs | 6 +- 2 files changed, 155 insertions(+), 53 deletions(-) diff --git a/fedimint-core/src/db/mod.rs b/fedimint-core/src/db/mod.rs index 357fcc3a5a7..caed7367f70 100644 --- a/fedimint-core/src/db/mod.rs +++ b/fedimint-core/src/db/mod.rs @@ -407,13 +407,31 @@ impl Database { } /// Create [`Database`] isolated to a partition with a given `prefix` - pub fn with_prefix(&self, prefix: Vec) -> (Self, GlobalDBTxAccessToken) { + pub fn with_prefix(&self, prefix: Vec) -> Self { + Self { + inner: Arc::new(PrefixDatabase { + inner: self.inner.clone(), + global_dbtx_access_token: None, + prefix, + }), + module_decoders: self.module_decoders.clone(), + } + } + + /// Create [`Database`] isolated to a partition with a prefix for a given + /// `module_instance_id`, allowing the module to access `global_dbtx` with + /// the right `access_token` + pub fn with_prefix_module_id( + &self, + module_instance_id: ModuleInstanceId, + ) -> (Self, GlobalDBTxAccessToken) { + let prefix = module_instance_id_to_byte_prefix(module_instance_id); let global_dbtx_access_token = GlobalDBTxAccessToken::from_prefix(&prefix); ( Self { inner: Arc::new(PrefixDatabase { inner: self.inner.clone(), - global_dbtx_access_token, + global_dbtx_access_token: Some(global_dbtx_access_token), prefix, }), module_decoders: self.module_decoders.clone(), @@ -422,16 +440,6 @@ impl Database { ) } - /// Create [`Database`] isolated to a partition with a prefix for a given - /// `module_instance_id` - pub fn with_prefix_module_id( - &self, - module_instance_id: ModuleInstanceId, - ) -> (Self, GlobalDBTxAccessToken) { - let prefix = module_instance_id_to_byte_prefix(module_instance_id); - self.with_prefix(prefix) - } - pub fn with_decoders(&self, module_decoders: ModuleDecoderRegistry) -> Self { Self { inner: self.inner.clone(), @@ -650,7 +658,7 @@ where Inner: Debug, { prefix: Vec, - global_dbtx_access_token: GlobalDBTxAccessToken, + global_dbtx_access_token: Option, inner: Inner, } @@ -704,7 +712,7 @@ where #[derive(Debug)] struct PrefixDatabaseTransaction { inner: Inner, - global_dbtx_access_token: GlobalDBTxAccessToken, + global_dbtx_access_token: Option, prefix: Vec, } @@ -747,11 +755,15 @@ where &mut self, access_token: GlobalDBTxAccessToken, ) -> &mut dyn IDatabaseTransaction { - assert_eq!( - access_token, self.global_dbtx_access_token, - "Invalid access key used to access global_dbtx" - ); - &mut self.inner + if let Some(self_global_dbtx_access_token) = self.global_dbtx_access_token { + assert_eq!( + access_token, self_global_dbtx_access_token, + "Invalid access key used to access global_dbtx" + ); + &mut self.inner + } else { + self.inner.global_dbtx(access_token) + } } } @@ -1604,19 +1616,40 @@ impl<'tx, Cap> DatabaseTransaction<'tx, Cap> { } /// Get [`DatabaseTransaction`] isolated to a `prefix` - pub fn with_prefix<'a: 'tx>( + pub fn with_prefix<'a: 'tx>(self, prefix: Vec) -> DatabaseTransaction<'a, Cap> + where + 'tx: 'a, + { + DatabaseTransaction { + tx: Box::new(PrefixDatabaseTransaction { + inner: self.tx, + global_dbtx_access_token: None, + prefix, + }), + decoders: self.decoders, + commit_tracker: self.commit_tracker, + on_commit_hooks: self.on_commit_hooks, + capability: self.capability, + } + } + + /// Get [`DatabaseTransaction`] isolated to a prefix of a given + /// `module_instance_id`, allowing the module to access global_dbtx + /// with the right access token. + pub fn with_prefix_module_id<'a: 'tx>( self, - prefix: Vec, + module_instance_id: ModuleInstanceId, ) -> (DatabaseTransaction<'a, Cap>, GlobalDBTxAccessToken) where 'tx: 'a, { + let prefix = module_instance_id_to_byte_prefix(module_instance_id); let global_dbtx_access_token = GlobalDBTxAccessToken::from_prefix(&prefix); ( DatabaseTransaction { tx: Box::new(PrefixDatabaseTransaction { inner: self.tx, - global_dbtx_access_token, + global_dbtx_access_token: Some(global_dbtx_access_token), prefix, }), decoders: self.decoders, @@ -1628,19 +1661,6 @@ impl<'tx, Cap> DatabaseTransaction<'tx, Cap> { ) } - /// Get [`DatabaseTransaction`] isolated to a prefix of a given - /// `module_instance_id` - pub fn with_prefix_module_id<'a: 'tx>( - self, - module_instance_id: ModuleInstanceId, - ) -> (DatabaseTransaction<'a, Cap>, GlobalDBTxAccessToken) - where - 'tx: 'a, - { - let prefix = module_instance_id_to_byte_prefix(module_instance_id); - self.with_prefix(prefix) - } - /// Get [`DatabaseTransaction`] to `self` pub fn to_ref<'s, 'a>(&'s mut self) -> DatabaseTransaction<'a, Cap> where @@ -1664,19 +1684,43 @@ impl<'tx, Cap> DatabaseTransaction<'tx, Cap> { } /// Get [`DatabaseTransaction`] isolated to a `prefix` of `self` - pub fn to_ref_with_prefix<'a>( + pub fn to_ref_with_prefix<'a>(&'a mut self, prefix: Vec) -> DatabaseTransaction<'a, Cap> + where + 'tx: 'a, + { + DatabaseTransaction { + tx: Box::new(PrefixDatabaseTransaction { + inner: &mut self.tx, + global_dbtx_access_token: None, + prefix, + }), + decoders: self.decoders.clone(), + commit_tracker: match self.commit_tracker { + MaybeRef::Owned(ref mut o) => MaybeRef::Borrowed(o), + MaybeRef::Borrowed(ref mut b) => MaybeRef::Borrowed(b), + }, + on_commit_hooks: match self.on_commit_hooks { + MaybeRef::Owned(ref mut o) => MaybeRef::Borrowed(o), + MaybeRef::Borrowed(ref mut b) => MaybeRef::Borrowed(b), + }, + capability: self.capability, + } + } + + pub fn to_ref_with_prefix_module_id<'a>( &'a mut self, - prefix: Vec, + module_instance_id: ModuleInstanceId, ) -> (DatabaseTransaction<'a, Cap>, GlobalDBTxAccessToken) where 'tx: 'a, { + let prefix = module_instance_id_to_byte_prefix(module_instance_id); let global_dbtx_access_token = GlobalDBTxAccessToken::from_prefix(&prefix); ( DatabaseTransaction { tx: Box::new(PrefixDatabaseTransaction { inner: &mut self.tx, - global_dbtx_access_token, + global_dbtx_access_token: Some(global_dbtx_access_token), prefix, }), decoders: self.decoders.clone(), @@ -1694,17 +1738,6 @@ impl<'tx, Cap> DatabaseTransaction<'tx, Cap> { ) } - pub fn to_ref_with_prefix_module_id<'a>( - &'a mut self, - module_instance_id: ModuleInstanceId, - ) -> (DatabaseTransaction<'a, Cap>, GlobalDBTxAccessToken) - where - 'tx: 'a, - { - let prefix = module_instance_id_to_byte_prefix(module_instance_id); - self.to_ref_with_prefix(prefix) - } - /// Is this `Database` a global, unpartitioned `Database` pub fn is_global(&self) -> bool { self.tx.prefix_len() == 0 @@ -3429,4 +3462,75 @@ mod tests { "should not notify" ); } + + #[tokio::test] + async fn test_prefix_global_dbtx() { + let module_instance_id = 10; + let db = MemDatabase::new().into_database(); + + { + // Plain module id prefix, can use `global_dbtx` to access global_dbtx + let (db, access_token) = db.with_prefix_module_id(module_instance_id); + + let mut tx = db.begin_transaction().await; + let mut tx = tx.global_dbtx(access_token); + tx.insert_new_entry(&TestKey(1), &TestVal(1)).await; + tx.commit_tx().await; + } + + assert_eq!( + db.begin_transaction_nc().await.get_value(&TestKey(1)).await, + Some(TestVal(1)) + ); + + { + // Additional non-module inner prefix, does not interfere with `global_dbtx` + let (db, access_token) = db.with_prefix_module_id(module_instance_id); + + let db = db.with_prefix(vec![3, 4]); + + let mut tx = db.begin_transaction().await; + let mut tx = tx.global_dbtx(access_token); + tx.insert_new_entry(&TestKey(2), &TestVal(2)).await; + tx.commit_tx().await; + } + + assert_eq!( + db.begin_transaction_nc().await.get_value(&TestKey(2)).await, + Some(TestVal(2)) + ); + } + + #[tokio::test] + #[should_panic(expected = "Illegal to call global_dbtx on BaseDatabaseTransaction")] + async fn test_prefix_global_dbtx_panics_on_global_db() { + let db = MemDatabase::new().into_database(); + + let mut tx = db.begin_transaction().await; + let _tx = tx.global_dbtx(GlobalDBTxAccessToken::from_prefix(&[1])); + } + + #[tokio::test] + #[should_panic(expected = "Illegal to call global_dbtx on BaseDatabaseTransaction")] + async fn test_prefix_global_dbtx_panics_on_non_module_prefix() { + let db = MemDatabase::new().into_database(); + + let prefix = vec![3, 4]; + let db = db.with_prefix(prefix.clone()); + + let mut tx = db.begin_transaction().await; + let _tx = tx.global_dbtx(GlobalDBTxAccessToken::from_prefix(&prefix)); + } + + #[tokio::test] + #[should_panic(expected = "Illegal to call global_dbtx on BaseDatabaseTransaction")] + async fn test_prefix_global_dbtx_panics_on_wrong_access_token() { + let db = MemDatabase::new().into_database(); + + let prefix = vec![3, 4]; + let db = db.with_prefix(prefix.clone()); + + let mut tx = db.begin_transaction().await; + let _tx = tx.global_dbtx(GlobalDBTxAccessToken::from_prefix(&[1])); + } } diff --git a/gateway/ln-gateway/src/client.rs b/gateway/ln-gateway/src/client.rs index 40815c72e1f..4f8fc360d65 100644 --- a/gateway/ln-gateway/src/client.rs +++ b/gateway/ln-gateway/src/client.rs @@ -99,8 +99,7 @@ impl GatewayClientBuilder { let federation_id = config.invite_code.federation_id(); let db = gateway .gateway_db - .with_prefix(config.federation_index.to_le_bytes().to_vec()) - .0; + .with_prefix(config.federation_index.to_le_bytes().to_vec()); let client_builder = self .create_client_builder(db, &config, gateway.clone()) .await?; @@ -148,8 +147,7 @@ impl GatewayClientBuilder { } else { let db = gateway .gateway_db - .with_prefix(config.federation_index.to_le_bytes().to_vec()) - .0; + .with_prefix(config.federation_index.to_le_bytes().to_vec()); let secret = Self::derive_federation_secret(mnemonic, &federation_id); (db, secret) };